FaceBook發佈深度學習工具包PyTorch Hub,讓論文復現變得更容易

近日,PyTorch社區發佈了一個深度學習工具包PyTorchHub, 幫助機器學習工作者更快實現重要論文的復現工作。PyTorchHub由一個預訓練模型倉庫組成,專門用於提高研究工作的復現性以及新的研究。同時它還內置了對Google Colab的支持,並與Papers With Code集成。目前PyTorchHub包括了一系列與圖像分類、分割、生成以及轉換相關的模型。

可復現性是許多研究領域的基本要求,這其中當然包括基於機器學習技術的研究領域。然而, 許多機器學習相關論文要麼無法復現,要麼難以重現。隨着論文數量的持續增長,包括目前在 arXiv上預印刷的數萬份論文以及提交給會議的論文,研究工作的可復現性變得越來越重要。雖然其中許多論文都附有代碼以及訓練好的模型,但這種幫助顯然非常有限,復現過程中仍有大量需要讀者自己摸索的步驟。下面讓我們來看一下如何通過PyTorch Hub這一利器完成快速的模型發佈與工作復現。

image

如何快速發佈模型

這部分主要介紹了對於模型發佈者來說如何快速高效的將自己的模型加入PyTorch Hub庫。PyTorch Hub支持通過添加簡單的hubconf.py文件將預先訓練的模型(模型定義和預先訓練重)發佈到GitHub存儲庫。這提供了模型列表以及其依賴庫列表。一些示例可以在torchvisionhuggingface-bertgan-model-zoo存儲庫中找到。

Pytoch社區給出了torchvision的hubconf.py文件的示例:

# Optional list of dependencies required by the package
dependencies = ['torch']

from torchvision.models.alexnet import alexnet
from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161
from torchvision.models.inception import inception_v3
from torchvision.models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152, resnext50_32x4d, resnext101_32x8d
from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1
from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn
from torchvision.models.segmentation import fcn_resnet101, deeplabv3_resnet101
from torchvision.models.googlenet import googlenet
from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0
from torchvision.models.mobilenet import mobilenet_v2

在torchvision中,模型有以下特性:

  • 每個模型文件可以被獨立執行或實現某個功能
  • 不需要除了PyTorch之外的任何軟件包(在hubconf.py中編碼爲dependencies[‘torch’])
  • 他們不需要單獨的入口點,因爲模型在創建時可以無縫地開箱即用。

PyTroch社區認爲最小化包依賴性可減少用戶加載模型時遇到的困難。這裏他們給出了一個更爲複雜的例子——HuggingFace’s BERT模型,它的hubconf.py如下:

dependencies = ['torch', 'tqdm', 'boto3', 'requests', 'regex']

from hubconfs.bert_hubconf import (
    bertTokenizer,
    bertModel,
    bertForNextSentencePrediction,
    bertForPreTraining,
    bertForMaskedLM,
    bertForSequenceClassification,
    bertForMultipleChoice,
    bertForQuestionAnswering,
    bertForTokenClassification
)

此外,對於每個模型,PyTorch官方提到都需要爲其創建一個入口點。下面是一個用於指定bertForMaskedLM模型的入口點的代碼片段,這部分代碼完成的功能是返回加載了預訓練參數的模型。

def bertForMaskedLM(*args, **kwargs):
    """
    BertForMaskedLM includes the BertModel Transformer followed by the
    pre-trained masked language modeling head.
    Example:
      ...
    """
    model = BertForMaskedLM.from_pretrained(*args, **kwargs)
    return model

這些入口點可以看成是複雜的模型結構的一種封裝形式。它們可以在提供簡潔高效的幫助文檔的同時完成下載預訓練權重的功能(例如,通過pretrained = True),也可以集成其他特定功能,例如可視化。

通過hubconf.py,模型發佈者可以在Github上基於template提交他們的合併請求。PyTorch社區希望通過PyTorch Hub創建一系列高質量、易復現且效果好的模型以提高研究工作的復現性。因此,PyTorch會通過與模型發佈者合作的方式以完善請求,並有可能會在某些情況下拒絕發佈一些低質量的模型。一旦PyTorch社區接受了模型發佈者的請求,這些新的模型將會很快出現在PyTorch Hub的網頁上以供用戶瀏覽。

用戶工作流

