darts框架使用

文|Seraph

高版本Pytorch問題

  1. 運行test.py報錯IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number
    解決:update函數的參數loss.data[0]prec1.data[0]prec5.data[0]等修改爲loss.item()prec1.item()prec5.item()
  2. 執行python visualize.py DARTS報錯:
    test failed: ExecutableNotFound: failed to execute ['dot', '-Tpng'],make sure the Graphviz executables are on your systems ' PATH
    解決:除了pip install graphviz,還需要apt install graphviz
    FileNotFoundError: [Errno 2] No such file or directory: 'xdg-open'
    解決:安裝apt install xdg-utils
  3. 使用train_search.py得到的weight.pt進行test.py測試報錯:RuntimeError: Error(s) in loading state_dict for NetworkCIFAR
  • Missing key(s) in state_dict:
    解決:修改utils.py模塊中的load函數如下,(Pytorch老版本兼容問題)
def load(model, model_path):
  model.load_state_dict(torch.load(model_path)False)
  • size mismatch for : copying a param of from checkpoint, where the shape is torch.Size in current model.
    解決:輸入運行命令時輸入與.pt文件中模型一樣的--init_channels 16--layers 8。由於test.py和train.py中這兩個參數的默認值時一樣的,而train_search.py是不一樣的,所以要統一參數值,才能使運行模型一致。
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章