乾貨|PyTorch實用代碼段集錦

Facebook公司開源的深度學習框架PyTorch越來越火,PyTorch易於上手。本文節選github中PyTorch的常用實用代碼片段,供大家學習參考。

Github鏈接地址

https://github.com/ptrblck/pytorch_misc

代碼總覽

  • accumulate_gradients - Comparison of accumulated gradients/losses to vanilla batch update. 鏈接:https://github.com/ptrblck/pytorch_misc/blob/master/accumulate_gradients.py
  • adaptive_batchnorm- Adaptive BN implementation using two additional parameters: out = a * x + b * bn(x). 鏈接: https://github.com/ptrblck/pytorch_misc/blob/master/adaptive_batchnorm.py
  • adaptive_pooling_torchvision - Example of using adaptive pooling layers in pretrained models to use different spatial input shapes. 鏈接: https://github.com/ptrblck/pytorch_misc/blob/master/adaptive_pooling_torchvision.py
  • batch_norm_manual - Comparison of PyTorch BatchNorm layers and a manual calculation. 鏈接: https://github.com/ptrblck/pytorch_misc/blob/master/batch_norm_manual.py
  • change_crop_in_dataset - Change the image crop size on the fly using a Dataset. 鏈接: https://github.com/ptrblck/pytorch_misc/blob/master/change_crop_in_dataset.py
  • conv_rnn - Combines a 3DCNN with an RNN; uses windowed frames as inputs. 鏈接: https://github.com/ptrblck/pytorch_misc/blob/master/conv_rnn.py
  • csv_chunk_read - Provide data chunks from continuous .csv file. 鏈接: https://github.com/ptrblck/pytorch_misc/blob/master/csv_chunk_read.py
  • densenet_forwardhook - Use forward hooks to get intermediate activations from densenet121. Uses separate modules to process these activations further. 鏈接: https://github.com/ptrblck/pytorch_misc/blob/master/densenet_forwardhook.py
  • edge_weighting_segmentation - Apply weighting to edges for a segmentation task. 鏈接: https://github.com/ptrblck/pytorch_misc/blob/master/edge_weighting_segmentation.py
  • image_rotation_with_matrix - Rotate an image given an angle using 1.) a nested loop and 2.) a rotation matrix and mesh grid. 鏈接: https://github.com/ptrblck/pytorch_misc/blob/master/image_rotation_with_matrix.py
  • LocallyConnected2d - Implementation of a locally connected 2d layer. 鏈接: https://github.com/ptrblck/pytorch_misc/blob/master/LocallyConnected2d.py
  • mnist_autoencoder - Simple autoencoder for MNIST data. Includes visualizations of output images, intermediate activations and conv kernels. 鏈接: https://github.com/ptrblck/pytorch_misc/blob/master/mnist_autoencoder.py
  • mnist_permuted - MNIST training using permuted pixel locations. 鏈接: https://github.com/ptrblck/pytorch_misc/blob/master/mnist_permuted.py
  • model_sharding_data_parallel - Model sharding with DataParallel using 2 pairs of 2 GPUs. 鏈接: https://github.com/ptrblck/pytorch_misc/blob/master/model_sharding_data_parallel.py
  • momentum_update_nograd - Script to see how parameters are updated when an optimizer is used with momentum/running estimates, even if gradients are zero. 鏈接: https://github.com/ptrblck/pytorch_misc/blob/master/momentum_update_nograd.py
  • shared_array - Script to demonstrate the usage of shared arrays using multiple workers. 鏈接: https://github.com/ptrblck/pytorch_misc/blob/master/shared_array.py
  • unet_demo - Simple UNet demo. 鏈接: https://github.com/ptrblck/pytorch_misc/blob/master/unet_demo.py
  • weighted_sampling - Usage of WeightedRandomSampler using an imbalanced dataset with class imbalance 99 to 1 鏈接: https://github.com/ptrblck/pytorch_misc/blob/master/weighted_sampling.py
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章