您當前的位置:首頁 > 書法

Pytorch提取中間層的兩種方式:重寫模型和hook機制

作者:由 格子不太方 發表于 書法時間:2021-05-28

深度學習領域,很多方法通常需要對網路的中間層進行處理,本文介紹兩種提取中間層feature的方法。

一、重寫模型:在需要提取feature的地方增加變數

talk is cheap, show me the code ,直接看下面程式碼即可理解

import

torch

class

Model

torch

nn

Module

):

def

__init__

self

):

super

()

__init__

()

self

feas

=

[]

self

layer1

=

。。。

self

layer2

=

。。。

def

forward

self

x

):

x

=

self

layer1

x

self

feas

append

x

# 在你想要獲取feature的地方儲存該feature即可

x

=

self

layer2

x

return

x

model

=

Model

()

pred

=

model

data

feas

=

model

feas

# 模型forward後,就可以這樣得到中間層的feature

上述方式看起來是比較簡單的,但是它有部分限制:

首先這種方式需要改變model定義時的程式碼結構; 現實中我們其實不太希望動這些程式碼;

只適合簡單網路的改寫。對於包裝較多的複雜網路(使用大量子模組,sequential)等寫法的改寫,是十分困難的。比如改寫torchvision中的resnet和mobilenet就比較難;

靈活性十分差:比如讓你提取10個網路中BN層前的特徵圖,那麼你需要改寫10個網路的定義程式碼,還要找出全部BN層位置;

文章出處: 格子不太方:Pytorch提取中間層的兩種方式:重寫模型和hook機制

二、hook機制

方式1雖然在某些情況下可以迅速實現提取中間層的功能,但對於複雜網路、多個網路情況時,你會發現並不是那麼輕鬆。這時候更建議學一下torch的hook機制。

初學者通常都不願碰hook機制,覺得這東西麻煩。但事實上,它只是起了一個看起來很“奇怪”的名字。說白了hook就是torch的某種機制:可以在不修改原始碼的情況下,掛上帶有新功能的程式碼分支。“hook 鉤子”這個名字也很形象。

此外hook也不麻煩,只是很多教程把它寫的難看懂了。

使用hook只需要做兩件事:

1。 定義一個對feature進行處理的函式,比如叫hook_fun

2。 註冊hook:告訴模型,我將在哪些層使用hook_fun來處理feature

talk is cheap, show me the code,直接看程式碼:

import

torch

from

torch。nn

import

Conv2d

Linear

AdaptiveAvgPool2d

class

Model

torch

nn

Module

):

def

__init__

self

):

super

()

__init__

()

self

conv1

=

Conv2d

in_channels

=

3

out_channels

=

32

kernel_size

=

3

self

layer1

=

Linear

in_features

=

32

out_features

=

64

self

avgpool

=

AdaptiveAvgPool2d

1

def

forward

self

x

):

x

=

self

conv1

x

x

=

self

avgpool

x

x

=

torch

flatten

x

1

x

=

self

layer1

x

return

x

model

=

Model

()

‘’‘

定義好模型後,假設我們提取 avgpool前的feature,即conv1後的feature:

’‘’

# 這裡定義了一個類,類有一個接收feature的函式hook_fun。定義類是為了方便提取多箇中間層。

class

HookTool

def

__init__

self

):

self

fea

=

None

def

hook_fun

self

module

fea_in

fea_out

):

‘’‘

注意用於處理feature的hook函式必須包含三個引數[module, fea_in, fea_out],引數的名字可以自己起,但其意義是

固定的,第一個引數表示torch裡的一個子module,比如Linear,Conv2d等,第二個引數是該module的輸入,其型別是

tuple;第三個引數是該module的輸出,其型別是tensor。注意輸入和輸出的型別是不一樣的,切記。

’‘’

self

fea

=

fea_out

def

get_feas_by_hook

model

):

“”“

提取Conv2d後的feature,我們需要遍歷模型的module,然後找到Conv2d,把hook函式註冊到這個module上;

這就相當於告訴模型,我要在Conv2d這一層,用hook_fun處理該層輸出的feature。

由於一個模型中可能有多個Conv2d,所以我們要用hook_feas儲存下來每一個Conv2d後的feature

”“”

fea_hooks

=

[]

for

n

m

in

model

named_modules

():

if

isinstance

m

torch

nn

Conv2d

):

cur_hook

=

HookTool

()

m

register_forward_hook

cur_hook

hook_fun

fea_hooks

append

cur_hook

return

fea_hooks

fea_hooks

=

get_feas_by_hook

model

# 呼叫函式,完成註冊即可

x

=

torch

randn

([

32

3

224

224

])

out

=

model

x

print

‘The number of hooks is:’

len

fea_hooks

print

‘The shape of the first Conv2D feature is:’

fea_hooks

0

fea

shape

hook的基本使用就像上面那樣,你會發現即便再有別的模型,也無需更改模型程式碼,可以直接提取某些層。當然了,也有一些注意事項:

如果某個model中定義了很多Conv2d,但你只想提取一部分。此時可能需要你寫更多的條件來篩選出想要的module;

hook不單單只是register_forward_hook,還有register_backward_hook等;

假設網路三個連續層分別是a——>b——>c,你想提取b的輸出,有兩種hook_fun寫法,一種是提取b層的fea_out,另一種是提取c層的fea_in。這是因為b的輸出是c的輸入。但要注意,fea_in和fea_out的型別不同。;

模型每次前傳後,fea_hooks中提取的feature都會隨之改變。

文章出處:格子不太方:Pytorch提取中間層的兩種方式:重寫模型和hook機制

收藏是自己,點贊是鼓勵~

Pytorch提取中間層的兩種方式:重寫模型和hook機制

標簽: hook  self  fea  Feature  __