Skip to content

Example of Knowledge Distillation (KD) using PyTorch on the MNIST dataset

This example demonstrates Knowledge Distillation, a technique where a small “student” model is trained to mimic a larger, pre-trained “teacher” model.

Let’s have a brief introduction to Knowledge Distillation first.

šŸŽ“ What is Knowledge Distillation?

Knowledge Distillation (KD) is a machine learning technique used to transfer knowledge from a large, high-performing model (the “teacher”) to a smaller, more efficient model (the “student”).

The Goal: The main objective is to create a small, fast model (like for a smartphone or edge device) that achieves an accuracy close to the much larger, slower teacher model.

How it Works:

Instead of just training the student on the correct answers (known as “hard labels”), the student is trained to match the teacher’s full thought process (known as “soft labels,” which is the logits layer).

  • Hard Label: The answer is “9”.
  • Soft Label (from Teacher): “I’m 90% sure this is a 9, 8% sure it’s a 7, and 2% sure it’s a 4.”

This “soft” information, often called “dark knowledge,” is much richer. It teaches the student why a 9 looks a bit like a 7 or a 4, leading to a much smarter student model than one trained from scratch on hard labels alone.

In short: You use a big model to “teach” a small model, transferring its “intelligence” into a much more efficient package.

In the following code, we will

  1. Define a larger Teacher model and a smaller Student model.
  2. Define a utility function to count model parameters.
  3. Train the Teacher model on MNIST and evaluate its high accuracy.
  4. Define the distillation loss function.
  5. Train the Student model using knowledge distillation, learning from both the teacher and the true labels.
  6. Train an identical Student model from scratch (without distillation) for comparison.
  7. Print the final accuracies and parameter counts to show the benefit.

Python

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

These are the standard imports for a PyTorch project:

  • torch: The core PyTorch library.
  • torch.nn: Contains all the building blocks for neural networks, like layers (nn.Module, nn.Conv2d, nn.Linear) and loss functions.
  • torch.nn.functional (as F): Provides functions that don’t have learnable parameters, such as activation functions (F.relu) and loss functions (F.cross_entropy).
  • torch.optim: Includes optimization algorithms like optim.SGD (Stochastic Gradient Descent).
  • torchvision: A library for computer vision tasks. We use datasets to load MNIST and transforms to preprocess the images.
  • torch.utils.data.DataLoader: A utility that helps load data in batches, shuffle it, and load it in parallel.

Python

# --- 1. Settings and Hyperparameters ---

# Training settings
BATCH_SIZE = 64
EPOCHS_TEACHER = 5
EPOCHS_STUDENT = 10
LEARNING_RATE = 0.01
MOMENTUM = 0.9

# Knowledge Distillation (KD) settings
TEMPERATURE = 10  # Temperature for softening probabilities
ALPHA = 0.1       # Weight for the "hard" (true label) loss

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

This section defines all the configuration variables for the experiment.

  • Training Settings: These control the basic training process. BATCH_SIZE is the number of images processed in one step. EPOCHS_... defines how many times the model sees the entire dataset. LEARNING_RATE and MOMENTUM are parameters for the SGD optimizer.
  • KD Settings: These are specific to Knowledge Distillation.
    • TEMPERATURE (T): A high temperature “softens” the probability outputs of a model. For example, instead of [0.1, 0.9, 0.0], a high temperature might produce [0.25, 0.5, 0.25]. This forces the student to learn how the teacher model “thinks” (e.g., “it’s pretty sure it’s a 7, but it also looks a bit like a 9”). This is called “dark knowledge.”
    • ALPHA: This variable balances the two loss functions for the student. The student is trained with a combined loss: (ALPHA * hard_loss) + ((1 - ALPHA) * soft_loss). With ALPHA = 0.1, the student pays 10% attention to the true labels (hard loss) and 90% attention to matching the teacher’s softened outputs (soft loss).
  • Device: This code checks if a GPU (cuda) is available. If so, it sets the device to “cuda” to accelerate training; otherwise, it uses the “cpu”.

Python

# --- 2. Data Loading (MNIST) ---

