Pytorch_hook機制的理解及利用register_forward_hook(hook)中間層輸出

參考文獻:

【1】梯度計算問題含公式:參考鏈接1.

【2】pytorch改動和.data和.detch()問題:https://blog.csdn.net/dss_dssssd/article/details/83818181

【3】hook技術介紹:https://www.cnblogs.com/hellcat/p/8512090.html

【4】hook應用->中間層的輸出:https://blog.csdn.net/qq_40303258/article/details/106884317

【5】hook函數介紹:參考鏈接2

需要了解的基本點:

(1)backward()是Pytorch中用來求梯度的方法。

(2)Variable是對tensor的封裝,包含了三部分:

  •  .data:tensor本身
  • .grad:對應tensor的梯度
  • .grad_fn:該Variable是通過什麼方式獲得的

(3)pytorch 0.4版本後將tensor和Variable合併在了一起。

x = Variable(torch.randn(2, 1), requires_grad=True) # 利用Variable封裝tensor
##等效 x = torch.rand(2,1,requires_grad=True)
x = torch.rand(2,1) # 不等效

(4)hook種類分爲兩種

Tensor級別  register_hook(hook) ->爲Tensor註冊一個backward hook,用來獲取變量的梯度;hook必須遵循如下的格式:hook(grad) -> Tensor or None

nn.Module對象 register_forward_hook(hook)register_backward_hook(hook)兩種方法,分別對應前向傳播和反向傳播的hook函數。

(5)hook作用:獲取某些變量的中間結果的。Pytorch會自動捨棄圖計算的中間結果,所以想要獲取這些數值就需要使用hook函數。hook函數在使用後應及時刪除,以避免每次都運行鉤子增加運行負載。

舉例說明 Tensor級別  :

例子1(借鑑參考文獻1和3)

import torch 
from torch.autograd import Variable 


def print_grad(grad):
    print('grad is \n',grad)
 
x = Variable(torch.randn(2, 1), requires_grad=True)
## x = torch.rand(2,1,requires_grad=True) #  等效
print('x value is \n',x)
y = x+3
print('y value is \n',y)
z = torch.mean(torch.pow(y, 1/2))
lr = 1e-3

y.register_hook(print_grad) 
z.backward() # 梯度求解
x.data -= lr*x.grad.data
print('new x is\n',x)
output:
x value is 
 tensor([[ 2.5474],
        [-1.1597]], requires_grad=True)
y value is 
 tensor([[5.5474],
        [1.8403]], grad_fn=<AddBackward0>)
grad is 
 tensor([[0.1061],
        [0.1843]])
new x is
 tensor([[ 2.5473],
        [-1.1599]], requires_grad=True)

分析:

對於z來說,求梯度最終求解的是對x的梯度(導數,偏導),因此y是一箇中間變量。因此可以用register_hook()來獲取其作爲中間值的導數,否則z對於y的偏導是獲取不到的。x的偏導和y的偏導實際上是相同值,推導如下圖。

不用register_hook()的例子。

#y.register_hook(print_grad) 

z.backward() # 梯度求解
print('y\'s grad is ',y.grad)
print('x\'s grad is \n',x.grad)
x.data -= lr*x.grad.data
print('new x is\n',x)

output:
y's grad is  None
x's grad is 
 tensor([[0.1544],
        [0.1099]])
new x is
 tensor([[-0.3801],
        [ 2.1755]], requires_grad=True)

可以看出,z對於x的grad是存在的,但是z對於中間變量y的grad是不存在的。也就驗證了Pytorch會自動捨棄圖計算的中間結果這句話。

舉例說明 Module級別 

【1】register_forward_hook(hook)

在網絡執行forward()之後,執行hook函數,需要具有如下的形式:

hook(module, input, output) -> None or modified output

hook可以修改input和output,但是不會影響forward的結果。最常用的場景是需要提取模型的某一層(不是最後一層)的輸出特徵,但又不希望修改其原有的模型定義文件,這時就可以利用forward_hook函數。

import torch
import torch.nn as nn
import torch.nn.functional as F

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        out = F.relu(self.conv1(x))     #1 
        out = F.max_pool2d(out, 2)      #2
        out = F.relu(self.conv2(out))   #3
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out

features = []
def hook(module, input, output): 
    # module: model.conv2 
    # input :in forward function  [#2]
    # output:is  [#3 self.conv2(out)]
    features.append(output.clone().detach())
    # output is saved  in a list 


net = LeNet() ## 模型實例化 
x = torch.randn(2, 3, 32, 32) ## input 
handle = net.conv2.register_forward_hook(hook) ## 獲取整個Lenet模型 conv2的中間結果
y = net(x)  ## 獲取的是 關於 input x 的 conv2 結果 

print(features[0].size()) # 即 [#3 self.conv2(out)]
handle.remove() ## hook刪除 

以上文字和代碼示例,均來自參考文獻5中的示例,由於示例對於register_forward_hook(hook)沒有過多註解,因此我加了一些註解。

個人理解:register_forward_hook(hook) 作用就是(假設想要conv2層),那麼就是根據 model(該層),該層input,該層output,可以將 output獲取。

register_forward_hook(hook)  最大的作用也就是當訓練好某個model,想要展示某一層對最終目標的影響效果。

例子:【借鑑參考文獻4】

class LayerActivations:
    features = None
    def __init__(self, model, layer_num):
        self.hook = model[layer_num].register_forward_hook(self.hook_fn)
        # 獲取model.features中某一層的output
    
    def hook_fn(self, module, input, output):
        self.features = output.cpu()
 
    def remove(self): ## 刪除hook
        self.hook.remove()


''' 類似於以下格式
class CNNnet1(torch.nn.Module): ## wangluo jiegou  
    def __init__(self):
        super(CNNnet1,self).__init__()
        self.features = nn.Sequential(
            nn.Conv1d(),  
            torch.nn.ReLU(),
            torch.nn.Conv1d(),
            torch.nn.ReLU(),
            torch.nn.Conv1d(),
            torch.nn.BatchNorm1d(),
            torch.nn.MaxPool1d()
            torch.nn.ReLU()
        ) 
'''     
#### model= CNN()
#### train(model,train_loader,learning_rate,batch_size,epochs)
#### 
model.eval() 
test_dataset = DataSet(test_features, test_labels) 
test_loader = DataLoader(test_dataset,batch_size=1,shuffle=True)
        
img = next(iter(test_loader))[0] # gain a input 

for i in range(len(model.features)): # model.features is a nn.Sequential()
    conv_out = LayerActivations(model.features,i) # 實例化,獲取每一層
    ouput = model(img)
    act = conv_out.features # gain the ith output
    conv_out.remove # delete the hook

    plt.imshow(act[0].detach().numpy(),cmap='hot') # output is showed using 熱力圖 
    plt.colorbar(shrink=0.4) # 句柄大小
    plt.show() 

大概畫完了就是這個樣子[每一層都有一個圖,不做過多展示]:

其中 plt.imshow()是熱力圖畫法,詳情點擊鏈接。可以把參考文獻4中是將所有的中間層畫到了一張畫布上,因爲卷積層尺寸不同,我就沒放在一起。

[2]register_backward_hook(hook)

因爲暫時沒有用到,不做詳細講解,具體可參考參考文獻5。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章