您當前的位置:首頁 > 文化

seq2seq之tensorflow原始碼解析

作者:由 邱震宇 發表于 文化時間:2018-12-19

很直接的說,這篇文章就是對tensorflow的seq2seq大禮包的原始碼做了一定程度的解析。個人一直覺得作為一個演算法工程師,要經常學習好的開源框架裡面的工程程式碼,這樣不僅能夠在實現自定義模型時好下手,也能提升自己的工程能力。本文用到的tensorflow框架版本為1。4。目前比較新的版本中的相關程式碼主體改動有一些,不是很大,所以可以適配著看。

之前的文章中,我已經對seq2seq+attention模型有過講解,當時以為自己弄懂了。但是最近透過一些比賽,實際去實現這個模型的時候,才發覺自己並沒有真正吃透這個模型,因此下面重新對這個帶attention機制的seq2seq模型進行回顧,然後再從tensorflow的seq2seq大禮包原始碼入手,講解其中的步驟。為了更好描述,這裡以最近參加的標題生成的比賽作為例子。(成績不太好,就不寫總結了)

假設當前的輸入資料,已經對out-of-vocab詞用代替,且對decoder的輸入資料增加了標記,對輸出資料末尾增加了標記,最後已經padding成等長資料。模型的主要結構為encoder-decoder。規定輸入和輸出共享一個詞典和embedding層。

Encoder

首先是encoder。主要作用是讀取並學習document text中的資訊。常使用RNN或者帶門機制的RNN(GRU或者LSTM)。一般會先對輸入進行embedding_lookup,獲取對應的embedding_input後,將其輸入到RNN的中進行編碼。一般的encoder會直接輸出最後

t_n

時刻的狀態

h_{t_n}

,作為decoder的初始

t_0

狀態。這裡有個改進點,可以提升整個模型的效能:

之所以可以將encoder的

h_{t_n}

作為decoder的初始狀態,在於經過RNN的學習與傳遞後,最後一個時刻的狀態h已經整合了當前網路對輸入文字的分析和總結,因此直接將其作為decoder的初始狀態,相對於隨機初始化操作,可以加快網路的訓練速度,提升模型的效能。那麼,我們可以透過其他方式,更好得獲得encoder的學習成果,比如將所有t時刻的狀態

h_{t_i}

進行一個average pooling,max pooling,甚至進行一次self attention,最後得到一個綜合的輸出,作為decoder的初始狀態,這樣應該能比baseline方法更好。在實踐過程中,也確實是這樣,其中average pooling的效果最好,而self attention之所以效果較差的原因可能是模型中後面decoder的時候也使用了attention機制,因此這裡使用self attention機制對模型效果提升沒有很大的效果,相反,還會導致過擬合。且average pooling的效率最高,因此最後我選用了這種方式,來輸出decoder的初始狀態。

當前encoder模組有兩個比較重要的輸出會被後續的decoder使用:

1、RNN中的每個時刻的狀態

h_{t_i}

2、RNN中的每個時刻的輸出

o_{t_i}

其中,狀態

h_{t_i}

會作為decoder的初始狀態被使用到,而輸出

o_{t_i}

則會在attention機制計算時使用到。這就是該模型的核心之處。下面用圖表示,會比較清楚。

Decoder

seq2seq之tensorflow原始碼解析

decoder的主體也是使用的RNN,原始輸入經過embedding_lookup(

和encoder共享的

)後,輸出到RNN中,每個t時刻的狀態為

s_t

,其中

s_0

的初始化已經在前面講過,用encoder的相關資訊來初始化。然後就是模型的attention機制的計算和使用。

這裡以Bahdanau Attention計算方式為例。這個attention具體的內容就不講了,有關attention的內容之前文章有描述,這裡直接放上公式:

e^t_{i}=v^Ttanh(W_oo_i+W_ss_t+b_{attn})\\ a^t = softmax(e^t)\\

其中,

v,W_o,W_s,b_attn

