tensor、numpy、vector转换
python中:
**numpy -> tensor**: `torch.from_numpy(ndarray)`
**tensor -> numpy**: `tensor.numpy()`
1
2
3
c++中:
**array -> tensor**: `torch::tensor(at::ArrayRef<float>({3.1, 3.2, 3.3, ...}));` 一般数组
**cv::Mat -> tensor**: `torch::tensor(at::ArrayRef<uint8_t>(img.data, img.rows * img.cols * 3)).view({img.rows, img.cols, 3}); `指针
1
2
3
**tensor->vector**: ` vector<float> xxx(tensor.data<float>(), tensor.data<float>()+tensor.numel())`; 指针直接初始化vector
**tensor -> cv::Mat**: `Mat m(tensor.size(0), tensor.size(1), CV_8UC1, (void*) tensor.data<uint8_t>());`
1
2
3
cv::Mat 与tensor的其他转换方法:
// from_blob不申请新空间,只是image.data的view。可以用.clone深拷贝。
at::Tensor tensor_image = torch::from_blob(image.data, {1, 3, image.rows, image.cols}, at::kByte); //at::kByte与image的dypte一直;其它如at::kFloat等。
// 之后再转其他dtype
tensor_image = tensor_image.to(at::kFloat);
tensor_image.data<float>(); //这是一个float* 指针。
// an example
at::Tensor compute(at::Tensor x, at::Tensor w) {
cv::Mat input(x.size(0), x.size(1), CV_32FC1, x.data<float>());
cv::Mat warp(3, 3, CV_32FC1, w.data<float>());
cv::Mat output;
cv::warpPerspective(input, output, warp, {64, 64});
return torch::from_blob(output.ptr<float>(), {64, 64}).clone();
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
c++ model forward 返回值
返回类型为 torch::jit::IValue
torch::jit::IValue result = module->forward(inputs);
1
如果只有一个返回值,可以直接转tensor:
auto outputs = module->forward(inputs).toTensor();
1
如果有多个返回值,需要先转tuple:
auto outputs = module->forward(inputs).toTuple();
torch::Tensor out1 = outputs->elements()[0].toTensor();
torch::Tensor out2 = outputs->elements()[1].toTensor();
1
2
3
使用gpu
把model和inputs都放到gpu上:
std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(argv[2]);
//put to cuda
module->to(at::kCUDA);
// 注意是把tensor放到gpu,而不是vector<torch::jit::IValue>
std::vector<torch::jit::IValue> inputs;
image_tensor.to(at::kCUDA)
inputs.push_back(image_tensor)
1
2
3
4
5
6
7
8
可以指定 GPU id: to(torch::Device(torch::kCUDA, id))
————————————————
版权声明:本文为CSDN博主「西伯利亚的蓝眼睛」的原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接及本声明。
原文链接:https://blog.csdn.net/qq_14975217/article/details/90512374