Pytorch flatten 和 merge

 

 

 

1 Flatten

Flatten就是將2D的特徵圖壓扁爲1D的特徵向量,用於全連接層的輸入。

 

# Flatten繼承Module
class Flatten(nn.Module):
    # 構造函數,沒有什麼要做的
    def __init__(self):
        # 調用父類構造函數
        super(Flatten, self).__init__()

    # 實現forward函數
    def forward(self, input):
        # 保存batch維度,後面的維度全部壓平,例如輸入是28*28的特徵圖,壓平後爲784的向量
        return input.view(input.size(0), -1)

 

2 merge

    to be continue

 

3 torch.reshape()


   reshape()可以由torch.reshape(),也可由torch.Tensor.reshape()調用


4 view()


view()只可以由torch.Tensor.view()來調用
v

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章