均為網路權重引數。

o_i

為參與attention計算的memory,即encoder的output,而

s_t

為decoder在decoder t時刻的狀態。

這裡注意,我使用的是Global attention,相對的還有Local attention,由於Local attention計算方面還要考慮上下文windows,為了便於講解,使用了簡單的Global attention機制,且效果也挺好的。

這個公式表明,我每次要計算輸出當前時刻的decoder output logits時,需要將當前時刻的decoder 狀態整合encoder的output,一起輸入到attention公式中計算,最後得到alignment,這個alignment只與當前時刻t的計算有關,一般是不用儲存的。當然後續應用coverage機制的時候,需要將這個alignment的歷史資訊都先儲存下來。

得到alignment後,接下來就是應用

a^t

到encoder的output中,進行加權求和,得到最終的attention context。這個是融合了encoder,decoder輸入資訊的attention學習成果。

o^*_t = \sum_ia^t_io_i\\

這個學習成果需要和decoder的中間狀態

s_t

整合在一起(一般使用concat,也有其他整合方法),然後輸入到全連線層中,並最終輸出logits,或者輸出一個softmax機率,即詞庫中每個詞在當前位置的機率分佈。最後計算當前時刻的

loss_t

當然,所有時刻的

loss_t

計算完後,需要將loss加起來,並做sequence_mask,將pad位置的loss過濾掉,然後求平均。

Tensorflow seq2seq

上述過程,看著挺簡單的,但是實現起來,其實還是挺麻煩的,因此tensorflow提供了一個方便的大禮包,包含了seq2seq的各個階段的計算介面,只要按照規則正確呼叫,就能搭建出模型。但是這個大禮包,雖然很方便,但是要進行進一步的自定義修改,還是挺難的,需要對裡面的程式碼進行長時間的閱讀和理解,且tensorflow的文件可以說是比較少了,所以這裡把我這段時間閱讀程式碼的總結寫下來,方便大家參考。

大禮包的主要核心部分在decoder側,encoder側,還是使用我們熟悉的RNN的介面,一開始我使用的是CudnnGRU,但是由於其不支援dynamic_rnn,無法對pad位置的資訊特殊處理,因此實際使用效果雖然速度快,但是效果卻沒有原始的gru介面好,因此最後還是用了tf。contrib。rnn。GRUCell(貌似tf。contrib這個包會在後續的版本中逐漸捨棄,所以建議還是用另一個介面)。最後說明一下encoder的輸出為兩部分:

1、encoder_output,shape=[batch_size,time_steps,rnn_dims],為rnn的輸出

2、encoder_state,shape=[batch_size,dims],如果為雙向rnn,需要將兩個方向的state先拼接,然後再對所有時刻的state沿time_steps做average pooling。最後可能還需要對encoder_state做一層全連線層,使得它與decoder的RNN的dimension一致。

decoder的編寫流程

下面我先按照正常的流程使用大禮包編寫decoder的流程。然後按照程式碼的執行順序描述程式碼的執行機制。inference階段我使用beamSearch來舉例。

attention_mechanism

=

tf

contrib

seq2seq

BahdanauAttention

decoder_dim

encoder_output

memory_sequence_length

=

sum_len

normalize

=

True

decoder_cell

=

tf

contrib

seq2seq

AttentionWrapper

decoder_cell

attention_mechanism

attention_layer_size

=

atten_size

initial_state

=

decoder_cell

zero_state

dtype

=

tf

float32

batch_size

=

batch_size

initial_state

=

initial_state

clone

cell_state

=

encoder_state

#train

helper

=

tf

contrib

seq2seq

TrainingHelper

decoder_inputs_embedding

title_len

time_major

=

False

decoder

=

tf

contrib

seq2seq

BasicDecoder

decoder_cell

helper

initial_state

output_layer

=

projection_layer

outputs

self

train_dec_last_state

_

=

tf

contrib

seq2seq

