The code for the contributed paper "Learning to Prune in Training via Dynamic Channel Propagation" accepted by ICPR-2020. In this paper, we propose a novel network training mechanism called "dynamic channel propagation" to prune the deep neural networks during the training period. Note that this is a research project and by definition is unstable. Please write to us if you find something not correct or strange. We are sharing the codes under the condition that reproducing full or part of codes must cite the paper.
Here we show the source code of our scheme.
Our code is based on the deep-learning framework Pytorch and strongly reference to its official examples.
- python >= 3.5
- cuda >= 10.0
- torch >= 1.3.0, torchvision
For CIFAR-10, you may directly download it using pytorch API
from torchvision.datasets.cifar import CIFAR10 as dataset
# for training set
dataset(root='../data', train=True, download=True)
# for testing set
dataset(root='../data', train=False, download=True)As for ILSVRC-2012(ImageNet), you have to download it from the URL, unzip it and move the validating images to subfolders by the shell.
Firstly, enter the root directory of the project, and then generate a folder to store the results
cd [root directory of the project]
mkdir model
Type the following code to run on CIFAR-10
pyhton3 main.py -architecture [Vgg or ResNet] -decay [initial value of decay factor] -pr [global pruning rate of channels] -data_dir [path to the dataset]
In regard with pruning ResNet on ILSVRC-2012, type
pyhton3 main2.py -pr [global pruning rate of channels] -data_dir [path to the dataset]
| Data set | Architecture | Top-1 | Top-5 | FLOPs pruned |
|---|---|---|---|---|
| CIFAR-10 | VGG-16 | 93.50% | - | 73.3% |
| CIFAR-10 | ResNet-32 | 92.60% | - | 50.2% |
| ImageNet | ResNet-50 | 74.25% | 92.05% | 41.1% |
Please refer to our paper for more details.