pytorch转onnx模型多输入问题(如:Bert)

举个例子:
Bert模型有三个输入,因此就要创建三个dummy_input,然后利用一个tuple,传入函数中。

dummy_input0 = torch.LongTensor(Batch_size, seg_length).to(torch.device("cuda"))
dummy_input1 = torch.LongTensor(Batch_size, seg_length).to(torch.device("cuda"))
dummy_input2 = torch.LongTensor(Batch_size, seg_length).to(torch.device("cuda"))
torch.onnx.export(model. (dummy_input0, dummy_input1, dummy_input2), filepath)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章