Expectation Maximization (EM) & implementation

Subscribe to get access

Read more of this content when you subscribe today.

Expectation Maximization (EM) is an iterative algorithm used for finding maximum likelihood estimates of parameters in statistical models, particularly when the model involves latent variables (variables that are not directly observed). The algorithm is commonly used in scenarios where the data is incomplete or has missing values. Here’s a step-by-step explanation of how EM works:

Overview

  1. Initialization: Start with initial guesses for the parameters.
  2. Expectation Step (E-Step): Estimate the missing data given the current parameters.
  3. Maximization Step (M-Step): Update the parameters to maximize the likelihood given the estimated data from the E-Step.
  4. Iteration: Repeat the E-Step and M-Step until convergence (i.e., the parameters no longer change significantly).

Detailed Steps

  1. Initialization: Choose initial values for the parameters, often randomly or based on some heuristics.
  2. Expectation Step (E-Step):
  • Calculate the expected value of the log-likelihood function, with respect to the conditional distribution of the latent variables given the observed data and the current estimates of the parameters.
  • This involves computing the probability distribution over the latent variables.
  1. Maximization Step (M-Step):
  • Find the parameters that maximize the expected log-likelihood found in the E-Step.
  • Update the parameters to these new values.
  1. Iteration:
  • Alternate between the E-Step and the M-Step.
  • Check for convergence, typically by seeing if the change in the log-likelihood or the parameters is below a certain threshold.

Applications

EM is widely used in various fields, including:

  • Gaussian Mixture Models (GMMs): EM is used to estimate the parameters of GMMs, which are used for clustering and density estimation.
  • Hidden Markov Models (HMMs): EM helps in estimating transition probabilities and emission probabilities.
  • Missing Data Imputation: EM can handle datasets with missing values by iteratively estimating the missing values and updating the parameters.
  • Topic Models: In natural language processing, EM is used for Latent Dirichlet Allocation (LDA) to find topics in a collection of documents.

Convergence

The EM algorithm is guaranteed to increase the likelihood at each iteration, but it is not guaranteed to find the global maximum. It may converge to a local maximum depending on the initial parameter values.

Example: Gaussian Mixture Model

Consider a dataset with n observations and we assume that the data is generated from a mixture of k Gaussian distributions. The parameters we want to estimate are the means, covariances, and mixing coefficients of the Gaussians.

  1. Initialization: Initialize the means, covariances, and mixing coefficients.
  2. E-Step: Calculate the responsibility that each Gaussian distribution takes for each observation. This is the probability that a given observation belongs to a particular Gaussian component.
  3. M-Step: Update the means, covariances, and mixing coefficients using the responsibilities calculated in the E-Step.
  4. Iteration: Repeat the E-Step and M-Step until the parameters converge.

Codes: In the following example codes, synthetic data is generated from a mixture of three Gaussian distributions, and the EM algorithm is applied to estimate the parameters of the GMM. The results include the weights, means, and covariances of the Gaussian components. You can play with the codes by changing the parameters

import numpy as np
from scipy.stats import multivariate_normal

def initialize_parameters(X, k):
    n, d = X.shape
    weights = np.ones(k) / k
    means = X[np.random.choice(n, k, False)]
    covariances = np.array([np.eye(d)] * k)
    return weights, means, covariances

def e_step(X, weights, means, covariances):
    n, d = X.shape
    k = len(weights)
    responsibilities = np.zeros((n, k))
    
    for i in range(k):
        responsibilities[:, i] = weights[i] * multivariate_normal.pdf(X, means[i], covariances[i])
    
    responsibilities /= responsibilities.sum(1)[:, np.newaxis]
    return responsibilities

def m_step(X, responsibilities):
    n, d = X.shape
    k = responsibilities.shape[1]
    
    nk = responsibilities.sum(axis=0)
    weights = nk / n
    means = np.dot(responsibilities.T, X) / nk[:, np.newaxis]
    
    covariances = np.zeros((k, d, d))
    for i in range(k):
        diff = X - means[i]
        covariances[i] = np.dot(responsibilities[:, i] * diff.T, diff) / nk[i]
    
    return weights, means, covariances

def log_likelihood(X, weights, means, covariances):
    n, d = X.shape
    k = len(weights)
    log_likelihood = 0
    
    for i in range(k):
        log_likelihood += weights[i] * multivariate_normal.pdf(X, means[i], covariances[i])
    
    return np.sum(np.log(log_likelihood))

def em_algorithm(X, k, max_iter=100, tol=1e-4):
    weights, means, covariances = initialize_parameters(X, k)
    log_likelihoods = []
    
    for iteration in range(max_iter):
        responsibilities = e_step(X, weights, means, covariances)
        weights, means, covariances = m_step(X, responsibilities)
        
        ll = log_likelihood(X, weights, means, covariances)
        log_likelihoods.append(ll)
        
        if iteration > 0 and np.abs(log_likelihoods[-1] - log_likelihoods[-2]) < tol:
            break
    
    return weights, means, covariances, log_likelihoods

# Example usage
np.random.seed(42)
X = np.vstack([np.random.multivariate_normal(mean, np.eye(2), 100) for mean in [[0, 0], [5, 5], [0, 5]]])

k = 3
weights, means, covariances, log_likelihoods = em_algorithm(X, k)

print("Weights:\n", weights)
print("Means:\n", means)
print("Covariances:\n", covariances)

The output is

Weights:
 [0.33897411 0.32778969 0.3332362 ]
Means:
 [[-0.05306686  4.80730254]
 [-0.11766118 -0.00522756]
 [ 5.13881385  5.06920179]]
Covariances:
 [[[ 0.97884739  0.01651067]
  [ 0.01651067  0.99516519]]

 [[ 0.73668391  0.02409119]
  [ 0.02409119  0.9102382 ]]

 [[ 1.01449867 -0.15134066]
  [-0.15134066  0.86954933]]]


Discover more from Science Comics

Subscribe to get the latest posts sent to your email.

Leave a Reply

error: Content is protected !!