Pytorch提取中間層的兩種方式:重寫模型和hook機制
深度學習領域,很多方法通常需要對網路的中間層進行處理,本文介紹兩種提取中間層feature的方法。
一、重寫模型:在需要提取feature的地方增加變數
talk is cheap, show me the code ,直接看下面程式碼即可理解
import
torch
class
Model
(
torch
。
nn
。
Module
):
def
__init__
(
self
):
super
()
。
__init__
()
self
。
feas
=
[]
self
。
layer1
=
。。。
self
。
layer2
=
。。。
def
forward
(
self
,
x
):
x
=
self
。
layer1
(
x
)
self
。
feas
。
append
(
x
)
# 在你想要獲取feature的地方儲存該feature即可
x
=
self
。
layer2
(
x
)
return
x
model
=
Model
()
pred
=
model
(
data
)
feas
=
model
。
feas
# 模型forward後,就可以這樣得到中間層的feature
上述方式看起來是比較簡單的,但是它有部分限制:
首先這種方式需要改變model定義時的程式碼結構; 現實中我們其實不太希望動這些程式碼;
只適合簡單網路的改寫。對於包裝較多的複雜網路(使用大量子模組,sequential)等寫法的改寫,是十分困難的。比如改寫torchvision中的resnet和mobilenet就比較難;
靈活性十分差:比如讓你提取10個網路中BN層前的特徵圖,那麼你需要改寫10個網路的定義程式碼,還要找出全部BN層位置;
文章出處: 格子不太方:Pytorch提取中間層的兩種方式:重寫模型和hook機制
二、hook機制
方式1雖然在某些情況下可以迅速實現提取中間層的功能,但對於複雜網路、多個網路情況時,你會發現並不是那麼輕鬆。這時候更建議學一下torch的hook機制。
初學者通常都不願碰hook機制,覺得這東西麻煩。但事實上,它只是起了一個看起來很“奇怪”的名字。說白了hook就是torch的某種機制:可以在不修改原始碼的情況下,掛上帶有新功能的程式碼分支。“hook 鉤子”這個名字也很形象。
此外hook也不麻煩,只是很多教程把它寫的難看懂了。
使用hook只需要做兩件事:
1。 定義一個對feature進行處理的函式,比如叫hook_fun
2。 註冊hook:告訴模型,我將在哪些層使用hook_fun來處理feature
talk is cheap, show me the code,直接看程式碼:
import
torch
from
torch。nn
import
Conv2d
,
Linear
,
AdaptiveAvgPool2d
class
Model
(
torch
。
nn
。
Module
):
def
__init__
(
self
):
super
()
。
__init__
()
self
。
conv1
=
Conv2d
(
in_channels
=
3
,
out_channels
=
32
,
kernel_size
=
3
)
self
。
layer1
=
Linear
(
in_features
=
32
,
out_features
=
64
)
self
。
avgpool
=
AdaptiveAvgPool2d
(
1
)
def
forward
(
self
,
x
):
x
=
self
。
conv1
(
x
)
x
=
self
。
avgpool
(
x
)
x
=
torch
。
flatten
(
x
,
1
)
x
=
self
。
layer1
(
x
)
return
x
model
=
Model
()
‘’‘
定義好模型後,假設我們提取 avgpool前的feature,即conv1後的feature:
’‘’
# 這裡定義了一個類,類有一個接收feature的函式hook_fun。定義類是為了方便提取多箇中間層。
class
HookTool
:
def
__init__
(
self
):
self
。
fea
=
None
def
hook_fun
(
self
,
module
,
fea_in
,
fea_out
):
‘’‘
注意用於處理feature的hook函式必須包含三個引數[module, fea_in, fea_out],引數的名字可以自己起,但其意義是
固定的,第一個引數表示torch裡的一個子module,比如Linear,Conv2d等,第二個引數是該module的輸入,其型別是
tuple;第三個引數是該module的輸出,其型別是tensor。注意輸入和輸出的型別是不一樣的,切記。
’‘’
self
。
fea
=
fea_out
def
get_feas_by_hook
(
model
):
“”“
提取Conv2d後的feature,我們需要遍歷模型的module,然後找到Conv2d,把hook函式註冊到這個module上;
這就相當於告訴模型,我要在Conv2d這一層,用hook_fun處理該層輸出的feature。
由於一個模型中可能有多個Conv2d,所以我們要用hook_feas儲存下來每一個Conv2d後的feature
”“”
fea_hooks
=
[]
for
n
,
m
in
model
。
named_modules
():
if
isinstance
(
m
,
torch
。
nn
。
Conv2d
):
cur_hook
=
HookTool
()
m
。
register_forward_hook
(
cur_hook
。
hook_fun
)
fea_hooks
。
append
(
cur_hook
)
return
fea_hooks
fea_hooks
=
get_feas_by_hook
(
model
)
# 呼叫函式,完成註冊即可
x
=
torch
。
randn
([
32
,
3
,
224
,
224
])
out
=
model
(
x
)
(
‘The number of hooks is:’
,
len
(
fea_hooks
)
(
‘The shape of the first Conv2D feature is:’
,
fea_hooks
[
0
]
。
fea
。
shape
)
hook的基本使用就像上面那樣,你會發現即便再有別的模型,也無需更改模型程式碼,可以直接提取某些層。當然了,也有一些注意事項:
如果某個model中定義了很多Conv2d,但你只想提取一部分。此時可能需要你寫更多的條件來篩選出想要的module;
hook不單單只是register_forward_hook,還有register_backward_hook等;
假設網路三個連續層分別是a——>b——>c,你想提取b的輸出,有兩種hook_fun寫法,一種是提取b層的fea_out,另一種是提取c層的fea_in。這是因為b的輸出是c的輸入。但要注意,fea_in和fea_out的型別不同。;
模型每次前傳後,fea_hooks中提取的feature都會隨之改變。
文章出處:格子不太方:Pytorch提取中間層的兩種方式:重寫模型和hook機制
收藏是自己,點贊是鼓勵~