# Standard MNIST transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)) # MNIST mean and std
])

# Load datasets
train_dataset = datasets.MNIST(
    root='./data', train=True, download=True, transform=transform
)
test_dataset = datasets.MNIST(
    root='./data', train=False, download=True, transform=transform
)

# Create data loaders
train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True
)
test_loader = DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=False
)

This chunk handles loading and preparing the MNIST dataset of handwritten digits.

  • transform: This defines a preprocessing pipeline.
    • transforms.ToTensor(): Converts the input images (which are PIL images) into PyTorch tensors and scales their pixel values from the [0, 255] range to the [0.0, 1.0] range.
    • transforms.Normalize((0.1307,), (0.3081,)): This normalizes the tensor’s values. It subtracts the mean (0.1307) and divides by the standard deviation (0.3081) of the MNIST dataset. This helps the model train more stably and faster.
  • datasets.MNIST: This downloads the training (train=True) and testing (train=False) sets to the ./data folder if they don’t already exist. It applies the transform to every image as it’s loaded.
  • DataLoader: This wraps the datasets and turns them into iterators that provide data in batches (batch_size=BATCH_SIZE). shuffle=True for the training loader is crucial; it randomizes the order of data in each epoch to prevent the model from learning the order of the data.
See also  Minibatch learning and variations of Gradient Descent

Python

# --- 3. Model Definitions ---

