In this post, we will talk about feature-based knowledge distillation. One of the pioneering paper is “FitNets: Hints for Thin Deep Nets” by Adriana Romero, Nicolas Ballas, Samira Ebrahimi Kahou, Antoine Chassang,
Carlo Gatta & Yoshua Bengio. The main idea of this paper is to introduce a new training method called FitNets (short for “Thin Deep Nets”). This method compresses large, wide “teacher” neural networks into smaller “student” networks that are deeper and thinner. The core problem the paper addresses is that while deep networks are very powerful, training them is difficult. This is especially true for networks that are both very deep and “thin” (meaning they have far fewer parameters in their layers).
The paper’s solution is to extend the concept of Knowledge Distillation (KD).
- Standard KD involves training a small student network to mimic the final “soft” outputs (the probabilities) of a large, pre-trained teacher network.
- FitNets extends this by using not only the teacher’s final output but also its intermediate hidden layers as “hints” to guide the student’s training process.
This “hint-based training” acts as a form of pre-training, guiding the student network into a good starting place, which makes the difficult task of training a very deep, thin network possible.
🧠 How FitNets Work: A Two-Stage Process
The FitNet training method is divided into two main stages, as shown in Figure 1 of the paper:

Stage 1: Hint-Based Training (Fig. 1b)
This stage pre-trains the first half of the student network.
- Select Layers: A “hint layer” (usually a middle layer) is chosen from the large teacher network. A corresponding “guided layer” is chosen in the thin student network.
- Add Regressor: Because the student is thinner, its guided layer is smaller (has fewer outputs) than the teacher’s hint layer. To make them comparable, a small regressor network is attached to the student’s guided layer to expand its output to match the size of the teacher’s hint.
- Train: The parameters of the first half of the student network (up to the guided layer) and the regressor are trained to make the student’s guided output match the teacher’s hint output.
This initial stage effectively forces the student’s early layers to learn similar intermediate representations as the teacher, providing a crucial “hint” that makes the network easier to optimize.
Stage 2: Knowledge Distillation (Fig. 1c)
After Stage 1 is complete, the regressor is discarded. The entire student network (using the pre-trained parameters from Stage 1 for its first half) is then trained using a standard Knowledge Distillation loss function.
This loss function trains the student to do two things simultaneously:
- Match the true labels of the data (standard classification training).
- Match the softened outputs (the teacher’s probability distribution) from the final layer of the teacher network.
🏆 The Result
By using this hint-based, two-stage process, the authors were able to successfully train very deep and thin networks that standard methods (like backpropagation or even regular KD) failed to train effectively.
The resulting FitNets are:
- Highly Compressed: For example, on the CIFAR-10 dataset, the FitNet student outperformed its teacher (91.61% accuracy vs. 90.18%) while having 10.4 times fewer parameters.
- Fast and Efficient: The “thin” nature of the networks means they require significantly fewer computations and are much faster at inference time.
In short, FitNets provides a way to get the performance benefits of depth while maintaining the speed and efficiency of a thin, compact model.
In the following code, we will implement feature-based knowledge distillation. However, we won’t follow the specific stage-wise training algorithm that was the paper’s key proposal for getting very deep, thin networks to train successfully, because we actually will train a not so deep network for MNIST as an example. However, the code can be easily adapted to FitNets training style
Again, this specific example uses two types of knowledge:
- Logit Distillation: The student learns from the teacher’s final output probabilities (the “soft labels”).
- Feature Distillation: The student also learns to mimic the teacher’s intermediate feature maps (internal representations).
Python
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
Imports
This block imports all the necessary libraries from PyTorch.
torch: The core PyTorch library.torch.nn: Contains the building blocks for neural networks (layers, models, loss functions).torch.nn.functional(asF): Provides common functions like activation (e.g.,relu) and loss functions (e.g.,cross_entropy).torch.optim: Contains optimization algorithms like SGD.torchvision: A package for computer vision that includes popular datasets (like MNIST) and image transformations.torch.utils.data.DataLoader: A utility to easily load, batch, and shuffle data.
Python
# --- 1. Settings and Hyperparameters ---
# Training settings
BATCH_SIZE = 64
EPOCHS_TEACHER = 5
EPOCHS_STUDENT = 10
LEARNING_RATE = 0.01
MOMENTUM = 0.9
# Knowledge Distillation (KD) settings
TEMPERATURE = 10 # Temperature for softening probabilities
ALPHA = 0.1 # Weight for the "hard" (true label) loss. (1-ALPHA) is for the "soft" (logit) loss.
GAMMA = 1.0 # Weight for the "feature" (latent) loss.
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
1. Settings and Hyperparameters
This section defines all the key configuration variables for the experiment.
- Standard Training:
BATCH_SIZE,LEARNING_RATE, andMOMENTUMare standard settings for the SGD optimizer. The teacher is trained for 5 epochs, and the student models are trained for 10. - Knowledge Distillation (KD):
TEMPERATURE: This is a key KD parameter. When applied to the logits before softmax, a higher temperature (T > 1) “softens” the probability distribution. This encourages the student to learn the subtle relationships between classes that the teacher has learned (e.g., “this ‘7’ looks a bit like a ‘1’”).ALPHA: This controls the balance between two losses. A smallALPHA(like 0.1) means the final loss is 10% from the “hard” loss (student vs. true labels) and 90% from the “soft” loss (student vs. teacher logits).GAMMA: This is the weight for the new loss in this script: the feature loss. It controls how much the student is penalized for having different intermediate features than the teacher.
- Device: This code checks if a CUDA-enabled GPU is available and selects it (
"cuda"); otherwise, it defaults to the CPU ("cpu").
Python
# --- 2. Data Loading (MNIST) ---
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
2. Data Loading (MNIST)
This block prepares the MNIST handwritten digit dataset.
transform: Defines a sequence of preprocessing steps.transforms.ToTensor(): Converts the input PIL images into PyTorch Tensors.transforms.Normalize((0.1307,), (0.3081,)): Normalizes the tensor’s pixel values. The numbers (0.1307, 0.3081) are the pre-computed mean and standard deviation of the MNIST dataset.
datasets.MNIST: Downloads (if not already present) and loads the training and test sets, applying thetransformto each image.DataLoader: Wraps the datasets. This utility will feed data to the models in batches (BATCH_SIZE) and will shuffle the training data (shuffle=True) at every epoch to improve generalization.
Python
# --- 3. Model Definitions (Updated) ---
class TeacherNet(nn.Module):
"""Teacher model now returns final logits AND intermediate features."""
def __init__(self):
super(TeacherNet, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc1 = nn.Linear(64 * 7 * 7, 256)
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
# --- Extract intermediate features ---
# These are the features we will force the student to mimic
features = self.pool(F.relu(self.conv2(x))) # Shape: [B, 64, 7, 7]
x_flat = features.view(-1, 64 * 7 * 7) # Flatten
x = F.relu(self.fc1(x_flat))
logits = self.fc2(x) # Output raw logits
return logits, features
3. Model Definition: TeacherNet
This defines the large, complex “Teacher” model. It’s a standard Convolutional Neural Network (CNN) with one crucial modification.
- Architecture: It has two convolutional layers (
conv1with 32 filters,conv2with 64 filters) followed by two fully-connected layers (fc1with 256 hidden units,fc2outputting 10 class scores). forwardmethod: This defines the data flow.- The key change is that instead of just returning the final
logits, this model also returns the intermediatefeaturesfrom after the second convolutional block. - This
featurestensor (shape[Batch, 64, 7, 7]) represents the teacher’s “understanding” of the image at a deep level, and it will be used to train the student.
- The key change is that instead of just returning the final
Python
class StudentNet(nn.Module):
"""Student model now includes a 'feature adapter' and returns adapted features."""
def __init__(self):
super(StudentNet, self).__init__()
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
# --- Feature Adapter ---
# This module's job is to transform the student's features
# (16 channels, 14x14) into the teacher's feature shape (64 channels, 7x7).
self.feature_adapter = nn.Sequential(
nn.Conv2d(16, 64, kernel_size=1), # 1x1 conv to match channels
nn.AdaptiveAvgPool2d((7, 7)) # Avg pool to match spatial dimensions
)
self.fc1 = nn.Linear(16 * 14 * 14, 32) # FC layer still uses the original 14x14 features
self.fc2 = nn.Linear(32, 10)
def forward(self, x):
# --- Student's base features ---
x_conv = self.pool(F.relu(self.conv1(x))) # Shape: [B, 16, 14, 14]
# --- Adapt features for comparison ---
features_adapted = self.feature_adapter(x_conv) # Shape: [B, 64, 7, 7]
# --- Continue to classification ---
x_flat = x_conv.view(-1, 16 * 14 * 14) # Flatten original features
x = F.relu(self.fc1(x_flat))
logits = self.fc2(x) # Output raw logits
return logits, features_adapted
3. Model Definition: StudentNet
This defines the small, simple “Student” model.
- Architecture: It’s much smaller than the teacher: only one convolutional layer (
conv1with 16 filters) and smaller fully-connected layers (fc1with 32 hidden units). feature_adapter: This is the most important part. The student’s internal features (x_conv) have a shape of[B, 16, 14, 14], which is different from the teacher’s features ([B, 64, 7, 7]). We cannot directly compare them.- The
feature_adapteris a mini-network whose only job is to transform the student’s features into the teacher’s feature shape. nn.Conv2d(16, 64, kernel_size=1): A 1×1 convolution changes the number of channels from 16 to 64.nn.AdaptiveAvgPool2d((7, 7)): This forcefully resizes the spatial dimensions from 14×14 down to 7×7.
- The
forwardmethod:- It calculates its own internal features (
x_conv). - It passes
x_convthrough thefeature_adapterto getfeatures_adapted. - It uses the original
x_convto continue its own classification path to producelogits. - It returns both its
logits(for logit-loss) and itsfeatures_adapted(for feature-loss).
- It calculates its own internal features (
Python
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
3. Model Definition: count_parameters
This is a simple helper function. It iterates through all parameters (weights and biases) in a model and sums up the total number of elements, giving the total trainable parameter count. This is used to show how much smaller the student is than the teacher.
Python
# --- 4. Standard Training and Evaluation (Updated) ---
def train_standard(model, train_loader, optimizer, epoch):
"""Standard training loop (now ignores the second model output)."""
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
# Models now return (logits, features). We only need logits here.
output, _ = model(data)
loss = F.cross_entropy(output, target)
loss.backward()
optimizer.step()
print(f"Train Epoch: {epoch} \tLoss: {loss.item():.6f}")
def evaluate(model, test_loader):
"""Standard evaluation loop (now ignores the second model output)."""
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
# Models now return (logits, features). We only need logits here.
output, _ = model(data)
test_loss += F.cross_entropy(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100. * correct / len(test_loader.dataset)
print(f'Test set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.2f}%)')
return accuracy
4. Standard Training and Evaluation
These are the functions for training a model “from scratch” (i.e., without a teacher).
train_standard: A normal training loop. It iterates through thetrain_loader, gets model predictions, calculates the standard cross-entropy loss (how wrong the prediction is vs. the truetargetlabel), and updates the model weights (optimizer.step()).evaluate: A normal testing loop. It iterates through thetest_loader, calculates the total loss, and counts the number of correct predictions to report accuracy.with torch.no_grad()is used to disable gradient calculation, saving memory and time.- Key Update: Both functions now use
output, _ = model(data). This unpacks the(logits, features)tuple returned by the models and throws away thefeatures(assigning them to_) because standard training only cares about the finaloutput(logits).
Python
# --- 5. Knowledge Distillation Loss and Training (Updated) ---
def distillation_loss(
student_logits,
labels,
teacher_logits,
student_features,
teacher_features,
T,
alpha,
gamma
):
"""
Calculates the *combined* knowledge distillation loss.
Now includes:
1. Hard Loss (Cross-Entropy)
2. Soft Loss (KL Divergence on logits)
3. Feature Loss (MSE on feature maps)
"""
# 1. Soft Loss (Logit Distillation)
soft_loss = nn.KLDivLoss(reduction='batchmean', log_target=True)(
F.log_softmax(student_logits / T, dim=1),
F.softmax(teacher_logits / T, dim=1) # Note: A more standard target here would be F.log_softmax
) * (T * T)
# 2. Hard Loss (Standard Cross-Entropy)
hard_loss = F.cross_entropy(student_logits, labels)
# 3. Feature Loss (Latent Layer Distillation)
# We use Mean Squared Error (MSE) loss
feature_loss = F.mse_loss(student_features, teacher_features)
# Combine the losses
# Logit/Hard loss combo is weighted by (1-alpha) and alpha
# The feature loss is then added, weighted by gamma
combined_loss = (alpha * hard_loss) + ((1 - alpha) * soft_loss) \
+ (gamma * feature_loss)
return combined_loss
5. Distillation Loss Function
This is the core logic of the experiment. This function calculates a single, combined loss value from three different components:
soft_loss: This is the classic KD loss. It measures the difference (using KL Divergence) between the student’s and teacher’s “softened” probability distributions. Dividing byT(Temperature) makes the distributions “softer,” and multiplying by(T * T)scales the gradient back to a reasonable magnitude.hard_loss: This is the standard cross-entropy loss, comparing the student’s logits to the true labels. This ensures the student still learns to predict the correct answer.feature_loss: This is the new component. It uses Mean Squared Error (mse_loss) to force the student’sfeatures_adaptedto be as numerically similar as possible to the teacher’sfeatures.
The function then returns a weighted sum of these three losses, using alpha to balance the hard/soft logit losses and gamma to control the strength of the feature-matching loss.
Python
def train_distillation(
student,
teacher,
train_loader,
optimizer,
epoch,
T,
alpha,
gamma
):
"""Training loop for full knowledge distillation (logits + features)."""
student.train()
teacher.eval() # Teacher is in eval mode
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
# Get student outputs
student_logits, student_features = student(data)
# Get teacher outputs (with no_grad)
with torch.no_grad():
teacher_logits, teacher_features = teacher(data)
# Calculate the combined distillation loss
loss = distillation_loss(
student_logits, target, teacher_logits,
student_features, teacher_features,
T, alpha, gamma
)
loss.backward()
optimizer.step()
print(f"Train Epoch: {epoch} \tKD Loss: {loss.item():.6f}")
5. Distillation Training Loop
This is the training function that uses the distillation_loss.
student.train(): Puts the student model in training mode (enables dropout, etc.).teacher.eval(): Crucially, puts the teacher model in evaluation mode. This freezes the teacher’s weights (e.g., in batchnorm) and, more importantly, ensures its weights are not updated.- In the loop:
- It gets outputs from both the student and teacher.
with torch.no_grad()is used for the teacher pass to prevent gradients from being calculated for the teacher, saving memory.- It calls the
distillation_lossfunction with all the required inputs (logits and features from both models, plus labels and hyperparameters). loss.backward()calculates gradients only for the student model’s parameters.optimizer.step()updates only the student model’s weights.
Python
# --- 6. Main Execution ---
# --- Step A: Train the Teacher ---
print("--- 1. Training Teacher Model ---")
teacher_model = TeacherNet().to(device)
optimizer_teacher = optim.SGD(teacher_model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)
for epoch in range(1, EPOCHS_TEACHER + 1):
# Note: We update train_standard to handle the new model output
train_standard(teacher_model, train_loader, optimizer_teacher, epoch)
print("\n--- Evaluating Teacher Model ---")
teacher_acc = evaluate(teacher_model, test_loader)
6. Main Execution (Step A: Train Teacher)
This is the first step of the experiment.
- An instance of
TeacherNetis created and moved to the GPU/CPU (device). - An
SGDoptimizer is created to update the teacher’s parameters. - The
train_standardfunction is called forEPOCHS_TEACHER(5) epochs. This trains the teacher model normally, using only the true labels. - Finally, the trained teacher is evaluated on the test set, and its accuracy is stored in
teacher_acc.
Python
# --- Step B: Train Student with Knowledge Distillation ---
print("\n--- 2. Training Student with Full Distillation (Logits + Features) ---")
student_model_kd = StudentNet().to(device)
optimizer_student_kd = optim.SGD(student_model_kd.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)
for epoch in range(1, EPOCHS_STUDENT + 1):
train_distillation(
student=student_model_kd,
teacher=teacher_model,
train_loader=train_loader,
optimizer=optimizer_student_kd,
epoch=epoch,
T=TEMPERATURE,
alpha=ALPHA,
gamma=GAMMA
)
print("\n--- Evaluating Distilled Student Model ---")
student_kd_acc = evaluate(student_model_kd, test_loader)
6. Main Execution (Step B: Train Distilled Student)
This is the second step, where the knowledge transfer happens.
- An instance of
StudentNetis created (student_model_kd). - An
SGDoptimizer is created for the student’s parameters. - The special
train_distillationfunction is called forEPOCHS_STUDENT(10) epochs. It passes in both the student (to be trained) and the already-trainedteacher_model(to provide guidance). - The resulting “distilled” student is evaluated, and its accuracy is stored in
student_kd_acc.
Python
# --- Step C: Train Student from Scratch (for comparison) ---
print("\n--- 3. Training Student from Scratch (Baseline) ---")
student_model_scratch = StudentNet().to(device)
optimizer_student_scratch = optim.SGD(student_model_scratch.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)
for epoch in range(1, EPOCHS_STUDENT + 1):
train_standard(student_model_scratch, train_loader, optimizer_student_scratch, epoch)
print("\n--- Evaluating 'Scratch' Student Model ---")
student_scratch_acc = evaluate(student_model_scratch, test_loader)
6. Main Execution (Step C: Train Baseline Student)
This is the control experiment, crucial for a fair comparison.
- A new, separate instance of
StudentNetis created (student_model_scratch). - A new optimizer is created for its parameters.
- This student is trained using the
train_standardfunction—the same way the teacher was trained. It never sees the teacher model; it only learns from the true labels. - This “scratch” student is evaluated, and its accuracy is stored in
student_scratch_acc.
Python
# --- 4. Final Comparison ---
print("\n" + "="*30)
print("--- Final Results ---")
print(f"Teacher Model:\t\tParams: {count_parameters(teacher_model):,}\tAccuracy: {teacher_acc:.2f}%")
print(f"Student (Distilled):\tParams: {count_parameters(student_model_kd):,}\tAccuracy: {student_kd_acc:.2f}%")
print(f"Student (Scratch):\tParams: {count_parameters(student_model_scratch):,}\tAccuracy: {student_scratch_acc:.2f}%")
print("="*30)
6. Main Execution (Step 4: Final Comparison)
This final block prints a summary of the results. It uses the count_parameters function to show the model sizes and prints the final accuracies.
The expected outcome is:
- Teacher: High parameter count, high accuracy.
- Student (Scratch): Low parameter count, lower accuracy.
- Student (Distilled): Low parameter count (same as scratch), but higher accuracy than the scratch student, demonstrating that it successfully learned “dark knowledge” from the teacher.
whole code
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# --- 1. Settings and Hyperparameters ---
# Training settings
BATCH_SIZE = 64
EPOCHS_TEACHER = 5
EPOCHS_STUDENT = 10
LEARNING_RATE = 0.01
MOMENTUM = 0.9
# Knowledge Distillation (KD) settings
TEMPERATURE = 10 # Temperature for softening probabilities
ALPHA = 0.1 # Weight for the "hard" (true label) loss. (1-ALPHA) is for the "soft" (logit) loss.
GAMMA = 1.0 # Weight for the "feature" (latent) loss.
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# --- 2. Data Loading (MNIST) ---
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
# --- 3. Model Definitions (Updated) ---
class TeacherNet(nn.Module):
"""Teacher model now returns final logits AND intermediate features."""
def __init__(self):
super(TeacherNet, self).__init__()
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc1 = nn.Linear(64 * 7 * 7, 256)
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
# --- Extract intermediate features ---
# These are the features we will force the student to mimic
features = self.pool(F.relu(self.conv2(x))) # Shape: [B, 64, 7, 7]
x_flat = features.view(-1, 64 * 7 * 7) # Flatten
x = F.relu(self.fc1(x_flat))
logits = self.fc2(x) # Output raw logits
return logits, features
class StudentNet(nn.Module):
"""Student model now includes a 'feature adapter' and returns adapted features."""
def __init__(self):
super(StudentNet, self).__init__()
self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
# --- Feature Adapter ---
# This module's job is to transform the student's features
# (16 channels, 14x14) into the teacher's feature shape (64 channels, 7x7).
self.feature_adapter = nn.Sequential(
nn.Conv2d(16, 64, kernel_size=1), # 1x1 conv to match channels
nn.AdaptiveAvgPool2d((7, 7)) # Avg pool to match spatial dimensions
)
self.fc1 = nn.Linear(16 * 14 * 14, 32) # FC layer still uses the original 14x14 features
self.fc2 = nn.Linear(32, 10)
def forward(self, x):
# --- Student's base features ---
x_conv = self.pool(F.relu(self.conv1(x))) # Shape: [B, 16, 14, 14]
# --- Adapt features for comparison ---
features_adapted = self.feature_adapter(x_conv) # Shape: [B, 64, 7, 7]
# --- Continue to classification ---
x_flat = x_conv.view(-1, 16 * 14 * 14) # Flatten original features
x = F.relu(self.fc1(x_flat))
logits = self.fc2(x) # Output raw logits
return logits, features_adapted
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
# --- 4. Standard Training and Evaluation (Updated) ---
def train_standard(model, train_loader, optimizer, epoch):
"""Standard training loop (now ignores the second model output)."""
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
# Models now return (logits, features). We only need logits here.
output, _ = model(data)
loss = F.cross_entropy(output, target)
loss.backward()
optimizer.step()
print(f"Train Epoch: {epoch} \tLoss: {loss.item():.6f}")
def evaluate(model, test_loader):
"""Standard evaluation loop (now ignores the second model output)."""
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
# Models now return (logits, features). We only need logits here.
output, _ = model(data)
test_loss += F.cross_entropy(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
accuracy = 100. * correct / len(test_loader.dataset)
print(f'Test set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.2f}%)')
return accuracy
# --- 5. Knowledge Distillation Loss and Training (Updated) ---
def distillation_loss(
student_logits,
labels,
teacher_logits,
student_features,
teacher_features,
T,
alpha,
gamma
):
"""
Calculates the *combined* knowledge distillation loss.
Now includes:
1. Hard Loss (Cross-Entropy)
2. Soft Loss (KL Divergence on logits)
3. Feature Loss (MSE on feature maps)
"""
# 1. Soft Loss (Logit Distillation)
soft_loss = nn.KLDivLoss(reduction='batchmean', log_target=True)(
F.log_softmax(student_logits / T, dim=1),
F.softmax(teacher_logits / T, dim=1)
) * (T * T)
# 2. Hard Loss (Standard Cross-Entropy)
hard_loss = F.cross_entropy(student_logits, labels)
# 3. Feature Loss (Latent Layer Distillation)
# We use Mean Squared Error (MSE) loss
feature_loss = F.mse_loss(student_features, teacher_features)
# Combine the losses
# Logit/Hard loss combo is weighted by (1-alpha) and alpha
# The feature loss is then added, weighted by gamma
combined_loss = (alpha * hard_loss) + ((1 - alpha) * soft_loss) \
+ (gamma * feature_loss)
return combined_loss
def train_distillation(
student,
teacher,
train_loader,
optimizer,
epoch,
T,
alpha,
gamma
):
"""Training loop for full knowledge distillation (logits + features)."""
student.train()
teacher.eval() # Teacher is in eval mode
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
# Get student outputs
student_logits, student_features = student(data)
# Get teacher outputs (with no_grad)
with torch.no_grad():
teacher_logits, teacher_features = teacher(data)
# Calculate the combined distillation loss
loss = distillation_loss(
student_logits, target, teacher_logits,
student_features, teacher_features,
T, alpha, gamma
)
loss.backward()
optimizer.step()
print(f"Train Epoch: {epoch} \tKD Loss: {loss.item():.6f}")
# --- 6. Main Execution ---
# --- Step A: Train the Teacher ---
print("--- 1. Training Teacher Model ---")
teacher_model = TeacherNet().to(device)
optimizer_teacher = optim.SGD(teacher_model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)
for epoch in range(1, EPOCHS_TEACHER + 1):
# Note: We update train_standard to handle the new model output
train_standard(teacher_model, train_loader, optimizer_teacher, epoch)
print("\n--- Evaluating Teacher Model ---")
teacher_acc = evaluate(teacher_model, test_loader)
# --- Step B: Train Student with Knowledge Distillation ---
print("\n--- 2. Training Student with Full Distillation (Logits + Features) ---")
student_model_kd = StudentNet().to(device)
optimizer_student_kd = optim.SGD(student_model_kd.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)
for epoch in range(1, EPOCHS_STUDENT + 1):
train_distillation(
student=student_model_kd,
teacher=teacher_model,
train_loader=train_loader,
optimizer=optimizer_student_kd,
epoch=epoch,
T=TEMPERATURE,
alpha=ALPHA,
gamma=GAMMA
)
print("\n--- Evaluating Distilled Student Model ---")
student_kd_acc = evaluate(student_model_kd, test_loader)
# --- Step C: Train Student from Scratch (for comparison) ---
print("\n--- 3. Training Student from Scratch (Baseline) ---")
student_model_scratch = StudentNet().to(device)
optimizer_student_scratch = optim.SGD(student_model_scratch.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)
for epoch in range(1, EPOCHS_STUDENT + 1):
train_standard(student_model_scratch, train_loader, optimizer_student_scratch, epoch)
print("\n--- Evaluating 'Scratch' Student Model ---")
student_scratch_acc = evaluate(student_model_scratch, test_loader)
# --- 4. Final Comparison ---
print("\n" + "="*30)
print("--- Final Results ---")
print(f"Teacher Model:\t\tParams: {count_parameters(teacher_model):,}\tAccuracy: {teacher_acc:.2f}%")
print(f"Student (Distilled):\tParams: {count_parameters(student_model_kd):,}\tAccuracy: {student_kd_acc:.2f}%")
print(f"Student (Scratch):\tParams: {count_parameters(student_model_scratch):,}\tAccuracy: {student_scratch_acc:.2f}%")
print("="*30)