tensor、numpy、vector轉換
python中:
**numpy -> tensor**: `torch.from_numpy(ndarray)`
**tensor -> numpy**: `tensor.numpy()`
1
2
3
c++中:
**array -> tensor**: `torch::tensor(at::ArrayRef<float>({3.1, 3.2, 3.3, ...}));` 一般數組
**cv::Mat -> tensor**: `torch::tensor(at::ArrayRef<uint8_t>(img.data, img.rows * img.cols * 3)).view({img.rows, img.cols, 3}); `指針
1
2
3
**tensor->vector**: ` vector<float> xxx(tensor.data<float>(), tensor.data<float>()+tensor.numel())`; 指針直接初始化vector
**tensor -> cv::Mat**: `Mat m(tensor.size(0), tensor.size(1), CV_8UC1, (void*) tensor.data<uint8_t>());`
1
2
3
cv::Mat 與tensor的其他轉換方法:
// from_blob不申請新空間,只是image.data的view。可以用.clone深拷貝。
at::Tensor tensor_image = torch::from_blob(image.data, {1, 3, image.rows, image.cols}, at::kByte); //at::kByte與image的dypte一直;其它如at::kFloat等。
// 之後再轉其他dtype
tensor_image = tensor_image.to(at::kFloat);
tensor_image.data<float>(); //這是一個float* 指針。
// an example
at::Tensor compute(at::Tensor x, at::Tensor w) {
cv::Mat input(x.size(0), x.size(1), CV_32FC1, x.data<float>());
cv::Mat warp(3, 3, CV_32FC1, w.data<float>());
cv::Mat output;
cv::warpPerspective(input, output, warp, {64, 64});
return torch::from_blob(output.ptr<float>(), {64, 64}).clone();
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
c++ model forward 返回值
返回類型爲 torch::jit::IValue
torch::jit::IValue result = module->forward(inputs);
1
如果只有一個返回值,可以直接轉tensor:
auto outputs = module->forward(inputs).toTensor();
1
如果有多個返回值,需要先轉tuple:
auto outputs = module->forward(inputs).toTuple();
torch::Tensor out1 = outputs->elements()[0].toTensor();
torch::Tensor out2 = outputs->elements()[1].toTensor();
1
2
3
使用gpu
把model和inputs都放到gpu上:
std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(argv[2]);
//put to cuda
module->to(at::kCUDA);
// 注意是把tensor放到gpu,而不是vector<torch::jit::IValue>
std::vector<torch::jit::IValue> inputs;
image_tensor.to(at::kCUDA)
inputs.push_back(image_tensor)
1
2
3
4
5
6
7
8
可以指定 GPU id: to(torch::Device(torch::kCUDA, id))
————————————————
版權聲明:本文爲CSDN博主「西伯利亞的藍眼睛」的原創文章,遵循CC 4.0 BY-SA版權協議,轉載請附上原文出處鏈接及本聲明。
原文鏈接:https://blog.csdn.net/qq_14975217/article/details/90512374