實例化了一個對象後直接把這個對象當作一個函數調用 直接調用forward() 解決一個問題

原問題 主要我剛剛註冊那個 還沒法回覆 所以貼過來解一下

Python 語法問題,實例化了一個對象後直接把這個對象當作一個函數調用,這是什麼語法?算是函數式編程嗎?
~~~~

最近正在看 PyTorch,在 PyTorch 裏看到了一種寫法感覺很酷,但有點疑惑,遂來請教,感謝各位大佬答疑

在創建 PyTorch 的模型時會定義一個 class,這裏是我抄來的代碼:

class Encoder(nn.Module):

    def __init__(
            self,
            n_src_vocab, len_max_seq, d_word_vec,
            n_layers, n_head, d_k, d_v,
            d_model, d_inner, dropout=0.1):

        super().__init__()
        # 具體的函數邏輯我就刪掉了,應該沒有什麼關係


    def forward(self, src_seq, src_pos, return_attns=False):

        # 具體的函數邏輯我就刪掉了,應該沒有什麼關係

        return enc_output

然後在另外的函數裏初始化了這個類,給對象傳入參數,代碼如下:

encoder = Encoder(參數)
encoder(input_data)

按照我的理解這裏在調用 encoder 時就相當於調用了 forward 函數,但是不需要用encoder.forward()這樣的語法,請問這個叫什麼?我該在搜什麼關鍵詞能找到這個語法?

解答參考

> pyTorch之前向傳播函數自動調用forward

主要思路是:
調用forward方法的具體流程

以一個Module爲例:
1. 調用module的call方法
2. module的call裏面調用module的forward方法
3. forward裏面如果碰到Module的子類,回到第1步,如果碰到的是Function的子類,繼續往下
4. 調用Function的call方法
5. Function的call方法調用了Function的forward方法。
6. Function的forward返回值
7. module的forward返回值
8. 在module的call進行forward_hook操作,然後返回值 上述中“調用module的call方法”是指nn.Module 的__call__方法。定義__call__方法的類可以當作函數調用,具體參考Python的面向對象編程。也就是說,當把定義的網絡模型model當作函數調用的時候就自動調用定義的網絡模型的forward方法。

nn.Module 的__call__方法部分源碼如下所示:

 
def __call__(self, *input, **kwargs):
 
	result = self.forward(*input, **kwargs)
 
	for hook in self._forward_hooks.values():
 
	#將註冊的hook拿出來用
 
		hook_result = hook(self, input, result)
	 
		...
 
	return result

可以看到,當執行model(x)的時候,底層自動調用forward方法計算結果。具體示例如下:

class Function:
	def __init__(self):
	...
	
	def forward(self, inputs):
	...
	return outputs
	
	def backward(self, grad_outs):
	...
	return grad_ins
	
	def _backward(self, grad_outs):
	 
	hooked_grad_outs = grad_outs
	 
	for hook in hook_in_outputs:
	 
		hooked_grad_outs = hook(hooked_grad_outs)
		 
		grad_ins = self.backward(hooked_grad_outs)
		 
		hooked_grad_ins = grad_ins
	 
	for hook in hooks_in_module:
	 	hooked_grad_ins = hook(hooked_grad_ins)
	return hooked_grad_ins
	
model = LeNet()
y = model(x)

如上則調用網絡模型定義的forward方法。

我後面看得暈暈乎乎 大概知道內部函數在你調用model的時候調用了forward()
ta給的源碼也沒看懂。。。更多細節求教!

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章