Skip to content
/ MGF Public

[Neurips'24] PyTorch Implementation of "MGF: Mixed Gaussian Flow for Diverse Trajectory Prediction".

License

Notifications You must be signed in to change notification settings

mulplue/MGF

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

4 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MGF: Mixed Gaussian Flow for Diverse Trajectory Prediction

[Neurips'24] PyTorch Implementation of "MGF: Mixed Gaussian Flow for Diverse Trajectory Prediction". (https://arxiv.org/abs/2402.12238)

mgf_arch

Installation

Environment

conda create -n mgf python=3.9
conda activate mgf
pip install -r requirements.txt

Data

python src/data/TP/process_data.py

Evaluation

Run Inference

python src/test.py --scene {scene_name}
  • scene_name = eth/hotel/univ/zara1/zara2/sdd

Expected Results

ETH HOTEL UNIV ZARA1 ZARA2 SDD
ADE 0.40 0.13 0.21 0.17 0.14 7.74
FDE 0.59 0.20 0.39 0.29 0.24 12.17

Training

python src/train.py --model_name {model_name} --scene {scene_name} --gpu {gpu_id}
  • scene_name = eth/hotel/univ/zara1/zara2/sdd
  • Note: The real parameters are set in the config file, not in args.

Acknowledgement

  • The basic code framework is based on FlowChain.

About

[Neurips'24] PyTorch Implementation of "MGF: Mixed Gaussian Flow for Diverse Trajectory Prediction".

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages