libtorch的forward函數輸入參數格式爲std::vector<IValue>,當網絡輸入有多個Tensor時,把這些Tensor依次pushback進這個vector即可。
舉例說明:
class CAB(nn.Module):
def __init__(self, in_channels, out_channels):
super(CAB, self).__init__()
self.global_pooling = nn.AdaptiveAvgPool2d(output_size=1)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.sigmod = nn.Sigmoid()
def forward(self, x1, x2):
#x1, x2 = x # high, low
x = torch.cat([x1, x2], dim=1)
x = self.global_pooling(x)
x = self.conv1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.sigmod(x)
x2 = x * x2
res = x2 + x1
return res
以上CAB爲pytorch中定義的網絡結構,有x1和x2兩個輸入
trace代碼如下:
CAB_model = CAB(128, 64)
X1 = torch.rand([1, 64, 48, 48])
X2 = torch.rand([1, 64, 48, 48])
Y = CAB_model(X1, X2)
CAB_traced_script_module = torch.jit.trace(CAB_model, (X1, X2))
CAB_traced_script_module.save("./traced/traced_CAB.pt")
libtorch調用代碼如下:
torch::jit::script::Module CAB_model;
try {
// Deserialize the ScriptModule from a file using torch::jit::load().
CAB_model = torch::jit::load("./traced/traced_CAB.pt");
}
catch (const c10::Error& e) {
std::cerr << "error loading the twoStream model\n";
return -1;
}
at::Tensor X1_tensor = torch::zeros({ 1, 64, 48, 48 });
at::Tensor X2_tensor = torch::zeros({ 1, 64, 48, 48 });
std::vector<torch::jit::IValue> inputs;
inputs.push_back(X1_tensor);
inputs.push_back(X2_tensor);
auto CAB_output = CAB_model.forward(inputs).toTensor();
c10::IntList size;
size = CAB_output.sizes();
std::cout << size.at(0) << "," << size.at(1) << "," << size.at(2) << "," << size.at(3) << std::endl;
參考鏈接: