Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Mac files
.DS_Store

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/
46 changes: 25 additions & 21 deletions fcit/fcit.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,14 @@ def interleave(x, z, seed=None):
An array of shape (n_samples, x_dim + z_dim) in which
the columns of x and z are interleaved at random.
"""
state = np.random.get_state()
np.random.seed(seed or int(time.time()))
total_ids = np.random.permutation(x.shape[1]+z.shape[1])
np.random.set_state(state)
rnd = np.random.RandomState(seed)
total_ids = rnd.permutation(x.shape[1]+z.shape[1])
out = np.zeros([x.shape[0], x.shape[1] + z.shape[1]])
out[:, total_ids[:x.shape[1]]] = x
out[:, total_ids[x.shape[1]:]] = z
return out

def cv_besttree(x, y, z, cv_grid, logdim, verbose, prop_test):
def cv_besttree(x, y, z, cv_grid, logdim, verbose, prop_test, random_state, n_jobs):
""" Choose the best decision tree hyperparameters by
cross-validation. The hyperparameter to optimize is min_samples_split
(see sklearn's DecisionTreeRegressor).
Expand All @@ -55,6 +53,8 @@ def cv_besttree(x, y, z, cv_grid, logdim, verbose, prop_test):
logdim (bool): If True, set max_features to 'log2'.
verbose (bool): If True, print out extra info.
prop_test (float): Proportion of validation data to use.
random_state (int): Random seed.
n_jobs (int): Number of threads to use for parallel computation.

Returns:
DecisionTreeRegressor with the best hyperparameter setting.
Expand All @@ -66,16 +66,16 @@ def cv_besttree(x, y, z, cv_grid, logdim, verbose, prop_test):
elif len(cv_grid) == 1:
min_samples_split = cv_grid[0]
else:
clf = DecisionTreeRegressor(max_features=max_features)
splitter = ShuffleSplit(n_splits=3, test_size=prop_test)
clf = DecisionTreeRegressor(max_features=max_features, random_state=random_state)
splitter = ShuffleSplit(n_splits=3, test_size=prop_test, random_state=random_state)
cv = GridSearchCV(estimator=clf, cv=splitter,
param_grid={'min_samples_split': cv_grid}, n_jobs=-1)
cv.fit(interleave(x, z), y)
param_grid={'min_samples_split': cv_grid}, n_jobs=n_jobs)
cv.fit(interleave(x, z, seed=random_state), y)
min_samples_split = cv.best_params_['min_samples_split']
if verbose:
print('min_samples_split: {}.'.format(min_samples_split))
clf = DecisionTreeRegressor(max_features=max_features,
min_samples_split=min_samples_split)
min_samples_split=min_samples_split, random_state=random_state)
return clf

def obtain_error(data_and_i):
Expand All @@ -89,12 +89,12 @@ def obtain_error(data_and_i):
data['n_test']: Number of test points.
data['clf']: Decision tree regressor.
"""
data, i = data_and_i
data, i, random = data_and_i
x = data['x']
y = data['y']
z = data['z']
if data['reshuffle']:
perm_ids = np.random.permutation(x.shape[0])
perm_ids = random.permutation(x.shape[0])
else:
perm_ids = np.arange(x.shape[0])
data_permutation = data['data_permutation'][i]
Expand All @@ -110,7 +110,7 @@ def obtain_error(data_and_i):

def test(x, y, z=None, num_perm=8, prop_test=.1,
discrete=(False, False), plot_return=False, verbose=False,
logdim=False, cv_grid=[2, 8, 64, 512, 1e-2, .2, .4], **kwargs):
logdim=False, cv_grid=[2, 8, 64, 512, 1e-2, .2, .4], random_state=None, n_jobs=-1, **kwargs):
""" Fast conditional independence test, based on decision-tree regression.

See Chalupka, Perona, Eberhardt 2017 [arXiv link coming].
Expand All @@ -129,11 +129,15 @@ def test(x, y, z=None, num_perm=8, prop_test=.1,
logdim (bool): If True, set max_features='log2' in the decision tree.
cv_grid (list): min_impurity_splits to cross-validate when training
the decision tree regressor.
random_state (int): Seed for random number generator.
n_jobs (int): Number of threads to use for parallel computation.

Returns:
p (float): The p-value for the null hypothesis
that x is independent of y.
"""
random = np.random.RandomState(random_state)

# Compute test set size.
n_samples = x.shape[0]
n_test = int(n_samples * prop_test)
Expand All @@ -155,10 +159,10 @@ def test(x, y, z=None, num_perm=8, prop_test=.1,
d0_stats = np.zeros(num_perm)
d1_stats = np.zeros(num_perm)
data_permutations = [
np.random.permutation(n_samples) for i in range(num_perm)]
random.permutation(n_samples) for i in range(num_perm)]

# Compute mses for y = f(x, z), varying train-test splits.
clf = cv_besttree(x, y, z, cv_grid, logdim, verbose, prop_test=prop_test)
clf = cv_besttree(x, y, z, cv_grid, logdim, verbose, prop_test=prop_test, random_state=random_state, n_jobs=n_jobs)
datadict = {
'x': x,
'y': y,
Expand All @@ -168,20 +172,20 @@ def test(x, y, z=None, num_perm=8, prop_test=.1,
'reshuffle': False,
'clf': clf,
}
d1_stats = np.array(joblib.Parallel(n_jobs=-1, max_nbytes=100e6)(
joblib.delayed(obtain_error)((datadict, i)) for i in range(num_perm)))
d1_stats = np.array(joblib.Parallel(n_jobs=n_jobs, max_nbytes=100e6)(
joblib.delayed(obtain_error)((datadict, i, random)) for i in range(num_perm)))

# Compute mses for y = f(x, reshuffle(z)), varying train-test splits.
if z.shape[1] == 0:
x_indep_y = x[np.random.permutation(n_samples)]
x_indep_y = x[random.permutation(n_samples)]
else:
x_indep_y = np.empty([x.shape[0], 0])
clf = cv_besttree(x_indep_y, y, z, cv_grid, logdim,
verbose, prop_test=prop_test)
verbose, prop_test=prop_test, random_state=random_state, n_jobs=n_jobs)
datadict['reshuffle'] = True
datadict['x'] = x_indep_y
d0_stats = np.array(joblib.Parallel(n_jobs=-1, max_nbytes=100e6)(
joblib.delayed(obtain_error)((datadict, i)) for i in range(num_perm)))
d0_stats = np.array(joblib.Parallel(n_jobs=n_jobs, max_nbytes=100e6)(
joblib.delayed(obtain_error)((datadict, i, random)) for i in range(num_perm)))

if verbose:
np.set_printoptions(precision=3)
Expand Down