dynamic_decode

decoder

output_time_major

=

False

#logits

decoder_output

=

outputs

rnn_output

上述是訓練過程,下面我們自底向上的順序來說明這個原始碼。

原始碼執行流程

由於下面內容比較繁瑣,因此先po出流程圖,不喜歡看文字的同志可以看這幅圖瞭解一下大概呼叫流程。該圖只是一個整體的流程,很多細節在後面文字描述。

seq2seq之tensorflow原始碼解析

tensorflow seq2seq大禮包模組簡易流程

1、呼叫棧的最底層就是我們的

tf.contrib.seq2seq.BahdanauAttention

,它提供了具體的attention計算方法,原始碼位於

tf.contrib.seq2seq.AttentionWrapper

中。

seq2seq之tensorflow原始碼解析

可以看到它繼承了一個基類,這個基類的主要作用是在_

init

_方法中對memory做了全連線層的計算(如果定義了dense layer的話),得到attention計算中的key和value。而子類則定義了query_layer和memory_layer,以及最後歸一化的方法(一般是softmax)

seq2seq之tensorflow原始碼解析

接下來就是該類的核心方法:

seq2seq之tensorflow原始碼解析

可以看到該方法是一個__call__()方法,熟悉python的同志肯定知道,一個類實現了這個方法,可以讓類變得可以向函式方法一樣被直接呼叫。其中,主要核心就是呼叫了_bahdanau_score方法來進行attention計算。這個方法具體就不貼出來了,就是按照論文中的計算方法。

備註:這裡有一個_probability_fn方法,需要輸入歷史的alignment,這個就是之前說的,有一些attention方法需要使用歷史的alignment資料,因此所有attention方法都預設會呼叫這個_probability_fn方法,只不過在BahdanauAttention中它是不對歷史的alignment做任何操作,直接對score做softmax。

2、往上面一層就是

AttentionWrapper

,這個類負責呼叫整個Attention機制的流程。首先看一下它的類定義:

seq2seq之tensorflow原始碼解析

可以看出來,它繼承了RNNCell,說明這也是一個具有RNN特性的實現類,且在RNNCell上,封裝了Attention的特性。

另外,繼承了RNNCell的所有子類,需要實現call()方法,使得在構建網路計算圖(build())時,能夠呼叫相關邏輯

seq2seq之tensorflow原始碼解析

分析它的__init__方法,可以看到它接收decoder中的RNNCell,之前定義的Attention機制,另外關注一下

alignment_history

,這個屬性就是能控制是否儲存歷史alignment資訊的開關。Attention_layer_size屬性則定義了Attention操作後是否需要連線一個全連線層到指定的dimension。

該類有兩個重要的方法,zero_state和call。前者用於初始狀態s的封裝和生成,後者主要是用於attention計算的主流程。

下面看一下zero_state方法, 其實該方法最重要的就是返回一個封裝過的可用於Attention計算的AttentionState。

seq2seq之tensorflow原始碼解析

下面看一下這個封裝類的結構:

seq2seq之tensorflow原始碼解析

可以看到,其實該類就是一個namedtuple的資料結構封裝。

cell_state儲存的是AttentionWrapper包裹的RnnCell在t-1時刻的狀態

attention儲存的是t-1時刻輸出的context

time儲存的當前時刻t

alignments儲存的是t-1時刻輸出的alignment

alignment_history儲存的是所有時刻的歷史alignment資訊

AttentionWrapperState中還有一個clone方法,在我們的模型圖中也有呼叫的地方:

initial_state = initial_state。clone(cell_state=encoder_state)

其實就是對我們初始化的AttentionWrapperState物件,將cell_state的屬性值對替換為從encoder輸出的state(經過average pooling)。

下面是AttentionWrapper類的核心方法:call,該方法定義了attention操作的主流程。

seq2seq之tensorflow原始碼解析

該方法的引數為inputs:即decoder中的當前時刻t的輸入,而state則是封裝過的AttentionWrapperState。下面對關鍵程式碼進行註釋:

