6 Module -庖丁解牛之pytorch

Module存儲了模塊類的函數

pytorch中模塊非常容易使用,只需要派生自Module,重載兩個函數就行了,那麼Module都做了什麼

class Module(object):
  def __init__(self):
        self._backend = thnn_backend
        self._parameters = OrderedDict()
        self._buffers = OrderedDict()
        self._backward_hooks = OrderedDict()
        self._forward_hooks = OrderedDict()
        self._forward_pre_hooks = OrderedDict()
        self._modules = OrderedDict()
        self.training = True

構造函數生成一堆有序字典,用來存儲各種參數,暫且不表,先說第一個結構self._backend是一個全局THNNFunctionBackend()類,存儲一個一系列函數指針, 這個類派生類是FunctionBackend

class FunctionBackend(object):
    def __init__(self):
        self.function_classes = {}
    def register_function(self, name, function_class):
        self.function_classes[name] = function_class

其中這個類的function_classes字典的鍵是名稱,值是函數,使用register_function添加註冊,註冊完畢後約有118個函數,本文的pytorch版本是0.4.1

RNN                                      <function RNN at 0x7f4330534378>
RNNTanhCell                              <function RNNTanhCell at 0x7f4330530d90>
RNNReLUCell                              <function RNNReLUCell at 0x7f43305309d8>
LSTMCell                                 <function LSTMCell at 0x7f4330530e18>
GRUCell                                  <function GRUCell at 0x7f4330530ea0>
Dropout                                  <class 'torch.nn._functions.dropout.Dropout'>
Dropout2d                                <class 'torch.nn._functions.dropout.FeatureDropout'>
Dropout3d                                <class 'torch.nn._functions.dropout.FeatureDropout'>
MarginCriterion                          <class 'torch.nn._functions.thnn.auto.MarginCriterion'>
MarginCriterionBackward                  <class 'torch.nn._functions.thnn.auto.MarginCriterionBackward'>
GatedLinear                              <class 'torch.nn._functions.thnn.auto.GatedLinear'>
GatedLinearBackward                      <class 'torch.nn._functions.thnn.auto.GatedLinearBackward'>
SpatialFullConvolutionMap                <class 'torch.nn._functions.thnn.auto.SpatialFullConvolutionMap'>
SpatialFullConvolutionMapBackward        <class 'torch.nn._functions.thnn.auto.SpatialFullConvolutionMapBackward'>
VolumetricFractionalMaxPooling           <class 'torch.nn._functions.thnn.auto.VolumetricFractionalMaxPooling'>
VolumetricFractionalMaxPoolingBackward   <class 'torch.nn._functions.thnn.auto.VolumetricFractionalMaxPoolingBackward'>
VolumetricFullDilatedConvolution         <class 'torch.nn._functions.thnn.auto.VolumetricFullDilatedConvolution'>
VolumetricFullDilatedConvolutionBackward <class 'torch.nn._functions.thnn.auto.VolumetricFullDilatedConvolutionBackward'>
Col2Im                                   <class 'torch.nn._functions.thnn.auto.Col2Im'>
Col2ImBackward                           <class 'torch.nn._functions.thnn.auto.Col2ImBackward'>
DilatedConv2d                            <class 'torch.nn._functions.thnn.auto.DilatedConv2d'>
DilatedConv2dBackward                    <class 'torch.nn._functions.thnn.auto.DilatedConv2dBackward'>
SpatialConvolutionLocal                  <class 'torch.nn._functions.thnn.auto.SpatialConvolutionLocal'>
SpatialConvolutionLocalBackward          <class 'torch.nn._functions.thnn.auto.SpatialConvolutionLocalBackward'>
FeatureLPPooling                         <class 'torch.nn._functions.thnn.auto.FeatureLPPooling'>
FeatureLPPoolingBackward                 <class 'torch.nn._functions.thnn.auto.FeatureLPPoolingBackward'>
VolumetricGridSamplerBilinear            <class 'torch.nn._functions.thnn.auto.VolumetricGridSamplerBilinear'>
VolumetricGridSamplerBilinearBackward    <class 'torch.nn._functions.thnn.auto.VolumetricGridSamplerBilinearBackward'>
TemporalUpSamplingNearest                <class 'torch.nn._functions.thnn.auto.TemporalUpSamplingNearest'>
TemporalUpSamplingNearestBackward        <class 'torch.nn._functions.thnn.auto.TemporalUpSamplingNearestBackward'>
SpatialUpSamplingNearest                 <class 'torch.nn._functions.thnn.auto.SpatialUpSamplingNearest'>
SpatialUpSamplingNearestBackward         <class 'torch.nn._functions.thnn.auto.SpatialUpSamplingNearestBackward'>
ReflectionPad1d                          <class 'torch.nn._functions.thnn.auto.ReflectionPad1d'>
ReflectionPad1dBackward                  <class 'torch.nn._functions.thnn.auto.ReflectionPad1dBackward'>
SpatialConvolutionMap                    <class 'torch.nn._functions.thnn.auto.SpatialConvolutionMap'>
SpatialConvolutionMapBackward            <class 'torch.nn._functions.thnn.auto.SpatialConvolutionMapBackward'>
NLLLoss                                  <class 'torch.nn._functions.thnn.auto.NLLLoss'>
NLLLossBackward                          <class 'torch.nn._functions.thnn.auto.NLLLossBackward'>
Softplus                                 <class 'torch.nn._functions.thnn.auto.Softplus'>
SoftplusBackward                         <class 'torch.nn._functions.thnn.auto.SoftplusBackward'>
LogSigmoid                               <class 'torch.nn._functions.thnn.auto.LogSigmoid'>
LogSigmoidBackward                       <class 'torch.nn._functions.thnn.auto.LogSigmoidBackward'>
SpatialUpSamplingBilinear                <class 'torch.nn._functions.thnn.auto.SpatialUpSamplingBilinear'>
SpatialUpSamplingBilinearBackward        <class 'torch.nn._functions.thnn.auto.SpatialUpSamplingBilinearBackward'>
ReplicationPad3d                         <class 'torch.nn._functions.thnn.auto.ReplicationPad3d'>
ReplicationPad3dBackward                 <class 'torch.nn._functions.thnn.auto.ReplicationPad3dBackward'>
MultiMarginLoss                          <class 'torch.nn._functions.thnn.auto.MultiMarginLoss'>
MultiMarginLossBackward                  <class 'torch.nn._functions.thnn.auto.MultiMarginLossBackward'>
ReplicationPad1d                         <class 'torch.nn._functions.thnn.auto.ReplicationPad1d'>
ReplicationPad1dBackward                 <class 'torch.nn._functions.thnn.auto.ReplicationPad1dBackward'>
MultiLabelMarginLoss                     <class 'torch.nn._functions.thnn.auto.MultiLabelMarginLoss'>
MultiLabelMarginLossBackward             <class 'torch.nn._functions.thnn.auto.MultiLabelMarginLossBackward'>
SpatialFullDilatedConvolution            <class 'torch.nn._functions.thnn.auto.SpatialFullDilatedConvolution'>
SpatialFullDilatedConvolutionBackward    <class 'torch.nn._functions.thnn.auto.SpatialFullDilatedConvolutionBackward'>
SoftMarginLoss                           <class 'torch.nn._functions.thnn.auto.SoftMarginLoss'>
SoftMarginLossBackward                   <class 'torch.nn._functions.thnn.auto.SoftMarginLossBackward'>
NLLLoss2d                                <class 'torch.nn._functions.thnn.auto.NLLLoss2d'>
NLLLoss2dBackward                        <class 'torch.nn._functions.thnn.auto.NLLLoss2dBackward'>
MSELoss                                  <class 'torch.nn._functions.thnn.auto.MSELoss'>
MSELossBackward                          <class 'torch.nn._functions.thnn.auto.MSELossBackward'>
Sigmoid                                  <class 'torch.nn._functions.thnn.auto.Sigmoid'>
SigmoidBackward                          <class 'torch.nn._functions.thnn.auto.SigmoidBackward'>
VolumetricUpSamplingTrilinear            <class 'torch.nn._functions.thnn.auto.VolumetricUpSamplingTrilinear'>
VolumetricUpSamplingTrilinearBackward    <class 'torch.nn._functions.thnn.auto.VolumetricUpSamplingTrilinearBackward'>
BCELoss                                  <class 'torch.nn._functions.thnn.auto.BCELoss'>
BCELossBackward                          <class 'torch.nn._functions.thnn.auto.BCELossBackward'>
Square                                   <class 'torch.nn._functions.thnn.auto.Square'>
SquareBackward                           <class 'torch.nn._functions.thnn.auto.SquareBackward'>
ReplicationPad2d                         <class 'torch.nn._functions.thnn.auto.ReplicationPad2d'>
ReplicationPad2dBackward                 <class 'torch.nn._functions.thnn.auto.ReplicationPad2dBackward'>
L1Loss                                   <class 'torch.nn._functions.thnn.auto.L1Loss'>
L1LossBackward                           <class 'torch.nn._functions.thnn.auto.L1LossBackward'>
SpatialGridSamplerBilinear               <class 'torch.nn._functions.thnn.auto.SpatialGridSamplerBilinear'>
SpatialGridSamplerBilinearBackward       <class 'torch.nn._functions.thnn.auto.SpatialGridSamplerBilinearBackward'>
Sqrt                                     <class 'torch.nn._functions.thnn.auto.Sqrt'>
SqrtBackward                             <class 'torch.nn._functions.thnn.auto.SqrtBackward'>
TemporalRowConvolution                   <class 'torch.nn._functions.thnn.auto.TemporalRowConvolution'>
TemporalRowConvolutionBackward           <class 'torch.nn._functions.thnn.auto.TemporalRowConvolutionBackward'>
SpatialFractionalMaxPooling              <class 'torch.nn._functions.thnn.auto.SpatialFractionalMaxPooling'>
SpatialFractionalMaxPoolingBackward      <class 'torch.nn._functions.thnn.auto.SpatialFractionalMaxPoolingBackward'>
TemporalUpSamplingLinear                 <class 'torch.nn._functions.thnn.auto.TemporalUpSamplingLinear'>
TemporalUpSamplingLinearBackward         <class 'torch.nn._functions.thnn.auto.TemporalUpSamplingLinearBackward'>
VolumetricDilatedMaxPooling              <class 'torch.nn._functions.thnn.auto.VolumetricDilatedMaxPooling'>
VolumetricDilatedMaxPoolingBackward      <class 'torch.nn._functions.thnn.auto.VolumetricDilatedMaxPoolingBackward'>
Threshold                                <class 'torch.nn._functions.thnn.auto.Threshold'>
ThresholdBackward                        <class 'torch.nn._functions.thnn.auto.ThresholdBackward'>
Abs                                      <class 'torch.nn._functions.thnn.auto.Abs'>
AbsBackward                              <class 'torch.nn._functions.thnn.auto.AbsBackward'>
Softshrink                               <class 'torch.nn._functions.thnn.auto.Softshrink'>
SoftshrinkBackward                       <class 'torch.nn._functions.thnn.auto.SoftshrinkBackward'>
LeakyReLU                                <class 'torch.nn._functions.thnn.auto.LeakyReLU'>
LeakyReLUBackward                        <class 'torch.nn._functions.thnn.auto.LeakyReLUBackward'>
VolumetricUpSamplingNearest              <class 'torch.nn._functions.thnn.auto.VolumetricUpSamplingNearest'>
VolumetricUpSamplingNearestBackward      <class 'torch.nn._functions.thnn.auto.VolumetricUpSamplingNearestBackward'>
VolumetricDilatedConvolution             <class 'torch.nn._functions.thnn.auto.VolumetricDilatedConvolution'>
VolumetricDilatedConvolutionBackward     <class 'torch.nn._functions.thnn.auto.VolumetricDilatedConvolutionBackward'>
Tanh                                     <class 'torch.nn._functions.thnn.auto.Tanh'>
TanhBackward                             <class 'torch.nn._functions.thnn.auto.TanhBackward'>
TemporalSubSampling                      <class 'torch.nn._functions.thnn.auto.TemporalSubSampling'>
TemporalSubSamplingBackward              <class 'torch.nn._functions.thnn.auto.TemporalSubSamplingBackward'>
ELU                                      <class 'torch.nn._functions.thnn.auto.ELU'>
ELUBackward                              <class 'torch.nn._functions.thnn.auto.ELUBackward'>
Hardtanh                                 <class 'torch.nn._functions.thnn.auto.Hardtanh'>
HardtanhBackward                         <class 'torch.nn._functions.thnn.auto.HardtanhBackward'>
L1Cost                                   <class 'torch.nn._functions.thnn.auto.L1Cost'>
L1CostBackward                           <class 'torch.nn._functions.thnn.auto.L1CostBackward'>
SpatialSubSampling                       <class 'torch.nn._functions.thnn.auto.SpatialSubSampling'>
SpatialSubSamplingBackward               <class 'torch.nn._functions.thnn.auto.SpatialSubSamplingBackward'>
Im2Col                                   <class 'torch.nn._functions.thnn.auto.Im2Col'>
Im2ColBackward                           <class 'torch.nn._functions.thnn.auto.Im2ColBackward'>
KLDivLoss                                <class 'torch.nn._functions.thnn.auto.KLDivLoss'>
KLDivLossBackward                        <class 'torch.nn._functions.thnn.auto.KLDivLossBackward'>
SmoothL1Loss                             <class 'torch.nn._functions.thnn.auto.SmoothL1Loss'>
SmoothL1LossBackward                     <class 'torch.nn._functions.thnn.auto.SmoothL1LossBackward'>
ReflectionPad2d                          <class 'torch.nn._functions.thnn.auto.ReflectionPad2d'>
ReflectionPad2dBackward                  <class 'torch.nn._functions.thnn.auto.ReflectionPad2dBackward'>
CrossMapLRN2d                            <class 'torch.nn._functions.thnn.normalization.CrossMapLRN2d'>
EmbeddingBag                             <class 'torch.nn._functions.thnn.sparse.EmbeddingBag'>

