This project is an implementation of the paper "Reducing Divergence in Batch Normalization for Domain Adaptation."
project/
├── configs/ # Configuration Files
│ └── default.yaml # Default Configuration
├── data/ # Dataset Directory
│ ├── office31/ # Office-31Dataset
│ ├── imageclef/ # ImageCLEF-DADataset
│ ├── officehome/ # Office-HomeDataset
│ └── visda2017/ # VisDA-2017Dataset
├── datasets/ # Dataset Loading Module
│ ├── office31.py # Office-31Dataset Class
│ ├── imageclef.py # ImageCLEF-DADataset Class
│ ├── officehome.py # Office-HomeDataset Class
│ ├── visda2017.py # VisDA-2017Dataset Class
│ └── transforms.py # Data Transformation
├── models/ # Model Definitions
│ ├── backbone.py # Feature Extractor (ResNet)
│ ├── classifier.py # Classifier Head
│ ├── discriminator.py# Domain Discriminator
│ ├── rbn.py # Refined Batch Normalization
│ └── cdan.py # CDAN Model
├── utils/ # Utility Functions
│ ├── logger.py # Logging Utilities
│ ├── losses.py # Loss Functions
│ └── utils.py # Common Utility Functions
├── output/ # Output Directory
│ ├── checkpoints/ # Model Checkpoints
│ └── logs/ # Training Logs
├── scripts/ # Run Scripts
│ ├── train.sh # Training Script
│ └── test.sh # Testing Script
├── train.py # Main Training Program
├── test.py # Main Testing Program
├── requirements.txt # Project Dependencies
└── README.md # Project Documentation
- Python >= 3.7
- PyTorch >= 1.7.0
- CUDA >= 10.1 (for GPU training)
Directory Structure:
data/office31/
├── amazon/
│ └── images/
│ ├── back_pack/
│ ├── bike/
│ └── ...
├── dslr/
│ └── images/
│ ├── back_pack/
│ ├── bike/
│ └── ...
└── webcam/
└── images/
├── back_pack/
├── bike/
└── ...
Directory Structure:
data/imageclef/
├── i/
│ ├── class1/
│ ├── class2/
│ └── ...
├── p/
│ ├── class1/
│ ├── class2/
│ └── ...
└── c/
├── class1/
├── class2/
└── ...
Directory Structure:
data/officehome/
├── Art/
│ ├── Alarm_Clock/
│ ├── Backpack/
│ └── ...
├── Clipart/
├── Product/
└── Real_World/
Directory Structure:
data/visda2017/
├── train/
│ ├── aeroplane/
│ ├── bicycle/
│ └── ...
└── validation/
├── aeroplane/
├── bicycle/
└── ...
Modify configs/default.yaml in the configuration parameters:
data:
dataset_type: 'office31'
num_classes: 31
model:
use_rbn: true
replace_layer: 3bash scripts/train.shDirect execution:
python train.py \
--source_data data/office31/amazon \
--target_data data/office31/webcam \
--dataset_type office31 \
--num_classes 31 \
--use_rbn \
--replace_layer 3 \
--epochs 50bash scripts/test.shDirect execution:
python test.py \
--test_data data/office31/webcam \
--dataset_type office31 \
--num_classes 31 \
--model_path output/model_final.pth \
--use_rbn \
--replace_layer 3- Replace the later-stage BN layers in the network with RBN (Refined Batch Normalization) to reduce cumulative estimation bias
- Based on the CDAN framework for domain adaptation.
- Use ResNet-50 as a feature extractor.
- Support multiple commonly used domain adaptation datasets. Provide a complete training and evaluation process.
Accuracy on the Office-31Dataset:
| Method | A → W | D → W | W → D | A → D | D → A | W → A | Avg |
|---|---|---|---|---|---|---|---|
| CDAN | 94.1 | 98.6 | 100.0 | 92.9 | 70.1 | 69.3 | 87.7 |
| CDAN+RBN | 95.9 | 99.1 | 100.0 | 95.7 | 76.1 | 74.5 | 90.2 |
This project is licensed under the MIT License. See the LICENSE file for details.