python訓練的模型,轉換爲onnx模型後,用python代碼可以方便進行推理,但是java代碼如何實現呢?
首先ONNX 推理,可以使用onnxruntime
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<version>1.15.1</version>
</dependency>
另外,訓練的模型需要用到bert分詞器,將單詞和字變成token id, github上有 https://github.com/ankiteciitkgp/bertTokenizer,我們基於這個庫簡單改造下,來適配bert onnx模型的輸入,改造後代碼見: https://github.com/jadepeng/bertTokenizer
主要新增了tokenizeOnnxTensor
方法,返回適配bert模型輸入的onnx tensor
完整demo代碼:
public class OnnxTests {
public static void main(String[] args) throws IOException, OrtException {
BertTokenizer bertTokenizer = new BertTokenizer("D:\\model\\vocab.txt");
var env = OrtEnvironment.getEnvironment();
var session = env.createSession("D:\\model\\output\\onnx\\fp16_model.onnx",
new OrtSession.SessionOptions());
var inputMap = bertTokenizer.tokenizeOnnxTensor(Arrays.asList("hello world 你好", "腫瘤治療未來發展趨勢"));
try (var results = session.run(inputMap)) {
System.out.println(results);
var embeddings = (float[][])results.get(0).getValue();
for (var embedding : embeddings) {
System.out.println(JSON.toJSONString(embedding));
}
}
}
}