Skip to content

Given leaf nodes, generate the most cost effective tree topology and ancestors using gradient descent

Notifications You must be signed in to change notification settings

ramithuh/differentiable-trees

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

273 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

differentiable-trees

Given leaf nodes, generate the most cost effective, 1) tree topology and 2) ancestors using gradient descent.

The tree search is performed such that we enforce to tree_forcing_loss and tree_traversal_cost.

  • therefore, this approach does not need training data

Current limitations

  • convergence guarantees need to be explored

Example experiment : https://wandb.ai/ramithx1/Differentiable-Trees/reports/Differentiable-Trees--VmlldzozMzg4MTE2?accessToken=cyle7be7zv37oxb05bdbsbqu1n3lyenrm69ly9y8swyf2s3emlt7vcy18z9zpgei

ezgif-1-54b788b4ec

Environment Setup

conda create -n jax_env python=3.8
source activate jax_env
conda install -c anaconda ipykernel -y
python -m ipykernel install --user --name=jax_env


#install remaining packages through pip
pip install -r requirements.txt

Install JAX (fasrc cluster)

Installs the wheel compatible with Cuda >= 11.4 and cudnn >= 8.2

pip install "jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Quick setup on Colab

[[ ! -e /content/sample_data ]] && exit  ## run this cell only in colab
 
pip install optax -qqq
pip install networkx==2.5 -qqq
pip install netgraph
git clone https://github.com/ramithuh/differentiable-trees --quiet

About

Given leaf nodes, generate the most cost effective tree topology and ancestors using gradient descent

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published