Author: Chanoch, Send a Email
A simple project for text-to-image remote sensing image generation.
- Python >= 3.10
- uv package manager
-
Install uv (if not already installed): Refer to the official documentation.
-
Clone the repository:
git clone cd rs-diffusion -
Install dependencies:
uv sync
-
Activate the virtual environment:
# On macOS/Linux source .venv/bin/activate # On Windows .venv\Scripts\activate
The framework uses a unified main.py script with YAML configuration files.
python main.py --config config/{model}/train_config.yaml --mode trainpython main.py --config config/{model}/val_config.yaml --mode valpython main.py --config config/{model}/inference_config.yaml --mode inferenceConfiguration files are located in the config/ directory. See config/README.md for detailed documentation.
Key configuration sections:
- Model: Model architecture parameters
- Engine: Model-specific engine parameters (e.g., diffusion timesteps)
- Data: Data loader configuration
- Process: Process-specific parameters (training, validation, inference)
Using torchrun (recommended):
torchrun --nproc_per_node=4 main.py --config config/DDPM_UNet/train_config.yaml --mode trainOr using the provided script:
./scripts/launch_distributed.sh config/DDPM_UNet/train_config.yaml train 4Using torchrun:
# On node 0
torchrun --nnodes=2 --node_rank=0 --nproc_per_node=4 --master_addr="<node0_ip>" --master_port=29500 main.py --config config/DDPM_UNet/train_config.yaml --mode train
# On node 1
torchrun --nnodes=2 --node_rank=1 --nproc_per_node=4 --master_addr="<node0_ip>" --master_port=29500 main.py --config config/DDPM_UNet/train_config.yaml --mode trainFor SLURM clusters, use the provided script:
sbatch scripts/launch_distributed_slurm.shDistributed training is automatically detected via environment variables (RANK, WORLD_SIZE, LOCAL_RANK). The framework will:
- Automatically wrap the model with DDP
- Use
DistributedSamplerfor data loading - Only save checkpoints and log to TensorBoard on rank 0
- Synchronize losses across all processes
You can configure distributed training options in the config file:
distributed:
enabled: false # Set to true to enable (or use environment variables)
find_unused_parameters: false # Set to true if model has unused parametersNote: The effective batch size is batch_size * num_gpus * num_nodes. Adjust your learning rate accordingly.
The framework uses torchvision.datasets.ImageFolder for ImageNet-like directory structures. Your dataset should be organized as:
data/
train/
class1/
img1.jpg
img2.jpg
class2/
img3.jpg
...
val/
class1/
...