seq2seq之tensorflow原始碼解析

上述程式碼的主要操作是將當前時刻input和前一時刻的context拼接後,輸入到decoder中的RNN層中做處理。最後輸出output,以及下一個時刻的中間狀態。

seq2seq之tensorflow原始碼解析

主要核心是方法_compute_attention方法,該方法是attention計算的入口。

seq2seq之tensorflow原始碼解析

得到attention的context和alignment資訊後,就是返回需要的資訊,其中比較重要的是返回的當前時刻的中間狀態為AttentionWrapperState,這個中間狀態會被下一時刻t+1的計算使用。

seq2seq之tensorflow原始碼解析

需要注意的是最後返回的時候,有一個flag,_output_attention,這個是控制當前是否返回attention的資訊還是rnn的output資訊,對於BahdanauAttention style來說,是返回cell_output。

3、再往上一層,找到

tf.contrib.seq2seq.BasicDecoder

,這個類主要作用是將上述的所有操作流程按照decoder的序列長度依次按順序執行。

seq2seq之tensorflow原始碼解析

看到它繼承了Decoder這個類。它的核心方法為step方法,下面就basicDecoder的step方法具體描述:

seq2seq之tensorflow原始碼解析

它的引數為time:當前時刻,inputs:decoder的輸入,state:前一時刻傳遞而來的狀態。下面是具體的程式碼流程:

seq2seq之tensorflow原始碼解析

首先這裡的_cell,是我們之前定義的AttentioWrapper,因為它是繼承了RnnCell,因此具有RnnCell的特性。這裡相當於是呼叫做了前兩個大模組的操作。返回了當前時刻的output,state。

那麼在訓練階段,模型是如何推動上面的計算步驟一步一步到最後的呢?下面要用到另一個有用的大禮包的類:

tf.contrib.seq2seq.TrainingHelper

seq2seq之tensorflow原始碼解析

上圖中的helper就是要用的幫助訓練的類(顧名思義)。這裡呼叫了兩個方法,第一個是sample,即根據當前時刻的output,獲取當前時刻詞分佈中機率最大的那個詞id。

seq2seq之tensorflow原始碼解析

第二個是next_inputs方法,主要是根據當前處理的時刻t,讀取下一個時刻的輸入,用於下一時刻的計算,並返回序列處理結束標誌。

seq2seq之tensorflow原始碼解析

step的最後一個步驟就是返回處理的結果,這裡它又封裝了一個特殊的資料型別:BasicDecoderOutput

seq2seq之tensorflow原始碼解析

BasicDecoderOutput定義如下:

seq2seq之tensorflow原始碼解析

其實根AttentioWrapperState相似,也是一個封裝了的namedtuple,主要儲存的是rnn_output,以及最後得到的詞的id。這樣封裝的好處是,rnn_output可以用於計算loss時直接使用,而sample_id則是在inference階段可以用來輸出結果。

4、最高一層是大禮包中最重要的一個部分,上述basicDecoder的step方法,如果沒有任何上層介面驅動,也是無法完成。因此

tf.contrib.seq2seq.dynamic_decode

就是用於完成這項工作的。

與其他介面不同,這個是一個可直接呼叫的方法,其方法定義如下:

seq2seq之tensorflow原始碼解析

其中decoder就是之前定義的basicDecoder。

impute_finished

屬性表示模型在梯度傳遞的時候會忽略最後標記為finished的位置。這個一般設為True,能夠保證梯度正確傳遞。而maximum_iterations為我們自定義的decoding最大長度,可以比設定的title_len大或者小,主要看調參。swap_memory表示在執行while迴圈是否啟用GPU-CPU記憶體交換。

下面只列出該放裡面的核心步驟:

seq2seq之tensorflow原始碼解析

首先這是tensorflow中的迴圈操作。它的迴圈條件condition為:

