darts框架使用

文|Seraph

高版本Pytorch问题

  1. 运行test.py报错IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number
    解决:update函数的参数loss.data[0]prec1.data[0]prec5.data[0]等修改为loss.item()prec1.item()prec5.item()
  2. 执行python visualize.py DARTS报错:
    test failed: ExecutableNotFound: failed to execute ['dot', '-Tpng'],make sure the Graphviz executables are on your systems ' PATH
    解决:除了pip install graphviz,还需要apt install graphviz
    FileNotFoundError: [Errno 2] No such file or directory: 'xdg-open'
    解决:安装apt install xdg-utils
  3. 使用train_search.py得到的weight.pt进行test.py测试报错:RuntimeError: Error(s) in loading state_dict for NetworkCIFAR
  • Missing key(s) in state_dict:
    解决:修改utils.py模块中的load函数如下,(Pytorch老版本兼容问题)
def load(model, model_path):
  model.load_state_dict(torch.load(model_path)False)
  • size mismatch for : copying a param of from checkpoint, where the shape is torch.Size in current model.
    解决:输入运行命令时输入与.pt文件中模型一样的--init_channels 16--layers 8。由于test.py和train.py中这两个参数的默认值时一样的,而train_search.py是不一样的,所以要统一参数值,才能使运行模型一致。
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章