您當前的位置:首頁 > 詩詞

從GAN到W-GAN的“硬核拆解”(五):使用梯度懲罰的W-GAN-GP

作者:由 鐵心核桃 發表于 詩詞時間:2021-03-31

在上篇文章中,我們線性規劃的對偶問題說起,費了九牛二虎之力,終於得出了W-GAN的最佳化模型,二話不說我們再貼一遍先

G^*=\mathop{\arg\min}_{G}\mathop{\max}_{f_w \in 1-Lipschitz} \mathbb{E}_{x \sim p_r(x)}[f_w(x)]-\mathbb{E}_{z \sim p(z)}\Big[f_w(G(z))\Big]

其中的

\mathop{\max}

步有即構造兩個分佈

p_g

p_r

之間的W距離,而構造這一距離時需要函式

f_w(x)

滿足1-利普希茨連續條件(實際我們也說了,是不是1都無所謂,只要函式別“跑飛”了就行)。理論上很圓滿,但是W-GAN在實現時卻選擇了一個略顯粗糙的方法——

權重裁剪

(weight clipping),按照預先設定的截斷範圍,透過人為硬性的權重截斷方式,以限制函式引數的方式來限制

f_w(x)

的取值範圍。

希望瞭解W-GAN最佳化模型是怎麼得到的童鞋,請戳下面的連結:

從動機角度看,權重裁剪的想法是極其樸素的,樸素到一行Python程式碼基本就搞定了,樸素到和W-GAN一路推導的艱辛相比都不比例。從效果角度看,權重裁剪也是存在大問題的。正是因為這個原因,才有了本篇文章介紹具有

梯度懲罰

(Gradient Penalty, GP)策略的W-GAN,即W-GAN-GP。

權重裁剪的問題

權重裁剪這種粗糙的操作存在兩個問題:

問題1——引數二值化

:W-GAN的最佳化目標是兩個數學期望的差:第一個數學期望

\mathbb{E}_{x \sim p_r(x)}[f_w(x)]

表示將真實資料帶入

f_w(x)

之後得到的平均值;第二個數學期望

\mathbb{E}_{x \sim p_g(x)}[f_w(x)]

表示將生成資料帶入

f_w(x)

之後得到的平均值。兩個期望之差自然表示真假兩路資料被

f_w

作用後平均值之間的差值(簡單說就是函式要在兩類資料上“扯開”)。試想我們在權重裁剪中給定常數

c

(如

c=0.01

)後,要求上面的差值越大越好,且要求引數不能小於

-c

且不能大於

c

,這樣的操作勢必導致引數的向著

-c

c

聚集,形成兩極分化,

f_w(x)

即退化為一個二值神經網路,表現能力將異常低下。下圖為引數二值化示意。

從GAN到W-GAN的“硬核拆解”(五):使用梯度懲罰的W-GAN-GP

引數二值化示意

問題2——訓練難調節

:權重截斷直接在引數上“做手術”,硬性將權重設定在

[-c,c]

範圍,這存在將本來大的引數截小和將本來小的引數截大的可能。因此截斷值

c

的設定至關重要,否則要麼梯度消失(梯度為0),要麼梯度爆炸(梯度為無窮大)。這個道理我們可以這麼簡單的理解:假設神經絡

f_w(x)

及對應的梯度下降法可以表達如下形式(為了簡單起見,不考慮偏置)

\begin{aligned} w_n\sigma_{n−1} (…\sigma_2 (w_2 \sigma_1 (w_1 x))…)\\  w^{(n+1)}=w^{(n)}-\eta \frac{\partial f_w(x)}{\partial w} \end{aligned}

f_w(x)

對第

i

層引數

w_i

梯度為

