您當前的位置:首頁 > 攝影

【論文解析-2】Self Attention自注意力機制

作者:由 PurJoy 發表于 攝影時間:2019-07-08

本文原始論文是:A Structured Self-Attentive Sentence Embedding,都是一些自己的學習筆記,歡迎討論、指教。

Self-Attention機制:

1、核心在於兩個線性變換:

A=Softmax(W_2(tanh(W_1*H^T)))

2、線性變換 :

W1:[batchSize, 2*hiddenSize, d_a]

W2:[batchSize, d_a,r]

解釋:其中r是注意力分佈數量,論文建議不低於2個;d_a是指中間引數,可以為任意大小。

論文的核心就是W1和W2,理解這兩個即可。

import

torch

import

torch。nn

as

nn

import

torch。nn。functional

as

F

# 輸入編碼:使用雙向LSTM,比較簡單,不解釋了

class

EncoderRNN

nn

Module

):

def

__init__

self

embed_size

hiden_size

vocab_size

gpu

):

super

EncoderRNN

self

__init__

()

self

gpu

=

gpu

self

hidden_size

=

hiden_size

self

embed_size

=

embed_size

self

vocab_size

=

vocab_size

self

embed

=

nn

Embedding

self

vocab_size

self

embed_size

self

lstm

=

nn

LSTM

self

embed_size

self

hidden_size

batch_first

=

True

bidirectional

=

True

def

init_hidden

self

batch_size

):

h0

=

torch

zeros

2

batch_size

self

hidden_size

c0

=

torch

zeros

2

batch_size

self

hidden_size

if

self

gpu

h0

=

h0

cuda

()

c0

=

c0

cuda

()

return

h0

c0

def

forward

self

sentences

):

batch_size

=

sentences

size

()[

0

h0

c0

=

self

init_hidden

batch_size

embed

=

self

embed

sentences

output

hn

cn

=

self

lstm

embed

h0

c0

return

output

# 以下self-attention程式碼是自己實現的,僅供參考

# 無特別說明,程式碼中的引數、註釋與論文中保持一致

class

SelfAttention

nn

Module

):

def

__int__

self

hidden_size

num_class

):

super

SelfAttention

self

__init__

()

self

labels

=

num_class

self

hidden_size

=

hidden_size

self

attention

=

nn

Sequential

# 對應於論文權重矩陣:W_s1,其中10指: d_a

nn

Linear

2

*

self

hidden_size

10

),

nn

Tanh

True

),

# # 對應於論文權重矩陣:W_s2, 其中5指:r

nn

Linear

10

5

self

output

=

nn

Linear

self

hidden_size

*

2

self

labels

def

forward

self

encode_output

):

# 計算自注意力權重矩陣A:atte_weight=A=[batch_size, r, seq_len]

atte_weight

=

F

softmax

self

attention

encode_output

),

dim

=

2

permute

0

2

1

# 計算隱藏層的加權和M:torch。bmm([batch_size, r, seq_len],[batch_size, seq_len, 2H])

# =>[batch_size, r, 2H] => [batch_size, 2H]=output

output

=

torch

sum

torch

bmm

atte_weight

encode_output

),

dim

=

1

# 注意本行程式碼是最後的全連線層和Softmax層,嚴格意義是不屬於self-attention框架程式碼了

result

=

F

softmax

self

output

output

),

dim

=

1

return

result

Reference:

從三大頂會論文看百變Self-Attention

標簽: self  size  hidden  batch  Embed