pytorch的官方教程裏提供了相關說明:
只保存模型用於以後的推斷的話使用.pth
或.pt
,這樣可以直接加載模型
A common PyTorch convention is to save models using either a .pt or .pth file extension.
torch.save(model, "model.pth") # or .pt
model = torch.load("model.pth")
斷點保存的話則使用.tar
,加載的時候模型需要使用load_state_dict()
方法
To save multiple components, organize them in a dictionary and use torch.save() to serialize the dictionary. A common PyTorch convention is to save these checkpoints using the .tar file extension.
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
...
}, "checkpoint.tar")
...
checkpoint = torch.load("checkpoint.tar")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
其中部份人羣喜歡使用.pth.tar
來表明這不是一個簡單的壓縮tar類型的文件
其實這個問題一直有人討論,因爲pth
同時也是Python的一種格式,所以有人甚至提出要更改一種後綴來區分…不過暫時不太需要考慮這個問題…
但實際上閱讀save的源碼就會發現,torch只是調用了Python的pickle來完成,而且沒有做任何的後綴名判斷,因此無論保存成什麼後綴都是可以的…