對於想使用PyTorch Hub對別人的工作進行復現的用戶,PyTorch Hub提供了以下幾個步驟:1)瀏覽可用的模型;2)加載模型;3)探索已加載的模型。下面讓我們來瀏覽幾個例子。

瀏覽可用的入口點

用戶可以使用torch.hub.list() API列出倉庫中的所有可用入口點。

>>> torch.hub.list('pytorch/vision')
>>>
['alexnet',
'deeplabv3_resnet101',
'densenet121',
...
'vgg16',
'vgg16_bn',
'vgg19',
 'vgg19_bn']

注意,PyTorch Hub還允許輔助入口點(除了預訓練模型),例如,用於BERT模型預處理的bertTokenizer,它可以使用戶工作流程更加順暢。

加載模型

對於PyTroch Hub中可用的模型,用戶可以使用torch.hub.load() API加載模型入口點。此外,torch.hub.help() API可以提供有關如何實例化模型的有用信息。

print(torch.hub.help('pytorch/vision', 'deeplabv3_resnet101'))
model = torch.hub.load('pytorch/vision', 'deeplabv3_resnet101', pretrained=True)

由於倉庫的持有者會不斷添加錯誤修復以及性能改進,PyTorch Hub允許用戶通過調用以下內容簡單地獲取最新更新:

model = torch.hub.load(..., force_reload=True)

這一舉措可以有效地減輕倉庫持有者重複發佈模型的負擔,從而使他們能夠更專注於自己的研究工作。同時,也確保了用戶可以獲得最新版本的模型。

此外,對於用戶來說,穩定性也是一個重要問題。因此,某些模型所有者會從特徵的分支或標籤爲他們提供服務,以確保代碼的穩定性。例如,pytorch_GAN_zoo會從hub分支爲他們提供服務:

model = torch.hub.load('facebookresearch/pytorch_GAN_zoo:hub', 'DCGAN', pretrained=True, useGPU=False)

這裏,傳遞給hub.load() 的 * args,** kwargs用於實例化模型。在上面的示例中,pretrained = True和useGPU = False被傳遞給模型的入口點。

探索已加載的模型

從PyTorch Hub加載模型後,用戶可以使用以下工作流查看已加載模型的可用方法,並更好地瞭解運行它所需的參數。

其中,dir(model)可以查看模型中可用的方法。下面是bertForMaskedLM的一些方法:

>>> dir(model)
>>>
['forward'
...
'to'
'state_dict',
]

help(model.forward)則會提供使已加載的模型運行時所需參數的視圖:

>>> help(model.forward)
>>>
Help on method forward in module pytorch_pretrained_bert.modeling:
forward(input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None)
...

更多細節可以查看BERTDeepLabV3頁面:

其他探索方式與相關資源

PyTorch Hub中提供的模型也支持Colab,並且會直接鏈接在Papers With Code上,用戶只需單擊鏈接即可開始使用:

image

PyTorch提供了一些相關資源幫助用戶快速上手PyTorch Hub:

FAQ

問:如果我們想貢獻一個Hub中已經有了的模型,但也許我的模型具有更高的準確性,我還應該貢獻嗎?
答:是的,請提交您的模型,Hub的下一步是開發投票系統以展示最佳模型。

問:誰負責保管PyTorch Hub的模型權重?
答:作爲貢獻者,您負責保管模型權重。您可以在您喜歡的雲存儲中託管您的模型,或者如果它符合限制,則可以在GitHub上託管您的模型。 如果您無法保管權重,請通過Hub倉庫中提交問題的方式與我們聯繫。

問:如果我的模型使用了私有化數據進行訓練怎麼辦?我還應該貢獻這個模型嗎?
答:請不要提交您的模型!PyTorch Hub以開源研究爲中心,並擴展到使用公開數據集來訓練這些模型。如果提交了私有模型的合併請求,我們將懇請您重新提交使用公開數據進行訓練後的模型。

問:我下載的模型保存在哪裏?
答:我們遵循XDG基本目錄規範,並遵循緩存文件和目錄的通用標準。這些位置按以下順序使用:

  • 調用hub.set_dir(<PATH_TO_HUB_DIR>)
  • 如果環境變量了TORCH_HOME,則爲$TORCH_HOME/hub。
  • 如果設置了環境變量XDG_CACHE_HOME,則爲$ XDG_CACHE_HOME / torch / hub。
  • ~/.cache/torch/hub

相關推薦:

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章