Given leaf nodes, generate the most cost effective, 1) tree topology and 2) ancestors using gradient descent.
The tree search is performed such that we enforce to tree_forcing_loss and tree_traversal_cost.
- therefore, this approach does not need training data
- convergence guarantees need to be explored
Example experiment : https://wandb.ai/ramithx1/Differentiable-Trees/reports/Differentiable-Trees--VmlldzozMzg4MTE2?accessToken=cyle7be7zv37oxb05bdbsbqu1n3lyenrm69ly9y8swyf2s3emlt7vcy18z9zpgei
conda create -n jax_env python=3.8
source activate jax_env
conda install -c anaconda ipykernel -y
python -m ipykernel install --user --name=jax_env
#install remaining packages through pip
pip install -r requirements.txt
Installs the wheel compatible with Cuda >= 11.4 and cudnn >= 8.2
pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
[[ ! -e /content/sample_data ]] && exit ## run this cell only in colab
pip install optax -qqq
pip install networkx==2.5 -qqq
pip install netgraph
git clone https://github.com/ramithuh/differentiable-trees --quiet
