您當前的位置:首頁 > 繪畫

基於pytorch實現(多頭)自注意力

作者:由 休提前事 發表于 繪畫時間:2022-03-23

我們介紹了注意力機制的原理以及實現的步驟

由於是基於點積實現的(還有其他實現方式,比如加的形式,大家可以自行了解),Q與K的維度一致。因此基於pytorch的實現並沒有什麼難度,直接上程式碼:

自注意力的程式碼實現

from

math

import

sqrt

import

torch

import

torch。nn

as

nn

class

SelfAttention

nn

Module

):

dim_in

int

dim_k

int

dim_v

int

def

__init__

self

dim_in

dim_k

dim_v

):

super

SelfAttention

self

__init__

()

self

dim_in

=

dim_in

self

dim_k

=

dim_k

self

dim_v

=

dim_v

self

linear_q

=

nn

Linear

dim_in

dim_k

bias

=

False

# Q、K的維度一致

self

linear_k

=

nn

Linear

dim_in

dim_k

bias

=

False

self

linear_v

=

nn

Linear

dim_in

dim_v

bias

=

False

self

_norm_fact

=

1

/

sqrt

dim_k

# 為了規範Q@K的乘積的方差範圍

def

forward

self

x

):

# x: (batch, n, dim_in) ——> (批次大小, 時序長度, 特徵維度)

batch

n

dim_in

=

x

shape

assert

dim_in

==

self

dim_in

q

=

self

linear_q

x

# batch, n, dim_k

k

=

self

linear_k

x

# batch, n, dim_k

v

=

self

linear_v

x

# batch, n, dim_v

dist

=

torch

bmm

q

k

transpose

1

2

))

*

self

_norm_fact

# batch, n, n

dist

=

torch

softmax

dist

dim

=-

1

# batch, n, n

att

=

torch

bmm

dist

v

return

att

多頭自注意力的程式碼實現

我們介紹了多頭注意力需要計算多個頭,並且最後拼接起來乘以一個權重得到最後的結果,實際上,為了加快計算速度,我們可以使用一個大矩陣將Q、K、V並行地計算出來,然後透過改變形狀、和交換維度把多個頭的Q、K、V放到同一個batch中進行和單頭注意力相同的計算,最後再把多個頭的注意力向量拼接起來得到最後的值。

from

math

import

sqrt

import

torch

import

torch。nn

as

nn

class

MultiHeadSelfAttention

nn

Module

):

dim_in

int

# input dimension

dim_k

int

# key and query dimension

dim_v

int

# value dimension

num_heads

int

# number of heads, for each head, dim_* = dim_* // num_heads

def

__init__

self

dim_in

dim_k

dim_v

num_heads

=

8

):

super

MultiHeadSelfAttention

self

__init__

()

assert

dim_k

%

num_heads

==

0

and

dim_v

%

num_heads

==

0

“dim_k and dim_v must be multiple of num_heads”

self

dim_in

=

dim_in

self

dim_k

=

dim_k

self

dim_v

=

dim_v

self

num_heads

=

num_heads

self

linear_q

=

nn

Linear

dim_in

dim_k

bias

=

False

self

linear_k

=

nn

Linear

dim_in

dim_k

bias

=

False

self

linear_v

=

nn

Linear

dim_in

dim_v

bias

=

False

self

_norm_fact

=

1

/

sqrt

dim_k

//

num_heads

def

forward

self

x

):

# x: tensor of shape (batch, n, dim_in)

batch

n

dim_in

=

x

shape

assert

dim_in

==

self

dim_in

nh

=

self

num_heads

dk

=

self

dim_k

//

nh

# dim_k of each head

dv

=

self

dim_v

//

nh

# dim_v of each head

q

=

self

linear_q

x

reshape

batch

n

nh

dk

transpose

1

2

# (batch, nh, n, dk)

k

=

self

linear_k

x

reshape

batch

n

nh

dk

transpose

1

2

# (batch, nh, n, dk)

v

=

self

linear_v

x

reshape

batch

n

nh

dv

transpose

1

2

# (batch, nh, n, dv)

dist

=

torch

matmul

q

k

transpose

2

3

))

*

self

_norm_fact

# batch, nh, n, n

dist

=

torch

softmax

dist

dim

=-

1

# batch, nh, n, n

att

=

torch

matmul

dist

v

# batch, nh, n, dv

att

=

att

transpose

1

2

reshape

batch

n

self

dim_v

# batch, n, dim_v

return

att

標簽: dim  self  batch  Linear  nh