您當前的位置:首頁 > 攝影

EM演算法之KL散度和Jensen不等式

作者:由 Cyber 發表于 攝影時間:2020-10-09

EM演算法的關鍵在於, 將難以直接求解的對數機率

lnP(X|\theta)

轉換為容易求解的對數機率

lnP(X,Z|\theta)

先來看對數機率的分解表示式:

EM演算法之KL散度和Jensen不等式

圖片來源於PRML

式(9。70)的推導過程如下:

\begin{aligned} lnP(X|\theta)&=\sum_z{q(z)lnP(X|\theta)}\\ &=\sum_z{q(z)[lnP(X,z|\theta)-lnP(z|X,\theta)]}......(1)\\ &=\sum_z{q(z)[lnP(X,z|\theta)-lnP(z|X,\theta)-q(z)]+q(z)}\\ &=\sum_z{q(z)ln\frac{P(X,z|\theta}{q(z)}}-\sum_z{q(z)ln\frac{P(z|X,\theta)}{q(z)}}\\ &=\mathcal{L}(q,\theta)+KL(q||p) \end{aligned}

上述推導中

q(z)

中的

z

使用的是小寫字母, 表示取具體的某一值。

q

p

分別表示隱變數的先驗機率和後驗機率, 在EM演算法的迭代過程中,

q

p

的完整表示為

q(Z|X, \theta^{old})

p(Z|X,\theta^{new})

,即先驗機率是引數更新前求得的機率, 後驗機率為引數更新後求得的機率。

如文章一開始所說, 之所以要進行這些繁瑣的變換的原因在於(1)式, 將難以求解的邊緣機率分佈轉換為容易求解的聯合機率分佈。聯合機率分佈之所以容易求解, 是因為聯合機率分佈中給出了隱變數的值, 一旦已知了隱變數, 觀測資料的機率就變得很容易求解(如混合高斯模型)。

式(9。71)中的

\mathcal{L}(q,\theta)

lnP(X|\theta)

的下界, 因為

KL(q||p)\geq0

EM演算法之KL散度和Jensen不等式

E步:

更新隱變數的分佈, 使用後驗機率分佈來替代先驗機率分佈:

q(Z) = P(Z|X,\theta^{i})

,其中

\theta^{i}

是上一次迭代中M步求出的

\theta^{new}

,如果是第一次迭代, 則為給定的初始值。

\begin{aligned} &q(Z) = P(Z|X,\theta^{i})\\ \Rightarrow &KL(q||p)=0\\ \Rightarrow &lnP(X|\theta)=\mathcal{L}(q,\theta) \end{aligned}

E步消除了隱變數的先驗機率和後驗機率之間的KL散度, 使得下界和原函式相等。

EM演算法之KL散度和Jensen不等式

M步:

固定

q

,透過更新引數

\theta

來增大下界

\mathcal{L}(q,\theta)

。由於在E步中下界與原始函式相等, 因此這裡增大原始函式的下界必然會增大原始函式

lnP(X|\theta)

EM演算法之KL散度和Jensen不等式

EM演算法之KL散度和Jensen不等式

圖1 M-step

M步一方面使得函式下界增大

\Delta2

, 另一方面由於引數改變, 隱變又有了新的後驗機率

p(Z|X, \theta^{new})

,因此

q\ne p\Rightarrow \Delta1 =KL(q||p)>0

M步透過求解使得下界取最大的引數

\theta

,使得下界增大

\Delta2

。同時,M步中引數

\theta

的更新還引起了值為

\Delta1

的KL散度,而這個大小為

\Delta1

的散度將會在下一步的E步中重新歸0。

也就是說, 引數更新引起的對數機率增大包含兩部分:

\begin{aligned} &lnP(X|\theta^{new})-lnP(X|\theta^{old})\\ &=\Delta1+\Delta2\\ &=KL(q||p)+\mathcal{L}(q,\theta^{new})-\mathcal{L}(q,\theta^{old})\\ &=KL(p(Z|X,\theta^{old})||p(Z|X,\theta^{new})+\mathcal{L}(q,\theta^{new})-\mathcal{L}(q,\theta^{old}) \end{aligned}

其中

\Delta1

是KL散度,KL散度又是如何與Jensen不等式聯絡起來的呢?

Jensen不等式

\begin{aligned} &lnP(X|\theta^{new})-lnP(X|\theta^{old})......(1)\\ &=ln(\sum_z{P(X|Z,\theta^{new})P(Z|\theta^{new}})-lnP(X|\theta^{old})\\ &=ln(\sum_z{P(Z|X,\theta^{old})\frac{P(X|Z,\theta^{new})P(Z|\theta^{new})}{P(Z|X,\theta^{old})}}-lnP(X|\theta^{old})\\ &\geq \sum_z{P(Z|X,\theta^{old})ln\frac{P(X|Z,\theta^{new})P(Z|\theta^{new})}{P(Z|X,\theta^{old})}}-lnP(X|\theta^{old})\\ &=\sum_z{P(Z|X,\theta^{old})ln\frac{P(X,Z|\theta^{new})}{P(Z|X,\theta^{old})}}-lnP(X|\theta^{old})\\ &=\sum_z{P(Z|X,\theta^{old})ln\frac{P(Z|X,\theta^{new})P(X|\theta^{new})}{P(Z|X,\theta^{old})}}-lnP(X|\theta^{old})\\ &=\sum_z{P(Z|X,\theta^{old})ln\frac{P(Z|X,\theta^{new})}{P(Z|X,\theta^{old})}}+\sum_z{P(Z|X,\theta^{old})lnP(X|\theta^{new})}-lnP(X|\theta^{old})\\ &=-KL(P(Z|X,\theta^{old})||P(Z|X,\theta^{new})+lnP(X|\theta^{new})-lnP(X|\theta^{old})......(2)\\  \end{aligned}

由於KL散度非負,很容易看出上式右邊(2)式小於左邊(1)式。當KL散度等於0時,等式兩端相等, 容易看出,推導過程中的Jesen不等式在

\theta^{new}=\theta^{old}

時取等號。

\theta^{new}\neq\theta^{old}

時, KL散度大於0, 同時Jensen不等式也取不到等號。如果推導過程中沒有引入Jesen不等式,那麼(1)式和(2)式完全相等,

KL散度剛好衡量了引入Jensen不等式帶來的誤差

,即:

\begin{aligned}  &ln\sum_z{P(Z|X,\theta^{old})\frac{P(X|Z,\theta^{new})P(Z|\theta^{new})}{P(Z|X,\theta^{old})}}-\sum_z{P(Z|X,\theta^{old})ln\frac{P(X|Z,\theta^{new})P(Z|\theta^{new})}{P(Z|X,\theta^{old})}}\\ &=KL(P(Z|X,\theta^{old})||P(Z|X,\theta^{new}) \end{aligned}

上式透過對左邊第二項進行變形湊出KL散度,不難證明等式成立。

[1] PRML

[2]統計學習方法。 李航

標簽: 散度  KL  機率  下界  不等式