Deconvolution by Optimal Transport for Spatial Transcriptomics
A Python implementation of the DOT algorithm for transferring cell type annotations from single-cell RNA-seq reference data to spatial transcriptomics data using multi-objective optimization.
- 🚀 GPU acceleration using PyTorch for fast computation
- 🧬 AnnData integration - seamlessly works with scanpy workflows
- 🎯 Multi-objective optimization using Frank-Wolfe algorithm
- 📊 High & low resolution support - works with both subcellular (Xenium, MERFISH, CosMx) and spot-based (Visium, ST) technologies
- 🎨 Built-in visualization tools for spatial cell type mapping
- 💾 Checkpointing for long-running optimizations
- ⚡ Mixed precision support for memory-efficient GPU training
git clone https://github.com/earmingol/DOTpy.git
cd DOTpy
pip install -e .- Python >= 3.8
- PyTorch >= 1.10.0 (with CUDA support for GPU acceleration)
- scanpy >= 1.9.0
- anndata >= 0.8.0
- numpy >= 1.20.0
- matplotlib >= 3.5.0
- scikit-learn >= 1.0.0
- scipy >= 1.7.0
import scanpy as sc
from dotpy import DOT, setup_reference, setup_spatial, plot_spatial_weights
# Load data
ref_adata = sc.read_h5ad('reference.h5ad')
spatial_adata = sc.read_h5ad('spatial.h5ad')
# Process reference and spatial data
ref_processed = setup_reference(
ref_adata,
cell_type_key='cell_type',
subcluster_size=10,
max_genes=5000,
verbose=True
)
spatial_processed = setup_spatial(
spatial_adata,
spatial_key='spatial',
th_spatial=0.84,
verbose=True
)
# Run DOT with batching
dot = DOT(
spatial_processed,
ref_processed,
batch_size=500 # Adjust for your GPU memory
)
dot.fit(
mode='highres',
iterations=100,
checkpoint_dir='./checkpoints', # Save checkpoints
checkpoint_freq=10,
verbose=True
)
# Get results
weights = dot.get_weights(normalize=True)
cell_types = dot.get_cell_types()
# Visualize results
plot_spatial_weights(
spatial_adata.obsm['spatial'],
weights,
cell_types=cell_types,
ncols=4,
save_path='cell_type_maps.png'
)dot.fit(
mode='highres',
iterations=100,
resume_from='./checkpoints/checkpoint_iter_50.pkl',
verbose=True
)For subcellular resolution data where each spot typically contains 1 cell:
dot.fit(
mode='highres',
ratios_weight=0.0,
iterations=100,
verbose=True
)For spot-based technologies where spots contain multiple cells:
dot.fit(
mode='lowres',
max_spot_size=20, # Maximum cells per spot
ratios_weight=0.3, # Weight for matching cell type proportions
iterations=100,
verbose=True
)DOT uses multi-objective optimization to find cell type assignments that:
- Match gene expression - Predicted expression should match observed spatial data
- Preserve spatial coherence - Neighboring spots should have similar composition
- Respect cell type abundances - Overall proportions should match reference (optional)
- Enforce sparsity - Limit mixing of cell types per spot
The optimization is performed using the Frank-Wolfe algorithm, which efficiently handles the constrained optimization problem on GPUs.
# Setup reference with custom parameters
ref_processed = setup_reference(
ref_adata,
cell_type_key='cell_type',
subcluster_size=15, # More subclusters per cell type
max_genes=10000, # Use more genes
remove_mt=True, # Remove mitochondrial genes
th_inner_logfold=0.75, # Log-fold threshold for gene selection
random_state=42, # For reproducibility
verbose=True
)
# Setup spatial with custom thresholds
spatial_processed = setup_spatial(
spatial_adata,
spatial_key='spatial',
th_spatial=0.80, # Adjust spatial similarity threshold
th_gene_low=0.01, # Minimum gene expression frequency
th_gene_high=0.99, # Maximum gene expression frequency
radius='auto', # Or specify numeric value
remove_mt=True, # Remove mitochondrial genes
verbose=True
)
# DOT with custom device and optimization settings
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dot = DOT(
spatial_processed,
ref_processed,
batch_size=500, # Adjust for GPU memory
device=device # Explicitly set device
)
# Fine-tune optimization
dot.fit(
mode='highres',
ratios_weight=0.2, # Weight for abundance matching
iterations=200, # More iterations
gap_threshold=0.001, # Tighter convergence
use_mixed_precision=True, # Use float16 on GPU
verbose=True
)# Check if CUDA is available
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
# Pass device to DOT
dot = DOT(
spatial_processed,
ref_processed,
device=device
)# Add results to spatial AnnData
spatial_adata.obsm['dot_weights'] = weights
# Add individual cell type columns
for i, ct in enumerate(cell_types):
spatial_adata.obs[f'dot_{ct}'] = weights[:, i]
# Save
spatial_adata.write('spatial_with_deconvolution.h5ad')from dotpy.visualization import plot_spatial_weights
fig = plot_spatial_weights(
coords=spatial_adata.obsm['spatial'],
weights=weights,
cell_types=cell_types,
ncols=4,
point_size=10,
cmap='magma',
flip_y=True,
save_path='cell_type_maps.png',
dpi=300
)from dotpy.visualization import plot_optimization_history
fig = plot_optimization_history(
dot.history,
save_path='optimization_history.png'
)from dotpy.visualization import plot_cell_type_proportions
fig = plot_cell_type_proportions(
weights,
cell_types=cell_types,
save_path='proportions.png'
)DOTpy automatically uses CUDA if available. For best performance:
# Check GPU memory
import torch
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")For very large datasets:
# Reduce number of genes
ref_processed = setup_reference(
ref_adata,
max_genes=2000, # Use fewer genes
...
)
# Reduce subclustering
ref_processed = setup_reference(
ref_adata,
subcluster_size=5, # Fewer subclusters
...
)
# Use smaller batch size
dot = DOT(spatial, ref, batch_size=100)
# Enable mixed precision on GPU
dot.fit(
mode='highres',
use_mixed_precision=True,
iterations=100
)# Faster (fewer iterations)
dot.fit(mode='highres', iterations=50)
# More accurate (more iterations, tighter convergence)
dot.fit(
mode='highres',
iterations=200,
gap_threshold=0.001
)This PyTorch implementation provides:
- ✅ Faster computation through GPU acceleration
- ✅ Same algorithm and mathematical formulation
- ✅ AnnData integration for Python/scanpy workflows
- ✅ Memory efficiency through PyTorch's optimized operations
Key differences:
- Uses PyTorch tensors instead of R matrices
- Integrates with scanpy/AnnData instead of Seurat
- Supports GPU acceleration out of the box
If you use DOT in your research, please cite:
Rahimi, A., Vale-Silva, L.A., Fälth Savitski, M. et al.
DOT: a flexible multi-objective optimization framework for transferring features across single-cell and spatial omics.
Nat Commun 15, 4994 (2024). https://doi.org/10.1038/s41467-024-48868-z
# Solution 1: Reduce batch size
dot = DOT(spatial, ref, batch_size=100)
# Solution 2: Enable mixed precision
dot.fit(mode='highres', use_mixed_precision=True)
# Solution 3: Use CPU
dot = DOT(spatial, ref, device='cpu')# Solution: Reduce data size
ref = setup_reference(adata, max_genes=2000, subcluster_size=5)# Check gene names
print(f"Ref genes: {ref_adata.var_names[:10]}")
print(f"Spatial genes: {spatial_adata.var_names[:10]}")
# Ensure gene names match (e.g., both use same gene ID system)# Solution: More iterations or looser threshold
dot.fit(iterations=200, gap_threshold=0.05)For questions and issues, please open an issue on GitHub.
Contributions are welcome! Please feel free to submit a Pull Request.
This library was written in Python using Claude Sonnet 4.5 and GPT-5.2 models.
MIT License