class TeacherNet(nn.Module):
    """A larger CNN for MNIST (The 'Teacher')"""
    def __init__(self):
        super(TeacherNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Calculate the flattened size after conv and pool layers
        # Input 28x28 -> Conv1 -> 28x28 -> Pool1 -> 14x14
        # -> Conv2 -> 14x14 -> Pool2 -> 7x7
        # Flattened size = 64 * 7 * 7
        self.fc1 = nn.Linear(64 * 7 * 7, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7) # Flatten
        x = F.relu(self.fc1(x))
        x = self.fc2(x) # Output raw logits
        return x

This defines the Teacher model. It’s a relatively large and complex Convolutional Neural Network (CNN).

  • __init__: The constructor defines the layers.
    • nn.Conv2d: Two convolutional layers. The first takes 1 channel (grayscale image) and outputs 32 channels. The second takes 32 and outputs 64. padding=1 keeps the 28×28 size.
    • nn.MaxPool2d: A pooling layer that halves the image dimensions (28×28 -> 14×14, then 14×14 -> 7×7).
    • nn.Linear: Two fully-connected (dense) layers. The first takes the flattened 7×7 image with 64 channels (64 * 7 * 7) and maps it to 256 features. The second maps the 256 features to 10 output values (called logits), one for each digit (0-9).
  • forward: This method defines how data flows through the layers. It’s a sequence of (Conv -> ReLU -> Pool) twice, then a view operation to flatten the 3D feature map into a 1D vector, followed by the two linear layers.

Python

class StudentNet(nn.Module):
    """A much smaller CNN for MNIST (The 'Student')"""
    def __init__(self):
        super(StudentNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Input 28x28 -> Conv1 -> 28x28 -> Pool1 -> 14x14
        # Flattened size = 16 * 14 * 14
        self.fc1 = nn.Linear(16 * 14 * 14, 32)
        self.fc2 = nn.Linear(32, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = x.view(-1, 16 * 14 * 14) # Flatten
        x = F.relu(self.fc1(x))
        x = self.fc2(x) # Output raw logits
        return x

def count_parameters(model):
    """Utility function to count model parameters"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

This section defines the Student model, which is the model we ultimately want to use (e.g., on a mobile device) because it’s small and fast.

  • StudentNet: This model has a similar structure to the teacher but is much smaller and simpler. It has only one convolutional layer (with 16 channels, vs. 32 and 64 for the teacher) and smaller linear layers (mapping to 32 features, vs. 256). This means it has far fewer parameters and will be much faster.
  • count_parameters: This is a helper function that simply counts the total number of learnable parameters (weights and biases) in a model. This is used at the end to prove that the student is, in fact, much smaller than the teacher.

Python

# --- 4. Standard Training and Evaluation Functions ---

def train_standard(model, train_loader, optimizer, epoch):
    """Standard training loop for a classifier."""
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        
        # Standard Cross-Entropy Loss
        loss = F.cross_entropy(output, target)
        
        loss.backward()
        optimizer.step()
        
    print(f"Train Epoch: {epoch} \tLoss: {loss.item():.6f}")

This is a standard training function. It will be used to train the Teacher model and the “baseline” Student model (the one trained without distillation).

  1. model.train(): Puts the model in “training mode” (which activates layers like Dropout, if any).
  2. data, target = data.to(device), target.to(device): Moves the batch of images and labels to the GPU/CPU.
  3. optimizer.zero_grad(): Clears any old gradients from the previous step.
  4. output = model(data): Performs a forward pass to get the model’s predictions (logits).
  5. loss = F.cross_entropy(output, target): Calculates the cross-entropy loss. This loss compares the model’s logits to the true labels (e.g., “this image is a 7”).
  6. loss.backward(): Performs backpropagation, calculating the gradient of the loss with respect to every model parameter.
  7. optimizer.step(): Updates the model’s parameters (weights) using the gradients and the optimizer’s algorithm (SGD).

Python

def evaluate(model, test_loader):
    """Standard evaluation loop."""
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    print(f'Test set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.2f}%)')
    return accuracy

This function evaluates the model’s performance on the test dataset.

  1. model.eval(): Puts the model in “evaluation mode” (disabling Dropout, etc.).
  2. with torch.no_grad(): This is a critical optimization. It tells PyTorch not to track gradients, which saves memory and speeds up computation since we are only testing, not training.
  3. Inside the loop, it calculates the total loss and counts the number of correct predictions.
  4. pred = output.argmax(dim=1, ...): This finds the index of the highest logit value for each image. This index is the model’s final prediction (e.g., 0, 1, 2… 9).
  5. correct += ...: It compares the model’s predictions (pred) to the true labels (target) and sums up the correct ones.
  6. Finally, it prints the average loss and the overall accuracy percentage.

Python

# --- 5. Knowledge Distillation Loss and Training ---

def distillation_loss(student_logits, labels, teacher_logits, T, alpha):
    """
    Calculates the knowledge distillation loss.
    ...
    """
    
    # 1. Distillation Loss (Soft Loss)
    # Use KLDivLoss, which expects log-probabilities (log_softmax) as input
    # and probabilities (softmax) as target.
    
    soft_loss = nn.KLDivLoss(reduction='batchmean')(
        F.log_softmax(student_logits / T, dim=1),
        F.softmax(teacher_logits / T, dim=1)
    ) * (T * T) # Scale the loss by T^2 as proposed in the original paper

    # 2. Student Loss (Hard Loss)
    # Standard cross-entropy loss between student logits and true labels
    hard_loss = F.cross_entropy(student_logits, labels)

    # 3. Combine the losses
    combined_loss = alpha * hard_loss + (1 - alpha) * soft_loss
    return combined_loss

This is the heart of the Knowledge Distillation logic. This custom loss function combines two separate losses:

  1. Soft Loss: This is the “distillation” part.
    • student_logits / T and teacher_logits / T: The outputs of both models are divided by the TEMPERATURE (T=10) to “soften” them.
    • F.log_softmax and F.softmax: We create probability distributions from the softened logits.
    • nn.KLDivLoss: The Kullback-Leibler Divergence Loss. This measures how different the student’s probability distribution is from the teacher’s. The goal is to minimize this difference, forcing the student to mimic the teacher’s “thought process.”
    • * (T * T): This scaling factor is part of the original distillation paper. It’s needed to properly scale the gradients, which are reduced by the softening process.
  2. Hard Loss: This is the F.cross_entropy loss we used before. It compares the student’s logits to the actual, true labels.
  3. Combined Loss: The final loss is a weighted sum, controlled by ALPHA. Since ALPHA = 0.1, the loss is 10% from the true labels and 90% from mimicking the teacher.
See also  Activation, Initialization and Training a Neural network

Python

def train_distillation(student, teacher, train_loader, optimizer, epoch, T, alpha):
    """Training loop for knowledge distillation."""
    student.train()
    teacher.eval() # Teacher is in eval mode and its weights are frozen
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        
        # Get student's output
        student_logits = student(data)
        
        # Get teacher's output (with no_grad to freeze it)
        with torch.no_grad():
            teacher_logits = teacher(data)
        
        # Calculate the distillation loss
        loss = distillation_loss(student_logits, target, teacher_logits, T, alpha)
        
        loss.backward()
        optimizer.step()
        
    print(f"Train Epoch: {epoch} \tKD Loss: {loss.item():.6f}")

This is the special training loop for the distilled student.

  • student.train(): The student model is in training mode.
  • teacher.eval(): The teacher model is in evaluation mode. This is very important; we do not want to train the teacher anymore.
  • with torch.no_grad(): This context is wrapped only around the teacher’s forward pass (teacher(data)). This tells PyTorch not to calculate gradients for the teacher, saving computation and ensuring its weights remain frozen.
  • loss = distillation_loss(...): It calls the custom loss function defined above, using both models’ outputs.
  • loss.backward(): This calculates gradients, but only for the student model (because the teacher was in no_grad mode).
  • optimizer.step(): This updates only the student’s weights.

Python

# --- 6. Main Execution ---

# --- Step A: Train the Teacher ---
print("--- 1. Training Teacher Model ---")
teacher_model = TeacherNet().to(device)
optimizer_teacher = optim.SGD(teacher_model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)

for epoch in range(1, EPOCHS_TEACHER + 1):
    train_standard(teacher_model, train_loader, optimizer_teacher, epoch)

print("\n--- Evaluating Teacher Model ---")
teacher_acc = evaluate(teacher_model, test_loader)

This is Step A of the main script.

  1. An instance of the large TeacherNet is created and moved to the device.
  2. An optimizer is created for the teacher’s parameters.
  3. It trains the teacher model for EPOCHS_TEACHER (5 epochs) using the train_standard function (i.e., using only the true labels and standard cross-entropy loss).
  4. It evaluates the fully-trained teacher and stores its final accuracy in teacher_acc.

Python

# --- Step B: Train Student with Knowledge Distillation ---
print("\n--- 2. Training Student with Distillation ---")
student_model_kd = StudentNet().to(device)
optimizer_student_kd = optim.SGD(student_model_kd.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)

for epoch in range(1, EPOCHS_STUDENT + 1):
    train_distillation(
        student=student_model_kd,
        teacher=teacher_model,
        train_loader=train_loader,
        optimizer=optimizer_student_kd,
        epoch=epoch,
        T=TEMPERATURE,
        alpha=ALPHA
    )

print("\n--- Evaluating Distilled Student Model ---")
student_kd_acc = evaluate(student_model_kd, test_loader)

This is Step B, the core of the experiment.

  1. A new instance of the small StudentNet is created (student_model_kd).
  2. An optimizer is created for this student’s parameters.
  3. It trains the student for EPOCHS_STUDENT (10 epochs) using the special train_distillation function.
  4. This function passes in the trained, frozen teacher model to act as the guide.
  5. It evaluates the “distilled” student and stores its accuracy in student_kd_acc.

Python

# --- Step C: Train Student from Scratch (for comparison) ---
print("\n--- 3. Training Student from Scratch (Baseline) ---")
student_model_scratch = StudentNet().to(device)
optimizer_student_scratch = optim.SGD(student_model_scratch.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)

for epoch in range(1, EPOCHS_STUDENT + 1):
    train_standard(student_model_scratch, train_loader, optimizer_student_scratch, epoch)

print("\n--- Evaluating 'Scratch' Student Model ---")
student_scratch_acc = evaluate(student_model_scratch, test_loader)

This is Step C, the “control” or “baseline” for comparison.

  1. A third model, another new instance of StudentNet, is created (student_model_scratch).
  2. An optimizer is created for its parameters.
  3. It trains this student for the same number of epochs (EPOCHS_STUDENT) as the distilled student.
  4. However, it uses the train_standard function, meaning it only learns from the true labels, with no help from the teacher.
  5. It evaluates this “from scratch” student and stores its accuracy in student_scratch_acc.

Python

# --- 4. Final Comparison ---
print("\n" + "="*30)
print("--- Final Results ---")
print(f"Teacher Model:\t\tParams: {count_parameters(teacher_model):,}\tAccuracy: {teacher_acc:.2f}%")
print(f"Student (Distilled):\tParams: {count_parameters(student_model_kd):,}\tAccuracy: {student_kd_acc:.2f}%")
print(f"Student (Scratch):\tParams: {count_parameters(student_model_scratch):,}\tAccuracy: {student_scratch_acc:.2f}%")
print("="*30)

This final chunk prints the results of the whole experiment.

  • It uses count_parameters to show that the Teacher model is much larger (more parameters) than the two Student models (which are identical in size).
  • It prints the final accuracy for all three models.
  • The expected result is that student_kd_acc (Distilled) will be higher than student_scratch_acc (Scratch). This proves that the student model learned more effectively by mimicking the teacher’s “soft” probabilities than it did by just learning from the “hard” true labels, successfully transferring knowledge from the large model to the small one.
See also  RUST for AI software development

full code:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# --- 1. Settings and Hyperparameters ---

# Training settings
BATCH_SIZE = 64
EPOCHS_TEACHER = 5
EPOCHS_STUDENT = 10
LEARNING_RATE = 0.01
MOMENTUM = 0.9

# Knowledge Distillation (KD) settings
TEMPERATURE = 10  # Temperature for softening probabilities
ALPHA = 0.1       # Weight for the "hard" (true label) loss

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# --- 2. Data Loading (MNIST) ---

# Standard MNIST transforms
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)) # MNIST mean and std
])

# Load datasets
train_dataset = datasets.MNIST(
    root='./data', train=True, download=True, transform=transform
)
test_dataset = datasets.MNIST(
    root='./data', train=False, download=True, transform=transform
)

# Create data loaders
train_loader = DataLoader(
    train_dataset, batch_size=BATCH_SIZE, shuffle=True
)
test_loader = DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=False
)

# --- 3. Model Definitions ---

class TeacherNet(nn.Module):
    """A larger CNN for MNIST (The 'Teacher')"""
    def __init__(self):
        super(TeacherNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Calculate the flattened size after conv and pool layers
        # Input 28x28 -> Conv1 -> 28x28 -> Pool1 -> 14x14
        # -> Conv2 -> 14x14 -> Pool2 -> 7x7
        # Flattened size = 64 * 7 * 7
        self.fc1 = nn.Linear(64 * 7 * 7, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7) # Flatten
        x = F.relu(self.fc1(x))
        x = self.fc2(x) # Output raw logits
        return x

class StudentNet(nn.Module):
    """A much smaller CNN for MNIST (The 'Student')"""
    def __init__(self):
        super(StudentNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
        # Input 28x28 -> Conv1 -> 28x28 -> Pool1 -> 14x14
        # Flattened size = 16 * 14 * 14
        self.fc1 = nn.Linear(16 * 14 * 14, 32)
        self.fc2 = nn.Linear(32, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = x.view(-1, 16 * 14 * 14) # Flatten
        x = F.relu(self.fc1(x))
        x = self.fc2(x) # Output raw logits
        return x

def count_parameters(model):
    """Utility function to count model parameters"""
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# --- 4. Standard Training and Evaluation Functions ---

def train_standard(model, train_loader, optimizer, epoch):
    """Standard training loop for a classifier."""
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        
        # Standard Cross-Entropy Loss
        loss = F.cross_entropy(output, target)
        
        loss.backward()
        optimizer.step()
        
    print(f"Train Epoch: {epoch} \tLoss: {loss.item():.6f}")

def evaluate(model, test_loader):
    """Standard evaluation loop."""
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    print(f'Test set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.2f}%)')
    return accuracy

# --- 5. Knowledge Distillation Loss and Training ---

def distillation_loss(student_logits, labels, teacher_logits, T, alpha):
    """
    Calculates the knowledge distillation loss.
    
    :param student_logits: Raw logits from the student model
    :param labels: True labels (for the hard loss)
    :param teacher_logits: Raw logits from the teacher model
    :param T: Temperature
    :param alpha: Weighting factor
    :return: The combined distillation loss
    """
    
    # 1. Distillation Loss (Soft Loss)
    # Use KLDivLoss, which expects log-probabilities (log_softmax) as input
    # and probabilities (softmax) as target.
    soft_loss = nn.KLDivLoss(reduction='batchmean', log_target=True)(
        F.log_softmax(student_logits / T, dim=1),
        F.softmax(teacher_logits / T, dim=1)
    ) * (T * T) # Scale the loss by T^2 as proposed in the original paper

    # 2. Student Loss (Hard Loss)
    # Standard cross-entropy loss between student logits and true labels
    hard_loss = F.cross_entropy(student_logits, labels)

    # 3. Combine the losses
    combined_loss = alpha * hard_loss + (1 - alpha) * soft_loss
    return combined_loss

def train_distillation(student, teacher, train_loader, optimizer, epoch, T, alpha):
    """Training loop for knowledge distillation."""
    student.train()
    teacher.eval() # Teacher is in eval mode and its weights are frozen
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        
        # Get student's output
        student_logits = student(data)
        
        # Get teacher's output (with no_grad to freeze it)
        with torch.no_grad():
            teacher_logits = teacher(data)
        
        # Calculate the distillation loss
        loss = distillation_loss(student_logits, target, teacher_logits, T, alpha)
        
        loss.backward()
        optimizer.step()
        
    print(f"Train Epoch: {epoch} \tKD Loss: {loss.item():.6f}")

# --- 6. Main Execution ---

# --- Step A: Train the Teacher ---
print("--- 1. Training Teacher Model ---")
teacher_model = TeacherNet().to(device)
optimizer_teacher = optim.SGD(teacher_model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)

for epoch in range(1, EPOCHS_TEACHER + 1):
    train_standard(teacher_model, train_loader, optimizer_teacher, epoch)

print("\n--- Evaluating Teacher Model ---")
teacher_acc = evaluate(teacher_model, test_loader)

# --- Step B: Train Student with Knowledge Distillation ---
print("\n--- 2. Training Student with Distillation ---")
student_model_kd = StudentNet().to(device)
optimizer_student_kd = optim.SGD(student_model_kd.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)

for epoch in range(1, EPOCHS_STUDENT + 1):
    train_distillation(
        student=student_model_kd,
        teacher=teacher_model,
        train_loader=train_loader,
        optimizer=optimizer_student_kd,
        epoch=epoch,
        T=TEMPERATURE,
        alpha=ALPHA
    )

print("\n--- Evaluating Distilled Student Model ---")
student_kd_acc = evaluate(student_model_kd, test_loader)

# --- Step C: Train Student from Scratch (for comparison) ---
print("\n--- 3. Training Student from Scratch (Baseline) ---")
student_model_scratch = StudentNet().to(device)
optimizer_student_scratch = optim.SGD(student_model_scratch.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)

for epoch in range(1, EPOCHS_STUDENT + 1):
    train_standard(student_model_scratch, train_loader, optimizer_student_scratch, epoch)

print("\n--- Evaluating 'Scratch' Student Model ---")
student_scratch_acc = evaluate(student_model_scratch, test_loader)

# --- 4. Final Comparison ---
print("\n" + "="*30)
print("--- Final Results ---")
print(f"Teacher Model:\t\tParams: {count_parameters(teacher_model):,}\tAccuracy: {teacher_acc:.2f}%")
print(f"Student (Distilled):\tParams: {count_parameters(student_model_kd):,}\tAccuracy: {student_kd_acc:.2f}%")
print(f"Student (Scratch):\tParams: {count_parameters(student_model_scratch):,}\tAccuracy: {student_scratch_acc:.2f}%")
print("="*30)

Run in Colab

Leave a Reply

error: Content is protected !!