怎樣用libtorch跑多輸入節點的網絡呢?

libtorch的forward函數輸入參數格式爲std::vector<IValue>,當網絡輸入有多個Tensor時,把這些Tensor依次pushback進這個vector即可。

舉例說明:

class CAB(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(CAB, self).__init__()
        self.global_pooling = nn.AdaptiveAvgPool2d(output_size=1)
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.sigmod = nn.Sigmoid()

    def forward(self, x1, x2):
        #x1, x2 = x  # high, low
        x = torch.cat([x1, x2], dim=1)
        x = self.global_pooling(x)
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.sigmod(x)
        x2 = x * x2
        res = x2 + x1
        return res

以上CAB爲pytorch中定義的網絡結構,有x1和x2兩個輸入

trace代碼如下:

CAB_model = CAB(128, 64)
X1 = torch.rand([1, 64, 48, 48])
X2 = torch.rand([1, 64, 48, 48])
Y = CAB_model(X1, X2)
CAB_traced_script_module = torch.jit.trace(CAB_model, (X1, X2))
CAB_traced_script_module.save("./traced/traced_CAB.pt")

libtorch調用代碼如下:

torch::jit::script::Module CAB_model;
try {
    // Deserialize the ScriptModule from a file using torch::jit::load().
    CAB_model = torch::jit::load("./traced/traced_CAB.pt");
}
catch (const c10::Error& e) {
    std::cerr << "error loading the twoStream model\n";
    return -1;
}

at::Tensor X1_tensor = torch::zeros({ 1, 64, 48, 48 });
at::Tensor X2_tensor = torch::zeros({ 1, 64, 48, 48 });
std::vector<torch::jit::IValue> inputs;
inputs.push_back(X1_tensor);
inputs.push_back(X2_tensor);

auto CAB_output = CAB_model.forward(inputs).toTensor();
c10::IntList size;
size = CAB_output.sizes();
std::cout << size.at(0) << "," << size.at(1) << "," << size.at(2) << "," << size.at(3) << std::endl;

參考鏈接:

https://github.com/pytorch/pytorch/issues/15523

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章