Skip to content

Comments

[WIP] wandb + hydra#4

Draft
yelabb wants to merge 2 commits intomainfrom
configuration-driven-experiments-hydra
Draft

[WIP] wandb + hydra#4
yelabb wants to merge 2 commits intomainfrom
configuration-driven-experiments-hydra

Conversation

@yelabb
Copy link
Owner

@yelabb yelabb commented Jan 22, 2026

No description provided.

Copilot AI review requested due to automatic review settings January 22, 2026 20:54
@yelabb yelabb marked this pull request as draft January 22, 2026 20:54
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces a major architectural refactoring to PhantomX v0.2, transitioning from scattered experiment scripts to a unified, configuration-driven framework using Hydra and WandB. The goal is to replace 25+ individual experiment scripts with a single entry point (train.py) and YAML configurations.

Changes:

  • New configuration system using Hydra with YAML configs for models, datasets, trainers, and experiments
  • WandB integration for experiment tracking and logging
  • Unified training infrastructure with src/trainer.py supporting both standard and progressive training
  • Modular data loading with src/datamodules/ for MC_Maze and MC_RTT datasets
  • Centralized utilities for seeding, metrics, logging, and augmentation
  • Updated dependencies and package structure

Reviewed changes

Copilot reviewed 32 out of 33 changed files in this pull request and generated 41 comments.

Show a summary per file
File Description
train.py Single unified entry point for all experiments with Hydra integration
src/trainer.py Generic trainer supporting VQ-VAE progressive training and standard training loops
src/datamodules/ Base datamodule and implementations for MC_Maze and MC_RTT datasets
src/utils/ Logging (WandB), seeding, metrics, and augmentation utilities
configs/ Comprehensive YAML configurations for models, datasets, trainers, and experiments
requirements.txt Added Hydra, WandB, and other dependencies for v0.2
pyproject.toml Updated package metadata and dependencies for v0.2
evaluate.py Standalone evaluation script for trained models
docs/MIGRATION_GUIDE.md Migration guide from v0.1 to v0.2
README.md Updated documentation with new v0.2 usage examples

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +87 to +89
@property
def training(self) -> bool:
return True # Override in wrapper
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 'training' property always returns True, which means augmentation will always be applied even during validation/testing. This defeats the purpose of checking 'self.training'. Consider removing this property if there's no proper way to track training mode, or integrate these transforms with PyTorch modules that have a proper train()/eval() mode.

Copilot uses AI. Check for mistakes.
if torch.rand(1).item() < self.p:
# Drop a contiguous block
block_size = int(n_channels * self.p * 2)
start = torch.randint(0, n_channels - block_size, (1,)).item()
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When block_size is large or n_channels - block_size is small or zero, torch.randint will fail with an error. For example, if n_channels=10, p=0.2, then block_size=int(100.22)=4, and torch.randint(0, 10-4, (1,)) would work. However, if n_channels=10, p=0.3, block_size=int(100.32)=6, and if n_channels-block_size=4, it works. But if p=0.5, block_size=10, then n_channels-block_size=0, and torch.randint(0, 0, (1,)) will raise an error. Add a check to ensure n_channels > block_size before calling torch.randint, or handle the edge case appropriately.

Suggested change
start = torch.randint(0, n_channels - block_size, (1,)).item()
# Ensure block_size is within a valid range to avoid invalid torch.randint arguments
block_size = max(1, min(block_size, n_channels))
start = torch.randint(0, n_channels - block_size + 1, (1,), device=neural.device).item()

Copilot uses AI. Check for mistakes.

def _should_pin_memory(self) -> bool:
"""Only pin memory if CUDA is available."""
import torch
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redefining 'torch' within the method is unnecessary and could cause confusion. The torch module is already imported at the module level (line 12), so this line is redundant and should be removed.

Suggested change
import torch

