Pytorch的C++使用

利用Pytorch的C++前端(libtorch)讀取預訓練權重並進行預測

https://blog.csdn.net/iamoldpan/article/details/85057238

https://blog.csdn.net/a819411321/article/details/97372177

 

 

pytorch 參數寫入二進制文件

    data = []
    for name, param in model.state_dict().items():
        print("name:",name, " param:", param.size())
        if param.size() != torch.Size([]):
            pa = param.reshape(-1).to(device)
            numpy_param = pa.detach().numpy()
            print("param:", len(numpy_param))
            data += list(numpy_param)
    print("len data:", len(data))
    f = open('weights/yolov3.bin','wb+')
    data_struct = struct.pack(('%df' % len(data)), *data)
    f.write(data_struct)
    f.close()

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章