您當前的位置:首頁 > 攝影

詳解標準RNN、GRU和LSTM cell之間的區別

作者:由 Jermy·Lu 發表于 攝影時間:2020-09-15

標準RNN cell狀態更新公式如下所示:

h_t = f(w_h h_{t-1} + w_x x_t)

解釋:

RNN cell當前時刻的hidden state

h_t

由上一時刻的hidden state

h_{t-1}

和當前時刻的input vector

x_t

共同決定,RNN的訓練過程實際上是根據loss 調整

w_h

w_x

兩個引數矩陣的過程。

標準GRU cell狀態更新公式如下所示:

z_t = \sigma (w_{z, h} h_{t-1} + w_{z, x} x_t)

r_t = \sigma (w_{r, h} h_{t-1} + w_{r, x} x_t)

\tilde{h_t} = tanh(w_{h, x} x_t + r_t * w_{h, h} h_{t-1})

h_t = z_t * h_{t-1} + (1-z_t)*\tilde{h_t}

解釋:

1。

z_t 和 r_t

中的

\sigma

表示sigmoid函式,

\sigma

函式值域為

(0, 1)

,其中

z_t

記為update gate,

r_t

記為reset gate。

2。

*

表示

Hadamard Product

,也就是操作矩陣中對應的元素相乘,因此要求兩個相乘矩陣是同型的。

3。

z_t、r_t和h_t

向量維度保持一致,但

z_t和r_t

向量元素值均處於

(0, 1)

4。

\tilde{h_t}

為當前時刻的“中間-隱狀態”。當

r_t

為全1向量時,

\tilde{h_t}

與標準RNN cell的隱狀態更新(計算)方式保持一致。

In addition

, 當

z_t

為全0向量時,

h_t = \tilde{h_t} = tanh(w_{h,x} x_t + w_{h,h} h_{t-1})

。此時,

GRU cell退化成了標準的RNN cell。

5。 當

r_t

為全0向量時,

\tilde{h_t} = tanh(w_{h,x}x_t)

。此時, 當前時刻隱狀態

h_t

的更新結果是否包含上一個時刻的隱狀態

h_{t-1}

完全取決於update gate

z_t

的值。

6。 GRU的訓練過程實際上是根據loss 調整

w_{z,x}、w_{z,h}、w_{r,x}、w_{r, h}、w_{h,x}和w_{h,h}

引數矩陣的過程。

標準LSTM cell狀態更新公式如下所示:

i_t = \sigma(w_i x_t + u_i h_{t-1})

f_t = \sigma(w_f x_t + u_f h_{t-1})

o_t = \sigma(w_o x_t + u_o h_{t-1})

\tilde{c_t} = tanh(w_c x_t + u_c h_{t-1})

c_t = f_t * c_{t-1} + i_t * \tilde{c_t}

h_t = o_t * tanh(c_t)

解釋:

1。

i_t、f_t和o_t

中的

\sigma

表示sigmoid函式,

\sigma

函式值域為

(0, 1)

,其中

i_t

記為input gate,

f_t

記為forget gate,

o_t

記為output gate。

2。

*

表示

Hadamard Product

,也就是操作矩陣中對應的元素相乘,因此要求兩個相乘矩陣是同型的。

\tilde{c_t}

記為“中間-記憶體狀態”,

c_t

記為最終記憶體狀態。

3。 當前時刻的 最終記憶體狀態

c_t

由forget gate 向量、上一時刻的最終記憶體狀態、input gate向量 和 當前時刻的“中間-記憶體狀態”共同決定。input gate決定當前時刻的輸入

x_t

和上一時刻的隱狀態

h_{t-1}

的資訊的保留部分,forget gate決定上一時刻的最終記憶體狀態資訊的保留部分。即可理解為:

是否忘記過去,是否保留當前輸入。

4。 而

o_t

起到的作用可以理解為:

在經歷過遺忘 和 保留輸入之後,是否將現有的資訊在當前時刻輸出

。因此,

h_t = o_t * tanh(c_t)

。否則,直接

h_t = tanh(c_t)

即可。

5。 LSTM的訓練過程就是根據loss更新

w_i、w_f、w_o、w_c、u_i、u_f、u_o和u_c

八個引數矩陣的過程。

標簽: Gate  時刻  Cell  狀態  矩陣