pytorch轉onnx模型多輸入問題(如:Bert)

舉個例子:
Bert模型有三個輸入,因此就要創建三個dummy_input,然後利用一個tuple,傳入函數中。

dummy_input0 = torch.LongTensor(Batch_size, seg_length).to(torch.device("cuda"))
dummy_input1 = torch.LongTensor(Batch_size, seg_length).to(torch.device("cuda"))
dummy_input2 = torch.LongTensor(Batch_size, seg_length).to(torch.device("cuda"))
torch.onnx.export(model. (dummy_input0, dummy_input1, dummy_input2), filepath)
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章