您當前的位置:首頁 > 文化

從訓練和預測角度來理解Transformer中Masked Self-Attention的原理

作者:由 Anish Hui 發表于 文化時間:2021-10-10

Transformer模型結構圖

在Transformer中Decoder會先經過一個masked self-attention層

從訓練和預測角度來理解Transformer中Masked Self-Attention的原理

使用Masked Self-Attention層可以解決下文提到的訓練階段和預測階段Decoder可能遇到的所有問題。

什麼是Masked Self-attention層

你只需要記住:masked self-attention層就是下面的網路連線(如果實現這樣的神經元連線,你只要記住一個sequence mask,讓右側的注意力係數

\alpha_{ij}=0

,那麼就可以達到這個效果)

從訓練和預測角度來理解Transformer中Masked Self-Attention的原理

訓練階段:

訓練時,你有decoder的target句子,你會直接輸入到masked self-attention層。(對於同一個src句子,只有這一個target輸入序列)

測試階段:

預測時,你會有很多的序列(對於同一個src句子,會有多個target輸入序列,慢慢生成的)

你會先把

作為序列,輸入到masked self-attention層,預測結果是y1

然後把

y1

作為序列,輸入到masked self-attention層(和訓練時一樣,都會用到mask矩陣來實現masked self-attention層的神經元連線方式),預測結果是y1, y2(由於可能有dropout,這個y1可能與第一步的y1稍微有點不同)

y1 y2

作為序列,輸入到masked self-attention層,每個位置上的預測結果是y1, y2, y3

從訓練和預測角度來理解Transformer中Masked Self-Attention的原理

可以看到預測階段我們希望是增量更新的,對於重複的單詞,我們希望預測的結果是一樣的,而且

y_i

永遠只是用到它和它左側的decoder輸入資訊,不會用到右側的。

greedy_decoder程式碼預測的輸入序列是一個個生成的

def

greedy_decoder

model

enc_input

start_symbol

):

“”“貪心編碼

For simplicity, a Greedy Decoder is Beam search when K=1。 This is necessary for inference as we don‘t know the

target sequence input。 Therefore we try to generate the target input word by word, then feed it into the transformer。

Starting Reference: http://nlp。seas。harvard。edu/2018/04/03/attention。html#greedy-decoding

:param model: Transformer Model

:param enc_input: The encoder input

:param start_symbol: The start symbol。 In this example it is ’S‘ which corresponds to index 4

:return: The target input

”“”

enc_outputs

enc_self_attns

=

model

encoder

enc_input

dec_input

=

torch

zeros

1

0

type_as

enc_input

data

terminal

=

False

next_symbol

=

start_symbol

while

not

terminal

# 預測階段:dec_input序列會一點點變長(每次新增一個新預測出來的單詞)

dec_input

=

torch

cat

([

dec_input

to

device

),

torch

tensor

([[

next_symbol

]],

dtype

=

enc_input

dtype

to

device

)],

-

1

dec_outputs

_

_

=

model

decoder

dec_input

enc_input

enc_outputs

projected

=

model

projection

dec_outputs

prob

=

projected

squeeze

0

max

dim

=-

1

keepdim

=

False

)[

1

next_word

=

prob

data

-

1

# 拿出當前預測的單詞(數字)

next_symbol

=

next_word

if

next_symbol

==

tgt_vocab

“E”

]:

terminal

=

True

# print(next_word)

# greedy_dec_predict = torch。cat(

# [dec_input。to(device), torch。tensor([[next_symbol]], dtype=enc_input。dtype)。to(device)],

# -1)

greedy_dec_predict

=

dec_input

[:,

1

:]

return

greedy_dec_predict

按照上面的說法,如果我在程式碼中關掉了dropout,那麼當預測序列是

x

時的輸出結果,應該是和預測序列時

x

的前3個位置結果是一樣的(增量更新)

驗證:關掉位置編碼中dropout後,你會發現之前的輸入x’1,x‘2,x’3經過decoder網路的結果果然是不變的。