Copilot uses AI. Check for mistakes.
Comment on lines +103 to +120
def compute_bits_per_spike(
predictions: Union[np.ndarray, torch.Tensor],
targets: Union[np.ndarray, torch.Tensor],
spike_counts: Union[np.ndarray, torch.Tensor],
) -> float:
"""
Compute bits per spike metric (NLB standard).

Args:
predictions: Predicted spike rates (log-rates)
targets: Actual spike counts
spike_counts: Total spike counts for normalization

Returns:
Bits per spike
"""
# This is a placeholder - implement based on NLB evaluation code
raise NotImplementedError("Bits per spike not yet implemented")
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function compute_bits_per_spike raises NotImplementedError but is included in the public API. If this metric is not yet implemented, either implement it, remove it from this PR, or clearly document in the docstring that it's a placeholder for future work.

Copilot uses AI. Check for mistakes.
Comment on lines +332 to +367
def train(self) -> Dict[str, Any]:
"""Run progressive training."""
trainer_cfg = self.cfg.get('trainer', {})

# Phase 1: Pre-train
print("\n[Phase 1/3] Pre-training encoder...")
if hasattr(self.model, 'use_vq'):
self.model.use_vq = False

pretrain_epochs = trainer_cfg.get('pretrain', {}).get('epochs', 30)
self.max_epochs = pretrain_epochs
pretrain_results = super().train()

# Phase 2: Init codebook
print("\n[Phase 2/3] Initializing codebook with k-means...")
if hasattr(self.model, 'init_codebook'):
self.model.init_codebook(self.train_loader)

# Phase 3: Fine-tune
print("\n[Phase 3/3] Fine-tuning with VQ...")
if hasattr(self.model, 'use_vq'):
self.model.use_vq = True

# Reset optimizer with lower LR
finetune_lr = trainer_cfg.get('finetune', {}).get('learning_rate', 3e-4)
finetune_epochs = trainer_cfg.get('finetune', {}).get('epochs', 50)

self.optimizer = AdamW(
self.model.parameters(),
lr=finetune_lr,
weight_decay=trainer_cfg.get('weight_decay', 1e-4),
)
self.scheduler = CosineAnnealingLR(self.optimizer, T_max=finetune_epochs)

# Reset state for fine-tuning
self.max_epochs = finetune_epochs
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ProgressiveTrainer modifies self.max_epochs multiple times during training (lines 342, 367), which breaks the original contract and could confuse users checking trainer.max_epochs. Consider using separate variables like current_phase_epochs instead of mutating the public max_epochs attribute.

Copilot uses AI. Check for mistakes.
"""

from pathlib import Path
from typing import Optional, Dict, Any, Callable
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import of 'Callable' is not used.

Suggested change
from typing import Optional, Dict, Any, Callable
from typing import Optional, Dict, Any

Copilot uses AI. Check for mistakes.
Comment on lines +61 to +65
except ImportError:
pass
try:
from . import mc_rtt
except ImportError:
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'except' clause does nothing but pass and there is no explanatory comment.

Suggested change
except ImportError:
pass
try:
from . import mc_rtt
except ImportError:
except ImportError:
# Optional datamodule: ignore if mc_maze is not installed/available.
pass
try:
from . import mc_rtt
except ImportError:
# Optional datamodule: ignore if mc_rtt is not installed/available.

Copilot uses AI. Check for mistakes.
Comment on lines +61 to +65
except ImportError:
pass
try:
from . import mc_rtt
except ImportError:
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'except' clause does nothing but pass and there is no explanatory comment.

Suggested change
except ImportError:
pass
try:
from . import mc_rtt
except ImportError:
except ImportError:
# mc_maze is an optional datamodule; ignore if it is not available.
pass
try:
from . import mc_rtt
except ImportError:
# mc_rtt is an optional datamodule; ignore if it is not available.

Copilot uses AI. Check for mistakes.
)
if result.returncode == 0:
return result.stdout.strip()
except (subprocess.SubprocessError, FileNotFoundError):
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'except' clause does nothing but pass and there is no explanatory comment.

Copilot uses AI. Check for mistakes.
)
if result.returncode == 0 and result.stdout.strip():
return result.stdout.strip()
except (subprocess.SubprocessError, FileNotFoundError):
Copy link

Copilot AI Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

'except' clause does nothing but pass and there is no explanatory comment.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant