Non-Local Neural Network程式碼閱讀
這是挺老的一篇文章了,以前看的時候,因為source code是caffe2的就稍微瞄了一眼,沒有仔細讀,這兩天整理mmdetection的時候發現了有pytorch的實現方法,就決定仔細讀一下原文。所謂的non-local主要argue的點是cnn裡的卷積,主要是一個區域性的weighted sum,並且和全連線層不同的是,fc裡的weight是學出來的,這邊是基於不同位置的relationship計算出來的。仔細看一下這個函式:
g是對於
做的一個變換,f是計算
和
之間的relation,實際上前面這玩意學習的是一個weight,後面是一個value,我們考慮一個最簡單的情況,這個f是一個歐式距離,然後g是一個恆等對映,大概就是其餘周圍畫素對於i的影響為,距離越遠,影響越小,然後本身的值越大,影響越大。C是一個norm項。
文章中考慮了一個最簡單的g就是一個線性embedding
是需要學習的一個weighted matrix,利用一個1x1卷積即可。
考慮了幾種不同的f:
Gaussian
Embedded Gaussian
Dot product
Concat
然後如圖所示
首先,x進來,輸入三個位置,學習這個
以及這個
,然後兩個乘到一塊去,進入一個softmax,就是學完的那個f函式,然後和學習好的g函式的結果,相乘得到最終的結果,經過一個1x1的卷積和原來的輸入相加得到最終的輸出。
下面來看一下mmdetection程式碼:
self
。
g
=
ConvModule
(
in_channels
,
inter_channels
,
kernel_size
=
1
,
activation
=
None
)
self
。
theta
=
ConvModule
(
in_channels
,
inter_channels
,
kernel_size
=
1
,
activation
=
None
)
self
。
phi
=
ConvModule
(
in_channels
,
inter_channels
,
kernel_size
=
1
,
activation
=
None
)
self
。
conv_out
=
ConvModule
(
inter_channels
,
in_channels
,
kernel_size
=
1
,
conv_cfg
,
norm_cfg
,
activation
=
None
)
g,theta,phi通常是nn。Conv2d,定義完這些之後,下面是定義不同的pairwise_weight,這邊主要給了兩種,一種是embedded_gaussian,一種是dot_product
def
embedded_gaussian
(
self
,
theta_x
,
phi_x
):
pairwise_weight
=
torch
。
matmul
(
theta_x
,
phi_x
)
if
self
。
use_scale
:
pairwise_weight
/=
theta_x
。
shape
[
-
1
]
**
0。5
pairwise_weight
=
pairwise_weight
。
softmax
(
dim
=-
1
)
return
pairwise_weight
def
dot_product
(
self
,
theta_x
phi_x
):
pairwise_weight
=
torch
。
matmul
(
theta_x
,
phi_x
)
pairwise_weight
/=
pairwise_weight
。
shape
[
-
1
]
return
pairwise_weight
然後看一下forward:
n
,
_
,
h
,
w
=
x
。
shape
#[N, HxW, C]
g_x
=
self
。
g
(
x
)
。
view
(
n
,
self
。
inter_channels
,
-
1
)
g_x
=
g_x
。
permute
(
0
,
2
,
1
)
# theta_x: [N, HxW, C]
theta_x
=
self
。
theta
(
x
)
。
view
(
n
,
self
。
inter_channels
,
-
1
)
theta_x
=
theta_x
。
permute
(
0
,
2
,
1
)
#phi_x [N,C,HxW]
phi_x
=
self
。
phi
(
x
)
。
view
(
n
,
self
。
inter_channels
,
-
1
)
pairwise_weight
=
pairwise_func
(
theta_x
,
phi_x
)
#y: [N, HxW, C]
y
=
torch
。
matmul
(
pairwise_weight
,
g_x
)
y
=
y
。
permute
(
0
,
2
,
1
)
。
reshape
(
n
,
self
。
inter_channels
,
h
,
w
)
output
=
x
+
self
。
conv_out
(
y
)
return
output
self-attention裡比較好的工作,但是由於theta,phi的計算,其實很吃視訊記憶體,我以前在我的網路里加過,很容易out of memory,尤其是低版本的pytorch