pytorch->onnx->trt 踩坑記錄

ERROR: ModelImporter.cpp:296 In function importModel: [5] Assertion failed: tensors.count(input_name)

問題原因:
pytorch版本爲1.4.0,pytorch版本過高引起的onnx解析問題.(據悉這個解析問題會發生在trt5.0-6.0,trt7.0不會出現,詳見trt6轉torch1.2以上版本的onnx)

問題解決:

  1. 上面的鏈接中,可以直接重新編譯onnx2trt,只需要載入build好的engine的同學可以直接去重新編譯onnx2trt,然後用這個轉onnx爲trt引擎。
  2. 降torch的版本到1.1.0

[E] [TRT] Network must have at least one output

問題原因:
網絡的輸出爲兩個,轉onnx的時候要明確這一點
問題解決:

error改之前:(我訓練出來的網絡應該有兩個輸出,以下代碼設置的是torch.onnx.export的參數)

input_names=["input"]
output_names=["output_mask"]

改之後

input_names=["input"]
output_names=["output_conv10,""output_mask"]

將torch.onnx.export的output_names參數的輸出name個數,設置的和網絡模型實際輸出個數相同後,就好了。

Only tuples, lists and Variables supported as JIT inputs, but got dict

問題原因:
pytorch下降到1.1.0,不支持字典的輸出。
問題解決:
即原來的網絡我是這麼輸出的


    def forward(self, x):
		...
        x5 = self.layer5(x4)
        return {"seg":x5,"mask":x4}

要改成這樣


    def forward(self, x):
		...
        x5 = self.layer5(x4)
        
        return x4, x5

然後,把調用的後處理改一下,就可以了。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章