Module存儲了模塊類的函數
pytorch中模塊非常容易使用,只需要派生自Module,重載兩個函數就行了,那麼Module都做了什麼
class Module(object):
def __init__(self):
self._backend = thnn_backend
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._modules = OrderedDict()
self.training = True
構造函數生成一堆有序字典,用來存儲各種參數,暫且不表,先說第一個結構self._backend是一個全局THNNFunctionBackend()類,存儲一個一系列函數指針, 這個類派生類是FunctionBackend
class FunctionBackend(object):
def __init__(self):
self.function_classes = {}
def register_function(self, name, function_class):
self.function_classes[name] = function_class
其中這個類的function_classes字典的鍵是名稱,值是函數,使用register_function添加註冊,註冊完畢後約有118個函數,本文的pytorch版本是0.4.1
RNN <function RNN at 0x7f4330534378>
RNNTanhCell <function RNNTanhCell at 0x7f4330530d90>
RNNReLUCell <function RNNReLUCell at 0x7f43305309d8>
LSTMCell <function LSTMCell at 0x7f4330530e18>
GRUCell <function GRUCell at 0x7f4330530ea0>
Dropout <class 'torch.nn._functions.dropout.Dropout'>
Dropout2d <class 'torch.nn._functions.dropout.FeatureDropout'>
Dropout3d <class 'torch.nn._functions.dropout.FeatureDropout'>
MarginCriterion <class 'torch.nn._functions.thnn.auto.MarginCriterion'>
MarginCriterionBackward <class 'torch.nn._functions.thnn.auto.MarginCriterionBackward'>
GatedLinear <class 'torch.nn._functions.thnn.auto.GatedLinear'>
GatedLinearBackward <class 'torch.nn._functions.thnn.auto.GatedLinearBackward'>
SpatialFullConvolutionMap <class 'torch.nn._functions.thnn.auto.SpatialFullConvolutionMap'>
SpatialFullConvolutionMapBackward <class 'torch.nn._functions.thnn.auto.SpatialFullConvolutionMapBackward'>
VolumetricFractionalMaxPooling <class 'torch.nn._functions.thnn.auto.VolumetricFractionalMaxPooling'>
VolumetricFractionalMaxPoolingBackward <class 'torch.nn._functions.thnn.auto.VolumetricFractionalMaxPoolingBackward'>
VolumetricFullDilatedConvolution <class 'torch.nn._functions.thnn.auto.VolumetricFullDilatedConvolution'>
VolumetricFullDilatedConvolutionBackward <class 'torch.nn._functions.thnn.auto.VolumetricFullDilatedConvolutionBackward'>
Col2Im <class 'torch.nn._functions.thnn.auto.Col2Im'>
Col2ImBackward <class 'torch.nn._functions.thnn.auto.Col2ImBackward'>
DilatedConv2d <class 'torch.nn._functions.thnn.auto.DilatedConv2d'>
DilatedConv2dBackward <class 'torch.nn._functions.thnn.auto.DilatedConv2dBackward'>
SpatialConvolutionLocal <class 'torch.nn._functions.thnn.auto.SpatialConvolutionLocal'>
SpatialConvolutionLocalBackward <class 'torch.nn._functions.thnn.auto.SpatialConvolutionLocalBackward'>
FeatureLPPooling <class 'torch.nn._functions.thnn.auto.FeatureLPPooling'>
FeatureLPPoolingBackward <class 'torch.nn._functions.thnn.auto.FeatureLPPoolingBackward'>
VolumetricGridSamplerBilinear <class 'torch.nn._functions.thnn.auto.VolumetricGridSamplerBilinear'>
VolumetricGridSamplerBilinearBackward <class 'torch.nn._functions.thnn.auto.VolumetricGridSamplerBilinearBackward'>
TemporalUpSamplingNearest <class 'torch.nn._functions.thnn.auto.TemporalUpSamplingNearest'>
TemporalUpSamplingNearestBackward <class 'torch.nn._functions.thnn.auto.TemporalUpSamplingNearestBackward'>
SpatialUpSamplingNearest <class 'torch.nn._functions.thnn.auto.SpatialUpSamplingNearest'>
SpatialUpSamplingNearestBackward <class 'torch.nn._functions.thnn.auto.SpatialUpSamplingNearestBackward'>
ReflectionPad1d <class 'torch.nn._functions.thnn.auto.ReflectionPad1d'>
ReflectionPad1dBackward <class 'torch.nn._functions.thnn.auto.ReflectionPad1dBackward'>
SpatialConvolutionMap <class 'torch.nn._functions.thnn.auto.SpatialConvolutionMap'>
SpatialConvolutionMapBackward <class 'torch.nn._functions.thnn.auto.SpatialConvolutionMapBackward'>
NLLLoss <class 'torch.nn._functions.thnn.auto.NLLLoss'>
NLLLossBackward <class 'torch.nn._functions.thnn.auto.NLLLossBackward'>
Softplus <class 'torch.nn._functions.thnn.auto.Softplus'>
SoftplusBackward <class 'torch.nn._functions.thnn.auto.SoftplusBackward'>
LogSigmoid <class 'torch.nn._functions.thnn.auto.LogSigmoid'>
LogSigmoidBackward <class 'torch.nn._functions.thnn.auto.LogSigmoidBackward'>
SpatialUpSamplingBilinear <class 'torch.nn._functions.thnn.auto.SpatialUpSamplingBilinear'>
SpatialUpSamplingBilinearBackward <class 'torch.nn._functions.thnn.auto.SpatialUpSamplingBilinearBackward'>
ReplicationPad3d <class 'torch.nn._functions.thnn.auto.ReplicationPad3d'>
ReplicationPad3dBackward <class 'torch.nn._functions.thnn.auto.ReplicationPad3dBackward'>
MultiMarginLoss <class 'torch.nn._functions.thnn.auto.MultiMarginLoss'>
MultiMarginLossBackward <class 'torch.nn._functions.thnn.auto.MultiMarginLossBackward'>
ReplicationPad1d <class 'torch.nn._functions.thnn.auto.ReplicationPad1d'>
ReplicationPad1dBackward <class 'torch.nn._functions.thnn.auto.ReplicationPad1dBackward'>
MultiLabelMarginLoss <class 'torch.nn._functions.thnn.auto.MultiLabelMarginLoss'>
MultiLabelMarginLossBackward <class 'torch.nn._functions.thnn.auto.MultiLabelMarginLossBackward'>
SpatialFullDilatedConvolution <class 'torch.nn._functions.thnn.auto.SpatialFullDilatedConvolution'>
SpatialFullDilatedConvolutionBackward <class 'torch.nn._functions.thnn.auto.SpatialFullDilatedConvolutionBackward'>
SoftMarginLoss <class 'torch.nn._functions.thnn.auto.SoftMarginLoss'>
SoftMarginLossBackward <class 'torch.nn._functions.thnn.auto.SoftMarginLossBackward'>
NLLLoss2d <class 'torch.nn._functions.thnn.auto.NLLLoss2d'>
NLLLoss2dBackward <class 'torch.nn._functions.thnn.auto.NLLLoss2dBackward'>
MSELoss <class 'torch.nn._functions.thnn.auto.MSELoss'>
MSELossBackward <class 'torch.nn._functions.thnn.auto.MSELossBackward'>
Sigmoid <class 'torch.nn._functions.thnn.auto.Sigmoid'>
SigmoidBackward <class 'torch.nn._functions.thnn.auto.SigmoidBackward'>
VolumetricUpSamplingTrilinear <class 'torch.nn._functions.thnn.auto.VolumetricUpSamplingTrilinear'>
VolumetricUpSamplingTrilinearBackward <class 'torch.nn._functions.thnn.auto.VolumetricUpSamplingTrilinearBackward'>
BCELoss <class 'torch.nn._functions.thnn.auto.BCELoss'>
BCELossBackward <class 'torch.nn._functions.thnn.auto.BCELossBackward'>
Square <class 'torch.nn._functions.thnn.auto.Square'>
SquareBackward <class 'torch.nn._functions.thnn.auto.SquareBackward'>
ReplicationPad2d <class 'torch.nn._functions.thnn.auto.ReplicationPad2d'>
ReplicationPad2dBackward <class 'torch.nn._functions.thnn.auto.ReplicationPad2dBackward'>
L1Loss <class 'torch.nn._functions.thnn.auto.L1Loss'>
L1LossBackward <class 'torch.nn._functions.thnn.auto.L1LossBackward'>
SpatialGridSamplerBilinear <class 'torch.nn._functions.thnn.auto.SpatialGridSamplerBilinear'>
SpatialGridSamplerBilinearBackward <class 'torch.nn._functions.thnn.auto.SpatialGridSamplerBilinearBackward'>
Sqrt <class 'torch.nn._functions.thnn.auto.Sqrt'>
SqrtBackward <class 'torch.nn._functions.thnn.auto.SqrtBackward'>
TemporalRowConvolution <class 'torch.nn._functions.thnn.auto.TemporalRowConvolution'>
TemporalRowConvolutionBackward <class 'torch.nn._functions.thnn.auto.TemporalRowConvolutionBackward'>
SpatialFractionalMaxPooling <class 'torch.nn._functions.thnn.auto.SpatialFractionalMaxPooling'>
SpatialFractionalMaxPoolingBackward <class 'torch.nn._functions.thnn.auto.SpatialFractionalMaxPoolingBackward'>
TemporalUpSamplingLinear <class 'torch.nn._functions.thnn.auto.TemporalUpSamplingLinear'>
TemporalUpSamplingLinearBackward <class 'torch.nn._functions.thnn.auto.TemporalUpSamplingLinearBackward'>
VolumetricDilatedMaxPooling <class 'torch.nn._functions.thnn.auto.VolumetricDilatedMaxPooling'>
VolumetricDilatedMaxPoolingBackward <class 'torch.nn._functions.thnn.auto.VolumetricDilatedMaxPoolingBackward'>
Threshold <class 'torch.nn._functions.thnn.auto.Threshold'>
ThresholdBackward <class 'torch.nn._functions.thnn.auto.ThresholdBackward'>
Abs <class 'torch.nn._functions.thnn.auto.Abs'>
AbsBackward <class 'torch.nn._functions.thnn.auto.AbsBackward'>
Softshrink <class 'torch.nn._functions.thnn.auto.Softshrink'>
SoftshrinkBackward <class 'torch.nn._functions.thnn.auto.SoftshrinkBackward'>
LeakyReLU <class 'torch.nn._functions.thnn.auto.LeakyReLU'>
LeakyReLUBackward <class 'torch.nn._functions.thnn.auto.LeakyReLUBackward'>
VolumetricUpSamplingNearest <class 'torch.nn._functions.thnn.auto.VolumetricUpSamplingNearest'>
VolumetricUpSamplingNearestBackward <class 'torch.nn._functions.thnn.auto.VolumetricUpSamplingNearestBackward'>
VolumetricDilatedConvolution <class 'torch.nn._functions.thnn.auto.VolumetricDilatedConvolution'>
VolumetricDilatedConvolutionBackward <class 'torch.nn._functions.thnn.auto.VolumetricDilatedConvolutionBackward'>
Tanh <class 'torch.nn._functions.thnn.auto.Tanh'>
TanhBackward <class 'torch.nn._functions.thnn.auto.TanhBackward'>
TemporalSubSampling <class 'torch.nn._functions.thnn.auto.TemporalSubSampling'>
TemporalSubSamplingBackward <class 'torch.nn._functions.thnn.auto.TemporalSubSamplingBackward'>
ELU <class 'torch.nn._functions.thnn.auto.ELU'>
ELUBackward <class 'torch.nn._functions.thnn.auto.ELUBackward'>
Hardtanh <class 'torch.nn._functions.thnn.auto.Hardtanh'>
HardtanhBackward <class 'torch.nn._functions.thnn.auto.HardtanhBackward'>
L1Cost <class 'torch.nn._functions.thnn.auto.L1Cost'>
L1CostBackward <class 'torch.nn._functions.thnn.auto.L1CostBackward'>
SpatialSubSampling <class 'torch.nn._functions.thnn.auto.SpatialSubSampling'>
SpatialSubSamplingBackward <class 'torch.nn._functions.thnn.auto.SpatialSubSamplingBackward'>
Im2Col <class 'torch.nn._functions.thnn.auto.Im2Col'>
Im2ColBackward <class 'torch.nn._functions.thnn.auto.Im2ColBackward'>
KLDivLoss <class 'torch.nn._functions.thnn.auto.KLDivLoss'>
KLDivLossBackward <class 'torch.nn._functions.thnn.auto.KLDivLossBackward'>
SmoothL1Loss <class 'torch.nn._functions.thnn.auto.SmoothL1Loss'>
SmoothL1LossBackward <class 'torch.nn._functions.thnn.auto.SmoothL1LossBackward'>
ReflectionPad2d <class 'torch.nn._functions.thnn.auto.ReflectionPad2d'>
ReflectionPad2dBackward <class 'torch.nn._functions.thnn.auto.ReflectionPad2dBackward'>
CrossMapLRN2d <class 'torch.nn._functions.thnn.normalization.CrossMapLRN2d'>
EmbeddingBag <class 'torch.nn._functions.thnn.sparse.EmbeddingBag'>
一不留神把pytorch支持的所有預定義模塊都給展示出來了。本文稍後開始講解這些預定義模塊的實現。
其他有序字典
self._parameters = OrderedDict() # 模塊網絡參數
self._buffers = OrderedDict() # 駐留內存(不釋放,不交換)
self._backward_hooks = OrderedDict() # 反向鉤子函數字典,
self._forward_hooks = OrderedDict() # 正向鉤子函數字典
self._forward_pre_hooks = OrderedDict() # 正向調用前鉤子函數字典
self._modules = OrderedDict() # 模塊列表
self.training = True # 訓練還是驗證
模塊函數
模塊的函數根據名稱可以知道其作用,此處僅僅列舉,不在詳述
名稱 | 作用 |
---|---|
forward | 前向計算虛函數 |
register_buffer | 註冊駐留內存 |
register_parameter | 註冊參數 |
add_module | 添加模塊 |
_apply | 針對所有參數的操作 |
apply | 針對所有子模塊的操作 |
cuda | 搬家到GPU上 |
cpu | 搬家到CPU上 |
type | 所有參數換類型嘍 |
float | 統統換成浮點 |
double | 統統換成雙精度浮點 |
half | 統統換成字(倆字節) |
to | 給用戶一個換類型和CGPU的接口,其實還是調用_ |
register_backward_hook | 註冊反向鉤子 |
register_forward_pre_hook | 註冊前向調用前鉤子 |
register_forward_hook | 註冊前向鉤子 |
_slow_forward | 沒有加速的前向函數 |
call | 給個參數就執行的前向調用 |
setstate | 快速設置所有字典狀態 |
getattr | 獲取屬性 |
setattr | 設置屬性 |
delattr | 刪除屬性 |
state_dict | 當前狀態字典的輸出 |
_load_from_state_dict | 從狀態字典中裝載的執行函數 |
load_state_dict | 裝載狀態的用戶接口 |
children | 子模塊 |
modules | 所有模塊 |
train | 訓練 |
eval | 評估 |
zero_grad | 參數梯度清零 |
share_memory | 使用共享內存 |
repr | 迭代器 |
dir | 列舉 |