PyTorch Custom Neural Network with Transfer Learning
Advanced PyTorch model with transfer learning, layer freezing, and mixed precision training
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import models, transforms
from torch.optim.lr_scheduler import ReduceLROnPlateau
# Custom dataset
class CustomDataset(Dataset):
def __init__(self, data, labels, transform=None):
self.data = data
self.labels = labels
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
label = self.labels[idx]
if self.transform:
sample = self.transform(sample)
return sample, label
# Transfer learning with ResNet
class TransferLearningModel(nn.Module):
def __init__(self, num_classes, pretrained=True):
super().__init__()
# Load pre-trained ResNet
self.backbone = models.resnet50(pretrained=pretrained)
# Freeze early layers
for param in list(self.backbone.parameters())[:-10]:
param.requires_grad = False
# Replace classifier
num_features = self.backbone.fc.in_features
self.backbone.fc = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(num_features, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, num_classes)
)
def forward(self, x):
return self.backbone(x)
# Training setup
model = TransferLearningModel(num_classes=10, pretrained=True)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)
# Mixed precision training
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()
# Training loop
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for inputs, labels in train_loader:
optimizer.zero_grad()
with autocast():
outputs = model(inputs)
loss = criterion(outputs, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
running_loss += loss.item()
scheduler.step(running_loss)
print(f'Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}')