【最佳實踐】.pth還是.tar?

pytorch的官方教程裏提供了相關說明:

只保存模型用於以後的推斷的話使用.pth.pt,這樣可以直接加載模型

A common PyTorch convention is to save models using either a .pt or .pth file extension.

torch.save(model, "model.pth") # or .pt
model = torch.load("model.pth")

斷點保存的話則使用.tar,加載的時候模型需要使用load_state_dict()方法

To save multiple components, organize them in a dictionary and use torch.save() to serialize the dictionary. A common PyTorch convention is to save these checkpoints using the .tar file extension.

torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, "checkpoint.tar")

...

checkpoint = torch.load("checkpoint.tar")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

其中部份人羣喜歡使用.pth.tar來表明這不是一個簡單的壓縮tar類型的文件

其實這個問題一直有人討論,因爲pth同時也是Python的一種格式,所以有人甚至提出要更改一種後綴來區分…不過暫時不太需要考慮這個問題…

但實際上閱讀save的源碼就會發現,torch只是調用了Python的pickle來完成,而且沒有做任何的後綴名判斷,因此無論保存成什麼後綴都是可以的…

在這裏插入圖片描述
源問題鏈接

發佈了137 篇原創文章 · 獲贊 85 · 訪問量 12萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章