seq2seq之tensorflow原始碼解析
很直接的說,這篇文章就是對tensorflow的seq2seq大禮包的原始碼做了一定程度的解析。個人一直覺得作為一個演算法工程師,要經常學習好的開源框架裡面的工程程式碼,這樣不僅能夠在實現自定義模型時好下手,也能提升自己的工程能力。本文用到的tensorflow框架版本為1。4。目前比較新的版本中的相關程式碼主體改動有一些,不是很大,所以可以適配著看。
之前的文章中,我已經對seq2seq+attention模型有過講解,當時以為自己弄懂了。但是最近透過一些比賽,實際去實現這個模型的時候,才發覺自己並沒有真正吃透這個模型,因此下面重新對這個帶attention機制的seq2seq模型進行回顧,然後再從tensorflow的seq2seq大禮包原始碼入手,講解其中的步驟。為了更好描述,這裡以最近參加的標題生成的比賽作為例子。(成績不太好,就不寫總結了)
假設當前的輸入資料,已經對out-of-vocab詞用
Encoder
首先是encoder。主要作用是讀取並學習document text中的資訊。常使用RNN或者帶門機制的RNN(GRU或者LSTM)。一般會先對輸入進行embedding_lookup,獲取對應的embedding_input後,將其輸入到RNN的中進行編碼。一般的encoder會直接輸出最後
時刻的狀態
,作為decoder的初始
狀態。這裡有個改進點,可以提升整個模型的效能:
之所以可以將encoder的
作為decoder的初始狀態,在於經過RNN的學習與傳遞後,最後一個時刻的狀態h已經整合了當前網路對輸入文字的分析和總結,因此直接將其作為decoder的初始狀態,相對於隨機初始化操作,可以加快網路的訓練速度,提升模型的效能。那麼,我們可以透過其他方式,更好得獲得encoder的學習成果,比如將所有t時刻的狀態
進行一個average pooling,max pooling,甚至進行一次self attention,最後得到一個綜合的輸出,作為decoder的初始狀態,這樣應該能比baseline方法更好。在實踐過程中,也確實是這樣,其中average pooling的效果最好,而self attention之所以效果較差的原因可能是模型中後面decoder的時候也使用了attention機制,因此這裡使用self attention機制對模型效果提升沒有很大的效果,相反,還會導致過擬合。且average pooling的效率最高,因此最後我選用了這種方式,來輸出decoder的初始狀態。
當前encoder模組有兩個比較重要的輸出會被後續的decoder使用:
1、RNN中的每個時刻的狀態
2、RNN中的每個時刻的輸出
其中,狀態
會作為decoder的初始狀態被使用到,而輸出
則會在attention機制計算時使用到。這就是該模型的核心之處。下面用圖表示,會比較清楚。
Decoder
decoder的主體也是使用的RNN,原始輸入經過embedding_lookup(
和encoder共享的
)後,輸出到RNN中,每個t時刻的狀態為
,其中
的初始化已經在前面講過,用encoder的相關資訊來初始化。然後就是模型的attention機制的計算和使用。
這裡以Bahdanau Attention計算方式為例。這個attention具體的內容就不講了,有關attention的內容之前文章有描述,這裡直接放上公式:
其中,
均為網路權重引數。
為參與attention計算的memory,即encoder的output,而
為decoder在decoder t時刻的狀態。
這裡注意,我使用的是Global attention,相對的還有Local attention,由於Local attention計算方面還要考慮上下文windows,為了便於講解,使用了簡單的Global attention機制,且效果也挺好的。
這個公式表明,我每次要計算輸出當前時刻的decoder output logits時,需要將當前時刻的decoder 狀態整合encoder的output,一起輸入到attention公式中計算,最後得到alignment,這個alignment只與當前時刻t的計算有關,一般是不用儲存的。當然後續應用coverage機制的時候,需要將這個alignment的歷史資訊都先儲存下來。
得到alignment後,接下來就是應用
到encoder的output中,進行加權求和,得到最終的attention context。這個是融合了encoder,decoder輸入資訊的attention學習成果。
這個學習成果需要和decoder的中間狀態
整合在一起(一般使用concat,也有其他整合方法),然後輸入到全連線層中,並最終輸出logits,或者輸出一個softmax機率,即詞庫中每個詞在當前位置的機率分佈。最後計算當前時刻的
。
當然,所有時刻的
計算完後,需要將loss加起來,並做sequence_mask,將pad位置的loss過濾掉,然後求平均。
Tensorflow seq2seq
上述過程,看著挺簡單的,但是實現起來,其實還是挺麻煩的,因此tensorflow提供了一個方便的大禮包,包含了seq2seq的各個階段的計算介面,只要按照規則正確呼叫,就能搭建出模型。但是這個大禮包,雖然很方便,但是要進行進一步的自定義修改,還是挺難的,需要對裡面的程式碼進行長時間的閱讀和理解,且tensorflow的文件可以說是比較少了,所以這裡把我這段時間閱讀程式碼的總結寫下來,方便大家參考。
大禮包的主要核心部分在decoder側,encoder側,還是使用我們熟悉的RNN的介面,一開始我使用的是CudnnGRU,但是由於其不支援dynamic_rnn,無法對pad位置的資訊特殊處理,因此實際使用效果雖然速度快,但是效果卻沒有原始的gru介面好,因此最後還是用了tf。contrib。rnn。GRUCell(貌似tf。contrib這個包會在後續的版本中逐漸捨棄,所以建議還是用另一個介面)。最後說明一下encoder的輸出為兩部分:
1、encoder_output,shape=[batch_size,time_steps,rnn_dims],為rnn的輸出
2、encoder_state,shape=[batch_size,dims],如果為雙向rnn,需要將兩個方向的state先拼接,然後再對所有時刻的state沿time_steps做average pooling。最後可能還需要對encoder_state做一層全連線層,使得它與decoder的RNN的dimension一致。
decoder的編寫流程
下面我先按照正常的流程使用大禮包編寫decoder的流程。然後按照程式碼的執行順序描述程式碼的執行機制。inference階段我使用beamSearch來舉例。
attention_mechanism
=
tf
。
contrib
。
seq2seq
。
BahdanauAttention
(
decoder_dim
,
encoder_output
,
memory_sequence_length
=
sum_len
,
normalize
=
True
)
decoder_cell
=
tf
。
contrib
。
seq2seq
。
AttentionWrapper
(
decoder_cell
,
attention_mechanism
,
attention_layer_size
=
atten_size
)
initial_state
=
decoder_cell
。
zero_state
(
dtype
=
tf
。
float32
,
batch_size
=
batch_size
)
initial_state
=
initial_state
。
clone
(
cell_state
=
encoder_state
)
#train
helper
=
tf
。
contrib
。
seq2seq
。
TrainingHelper
(
decoder_inputs_embedding
,
title_len
,
time_major
=
False
)
decoder
=
tf
。
contrib
。
seq2seq
。
BasicDecoder
(
decoder_cell
,
helper
,
initial_state
,
output_layer
=
projection_layer
)
outputs
,
self
。
train_dec_last_state
,
_
=
tf
。
contrib
。
seq2seq
。
dynamic_decode
(
decoder
,
output_time_major
=
False
)
#logits
decoder_output
=
outputs
。
rnn_output
上述是訓練過程,下面我們自底向上的順序來說明這個原始碼。
原始碼執行流程
由於下面內容比較繁瑣,因此先po出流程圖,不喜歡看文字的同志可以看這幅圖瞭解一下大概呼叫流程。該圖只是一個整體的流程,很多細節在後面文字描述。
tensorflow seq2seq大禮包模組簡易流程
1、呼叫棧的最底層就是我們的
tf.contrib.seq2seq.BahdanauAttention
,它提供了具體的attention計算方法,原始碼位於
tf.contrib.seq2seq.AttentionWrapper
中。
可以看到它繼承了一個基類,這個基類的主要作用是在_
init
_方法中對memory做了全連線層的計算(如果定義了dense layer的話),得到attention計算中的key和value。而子類則定義了query_layer和memory_layer,以及最後歸一化的方法(一般是softmax)
接下來就是該類的核心方法:
可以看到該方法是一個__call__()方法,熟悉python的同志肯定知道,一個類實現了這個方法,可以讓類變得可以向函式方法一樣被直接呼叫。其中,主要核心就是呼叫了_bahdanau_score方法來進行attention計算。這個方法具體就不貼出來了,就是按照論文中的計算方法。
備註:這裡有一個_probability_fn方法,需要輸入歷史的alignment,這個就是之前說的,有一些attention方法需要使用歷史的alignment資料,因此所有attention方法都預設會呼叫這個_probability_fn方法,只不過在BahdanauAttention中它是不對歷史的alignment做任何操作,直接對score做softmax。
2、往上面一層就是
AttentionWrapper
,這個類負責呼叫整個Attention機制的流程。首先看一下它的類定義:
可以看出來,它繼承了RNNCell,說明這也是一個具有RNN特性的實現類,且在RNNCell上,封裝了Attention的特性。
另外,繼承了RNNCell的所有子類,需要實現call()方法,使得在構建網路計算圖(build())時,能夠呼叫相關邏輯
。
分析它的__init__方法,可以看到它接收decoder中的RNNCell,之前定義的Attention機制,另外關注一下
alignment_history
,這個屬性就是能控制是否儲存歷史alignment資訊的開關。Attention_layer_size屬性則定義了Attention操作後是否需要連線一個全連線層到指定的dimension。
該類有兩個重要的方法,zero_state和call。前者用於初始狀態s的封裝和生成,後者主要是用於attention計算的主流程。
下面看一下zero_state方法, 其實該方法最重要的就是返回一個封裝過的可用於Attention計算的AttentionState。
下面看一下這個封裝類的結構:
可以看到,其實該類就是一個namedtuple的資料結構封裝。
cell_state儲存的是AttentionWrapper包裹的RnnCell在t-1時刻的狀態
attention儲存的是t-1時刻輸出的context
time儲存的當前時刻t
alignments儲存的是t-1時刻輸出的alignment
alignment_history儲存的是所有時刻的歷史alignment資訊
AttentionWrapperState中還有一個clone方法,在我們的模型圖中也有呼叫的地方:
initial_state = initial_state。clone(cell_state=encoder_state)
其實就是對我們初始化的AttentionWrapperState物件,將cell_state的屬性值對替換為從encoder輸出的state(經過average pooling)。
下面是AttentionWrapper類的核心方法:call,該方法定義了attention操作的主流程。
該方法的引數為inputs:即decoder中的當前時刻t的輸入,而state則是封裝過的AttentionWrapperState。下面對關鍵程式碼進行註釋:
上述程式碼的主要操作是將當前時刻input和前一時刻的context拼接後,輸入到decoder中的RNN層中做處理。最後輸出output,以及下一個時刻的中間狀態。
主要核心是方法_compute_attention方法,該方法是attention計算的入口。
得到attention的context和alignment資訊後,就是返回需要的資訊,其中比較重要的是返回的當前時刻的中間狀態為AttentionWrapperState,這個中間狀態會被下一時刻t+1的計算使用。
需要注意的是最後返回的時候,有一個flag,_output_attention,這個是控制當前是否返回attention的資訊還是rnn的output資訊,對於BahdanauAttention style來說,是返回cell_output。
3、再往上一層,找到
tf.contrib.seq2seq.BasicDecoder
,這個類主要作用是將上述的所有操作流程按照decoder的序列長度依次按順序執行。
看到它繼承了Decoder這個類。它的核心方法為step方法,下面就basicDecoder的step方法具體描述:
它的引數為time:當前時刻,inputs:decoder的輸入,state:前一時刻傳遞而來的狀態。下面是具體的程式碼流程:
首先這裡的_cell,是我們之前定義的AttentioWrapper,因為它是繼承了RnnCell,因此具有RnnCell的特性。這裡相當於是呼叫做了前兩個大模組的操作。返回了當前時刻的output,state。
那麼在訓練階段,模型是如何推動上面的計算步驟一步一步到最後的呢?下面要用到另一個有用的大禮包的類:
tf.contrib.seq2seq.TrainingHelper
上圖中的helper就是要用的幫助訓練的類(顧名思義)。這裡呼叫了兩個方法,第一個是sample,即根據當前時刻的output,獲取當前時刻詞分佈中機率最大的那個詞id。
第二個是next_inputs方法,主要是根據當前處理的時刻t,讀取下一個時刻的輸入,用於下一時刻的計算,並返回序列處理結束標誌。
step的最後一個步驟就是返回處理的結果,這裡它又封裝了一個特殊的資料型別:BasicDecoderOutput
BasicDecoderOutput定義如下:
其實根AttentioWrapperState相似,也是一個封裝了的namedtuple,主要儲存的是rnn_output,以及最後得到的詞的id。這樣封裝的好處是,rnn_output可以用於計算loss時直接使用,而sample_id則是在inference階段可以用來輸出結果。
4、最高一層是大禮包中最重要的一個部分,上述basicDecoder的step方法,如果沒有任何上層介面驅動,也是無法完成。因此
tf.contrib.seq2seq.dynamic_decode
就是用於完成這項工作的。
與其他介面不同,這個是一個可直接呼叫的方法,其方法定義如下:
其中decoder就是之前定義的basicDecoder。
impute_finished
屬性表示模型在梯度傳遞的時候會忽略最後標記為finished的位置。這個一般設為True,能夠保證梯度正確傳遞。而maximum_iterations為我們自定義的decoding最大長度,可以比設定的title_len大或者小,主要看調參。swap_memory表示在執行while迴圈是否啟用GPU-CPU記憶體交換。
下面只列出該放裡面的核心步驟:
首先這是tensorflow中的迴圈操作。它的迴圈條件condition為:
他會接受basicDecoder返回的finished標誌,並判斷當前是否已經處理結束。
然後是迴圈的body部分,也只放上核心部分:
即呼叫basicDecoder的step方法來執行decoding,這樣就與之前講的聯絡上了。
Inference階段
其實train階段和inference的不同點很簡單,在於inference階段沒有decoder的input,因此每個時刻的state計算都需要輸入前一個時刻的計算結果。這裡以BeamSearch舉例。
tiled_encoder_output
=
tf
。
contrib
。
seq2seq
。
tile_batch
(
self
。
encoder_output
,
multiplier
=
self
。
cfg
。
beam_width
))
tiled_encoder_final_state
=
tf
。
contrib
。
seq2seq
。
tile_batch
(
encoder_state
,
multiplier
=
self
。
cfg
。
beam_width
)
tiled_seq_len
=
tf
。
contrib
。
seq2seq
。
tile_batch
(
self
。
sum_len
,
multiplier
=
self
。
beam_width
)
attention_mechanism
=
tf
。
contrib
。
seq2seq
。
BahdanauAttention
(
self
。
cfg
。
lstm_units
,
tiled_encoder_output
,
memory_sequence_length
=
tiled_seq_len
,
normalize
=
True
)
decoder_cell
=
tf
。
contrib
。
seq2seq
。
AttentionWrapper
(
decoder_cell
,
attention_mechanism
,
attention_layer_size
=
self
。
cfg
。
lstm_units
*
2
)
initial_state
=
decoder_cell
。
zero_state
(
dtype
=
tf
。
float32
,
batch_size
=
self
。
batch_size
*
self
。
cfg
。
beam_width
)
initial_state
=
initial_state
。
clone
(
cell_state
=
tiled_encoder_final_state
)
decoder
=
tf
。
contrib
。
seq2seq
。
BeamSearchDecoder
(
cell
=
decoder_cell
,
embedding
=
self
。
embedding_init
,
start_tokens
=
tf
。
fill
([
self
。
batch_size
],
tf
。
constant
(
2
)),
end_token
=
tf
。
constant
(
3
),
initial_state
=
initial_state
,
beam_width
=
self
。
beam_width
,
output_layer
=
self
。
projection_layer
)
outputs
,
_
,
_
=
outputs
,
_
,
_
=
tf
。
contrib
。
seq2seq
。
dynamic_decode
(
decoder
,
output_time_major
=
false
,
maximum_iterations
=
self
。
decoder_max_iter
,
scope
=
decoder_scope
)
self
。
prediction
=
outputs
。
predicted_ids
整個inference流程與train類似,唯一不同的地方在於beamSearch演算法本身,需要將所有的輸入和中間狀態複製beam_size份,用於beam的搜尋。而主要的區分點在於使用的decoder不同,這裡我就著重講一下
tf.contrib.seq2seq.BeamSearchDecoder。
主要還是貼出其step方法中的核心過程:
beamSearch的方法中,會存在很多merge_beam,split_beam等改變tensor的shape操作,方便一些操作的計算,這裡就不仔細講了,只要記住merge一般和split應是成對出現
。
首先當然是呼叫AttentionWrapper,來計算輸出當前時刻的cell_output,以及下一個時刻的state。
然後就是另一個核心方法呼叫_beam_search_step:
這個方法主要是執行Beam搜尋的流程,核心的模組流程如下:
先計算當前時刻為止的所有候選序列計算機率值之和。
然後計算每個beam的分數:
然後是根據指定的beam_size,使用top_k運算,得到最合適的beam_size候選。
最後就是返回一些封裝的結果,就不具體列出了。
總結
其實還有很多原始碼並沒放到上面分析,鑑於篇幅問題,寫的太多,可能看得也越困難。總體來說,透過這次的比賽實踐,還是對seq2seq模型有一定的深入理解,無論是理論上還是工程實現上,tensorflow的大禮包的實現確實挺漂亮,一環扣一環,希望能吃透它的工程思維,融入到自己的實踐中。
上一篇:明悟之戀~鳳毛麟角