seq2seq之tensorflow原始碼解析

他會接受basicDecoder返回的finished標誌,並判斷當前是否已經處理結束。

然後是迴圈的body部分,也只放上核心部分:

seq2seq之tensorflow原始碼解析

即呼叫basicDecoder的step方法來執行decoding,這樣就與之前講的聯絡上了。

Inference階段

其實train階段和inference的不同點很簡單,在於inference階段沒有decoder的input,因此每個時刻的state計算都需要輸入前一個時刻的計算結果。這裡以BeamSearch舉例。

tiled_encoder_output

=

tf

contrib

seq2seq

tile_batch

self

encoder_output

multiplier

=

self

cfg

beam_width

))

tiled_encoder_final_state

=

tf

contrib

seq2seq

tile_batch

encoder_state

multiplier

=

self

cfg

beam_width

tiled_seq_len

=

tf

contrib

seq2seq

tile_batch

self

sum_len

multiplier

=

self

beam_width

attention_mechanism

=

tf

contrib

seq2seq

BahdanauAttention

self

cfg

lstm_units

tiled_encoder_output

memory_sequence_length

=

tiled_seq_len

normalize

=

True

decoder_cell

=

tf

contrib

seq2seq

AttentionWrapper

decoder_cell

attention_mechanism

attention_layer_size

=

self

cfg

lstm_units

*

2

initial_state

=

decoder_cell

zero_state

dtype

=

tf

float32

batch_size

=

self

batch_size

*

self

cfg

beam_width

initial_state

=

initial_state

clone

cell_state

=

tiled_encoder_final_state

decoder

=

tf

contrib

seq2seq

BeamSearchDecoder

cell

=

decoder_cell

embedding

=

self

embedding_init

start_tokens

=

tf

fill

([

self

batch_size

],

tf

constant

2

)),

end_token

=

tf

constant

3

),

initial_state

=

initial_state

beam_width

=

self

beam_width

output_layer

=

self

projection_layer

outputs

_

_

=

outputs

_

_

=

tf

contrib

seq2seq

dynamic_decode

decoder

output_time_major

=

false

maximum_iterations

=

self

decoder_max_iter

scope

=

decoder_scope

self

prediction

=

outputs

predicted_ids

整個inference流程與train類似,唯一不同的地方在於beamSearch演算法本身,需要將所有的輸入和中間狀態複製beam_size份,用於beam的搜尋。而主要的區分點在於使用的decoder不同,這裡我就著重講一下

tf.contrib.seq2seq.BeamSearchDecoder。

主要還是貼出其step方法中的核心過程:

beamSearch的方法中,會存在很多merge_beam,split_beam等改變tensor的shape操作,方便一些操作的計算,這裡就不仔細講了,只要記住merge一般和split應是成對出現

首先當然是呼叫AttentionWrapper,來計算輸出當前時刻的cell_output,以及下一個時刻的state。

seq2seq之tensorflow原始碼解析

然後就是另一個核心方法呼叫_beam_search_step:

seq2seq之tensorflow原始碼解析

這個方法主要是執行Beam搜尋的流程,核心的模組流程如下:

先計算當前時刻為止的所有候選序列計算機率值之和。

seq2seq之tensorflow原始碼解析

然後計算每個beam的分數:

seq2seq之tensorflow原始碼解析

然後是根據指定的beam_size,使用top_k運算,得到最合適的beam_size候選。

seq2seq之tensorflow原始碼解析

最後就是返回一些封裝的結果,就不具體列出了。

總結

其實還有很多原始碼並沒放到上面分析,鑑於篇幅問題,寫的太多,可能看得也越困難。總體來說,透過這次的比賽實踐,還是對seq2seq模型有一定的深入理解,無論是理論上還是工程實現上,tensorflow的大禮包的實現確實挺漂亮,一環扣一環,希望能吃透它的工程思維,融入到自己的實踐中。

標簽: decoder  Attention  STATE  encoder  seq2seq