Conversation
There was a problem hiding this comment.
Pull request overview
This PR introduces a major architectural refactoring to PhantomX v0.2, transitioning from scattered experiment scripts to a unified, configuration-driven framework using Hydra and WandB. The goal is to replace 25+ individual experiment scripts with a single entry point (train.py) and YAML configurations.
Changes:
- New configuration system using Hydra with YAML configs for models, datasets, trainers, and experiments
- WandB integration for experiment tracking and logging
- Unified training infrastructure with
src/trainer.pysupporting both standard and progressive training - Modular data loading with
src/datamodules/for MC_Maze and MC_RTT datasets - Centralized utilities for seeding, metrics, logging, and augmentation
- Updated dependencies and package structure
Reviewed changes
Copilot reviewed 32 out of 33 changed files in this pull request and generated 41 comments.
Show a summary per file
| File | Description |
|---|---|
| train.py | Single unified entry point for all experiments with Hydra integration |
| src/trainer.py | Generic trainer supporting VQ-VAE progressive training and standard training loops |
| src/datamodules/ | Base datamodule and implementations for MC_Maze and MC_RTT datasets |
| src/utils/ | Logging (WandB), seeding, metrics, and augmentation utilities |
| configs/ | Comprehensive YAML configurations for models, datasets, trainers, and experiments |
| requirements.txt | Added Hydra, WandB, and other dependencies for v0.2 |
| pyproject.toml | Updated package metadata and dependencies for v0.2 |
| evaluate.py | Standalone evaluation script for trained models |
| docs/MIGRATION_GUIDE.md | Migration guide from v0.1 to v0.2 |
| README.md | Updated documentation with new v0.2 usage examples |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @property | ||
| def training(self) -> bool: | ||
| return True # Override in wrapper |
There was a problem hiding this comment.
The 'training' property always returns True, which means augmentation will always be applied even during validation/testing. This defeats the purpose of checking 'self.training'. Consider removing this property if there's no proper way to track training mode, or integrate these transforms with PyTorch modules that have a proper train()/eval() mode.
| if torch.rand(1).item() < self.p: | ||
| # Drop a contiguous block | ||
| block_size = int(n_channels * self.p * 2) | ||
| start = torch.randint(0, n_channels - block_size, (1,)).item() |
There was a problem hiding this comment.
When block_size is large or n_channels - block_size is small or zero, torch.randint will fail with an error. For example, if n_channels=10, p=0.2, then block_size=int(100.22)=4, and torch.randint(0, 10-4, (1,)) would work. However, if n_channels=10, p=0.3, block_size=int(100.32)=6, and if n_channels-block_size=4, it works. But if p=0.5, block_size=10, then n_channels-block_size=0, and torch.randint(0, 0, (1,)) will raise an error. Add a check to ensure n_channels > block_size before calling torch.randint, or handle the edge case appropriately.
| start = torch.randint(0, n_channels - block_size, (1,)).item() | |
| # Ensure block_size is within a valid range to avoid invalid torch.randint arguments | |
| block_size = max(1, min(block_size, n_channels)) | |
| start = torch.randint(0, n_channels - block_size + 1, (1,), device=neural.device).item() |
|
|
||
| def _should_pin_memory(self) -> bool: | ||
| """Only pin memory if CUDA is available.""" | ||
| import torch |
There was a problem hiding this comment.
Redefining 'torch' within the method is unnecessary and could cause confusion. The torch module is already imported at the module level (line 12), so this line is redundant and should be removed.
| import torch |
| def compute_bits_per_spike( | ||
| predictions: Union[np.ndarray, torch.Tensor], | ||
| targets: Union[np.ndarray, torch.Tensor], | ||
| spike_counts: Union[np.ndarray, torch.Tensor], | ||
| ) -> float: | ||
| """ | ||
| Compute bits per spike metric (NLB standard). | ||
|
|
||
| Args: | ||
| predictions: Predicted spike rates (log-rates) | ||
| targets: Actual spike counts | ||
| spike_counts: Total spike counts for normalization | ||
|
|
||
| Returns: | ||
| Bits per spike | ||
| """ | ||
| # This is a placeholder - implement based on NLB evaluation code | ||
| raise NotImplementedError("Bits per spike not yet implemented") |
There was a problem hiding this comment.
The function compute_bits_per_spike raises NotImplementedError but is included in the public API. If this metric is not yet implemented, either implement it, remove it from this PR, or clearly document in the docstring that it's a placeholder for future work.
| def train(self) -> Dict[str, Any]: | ||
| """Run progressive training.""" | ||
| trainer_cfg = self.cfg.get('trainer', {}) | ||
|
|
||
| # Phase 1: Pre-train | ||
| print("\n[Phase 1/3] Pre-training encoder...") | ||
| if hasattr(self.model, 'use_vq'): | ||
| self.model.use_vq = False | ||
|
|
||
| pretrain_epochs = trainer_cfg.get('pretrain', {}).get('epochs', 30) | ||
| self.max_epochs = pretrain_epochs | ||
| pretrain_results = super().train() | ||
|
|
||
| # Phase 2: Init codebook | ||
| print("\n[Phase 2/3] Initializing codebook with k-means...") | ||
| if hasattr(self.model, 'init_codebook'): | ||
| self.model.init_codebook(self.train_loader) | ||
|
|
||
| # Phase 3: Fine-tune | ||
| print("\n[Phase 3/3] Fine-tuning with VQ...") | ||
| if hasattr(self.model, 'use_vq'): | ||
| self.model.use_vq = True | ||
|
|
||
| # Reset optimizer with lower LR | ||
| finetune_lr = trainer_cfg.get('finetune', {}).get('learning_rate', 3e-4) | ||
| finetune_epochs = trainer_cfg.get('finetune', {}).get('epochs', 50) | ||
|
|
||
| self.optimizer = AdamW( | ||
| self.model.parameters(), | ||
| lr=finetune_lr, | ||
| weight_decay=trainer_cfg.get('weight_decay', 1e-4), | ||
| ) | ||
| self.scheduler = CosineAnnealingLR(self.optimizer, T_max=finetune_epochs) | ||
|
|
||
| # Reset state for fine-tuning | ||
| self.max_epochs = finetune_epochs |
There was a problem hiding this comment.
The ProgressiveTrainer modifies self.max_epochs multiple times during training (lines 342, 367), which breaks the original contract and could confuse users checking trainer.max_epochs. Consider using separate variables like current_phase_epochs instead of mutating the public max_epochs attribute.
| """ | ||
|
|
||
| from pathlib import Path | ||
| from typing import Optional, Dict, Any, Callable |
There was a problem hiding this comment.
Import of 'Callable' is not used.
| from typing import Optional, Dict, Any, Callable | |
| from typing import Optional, Dict, Any |
| except ImportError: | ||
| pass | ||
| try: | ||
| from . import mc_rtt | ||
| except ImportError: |
There was a problem hiding this comment.
'except' clause does nothing but pass and there is no explanatory comment.
| except ImportError: | |
| pass | |
| try: | |
| from . import mc_rtt | |
| except ImportError: | |
| except ImportError: | |
| # Optional datamodule: ignore if mc_maze is not installed/available. | |
| pass | |
| try: | |
| from . import mc_rtt | |
| except ImportError: | |
| # Optional datamodule: ignore if mc_rtt is not installed/available. |
| except ImportError: | ||
| pass | ||
| try: | ||
| from . import mc_rtt | ||
| except ImportError: |
There was a problem hiding this comment.
'except' clause does nothing but pass and there is no explanatory comment.
| except ImportError: | |
| pass | |
| try: | |
| from . import mc_rtt | |
| except ImportError: | |
| except ImportError: | |
| # mc_maze is an optional datamodule; ignore if it is not available. | |
| pass | |
| try: | |
| from . import mc_rtt | |
| except ImportError: | |
| # mc_rtt is an optional datamodule; ignore if it is not available. |
| ) | ||
| if result.returncode == 0: | ||
| return result.stdout.strip() | ||
| except (subprocess.SubprocessError, FileNotFoundError): |
There was a problem hiding this comment.
'except' clause does nothing but pass and there is no explanatory comment.
| ) | ||
| if result.returncode == 0 and result.stdout.strip(): | ||
| return result.stdout.strip() | ||
| except (subprocess.SubprocessError, FileNotFoundError): |
There was a problem hiding this comment.
'except' clause does nothing but pass and there is no explanatory comment.
No description provided.