從訓練和預測角度來理解Transformer中Masked Self-Attention的原理

為什麼需要Masked Self-attention層

對於這個疑問,我在知乎上看到別人也有這樣的困惑,

在測試或者預測時,Transformer裡decoder為什麼還需要seq mask?

Transformer原論文中是這樣解釋的:

從訓練和預測角度來理解Transformer中Masked Self-Attention的原理

反正我知道他說的意思,但是又好像沒完全懂。

直到後來看到一個部落格

深入理解transformer原始碼

,才理解透徹了這個問題。

這個問題我們分兩個角度來看:

訓練階段為什麼要用masked?

這個比較好理解,因為你訓練的時候算loss,是用當前decoder輸入所有單詞對應位置的輸出

y_1,y_2,...y_t

與真實的翻譯結果ground truth去分別算cross entropy loss,然後把t個loss加起來的,如果你用的是self-attention,那麼

y_1

這個輸出裡面是包含了

x

右側的單詞資訊的(特別是包含了

x

這個你要預測的下一個單詞的資訊),這是用到了未來的資訊,模型是在作弊,屬於資訊洩露。在實際推理過程中,我們顯然是不可能提前知道未來資訊的。

我們可以看下面Transformer訓練過程的程式碼:

訓練是用

y_1,y_2,...y_t

與真實的翻譯結果ground truth去分別算cross entropy loss,然後把t個loss加起來,得到loss的(那顯然是需要用masked self attention的,否則

y_1

是包含右側的資訊的,那你還把他加到loss裡面,它都作弊了,顯然這個位置算出來的loss是會比較低,但有什麼意義呢,真正推理的時候你哪來未來的資訊啊)

從訓練和預測角度來理解Transformer中Masked Self-Attention的原理

預測階段為什麼也要用masked?

很多部落格說是Transformer使用sequence masked是在模擬預測時的情況,因為預測結果是迭代生成的,這是為了不讓模型偷看到未來的內容,這樣解釋也沒有錯,但是沒有講清楚預測階段為什麼要用masked。

預測階段還使用Masked的原因:

原因1:預測階段要保持重複的單詞預測結果是一樣的,這樣不僅合理,而且可以增量更新(我們在預測是會選擇性忽略重複的預測的詞,只摘取最新預測的單詞拼接到輸入序列中)

如果我在程式碼中關掉了dropout,那麼當預測序列是

x

時的輸出結果,應該是和預測序列時

x

的前3個位置結果是一樣的(增量更新)

原因2:恰好也可以與訓練時的模型架構保持一致,前向傳播的方式是一致的

附:Transformer預測階段為什麼要保持重複的單詞預測結果是一樣的?

從訓練和預測角度來理解Transformer中Masked Self-Attention的原理

摘自:

blog

簡單來說就是,你如果用self-attention,那麼

x

的輸出是

y_1,y_2

,而

x

的輸出是

z_1,z_2,z_3

(這裡面

y_1

會包含

x

的資訊,

z_1

更離譜會包含

x

的資訊,這tm屬於資訊洩露了)。而且我們希望的是增量更新,前面的同一個單詞預測結果,我們希望是一樣的。

你想想我們在作機器翻譯的時候,是一個個把輸出結果加到最終結果裡的。但我們是不會因為

z_1

預測更準(作弊了),就用這個單詞去替換之前已經預測出來的單詞

y_1

實戰:Transformer預測階段輸入序列

x

,我們是會得到

y_1,y_2,y_3

的,但是因為有masked機制的存在,我們基本可以保證此時的

y_1,y_2

和之前的輸入序列

x

的預測結果

y_1,y_2

是相同的(程式碼中關掉dropout的情況下)

所以一般預測的時候我們是直接用當前最後一個位置

x

對應的機率向量去預測下一個單詞

y_t

從訓練和預測角度來理解Transformer中Masked Self-Attention的原理

標簽: input  預測  masked  Attention  DEC