Most Important Person-guided Dual-branch Cross-Patch Attention for Group Affect Recognition (ICCV 2023)
This repository is the PyTorch implementation of "Most Important Person-guided Dual-branch Cross-Patch Attention for Group Affect Recognition." Please feel free to reach out for any questions or discussions.
If you use the codes and models from this repo, please cite our work. Thanks!
@inproceedings{xie2023most,
title={Most Important Person-Guided Dual-Branch Cross-Patch Attention for Group Affect Recognition},
author={Xie, Hongxia and Lee, Ming-Xian and Chen, Tzu-Jui and Chen, Hung-Jen and Liu, Hou-I and Shuai, Hong-Han and Cheng, Wen-Huang},
booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision},
pages={20598--20608},
year={2023}
}
To install requirements:
pip install -r requirements.txt
With conda:
conda create -n DCAT python=3.8
conda activate DCAT
conda install pytorch=1.7.1 torchvision cudatoolkit=11.0 -c pytorch -c nvidia
pip install -r requirements.txt
pip install mtcnn
pip install tensorflow
cd GetMIP_POINT/
Edit GAF_Face_filter.py line 27-29, to fill in the original image path, the directory to save your preprocess images, and the path to save the index of preprocess images.
To get preprocess images, run:
python GAF_Face_filter.pyCreate a new conda environment with the requirements below:
(1) Pytorch 1.0.0
(2) Python 3.6+
(3) Python packages: numpy, scipy, pyyaml/yaml, h5py, opencv, PIL
Download the pretrained model of POINT from: https://github.com/harlanhong/POINT
To get MIP result, run:
python POINT_new_dataset_test.py \
--index_name [GAF_Face_filter.py line.29] \
--dataset_path [GAF_Face_filter.py line.28] \
--result_dir [...] \
--model [path to pretrained model of POINT] --h 4 --N 2e.g.
python POINT_new_dataset_test.py \
--index_name ./data/GAF3_process/Validation/Neg_index.npy \
--dataset_path ./data/GAF3_process/Validation/Negative \
--result_dir ./resultFile/GAF3/Validation/Negative/ \
--model ./models/MS_h4_N2.pkl --h 4 --N 2We provide models trained on GAF 3.0 and GroupEmoW. You can find models here.
To train DCAT on GAF3.0 on a single node with 1 gpus for 300 epochs run:
python -m torch.distributed.launch --nproc_per_node=1 --master_port 12346 --use_env main.py --batch-size 64 \
--data-path /path/to/GAF_3.0/ \
--mip_root_train /path/to/GAF_3.0/train/ \
--mip_cropped_root_train /path/to/GAF3_process/Train/ \
--mip_txt_root_train /path/to/resultFile/GAF3/Train/ \
--mip_root_val /path/to/GAF_3.0/val/ \
--mip_cropped_root_val/path/to/GAF3_process/Validation/ \
--mip_txt_root_val /path/to/resultFile/GAF3/Validation/ \
--output_dir ./checkpoint --data-set GAF --mip_select --mip_keep_ratio 0.5 --full_select --full_keep_ratio 0.5To evaluate a pretrained model on DCAT:
python -m torch.distributed.launch --nproc_per_node=1 --master_port 12346 --use_env main.py --batch-size 64 \
--data-path /path/to/GAF_3.0/ \
--mip_root_train /path/to/GAF_3.0/train/ \
--mip_cropped_root_train /path/to/GAF3_process/Train/ \
--mip_txt_root_train /path/to/resultFile/GAF3/Train/ \
--mip_root_val /path/to/GAF_3.0/val/ \
--mip_cropped_root_val/path/to/GAF3_process/Validation/ \
--mip_txt_root_val /path/to/resultFile/GAF3/Validation/ \
--output_dir ./checkpoint --eval --data-set GAF --mip_select --mip_keep_ratio 0.5 --full_select --full_keep_ratio 0.5 \
--load_pretrained /path/to/pretrained_models/GAF3/model_best.pth