Overfitting, Underfitting, Early Stopping, Restore Best Weights & Codes in PyTorch

Overfitting, Underfitting

Overfitting occurs when a neural network learns the training data too well, including its noise and specific details, resulting in poor generalization to new, unseen data. Overfitting does not imply that the weights are overweight. However, using a huge number of weights can make the model fit the training data too closely or learn the data by heart, leading to poor generalization of new data.

Note that, however, the neural network is a type of overparameterized model. So, it’s easy to overfit the data if not trained properly. Early stopping is a common strategy to mitigate overfitting and promote better generalization.

Early Stopping

If you’re cooking pasta, then you want it to be perfect, rather than undercooked or overcooked. This is like training a deep learning model where you aim to achieve optimal performance without overfitting or underfitting.

If you stop cooking too early, the pasta is still hard and not pleasant to eat. This is similar to underfitting in deep learning, where if you stop training too soon, the model hasn’t learned enough from the data and performs poorly on both training and validation sets.

If you keep cooking the pasta for too long, it becomes mushy and loses its texture. In deep learning, this is akin to overfitting, where the model continues to train and starts to memorize the training data rather than generalize from it. This leads to poor performance on new, unseen data.

To get the pasta just right, you need to periodically taste it to check if it’s done. This is like monitoring the model’s performance on a validation set during training. You check periodically to see if the performance is improving.

When the pasta reaches the perfect texture, you stop cooking. In deep learning, early stopping is like this: you monitor the validation performance, and when it stops improving (or starts to get worse), you stop training. This helps ensure that the model has learned just enough to generalize well without overfitting.

While the name of the term is early ‘stopping,’ it just means that even though the loss may still get lower and lower in the training set, by doing validation (tasting the pasta), we will stop it early to avoid overcooking. So, the term, in fact, involves proper timing based on ‘tasting’ your validation set.

How Early Stopping Works

  1. Monitor Performance: Track the performance metric (e.g., validation loss or accuracy) on the validation dataset after each epoch.
  2. Patience: Define a patience parameter, which is the number of epochs to wait after the last improvement in the performance metric before stopping the training.
  3. Stop Training: If the performance metric does not improve for a specified number of epochs (patience), stop the training process.

Restore Best Weights

Imagine you are writing a term paper for a class. Your goal is to submit the best possible version of your paper before the deadline (which represents the end of training for your deep learning model). Then, drafting and revising the paper is like training and evaluating your model.

As you write your term paper, you go through multiple drafts, revising and improving it each time. This is like training your model over several epochs, continually refining its performance. At some point, you have the best draft. This is akin to the epoch where your model performs best on the validation set, yielding the best weights.

Then, you continue revising your paper, but some of these revisions might make the paper worse rather than better. Similarly, continued training might lead to overfitting, where the model’s performance on the validation set deteriorates.

As the deadline approaches, you need to decide which version of your paper to submit. You realize that one of your previous drafts was the best. In deep learning, this is where you would restore the best weights.

You submit the version of the paper that received the best feedback. In deep learning, restoring the best weights means reverting to the epoch where the validation performance was optimal, ensuring the best possible model for new, unseen data.

Steps to Restore Best Weights:

  1. Whenever the validation loss decreases, save the current model weights.
  2. Keep track of the best validation loss encountered during training.
  3. Once early stopping is triggered, reload the best model weights.

Implementation Example of Early Stopping and Restore Best Weights

Here’s an example of implementing early stopping using PyTorch:

First, import necessary modules from PyTorch.

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import copy

Next, create some toy data. Here, x is a tensor of 100 samples, where each sample is a 10-dimensional feature vector, and y is a tensor of 100 target values.

# Toy data
x = torch.randn(100, 10)  # 100 samples, 10 features each
y = torch.randn(100, 1)  # 100 targets

Now, split the data into training and validation sets. In this case, the first 80 samples are used for training and the last 20 samples are used for validation.

# Split into training and validation sets
x_train, x_val = x[:80], x[80:]
y_train, y_val = y[:80], y[80:]

Wrap the training and validation data in Datasets and DataLoaders for minibatch training. The DataLoaders will allow us to iterate over the datasets in minibatches of size 20.

# Wrap in a Dataset and DataLoader for minibatch training
train_dataset = TensorDataset(x_train, y_train)
val_dataset = TensorDataset(x_val, y_val)
train_loader = DataLoader(train_dataset, batch_size=20)  # Minibatch size of 20
val_loader = DataLoader(val_dataset, batch_size=20)  # Minibatch size of 20

Define a simple model that includes a batch normalization layer. The nn.BatchNorm1d(5) line adds a batch normalization layer, which normalizes the output from the preceding layer.

# Define a simple model with batch normalization
model = nn.Sequential(
    nn.Linear(10, 5),
    nn.BatchNorm1d(5),  # Batch normalization
    nn.ReLU(),
    nn.Linear(5, 1)
)

Define the loss function as mean squared error, which is a common choice for regression tasks, and an optimizer, in this case, Stochastic Gradient Descent (SGD).

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

Set up variables for early stopping. In this case, if the validation loss doesn’t improve for 10 epochs, training will stop.

# Variables for early stopping
n_epochs_stop = 10
min_val_loss = float('inf')
epochs_no_improve = 0

Create a variable to store the best model parameters. This will be updated whenever the validation loss improves.

# Variable to store the best model parameters
best_model_params = copy.deepcopy(model.state_dict())

Finally, define the training loop. For each epoch (a full pass over the dataset), iterate over the training DataLoader, which yields minibatches of inputs and targets. For each minibatch, perform the forward pass, compute the loss, perform backpropagation to compute gradients, and then update the model’s parameters. After each epoch, compute the validation loss by iterating over the validation DataLoader. If the validation loss improves, save the model’s state dict. If the validation loss doesn’t improve for a number of epochs equal to n_epochs_stop, stop training.

# Training loop with early stopping
for epoch in range(100):
    # Training
    model.train()
    for batch_x, batch_y in train_loader:
        optimizer.zero_grad()
        outputs = model(batch_x)
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()

    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch_x, batch_y in val_loader:
            outputs = model(batch_x)
            val_loss += criterion(outputs, batch_y).item()

    # Check for early stopping
    if val_loss < min_val_loss:
        min_val_loss = val_loss
        epochs_no_improve = 0
        best_model_params = copy.deepcopy(model.state_dict())
    else:
        epochs_no_improve += 1
        if epochs_no_improve == n_epochs_stop:
            print('Early stopping!')
            break

    print(f'Epoch {epoch+1}, Loss: {loss.item()}, Validation Loss: {val_loss}')

Finally, load the best model parameters back into the model. This ensures that even if the model’s performance worsens during the last few epochs due to overfitting, we still have the parameters that produced the best performance on the validation set.

model.load_state_dict

See another example of Early Stopping, Restore Best Weights using check point on the MNIST dataset using PyTorch


Discover more from Science Comics

Subscribe to get the latest posts sent to your email.

Leave a Reply

error: Content is protected !!