Gradient Descent Algorithm & Codes in PyTorch

Gradient Descent is an optimization algorithm that iteratively adjusts the model’s parameters (weights and biases) to find the values that minimize the loss function.
The intuition behind gradient descent is learning how to move from your current position towards the lowest point in a loss landscape by taking iterative steps guided by the slope of the terrain.

Recall that a derivative measures the rate of change of a function with respect to its variable. For example, the first-order derivative of a function ( f(x) ) in one dimension is:
f'(x) = \frac{df(x)}{dx} = \lim_{\Delta x \to 0} \frac{f(x + \Delta x) - f(x)}{\Delta x}

So, the derivative of a function at a given point is the average rate of change in position over a tiny amount of time, which is the velocity of the function at that point (including the direction of the movement). It is also the slope of the tangent line to the function at that point.

The gradient of the loss function with respect to the model’s parameters indicates the direction and rate of change of the loss function. It points in the direction of the steepest ascent, so to minimize the loss, we move in the opposite direction (update the weights using the negative gradient).

Repeat the process of checking the slope and taking steps until you reach the bottom or a position where the slope is very gentle (convergence).

Concretely, given a loss function ( J(\theta) ) where ( \theta ) represents the parameters, the update rule for gradient descent is:
\theta := \theta - \eta \nabla J(\theta)
Where ( \eta ) is the learning rate; ( \nabla J(\theta) ) is the gradient of the loss function with respect to the parameters.

Formally, we have:

Gradient Descent Algorithm

  1. Initialize: Start with random values for the model parameters.
  2. Compute Gradient: Calculate the gradient of the loss function with respect to each parameter.
  3. Update Parameters: Adjust the parameters in the opposite direction of the gradient by a factor proportional to the learning rate.
  4. Repeat: Repeat the process until the loss function converges (i.e., changes very little between iterations) or a set number of iterations is reached.

PyTorch codes

First, import necessary modules from PyTorch.

import torch
import torch.optim as optim

Set up some simple toy data, where y is twice x:

# Toy data
x = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])
y = torch.tensor([2.0, 4.0, 6.0, 8.0, 10.0])  # y = 2x

Initialize the weight w to 1.0. The requires_grad=True argument means that PyTorch will track gradients for this tensor during optimization.

# Randomly initialize weights
w = torch.tensor([1.0], requires_grad=True)

Define a simple linear model, which just multiplies the input x by the weight w.

# Define model
def model(x):
    return x * w

Define the loss function as mean squared error, which is a common loss function for regression tasks.

# Define loss function
def loss_fn(y_hat, y):
    return ((y_hat - y)**2).mean()

Set up a stochastic gradient descent (SGD) optimizer, which will adjust w to minimize the loss:

# Define optimizer
optimizer = optim.SGD([w], lr=0.01)

Define the training loop. For each epoch (iteration over the dataset), it computes the model’s predictions y_hat, calculates the loss, and then uses loss.backward() to compute the gradients of the loss with respect to the model’s parameters (in this case, just w). The optimizer then adjusts w to minimize the loss based on these gradients. After the weights are updated, the gradients are cleared to prepare for the next iteration. Every 10 epochs, it prints the current epoch, the current value of w, and the current loss.

# Training loop
for epoch in range(100):
    y_hat = model(x)
    loss = loss_fn(y_hat, y)

    loss.backward()  # Compute gradients
    optimizer.step()  # Update weights
    optimizer.zero_grad()  # Clear gradients

    if epoch % 10 == 0:
        print(f'Epoch {epoch+1}: w = {w.item():.3f}, loss = {loss.item():.3f}')

Finally, print the model’s prediction for x = 5 after training. If the model has learned the true relationship y = 2x, this should be close to 10.

print(f'Prediction after training: f(5) = {model(5).item():.3f}')

Download the codes on Github


Discover more from Science Comics

Subscribe to get the latest posts sent to your email.

Leave a Reply

error: Content is protected !!