【手撕 - 深度學習】TF Lite 魔改:添加自定義 op

作者:LogM

本文原載於 https://segmentfault.com/u/logm/articles ,不允許轉載~


1. 前言

Tensorflow Lite 是 Tensorflow 移動端的版本。

有關於 Tensorflow 怎麼添加自定義 op,網上有很多博客都講到了,我就不介紹了。而 Tensorflow Lite 因爲相對小衆一些,所以網上關於添加自定義 op 的教程很少。

剛好最近因爲項目需要,我在 Tensorflow Lite 中添加了幾個自定義 op。我把我的思考過程以及修改步驟記錄下來,方便有相同需求的同學參考。

我花了大篇幅記錄思考過程和源碼閱讀過程,是希望給其他小夥伴一些啓發,以後遇到類似的深度學習框架魔改的問題,可以不依賴網上教程。

不關心思考過程和源碼閱讀的小夥伴,可以直接跳到文章的最後,我把修改的步驟做了總結。

2. 源碼來源

我使用源碼是 Tensorflow v1.13.2

Tensorflow Lite 位於 tensorflow/lite 目錄下。

3. 官方教程

官網也有關於 Tensorflow Lite 怎麼添加自定義 op 的教程,詳見官方地址

官方教程把"怎麼寫自定義 op 的代碼"講得很清楚,遺憾的是沒有詳細說明怎麼把這些新寫的代碼放入到工程中編譯。

4. 進入正題

第1步,找到目標文件夾位置

首先我們要找到源碼中放置自定義 op 的文件夾位置。有多種尋找的方式:

  1. tensorflow 源碼的目錄結構非常清楚,有過類似框架閱讀經驗的同學應該馬上能猜出位置;
  2. 官方教程告訴我們,自定義 op 的代碼要實現 PrepareEval 這兩個函數,那麼我們使用 grep 命令查找有哪些代碼文件中帶有這兩個函數。

最終,我們找到的位置是 tensorflow/lite/kernels

找到目標文件夾位置以後,把新增代碼放入該文件夾就可以了嗎?顯然,沒有這麼簡單。有幾個方面需要考慮:

  1. 代碼邏輯層面,新增代碼的邏輯怎麼與源碼的邏輯連接起來;
  2. 編譯層面,新增代碼怎麼參與編譯。

第2步,新增代碼的邏輯怎麼與源碼的邏輯連接起來?

有過類似深度學習框架閱讀經驗的同學應該很快能想到,對於"添加自定義op"這個操作,就是個"op註冊"的過程,所以馬上想到去尋找帶"register"字樣的文件。

而沒有深度學習框架閱讀經驗的同學也不用慌,官方教程告訴我們,自定義op在使用前需要調用 AddCustom 函數。那麼很明顯,這個函數就起到了將自定義op的邏輯與源碼邏輯連接起來的任務。所以使用 grep 命令查找有哪些代碼文件中帶有這個函數。

兩種方式殊途同歸,找到關鍵文件 tensorflow/lite/kernels/register.cc

// 文件:tensorflow/lite/kernels/register.cc
// 行數:22

namespace custom {

TfLiteRegistration* Register_AUDIO_SPECTROGRAM();
TfLiteRegistration* Register_LAYER_NORM_LSTM();
TfLiteRegistration* Register_MFCC();
TfLiteRegistration* Register_DETECTION_POSTPROCESS();
TfLiteRegistration* Register_RELU_1();

} 
// 文件:tensorflow/lite/kernels/register.cc
// 行數:278

  // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
  // custom ops aren't always included by default.
  AddCustom("Mfcc", tflite::ops::custom::Register_MFCC());
  AddCustom("AudioSpectrogram",
            tflite::ops::custom::Register_AUDIO_SPECTROGRAM());
  AddCustom("LayerNormLstm", tflite::ops::custom::Register_LAYER_NORM_LSTM());
  AddCustom("Relu1", tflite::ops::custom::Register_RELU_1());
  AddCustom("TFLite_Detection_PostProcess",
            tflite::ops::custom::Register_DETECTION_POSTPROCESS());

嘿嘿嘿,我們發現官方源碼中也放了5個自定義op,而且官方偷懶把自定義op與內置op的註冊過程寫在了一起,那麼我們來看看官方是怎麼寫自定義op的吧,比如 Relu1 這個。

