Make sure to install the following packages before running the experiments:
torchmetricspytorchtraj-distPython 3.x
We directly use the dataset processed by SIMformer:
Download the desired data and place it in the ../data/dataset/ directory. You can also modify the dataset path in the config.yaml file.
Take the experiments conducted on the DTW measure of the Porto dataset as an example:
-
Train the expert tracks:
python buffer.py --gpu_id=0 --seed=24 --dataset=porto --target_measure=dtw --model=SIMformer --mom=0.9 --buffer_path="buffer" --train_epochs=50 --num_experts=1 --track_aug=mask --nce_factor=0.01 --cl_T=0.05 -
distill dataset:
python distill.py --gpu_id=0 --seed=24 --dataset=porto --target_measure=dtw --model=SIMformer --mom=0.9 --buffer_path="./buffers" --train_epochs=50 --num_experts=100 --track_aug=mask --cl_T=0.05 --Iteration=600 --max_files=None --use_syn_minibatch=25 --eval_it=50 --num_eval=1 --max_start_epoch=45 --syn_steps=10 --lr_lr=1e-5 --distill_output_path="./logged_files" --syn_len_max=200 --epoch_eval_train=100 --syn_len=learnable --ste=E_len_loc_glo --syn_len_para_ini="{'method': 'dis_laplace', 'scale': 0.03}" --lr_len=0.5 --syn_ini="randomly_select" --syn_num=300 --len_norm_factor=10000 --nce_factor=0.01 --distill_aug=mask --syn_eval_aug=null --cl_T_syn_eval=1.0 --cl_T_distill=0.1
This project is built upon and modified from several existing open-source repositories:
- https://github.com/GeorgeCazenavette/mtt-distillation
- https://github.com/SUSTC-ChuangYANG/SIMformer/
- https://github.com/changyanchuan/TrajCL
- https://github.com/facebookresearch/moco
Thanks for their contributions.