These notes provide a theoretical treatment of Expectation Maximization, an iterative parameter estimation algorithm used to find local maxima of the likelihood function in the presence of hidden variables. Introductory textbooks [murphy:mlapp, bishop:prml] typically state the algorithm without explanation and expect students to work blindly through derivations. We find this approach to be unsatisfying, and instead choose to tackle the theory head-on, followed by plenty of examples. Following [neal1998:em], we view expectation-maximization as coordinate ascent on the Evidence Lower Bound. This perspective takes much of the mystery out of the algorithm and allows us to easily derive variants like Hard EM and Variational EM.
Problem Setting
Suppose we observe data X generated from a model p with true parameters θ∗ in the presence of hidden variables Z. As usual, we wish to compute the maximum likelihood estimate
θ^ML=argθmaxℓ(θ∣X)=argθmaxlogp(X∣θ)
of the parameters given our observed data. In some cases, we also seek to infer the values Z of the hidden variables Z. In the Bayesian spirit, we will treat the parameter θ∗ as a realization of some random variable Θ.
The observed data log-likelihood ℓ(θ∣X)=logp(X∣θ) of the parameters given the observed data is useful for both inference and parameter estimation, in which we must grapple with uncertainty about the hidden variables. Working directly with this quantity is often difficult in latent variable models because the inner sum cannot be brought out of the logarithm when we marginalize over the latent variables:
ℓ(θ∣X)=logp(X∣θ)=logz∑p(X,z∣θ)
In general, this likelihood is non-convex with many local maxima. In contrast, [murphy:mlapp] shows that when p(xn,zn∣θ) are exponential family distributions, the likelihood is convex, so learning is much easier. Expectation maximization exploits the fact that learning is easy when we observe all variables. We will alternate between inferring the values of the latent variables and re-estimating the parameters, assuming we have complete data.
Evidence Lower Bound
Our general approach will be to reason about the hidden variables through a proxy distribution q, which we use to compute a lower-bound on the log-likelihood. This section is devoted to deriving one such bound, called the Evidence Lower Bound (ELBO).
We can expand the data log-likelihood by marginalizing over the hidden variables:
ℓ(θ∣X)=logp(X∣θ)=logz∑p(X,z∣θ)
Through Jensen's inequality, we obtain the following bound, valid for any q:
The first term in the last line above closely resembles the cross entropy between q(Z) and the joint distribution p(X,Z) of the observed and hidden variables. However, the variables X are fixed to our observations X=X and so p(X,Z) is an unnormalized [ref]In this case, ∫p(X,z)dz=1.[/ref] distribution over Z. It is easy to see that this does not set us back too far; in fact, the lower bound L(q,θ) differs from a Kullback-Liebler divergence only by a constant with respect to Z:
This yields a second proof of the evidence lower bound, following from the nonnegativity of relative entropy. In fact, this is the proof given in [tzikas2008:variational] and [murphy:mlapp].
logp(X∣θ)=DKL(q∣∣p(Z∣X,θ))+L(q,θ)≥L(q,θ)
Selecting a Proxy Distribution
The quality of our lower bound L(q,θ) depends heavily on the choice of proxy distribution q(Z). We now show that the evidence lower bound is tight in the sense that equality holds when the proxy distribution q(Z) is chosen to be the hidden posterior p(Z∣X,θ). This will be useful later for proving that the Expectation Maximization algorithm converges.
Maximizing L(q,θ) with respect to q is equivalent to minimizing the relative entropy between q and the hidden posterior p(Z∣X,θ). Hence, the optimal choice for q is exactly the hidden posterior, for which DKL(q∣∣p(Z∣X,θ))=0, and
logp(X∣θ)=Eq[logp(X,Z∣θ)]+H(q)=L(q,θ)
In cases where the hidden posterior is intractable to compute, we choo
Expectation Maximization
Recall that the maximum likelihood estimate of the parameters θ given observed data X in the presence of hidden variables Z is
θ^ML=argθmaxℓ(θ∣X)=argθmaxlogp(X∣θ)
Unfortunately, when reasoning about hidden variables, finding a global maximum is difficult. Instead, the Expectation Maximization algorithm is an iterative procedure for computing a local maximum of the likelihood function, under the assumption that the hidden posterior p(Z∣X,θ) is tractable. We will take advantage of the evidence lower bound
ℓ(θ∣X)≥L(q,θ)
on the data likelihood. Consider only proxy distributions of the form qϑ(Z)=p(Z∣X,ϑ), where ϑ is some fixed configuration of the variables Θ, possibly different from our estimate θ. The optimal value for ϑ, in the sense that L(qϑ,θ) is maximum, depends on the particular choice of θ. Similarly, the optimal value for θ depends on the choice of ϑ. This suggests an iterative scheme in which we alternate between maximizing with respect to ϑ and with respect to θ, gradually improving the log-likelihood.
Iterative Procedure
Suppose at time t we have an estimate θt of the parameters. To improve our estimate, we perform two steps of coordinate ascent on L(ϑ,θ)≡L(qϑ,θ), as described in [neal1998:em],
E-Step
Compute a new lower bound on the observed log-likelihood, with
ϑt+1=argϑmaxL(ϑ,θt)=θt
M-Step
Estimate new parameters by optimizing over the lower bound,
θt+1=argθmaxL(ϑt+1,θ)=argθmaxEq[logp(X,Z∣θ)]
In the M-Step, the expectation is taken with respect to qϑt+1.
Alternate Formulation
In the M-Step, the entropy term of the evidence lower bound L(ϑt+1,θ) does not depend on θ. The remaining term Q(θt,θ)=Eq[logp(X,Z∣θ)] is sometimes called the auxiliary function or Q-function. To us, this is the expected complete-data log-likelihood.
Proof of Convergence
To prove convergence of this algorithm, we show that the data likelihood ℓ(θ∣X) increases after each update.
After a single iteration of Expectation Maximization, the observed data likelihood of the estimated parameters has not decreased, that is,
ℓ(θt∣X)≤ℓ(θt+1∣X)
This result is a simple consequence of all the hard work we have put in so far:
It is also possible to show that Expectation-Maximization converges to something useful.
(Neal & Hinton 1998, Thm. 2) Every local maximum of the evidence lower bound L(q,θ) is a local maximum of the data likelihood ℓ(θ∣X).
Starting from an initial guess θ0, We run this procedure until some stopping criterion is met and obtain a sequence {(ϑt,θt)}t=1T of parameter estimates.
Example: Coin Flips
Now that we have a good grasp on the theory behind Expectation Maximization, let's get some intuition by means of a simple example. As usual, the simplest possible example involves coin flips!
Probabilistic Model
Suppose we have two coins, each with a different probability of heads, θA and θB, unknown to us. We collect data from a series of N trials in order to estimate the bias of each coin. Each trial k consists of flipping the same random coin Zk a total of M times and recording only the total number Xk of heads.
This situation is best described by the following generative probabilistic model, which precisely describes our assumptions about how the data was generated. The corresponding graphical model and a set of sample data are shown in Figure .\
θZnXn∣Zn,θ=(θA,θB)∼Uniform{A,B}∼Bin[θZn,M]∀n=1,…,N∀n=1,…,Nfixed coin biasescoin indicatorshead count
Complete Data Log-Likelihood
The complete data log-likelihood for a single trial (xn,zn) is
logp(xn,zn∣θ)=logp(zn)+logp(xn∣zn,θ)
In this model, P(zn)=21 is uniform. The remaining term is
logp(xn∣zn,θ)=log(xnM)θznxn(1−θzn)M−xn=log(xnM)+xnlogθzn+(M−xn)log(1−θzn)
Expectation Maximization
Now that we have specified the probabilistic model and worked out all relevant probabilities, we are ready to derive an Expectation Maximization algorithm.
The E-Step is straightforward. The M-Step computes a new parameter estimate θt+1 by optimizing over the lower bound found in the E-Step. Let ϑ=ϑt+1=θt. Then,
θt+1=argθmaxL(θ,qϑ)=argθmaxEq[logp(X,Z∣θ)]=argθmaxEq[logp(X∣Z,θ)p(Z)]=argθmaxEq[logp(X∣Z,θ)]+logp(Z)=argθmaxEq[logp(X∣Z,θ)]
Now, because each trial is conditionally independent of the others, given the parameters,
Eq[logp(X∣Z,θ)]=Eq[logn=1∏Np(xn∣Zn,θ)]=n=1∑NEq[logp(xn∣Zn,θ)]=n=1∑NEq[xnlogθzn+(M−xn)log(1−θzn)]+n=1∑Nlog(xnM)=n=1∑NEq[xnlogθzn+(M−xn)log(1−θzn)]+const. w.r.t. θ=n=1∑Nqϑ(zn=A)[xnlogθA+(M−xn)logθA]+n=1∑Nqϑ(zn=B)[xnlogθB+(M−xn)logθB]+const. w.r.t. θ
Let ak=q(zk=A) and bk=q(zk=B). Note ∑k=1Nak=∑k=1Nbk=1. To maximize the above expression with respect to the parameters, we take derivatives with respect to θA and θB and set to zero:
∂θA∂[Eq[logp(X∣Z,θ)]]∂θB∂[Eq[logp(X∣Z,θ)]]=θA1n=1∑Nanxn+1−θA1n=1∑Nan(M−xn)=0=θB1n=1∑Nbnxn+1−θB1n=1∑Nbn(M−xn)=0
Solving for θA and θB, we obtain
θA=∑n=1NanM∑n=1NanxnθB=∑n=1NbnM∑n=1Nbnxn
Example: Gaussian Mixture Model
Probabilistic Model
In a Gaussian Mixture Model, samples are drawn from a random cluster, each normally distributed with its own mean and variance. Our goal will be to estimate the following parameters:
πμΣ=(π1,…,πK)=(μ1,…,μK)=(Σ1,…,ΣK)mixing weightscluster centerscluster variance
The full model specification is below. A graphical model is shown in Figure .
θznxn∣zn,θ=(π,μ,Σ)∼Cat[π]∼N(μzn,Σzn)model parameterscluster indicatorsbase distribution
Complete Data Log-Likelihood
The complete data log-likelihood for a single datapoint (xn,zn) is
logp(xn,zn∣θ)=logk=1∏KπkN(xn∣μk,Σk)I(zn=k)=k=1∑KI(zn=k)logπkN(xn∣μk,Σk)
Similarly, the complete data log-likelihood over all points {(xn,zn)}n=1N is
logp(X,Z∣θ)=n=1∑Nlogp(xn,zn∣θ)=n=1∑Nk=1∑KI(zn=k)logπkN(xn∣μk,Σk)
Hidden Posterior
The hidden posterior for a single point (xn,zn) can be found using Bayes' rule:
p(zn=k∣xn,θ)=p(xN∣θ)P(zn=k∣θ)p(xn∣zn=k,θ)=∑k′=1Kπk′N(xn∣μk′,Σk′)πkN(xn∣μk,Σk)
Expectation Maximization
Our derivation will follow that of [murphy:mlapp], adapted to our notation.
E-Step
Before the E-step, we have an estimate θt of the parameters, and seek to compute a new lower bound on the observed log-likelihood. Earlier, we showed that the optimal lower bound is
L(qθt,θ)=Eq[logp(X,Z∣θ)]+const.
where qθt(z)≡p(z∣X,θt) and the second term is constant with respect to θ. The E-Step requires us to derive an expression for the first term. Using , the expected complete data log-likelihood is given by
Q(θt,θ)=Eq[logp(X,Z∣θ)]=n=1∑Nk=1∑KEq[I(zn=k)logπkN(xn∣μk,Σk)]=n=1∑Nk=1∑KEq[I(zn=k)]logπkN(xn∣μk,Σk)=n=1∑Nk=1∑Kp(zn=k∣xn,θt)logπkN(xn∣μk,Σk)=n=1∑Nk=1∑Krnklogπk+n=1∑Nk=1∑KrnklogN(xn∣μk,Σk)
where rnk≡p(zn=k∣xn,θt) is the responsibility that cluster k takes for data point xn after step t. During the E-Step, we compute these values explicitly with .
M-Step
During the M-Step, we optimize our lower bound with respect to the parameters θ=(π,μ,Σ). For the mixing weights π, we use Lagrange multipliers to maximize the ELBO subject to the constraint ∑k=1Kπk=1. The Lagrangian is
Λ(π,λ)=Q(θt,θ)+λ(k=1∑Kπk−1)
Carrying out the optimization, we find that λ=−N. The correct update for the mixing weights is
πk=N1n=1∑Nrnk=Nrk
where rk≡∑n=1nrnk is the effective number of points assigned to cluster k. For the cluster centers μ and variance Σ, you should verify that the correct updates are
μk=rk∑n=1NrnkxnΣk=rk∑n=1NrnkxnxnT−μkμkT
Advice for Deriving EM Algorithms
The previous two examples suggest a general approach for deriving a new algorithm.
Specify the probabilistic model. Identify the observed variables, hidden variables, and parameters. Draw the corresponding graphical model to help determine the underlying independence structure.
Identify the complete-data likelihood P(X,Z∣θ). For exponential family models, the complete-data likelihood will be convex and easy to optimize. In other models, other work may be required.
Identify the hidden posterior P(Z∣X,θ). If this distribution is not tractable, you may want to consider variational inference, which we will discuss later.
Derive the E-Step. Write down an expression for Eq[logp(X∣Z,θ)].
Derive the M-Step. Try taking derivatives and setting to zero. If this doesn't work, you may need to resort to gradient-based methods or variational inference.