// 文件:tensorflow/lite/kernels/relu1.cc

#include "tensorflow/lite/context.h"
#include "tensorflow/lite/kernels/internal/tensor.h"
#include "tensorflow/lite/kernels/kernel_util.h"

namespace tflite {
namespace ops {
namespace custom {
namespace relu1 {

TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
  TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
  const TfLiteTensor* input = GetInput(context, node, 0);
  TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
  TfLiteTensor* output = GetOutput(context, node, 0);
  output->type = input->type;
  return context->ResizeTensor(context, output,
                               TfLiteIntArrayCopy(input->dims));
}

// This is derived from lite/kernels/activations.cc.
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
  const TfLiteTensor* input = GetInput(context, node, 0);
  TfLiteTensor* output = GetOutput(context, node, 0);
  const int elements = NumElements(input);
  const float* in = input->data.f;
  const float* in_end = in + elements;
  float* out = output->data.f;
  for (; in < in_end; ++in, ++out) {
    *out = std::min(std::max(0.f, *in), 1.f);
  }
  return kTfLiteOk;
}

}  // namespace relu1

TfLiteRegistration* Register_RELU_1() {
  static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr,
                                 relu1::Prepare, relu1::Eval};
  return &r;
}

}  // namespace custom
}  // namespace ops
}  // namespace tflite

可以看到,與官方給出的教程一樣,關鍵點是實現 PrepareEval 這兩個函數。我們自己在自定義op的代碼時,可以把這個文件當做參考模板。

第3步,新增代碼怎麼參與編譯?

這塊需要一些 C++ 大工程開發的知識,Tensorflow 是用 Bazel 作工程編譯的,所以關鍵點在目標文件夾下的 BUILD 文件。

BUILD 文件裏面這麼多的 library,我們的新代碼應該編譯到哪個 library 中呢?還記得 官方留的自定義op "Relu1" 嗎?我們來看看 "Relu1" 是編譯到哪個 library。

// 文件:tensorflow/lite/kernels/BUILD
// 行數:278

cc_library(
    name = "builtin_op_kernels",
    srcs = [
        ...     // 這裏有很多其他的源文件
        "mfcc.cc",
        "relu1.cc",
        ...     // 把新寫的代碼文件加到這邊就可以了
    ],
    hdrs = [
    ],
    copts = tflite_copts() + tf_opts_nortti_if_android() + EXTRA_EIGEN_COPTS,
    visibility = ["//visibility:private"],
    deps = [
        ":activation_functor",
        ":eigen_support",
        ":kernel_util",
        ":lstm_eval",
        ":op_macros",
        ":padding",
        "//tensorflow/lite:framework",
        "//tensorflow/lite:string_util",
        "//tensorflow/lite/c:c_api_internal",
        "//tensorflow/lite/kernels:gemm_support",
        "//tensorflow/lite/kernels/internal:audio_utils",
        "//tensorflow/lite/kernels/internal:kernel_utils",
        "//tensorflow/lite/kernels/internal:optimized",
        "//tensorflow/lite/kernels/internal:optimized_base",
        "//tensorflow/lite/kernels/internal:quantization_util",
        "//tensorflow/lite/kernels/internal:reference_base",
        "//tensorflow/lite/kernels/internal:tensor",
        "//tensorflow/lite/kernels/internal:tensor_utils",
        "@farmhash_archive//:farmhash",
        "@flatbuffers",
    ],
)

如果你熟悉 Bazel 或者熟悉類似的編譯工具的話,能夠很快明白,只要把新的代碼文件添加到 src=[] 裏,新的代碼就能參與到編譯過程中了。

5. 總結

Tensorflow Lite v1.13.2 中,官方偷了個懶,自定義 op 與內置 op 寫在同一個位置,都是編譯爲 builtin_op_kernels 庫。

Tensorflow Lite 的自定義 op 添加方式如下:

  1. 參照 官方教程 以及 tensorflow/lite/kernels/relu1.cc 編寫 op 代碼;
  2. 將 op 代碼放入 tensorflow/lite/kernels 文件夾下;
  3. 修改 tensorflow/lite/kernels/register.cc,完成新增 op 在代碼邏輯上的"註冊";
  4. 修改 tensorflow/lite/kernels/BUILD,將新代碼文件加入到 builtin_op_kernels 庫的編譯過程中;
  5. 參照 官方教程 重新編譯整個項目。
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章