\frac{\partial f_w(x)}{\partial w_i}=(w_n w_{n-1}  \cdots w_{i+1})\cdot (\sigma

可以看到鏈式法則下,靠前層權重的更新量與其後所有層權重乘積成正比。這就意味著:當

c

設定的比較小時,這一乘積將可能越乘越小,從而可能出現梯度消失;反過來,當

c

設定的比較大時,這一乘積將可能越乘越大,從而可能出現梯度爆炸。

梯度懲罰及實現

面對上面權重裁剪存在的問題,W-GAN-GP的解決方案為在目標函式上新增一個梯度懲罰項(Gradient Penalty, GP),“罰出”一個利普希茨條件。

版本1

的目標函式變為

w^*=\mathop{\arg\min}_w\mathbb{E}_{x \sim p_r(x)}[f_w(x)]-\mathbb{E}_{x \sim p_g(x)}\Big[f_w(x)\Big]+\lambda \mathop{\max}(\|\nabla_xf_w(x)\|,1)

其中

\lambda

為梯度懲罰項的常係數。上面式子可以看出兩件事:

\|\nabla_xf_w(x)\| \le 1

時,懲罰項就剩下一個常數

\lambda

,求

\mathop{\arg\min}

的時候可以忽略(目標函式退化為W-GAN的目標函式);

\|\nabla_xf_w(x)\| \gt 1

時,懲罰項為

\lambda \cdot\|\nabla_xf_w(x)\|

,這就意味著在求

\mathop{\arg\min}

的時候梯度越大則懲罰越大。

上面兩種情況意味著梯度超過1就做懲罰,即將梯度儘可能的限制在1之內。除此之外,W-GAN-GP還提出了第二個版本的目標函式(

版本2

w^*=\mathop{\arg\min}_w\mathbb{E}_{x \sim p_r(x)}[f_w(x)]-\mathbb{E}_{x \sim p_g(x)}\Big[f_w(x)\Big]+\lambda (\|\nabla_xf_w(x)\|-1)^2

版本2的梯度懲罰項來的更直接:當

\|\nabla_xf_w(x)\| \ne 1

時就做懲罰,且距離1越遠則懲罰越大,即將梯度儘可能的等於1。

對於上面兩個版本的梯度懲罰,W-GAN中的利普希茨條件只是要求函式

f_w

的梯度有界,但是具體是多少其實無所謂。儘管版本1更符合1−利普希茨條件的要求,但是其懲罰項中的“

\mathop{\max}

”不可微,所以版本2更加合理,所以我們就僅針對版本2的目標函式進行討論。

帶梯度懲罰的目標函式有了,那具體怎麼操作呢?利普希茨連續條件可是要求函式

f(x)

任意

位置的梯度長度不能超過一個定值

K

,即要求

處處

成立,上面懲罰項中抽象的

x

就表示任意的位置。但是,所謂的“任意”是無法操作的,所以還是需要帶入具體的點,我們用

\hat{x}

表示這些“任意且具體的點”,則有

w^*=\mathop{\arg\min}_w\mathbb{E}_{x \sim p_r(x)}[f_w(x)]-\mathbb{E}_{x \sim p_g(x)}\Big[f_w(x)\Big]+\lambda \mathbb{E}_{\hat{x} \sim p(\hat{x})}(\|\nabla_{\hat{x}}f_w(\hat{x})\|-1)^2

在上式中,由於“任意”點位也是隨機變數,所以我們加上了數學期望。對上面的“任意”點

\hat{x}

,我們希望最好能滿足如下兩點要求:

\hat{x}

最好是樣本空間的所有點,這樣才最“任意”。但是很遺憾,這要求無窮多點,但只有有限的樣本點,根本做不到;

\hat{x}

其次是全部的真實資料

x^r

和生成資料

x^g

,這樣才比較“任意”。但是也很遺憾,一次訓練不可能一次性讀入全部

x^r

,並生成同樣多的

x^g

,還是做不到。

W-GAN-GP的解決方案是一個

折中方案

\hat{x}

取自某一個訓練小批次(mini-batch)中的全部

x^r

x^g

,以及他們的隨機“混合”部分

\hat{x}=\epsilon x^r+(1-\epsilon)x^g,\ \epsilon \sim \mathcal{U}(0,1)

上面作用在“混合”樣本點的梯度約束可以用如下圖示表示

從GAN到W-GAN的“硬核拆解”(五):使用梯度懲罰的W-GAN-GP

“混合”樣本點的梯度約束示意(1D vs。 2D)

上面梯度懲罰方案可以用Python程式碼簡單實現為

# 梯度懲罰項的常係數

lambda

=

10

# 獲得真假樣本

G_sample

=

generator

z

D_real

=

discriminator

X

D_fake

=

discriminator

G_sample

# “混合”樣本的梯度約束

# 從均勻分佈U(0,1)中隨機獲取一個eps

eps

=

tf

random_uniform

([

mb_size

1

],

minval

=

0。

maxval

=

1。

# 用eps算一個隨機“混合”樣本出來

X_inter

=

eps

*

X

+

1。

-

eps

*

G_sample

# 表達梯度並做-1的二範數平方

grad

=

tf

gradients

discriminator

X_inter

),

X_inter

])[

0

grad_norm

=

tf

sqrt

tf

reduce_sum

((

grad

**

2

axis

=

1

))

grad_pen

=

lambda

*

tf

reduce_mean

((

grad_norm

-

1

**

2

# max min損失函式

D_loss

=

tf

reduce_mean

D_fake

-

tf

reduce_mean

D_real

+

grad_pen

G_loss

=

-

tf

reduce_mean

D_fake

上面的操作方式是否管用呢?咋說呢,實際W-GAN-GP的作者也說得比較牽強

“Given that enforcing the unit gradient norm constraint everywhere is intractable, enforcing it only along these straight lines

seems sufficient

and experimentally results in good performance” ——《Improved Training of Wasserstein GANs》

結束語

這篇文章中,我們用了比較少的篇幅介紹了W-GAN的一個改進版本W-GAN-GP,其本質只不過是換了一種方式使得函式滿足利普希茨連續條件而已,比起原始W-GAN中的直接對梯度做手術要高明一些。我們介紹W-GAN-GP並不是因為它多優秀,二是因為W-GAN挖了個坑,W-GAN-GP指示眾多填坑者中的一個比較有代表性的工作而已。畢竟,作為W-GAN的專題的後續,我們一點填坑的工作都不介紹,顯得不夠圓滿。

標簽: 梯度  GAN  懲罰  GP  函式