C++擴展pytorch,簡單加法例子

 首先編寫 our_add.cpp, 引入torch.h 定義加法

#include <torch/torch.h>
#include <vector>
#include <iostream>

torch::Tensor ADD(const torch::Tensor& a, const torch::Tensor& b){
    return a+b;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
    m.def("add", &ADD, "OUR_ADD add");
}

編寫setup.py

from setuptools import setup
from torch.utils.cpp_extension import CppExtension, BuildExtension

setup(
    name = 'our_add',
    version = '0.0.1',
    ext_modules = [CppExtension('our_add',sources=['add.cpp'])],
    cmdclass = {'build_ext': BuildExtension}
)

然後 cmd

python setup.py install

然後會進行編譯及安裝,完成後,就可以使用了.

import torch
import our_add

print(our_add.add)

x = torch.tensor(10)
y = torch.tensor(20)

z = our_add.add(x, y)

print(z)

-> <built-in method add of PyCapsule object at 0x7fa971e180f0>
-> tensor(30)

編譯完成後,會在python環境中進行安裝,如果需要移出來用,可以將包our_add-0.0.1-py3.7-linux-x86_64.egg 中的 .so與.py一起復製出來.

 

** 另 在調用時,先調用 torch ,再調用自己的包

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章