一不留神把pytorch支持的所有預定義模塊都給展示出來了。本文稍後開始講解這些預定義模塊的實現。

其他有序字典

        self._parameters = OrderedDict() # 模塊網絡參數
        self._buffers = OrderedDict()       # 駐留內存(不釋放,不交換)
        self._backward_hooks = OrderedDict() # 反向鉤子函數字典,
        self._forward_hooks = OrderedDict() # 正向鉤子函數字典
        self._forward_pre_hooks = OrderedDict() # 正向調用前鉤子函數字典
        self._modules = OrderedDict() # 模塊列表
        self.training = True # 訓練還是驗證

模塊函數

模塊的函數根據名稱可以知道其作用,此處僅僅列舉,不在詳述

名稱 作用
forward 前向計算虛函數
register_buffer 註冊駐留內存
register_parameter 註冊參數
add_module 添加模塊
_apply 針對所有參數的操作
apply 針對所有子模塊的操作
cuda 搬家到GPU上
cpu 搬家到CPU上
type 所有參數換類型嘍
float 統統換成浮點
double 統統換成雙精度浮點
half 統統換成字(倆字節)
to 給用戶一個換類型和CGPU的接口,其實還是調用_
register_backward_hook 註冊反向鉤子
register_forward_pre_hook 註冊前向調用前鉤子
register_forward_hook 註冊前向鉤子
_slow_forward 沒有加速的前向函數
call 給個參數就執行的前向調用
setstate 快速設置所有字典狀態
getattr 獲取屬性
setattr 設置屬性
delattr 刪除屬性
state_dict 當前狀態字典的輸出
_load_from_state_dict 從狀態字典中裝載的執行函數
load_state_dict 裝載狀態的用戶接口
children 子模塊
modules 所有模塊
train 訓練
eval 評估
zero_grad 參數梯度清零
share_memory 使用共享內存
repr 迭代器
dir 列舉
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章