您當前的位置:首頁 > 繪畫

Non-Local Neural Network程式碼閱讀

作者:由 ppgod 發表于 繪畫時間:2020-02-07

這是挺老的一篇文章了,以前看的時候,因為source code是caffe2的就稍微瞄了一眼,沒有仔細讀,這兩天整理mmdetection的時候發現了有pytorch的實現方法,就決定仔細讀一下原文。所謂的non-local主要argue的點是cnn裡的卷積,主要是一個區域性的weighted sum,並且和全連線層不同的是,fc裡的weight是學出來的,這邊是基於不同位置的relationship計算出來的。仔細看一下這個函式:

y_i = \frac{1}{C(x)}\sum_{\forall i}f(x_i,x_j)g(x_j)\\

g是對於

x_j

做的一個變換,f是計算

x_i

x_j

之間的relation,實際上前面這玩意學習的是一個weight,後面是一個value,我們考慮一個最簡單的情況,這個f是一個歐式距離,然後g是一個恆等對映,大概就是其餘周圍畫素對於i的影響為,距離越遠,影響越小,然後本身的值越大,影響越大。C是一個norm項。

文章中考慮了一個最簡單的g就是一個線性embedding

g(x_j) = W_g x_j\\

W_g

是需要學習的一個weighted matrix,利用一個1x1卷積即可。

考慮了幾種不同的f:

Gaussian

f(x_i, x_j) = e^{x_i^Tx_j}\\

Embedded Gaussian

f(x_i, x_j) = e^{\theta{(x_i^T)}\theta(x_j)}\\  \theta(x_i) = W_{\theta}x_i\\

Dot product

f(x_i, x_j) = \theta{(x_i^T)}\theta(x_j)\\

Concat

f(x_i, x_j) = ReLU(w_f^T[\theta{(x_i^T)},\theta(x_j)])\\

然後如圖所示

Non-Local Neural Network程式碼閱讀

首先,x進來,輸入三個位置,學習這個

\theta

以及這個

\phi

,然後兩個乘到一塊去,進入一個softmax,就是學完的那個f函式,然後和學習好的g函式的結果,相乘得到最終的結果,經過一個1x1的卷積和原來的輸入相加得到最終的輸出。

下面來看一下mmdetection程式碼:

self

g

=

ConvModule

in_channels

inter_channels

kernel_size

=

1

activation

=

None

self

theta

=

ConvModule

in_channels

inter_channels

kernel_size

=

1

activation

=

None

self

phi

=

ConvModule

in_channels

inter_channels

kernel_size

=

1

activation

=

None

self

conv_out

=

ConvModule

inter_channels

in_channels

kernel_size

=

1

conv_cfg

norm_cfg

activation

=

None

g,theta,phi通常是nn。Conv2d,定義完這些之後,下面是定義不同的pairwise_weight,這邊主要給了兩種,一種是embedded_gaussian,一種是dot_product

def

embedded_gaussian

self

theta_x

phi_x

):

pairwise_weight

=

torch

matmul

theta_x

phi_x

if

self

use_scale

pairwise_weight

/=

theta_x

shape

-

1

**

0。5

pairwise_weight

=

pairwise_weight

softmax

dim

=-

1

return

pairwise_weight

def

dot_product

self

theta_x

phi_x

):

pairwise_weight

=

torch

matmul

theta_x

phi_x

pairwise_weight

/=

pairwise_weight

shape

-

1

return

pairwise_weight

然後看一下forward:

n

_

h

w

=

x

shape

#[N, HxW, C]

g_x

=

self

g

x

view

n

self

inter_channels

-

1

g_x

=

g_x

permute

0

2

1

# theta_x: [N, HxW, C]

theta_x

=

self

theta

x

view

n

self

inter_channels

-

1

theta_x

=

theta_x

permute

0

2

1

#phi_x [N,C,HxW]

phi_x

=

self

phi

x

view

n

self

inter_channels

-

1

pairwise_weight

=

pairwise_func

theta_x

phi_x

#y: [N, HxW, C]

y

=

torch

matmul

pairwise_weight

g_x

y

=

y

permute

0

2

1

reshape

n

self

inter_channels

h

w

output

=

x

+

self

conv_out

y

return

output

self-attention裡比較好的工作,但是由於theta,phi的計算,其實很吃視訊記憶體,我以前在我的網路里加過,很容易out of memory,尤其是低版本的pytorch

標簽: self  Theta  weight  channels  phi