Hyperparameter tuning by train-validation-test split – process & example

An example:

Code (R + Python)

Python codes

To manually perform train-validation-test, one can follow these steps:

  1. Split the dataset into training, validation, and test sets.
  2. Train the Lasso model on the training set using different alpha values. (the lambda in Lasso is denoted as alpha in sklearn)
  3. Evaluate the model on the validation set to find the best alpha.
  4. Retrain the model on the combined training and validation sets using the best alpha.
  5. Finally, evaluate the model on the test set.

Here is the Python code to implement this:

# Import necessary libraries
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import Lasso
from sklearn.metrics import mean_squared_error

# Generate some sample data
np.random.seed(42)
X = np.random.randn(100, 5)
y = X[:, 0] * 3 + np.random.randn(100)

# Step 1: Split data into training (60%), validation (20%), and test (20%) sets
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.4, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)

# Step 2: Define a range of alpha (lambda) values for Lasso
alpha_values = np.logspace(-4, 1, 50)  # A range of 50 alpha values between 10^-4 and 10^1
validation_errors = []  # To store validation errors for each alpha

# Step 3: Train Lasso on training set and evaluate on validation set for each alpha
for alpha in alpha_values:
    lasso = Lasso(alpha=alpha)
    lasso.fit(X_train, y_train)  # Train on the training set
    y_val_pred = lasso.predict(X_val)  # Predict on validation set
    val_mse = mean_squared_error(y_val, y_val_pred)  # Calculate MSE on validation set
    validation_errors.append(val_mse)

# Find the best alpha that minimizes the validation MSE
best_alpha = alpha_values[np.argmin(validation_errors)]
print(f"Best alpha (lambda) found: {best_alpha}")

# Step 4: Retrain the model using the best alpha on the combined train and validation sets
X_train_val = np.vstack((X_train, X_val))
y_train_val = np.hstack((y_train, y_val))

best_lasso_model = Lasso(alpha=best_alpha)
best_lasso_model.fit(X_train_val, y_train_val)

# Step 5: Evaluate the model on the test set
y_test_pred = best_lasso_model.predict(X_test)
test_mse = mean_squared_error(y_test, y_test_pred)

print(f"Test Mean Squared Error: {test_mse}")

Explanation:

Data Splitting: We first split the dataset into 60% training, 20% validation, and 20% test sets using train_test_split.

Alpha Search:

  • A loop is used to train the Lasso model with different values of alpha (regularization parameter) on the training set.
  • For each alpha, we evaluate the model’s performance on the validation set and store the validation MSE.

Finding Best Alpha: The best alpha is the one that minimizes the validation MSE.

Model Retraining: Once the best alpha is identified, we retrain the Lasso model on the combined training and validation sets using this best alpha.

Evaluation on Test Set: Finally, the trained model is evaluated on the test set, and the test MSE is reported.

R codes

To perform Lasso regression using train-validation-test split to find the optimal value of the regularization parameter lambda, we follow these steps:

  1. Split the data into train, validation, and test sets.
  2. Train Lasso models with different lambda values on the training set.
  3. Use the validation set to find the best lambda.
  4. Retrain the model with the best lambda on the combined training and validation sets.
  5. Evaluate the final model on the test set.

We’ll use the glmnet package, which is popular for regularized regression.

R Code:

# Load necessary libraries
library(glmnet)
library(caret)

# Generate sample data
set.seed(42)
X <- matrix(rnorm(100 * 5), 100, 5)
y <- 3 * X[, 1] + rnorm(100)

# Step 1: Split the data into train (60%), validation (20%), and test (20%) sets
set.seed(42)
train_index <- createDataPartition(y, p = 0.6, list = FALSE)
X_train <- X[train_index, ]
y_train <- y[train_index]
X_temp <- X[-train_index, ]
y_temp <- y[-train_index]

# Split the remaining 40% into validation (20%) and test (20%)
set.seed(42)
val_index <- createDataPartition(y_temp, p = 0.5, list = FALSE)
X_val <- X_temp[val_index, ]
y_val <- y_temp[val_index]
X_test <- X_temp[-val_index, ]
y_test <- y_temp[-val_index]

# Step 2: Train Lasso models with different lambda values on training set
lambda_values <- 10^seq(-4, 1, length = 50)  # A range of lambda values
lasso_model <- glmnet(X_train, y_train, alpha = 1, lambda = lambda_values)

# Step 3: Find the best lambda by evaluating on validation set
val_predictions <- predict(lasso_model, X_val)
val_errors <- apply(val_predictions, 2, function(pred) mean((y_val - pred)^2))

# Get the best lambda with the minimum validation error
best_lambda <- lambda_values[which.min(val_errors)]
cat("Best lambda found:", best_lambda, "\n")

# Step 4: Retrain the Lasso model using the best lambda on combined train and validation sets
X_train_val <- rbind(X_train, X_val)
y_train_val <- c(y_train, y_val)
final_model <- glmnet(X_train_val, y_train_val, alpha = 1, lambda = best_lambda)

# Step 5: Evaluate the model on the test set
y_test_pred <- predict(final_model, X_test)
test_mse <- mean((y_test - y_test_pred)^2)

cat("Test Mean Squared Error:", test_mse, "\n")

Explanation:

  1. Data Splitting: The createDataPartition function from caret is used to split the data into training, validation, and test sets.
  2. Lasso Training: The glmnet function trains the Lasso model with different lambda values on the training set.
  3. Validation: We use the predict function to make predictions on the validation set for each lambda value and calculate the validation MSE.
  4. Finding Best Lambda: The best lambda is selected as the one with the smallest MSE on the validation set.
  5. Retraining and Testing: Finally, we retrain the model on the combined training and validation sets with the best lambda and evaluate its performance on the test set.

Discover more from Science Comics

Subscribe to get the latest posts sent to your email.

Leave a Reply

error: Content is protected !!