The implementation of the paper "Not All Prompts Are Made Equal: Prompt-based Pruning of Text-to-Image Diffusion Models"
APTP: We prune a text-to-image diffusion model like Stable Diffusion (left) into a mixture of efficient experts (right) in a prompt-based manner. Our prompt router routes distinct types of prompts to different experts, allowing experts' architectures to be separately specialized by removing layers or channels.
APTP pruning scheme. We train the prompt router and the set of architecture codes to prune a T2I diffusion model into a mixture of experts. The prompt router consists of three modules. We use a Sentence Transformer as the prompt encoder to encode the input prompt into a representation z. Then, the architecture predictor transforms z into the architecture embedding e that has the same dimensionality as architecture codes. Finally, the router routes the embedding e into an architecture code a(i). We use optimal transport to evenly distribute the prompts in a training batch among the architecture codes. The architecture code a(i) = (u(i), v(i)) determines pruning the model’s width and depth. We train the prompt router’s parameters and architecture codes in an end-to-end manner using the denoising objective of the pruned model LDDPM, distillation loss between the pruned and original models Ldistill, average resource usage for the samples in the batch R, and contrastive objective Lcont, encouraging embeddings e preserving semantic similarity of the representations z.
Follow these steps to set up the project:
Use the provided env.yaml file:
conda env create -f env.yamlActivate the environment:
conda activate pdmFrom the project root directory, install the dependencies:
pip install -e .Prepare the data for training as mentioned in the paper. You can also adapt aptp for your own dataset with minor code modifications.
Follow the instructions here to download Conceptual Captions. Place the data in a directory of your choice, maintaining the structure:
conceptual_captions
├── Train_GCC-training.tsv
├── Val_GCC-1.1.0-Validation.tsv
├── training
│ ├── 10007_560483514
│ └── ...
└── validation
├── 1852290_2006010568
└── ...
There are some urls in the Conceptual Captions dataset that are not valid. The download will result in some corrupt files that can't be opened. These could be removed to ensure a more efficient training.
Download 2014 train and 2014 val images from the COCO website. Place them in your chosen directory.
Download the 2014 train/val annotations and place them in the same directory as the images. Your directory should look like this:
coco
├── annotations
│ ├── captions_train2014.json
│ ├── captions_val2014.json
│ └── ...
└── images
├── train2014
│ ├── COCO_train2014_000000000009.jpg
│ └── ...
└── val2014
├── COCO_val2014_000000000042.jpg
└── ...
Training is done in two stages: pruning the pretrained T2I model (Stable Diffusion 2.1 in this case) and fine-tuning each expert on the prompts assigned to it. Configuration files for both Conceptual Captions and MS-COCO are provided in the configs directory. You can use these configuration files to run the pruning process. Sample multi-node SLURM and PBS scripts can be found in cluster scripts.
You can use the following command to run pruning:
accelerate launch scripts/aptp/prune.py \
--base_config_path path/to/configs/pruning/file.yaml \
--cache_dir /path/to/.cache/huggingface/ \
--wandb_run_name WANDB_PRUNING_RUN_NAME This creates a checkpoint directory named "wandb_run_name" in the logging directory specified in the config file.
The pruning stages results in pruning_checkpoint_dir, you can use the following command to run this filtering process:
accelerate launch scripts/aptp/filter_dataset.py \
--pruning_ckpt_dir path/to/pruning_checkpoint_dir \
--base_config_path path/to/configs/filtering/dataset.yaml \
--cache_dir /path/to/.cache/huggingface/This creates two files named {DATASET}_train_mapped_indices.pt and {DATASET}_validation_mapped_indices.pt in the pruning_checkpoint_dir directory. These files contain the expert id for each prompt in the training and validation sets.
Fine-tune an expert on the prompts assigned to it:
accelerate launch scripts/aptp/finteune.py \
--pruning_ckpt_dir path/to/pruning_checkpoint_dir \
--expert_id INDEX \
--base_config_path path/to/configs/finetuning/dataset.yaml \
--cache_dir /path/to/.cache/huggingface/ \
--wandb_run_name WANDB_FINETUNING_RUN_NAMEGenerate images using the experts:
accelerate launch scripts/metrics/generate_fid_images.py \
--finetuning_ckpt_dir path/to/pruning_checkpoint_dir \
--expert_id INDEX \
--base_config_path path/to/configs/img_generation/dataset.yaml \
--cache_dir /path/to/.cache/huggingface/The generated images will be saved in the{DATASET}_fid_images directory within the root finetuning checkpoint directory.
To evaluate APTP, we report the FID, CLIP Score, and CMMD.
We use clean-fid to calculate the FID score. The numbers reported in the paper are calculated using this pytorch legacy mode.
We report FID on the validation set of Conceptual Captions. So we can use the same validation mapped indices file created in the filtering process. First, we need to resize the images to 256x256. You can use the provided script. It will save the images as numpy arrays in the same root directory of the dataset.
We sample 30k images from the 2014 validation set of MS-COCO. Check out the sample and resize script. We need a filtering step on this subset to assign prompts to experts.
Generate custom statistics for both sets of reference images::
from cleanfid import fid
fid.make_custom_stats(dataset, dataset_path, mode="legacy_pytorch") # mode can be clean too.Now we can calculate the FID score for the generate images using the provided script.
To calculate clip score, we use this library. Extract features of reference images with the clip feature extraction script and calculate the score using the clip score script.
We use the cmmd-pytorch library to calculate CMMD. Refer to save_refs.py and compute_cmmd.py scripts for reference set feature extraction and distance calculation details.
Refer to here for config files for all baselines mentioned in the paper. The scripts to run these baselines are in this directory.
This project is licensed under the MIT License - see the LICENSE file for details.
If you find this work useful, please consider citing the following paper:
@article{2024aptp,
title={Not All Prompts Are Made Equal: Prompt-based Pruning of Text-to-Image Diffusion Models},
author={Ganjdanesh, Alireza, and Shirkavand, Reza and Gao, Shangqian and Huang, Heng},
journal={arXiv preprint arXiv:2406.12042},
year={2024}
}
