AI 編譯器CINN中的OpLowering優化Pass

一、Lower 主邏輯

OpLower::Lower() 接口中,主要分爲兩大類:

  • Elementwise類,主要涉及的 OpPattern 包括:kElementwisekBroadcastkInjective
  • Reduce 類,主要涉及的OpPattern包括:kReduction
std::vector<ir::LoweredFunc> OpLowerer::Lower(GroupPtr& group) {
  VLOG(3) << "Lowering Group : " << group->group_id << " , Op Pattern : " << group->op_pattern_kind;
  group->input_names.clear();
  group->output_names.clear();
  if (FLAGS_cinn_ir_schedule) {
    switch (group->op_pattern_kind) {
      case framework::kElementWise:
      case framework::kBroadcast:
      case framework::kInjective:
        return IRLowerOp(&OpLowerer::IRElementwiseCompute, &OpLowerer::IRElementwiseSchedule, group);   // << --- 第一大類 Elementwise相關
      case framework::kReduction:
        return IRLowerOp(&OpLowerer::IRReduceCompute, &OpLowerer::IRReduceSchedule, group);             // << --- 第二大類 Reduce 相關
      case framework::kOutFusible:
        LOG(FATAL) << "Group Pattern Kind kOutFusible Is Not Implemented!";
      case framework::kNonFusible:
        return IRLowerNonFusibleOp(group, /*apply_impl_schedule = */ true);
      default:
        LOG(FATAL) << "Group Pattern Kind Is Unknown!";
    }
  } else {
    LOG(FATAL) << "Previous IR Schedule Is Not Implemented!";
  }
}

二、Optimize 邏輯

op_lowering.cc 中的 IRLowerOp 的最後,會創建一個 LoweredFunc 對象,並對其調用 optim::Optimize() 函數。

std::vector<ir::LoweredFunc> OpLowerer::IRLowerOp(compute, schedule, group){
  // .... 省略
  auto temp_buffers = lang::GetTempBuffers(arg_tensors, stages, func_body);
  auto func =
      ir::_LoweredFunc_::Make(group->GetFuncName(), func_args, ir_sch.GetModule().GetExprs().at(0), temp_buffers);
  func = optim::Optimize(Expr(func), target_, false).as_lowered_func_ref();   // <----在函數最後會調用 Optimizer 對函數表達式進行優化,注意與 target 相關
  return {func};
}

其接口實現是在 optimize.cc 文件中,主要是對 LoweredFunc 對應的Expr應用各種優化 pass:

// 中層是 optimize.cc 中的:
Expr Optimize(Expr e, Target target, bool runtime_debug_info, bool remove_gpu_for_loops) {
    auto copied = IRCopy(e);
    // 與 target 無關的通用優化
    FoldCINNCallArguments(&copied);
    TransformPolyForToFor(&copied);
    ReplaceConstParamToInteger(&copied);
    CastSimplify(&copied);
    Simplify(&copied);
    UnrollLoop(&copied);
    // 與 target 有關的優化
    VectorizeLoops(&copied, target);
    MapExternCall(&copied, target);            // <---- 此處是這裏要關注和討論的 MapExternCall 優化
    
    // 僅在 CUDA 上的優化
    ir::SetCudaAxisInfo(&copied);
    RemoveGpuForloopsAxis(&copied);
    CudaSyncThreadsDropIfThenElse(&copied);
    
    // 又是與 target 無關的通用優化
    RemoveNestedBlock(&copied);
    ExternCallMultiOutputShallowStore(&copied);
    CastSimplify(&copied);
    Simplify(&copied);
    IfSimplify(&copied);
    
    // 與調試相關通用優化
    InsertDebugLogCallee(&copied);
}

三、各個優化Pass

接下來,我們逐個來研究每個 pass 的角色和作用。

3.1 FoldCINNCallArguments

此 Pass 的功能是通過 FoldCINNCallArgumentsMutator 來實現的:

void FoldCINNCallArguments(Expr* expr) { FoldCINNCallArgumentsMutator()(expr); }
此 Mutator 只關心ir::Block和ir::Store兩種類型節點:
struct FoldCINNCallArgumentsMutator : public ir::IRMutator<> {
  void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }

 private:
  void Visit(const ir::Block* op, Expr* expr);   // <----- Block
  void Visit(const ir::Store* op, Expr* expr);   // <----- Store

  void MutateCall(ir::Call* call);

 private:
  // To avoid the same call triggered duplicately.
  std::unordered_set<std::string> visited_call_; 
};

其中對於 ir::Block 類型的節點,找到所有的 CallType::CINN 的 Expr,然後判斷是否已經存在過,若是,則刪除 Block 中的此 statement 語句。
這裏補充下 ir::Call 節點中的 CallType 枚舉類型的值都有哪些:CINNIntrinsicExternISL

  void Visit(const ir::Block* op, Expr* expr) override {
    auto* node = expr->As<ir::Block>();
    for (auto it = node->stmts.begin(); it != node->stmts.end();) {
      if (it->As<ir::Store>()) {
        auto* call = it->As<ir::Store>()->value.As<ir::Call>(); // <---- 針對 x = cinn_call_func(args) 場景?
        if (call && call->is_cinn_call()) {
          // remove the duplicate calls.
          std::string key = utils::GetStreamCnt(Expr(call));
          if (visited_call_.count(key)) {
            it = node->stmts.erase(it);
            continue;
          }

          ir::IRMutator<>::Visit(&(*it), &(*it));     // <--- 這裏會觸發下面的 ir::Store 的處理邏輯
          visited_call_.insert(key);
          continue;
        }
      }

      ir::IRMutator<>::Visit(&(*it), &(*it));
      ++it;
    }
  }

對於 ir::Store 類型的節點,僅針對 CallType::CINN 類型的節點調用 MutateCall 函數進行修改和替換;

  void Visit(const ir::Store* op, Expr* expr) override {
    auto* node = expr->As<ir::Store>();
    if (node->value.As<ir::Call>()) {
      auto* call = node->value.As<ir::Call>();
      switch (call->call_type) {
        case ir::CallType::CINN:
          MutateCall(call);
          *expr = node->value;
          break;
        case ir::CallType::Intrinsic:
          break;
        case ir::CallType::Extern:
          break;
        default:
          CINN_NOT_IMPLEMENTED
      }
    }
  }

MuteCall 函數是此 Pass 的最核心邏輯,其作用是 call 節點中所有的輸入、輸出 args 中的 Tensor 類型,確認其都 defined 了 buffer ,並將 buffer 作爲真正的 args 替換原來的 read_args 和 write_args 。

思考:爲什麼要單獨對CINN類型的CallType多做這樣一件事情呢?背景是什麼?

  void MutateCall(ir::Call* call) {
    if (call->call_type == ir::CallType::Extern) return;

    std::vector<Expr> read_args;
    std::vector<Expr> write_args;
    for (auto& arg : call->read_args) {
      if (arg.as_tensor()) {
        CHECK(arg.as_tensor()->buffer.defined()) << "arg tensor [" << arg.as_tensor()->name << "] not has buffer";
        read_args.push_back(arg.as_tensor()->buffer);
      } else {
        read_args.push_back(arg);
      }
    }

    for (auto& arg : call->write_args) {
      if (arg.as_tensor()) {
        write_args.push_back(arg.as_tensor()->buffer);
      } else {
        write_args.push_back(arg);
      }
    }

    call->read_args  = read_args;
    call->write_args = write_args;
  }

3.2 ReplaceConstParamToInteger

這個 Pass 比較簡單,只針對 ir::Var 類型的節點,如果其 name 是以 _const_ 開頭的,則取其具體的值,轉爲Expr(如 Intmm)

static const char* kIslParamConstPrefix = "_const_";

struct Mutator : public ir::IRMutator<> {
  using ir::IRMutator<>::Visit;

  void Visit(const ir::_Var_* op, Expr* expr) override {
    if (utils::Startswith(op->name, poly::kIslParamConstPrefix)) {
      std::string value = op->name.substr(strlen(poly::kIslParamConstPrefix));
      *expr             = Expr(std::stoi(value));   // <----- 這裏強轉爲 int 類型,是隻存在類似 _const_12 這種情況
    }
  }
};

}  // namespace

void ReplaceConstParamToInteger(Expr* e) {
  Mutator mutator;
  mutator.Visit(e, e);
}

那這個const 前綴字符串拼接是在哪裏做的呢?是在 cinn::poly::ast_gen 中做的,相關邏輯代碼如下:

isl::set TransIdentityExtentToContextId(isl::set set) {
  std::vector<std::tuple<int, int>> iden_dim_offsets;
  for (int i = 0; i < isl_set_dim(set.get(), isl_dim_set); i++) {
    if (isl_set_axis_has_noparam_constant_bound(set.get(), i)) {
      auto range = isl_set_get_axis_range(set.get(), i);
      auto& minv = std::get<0>(range);
      auto& maxv = std::get<1>(range);

      int min_iv = minv.get_num_si();
      int max_iv = maxv.get_num_si();
      if (max_iv == min_iv) {
        iden_dim_offsets.emplace_back(i, max_iv);
      }
    }
  }

  isl::set res_set = set;
  for (auto offset_val : iden_dim_offsets) {
    auto& offset = std::get<0>(offset_val);
    auto& val    = std::get<1>(offset_val);   // <---- 是個 int 類型
    res_set      = isl::manage(isl_set_drop_constraints_involving_dims(res_set.copy(), isl_dim_set, offset, 1));

    std::string const_param_name = llvm::formatv("{0}{1}", kIslParamConstPrefix, val);  //<---- 在此處進行拼接的

    std::string cond_str = llvm::formatv(
        "{0} <= {1} <= {2}", val, isl_set_get_dim_name(res_set.get(), isl_dim_set, offset), const_param_name);
    std::string param_cond_str = llvm::formatv("{0} <= {1} < {2}", val, const_param_name, val + 2);

    std::string set_repr = llvm::formatv("[{0}] -> { {1}[{2}]: {3} and {4} }",
                                         const_param_name,
                                         isl_set_get_tuple_name(res_set.get()),
                                         utils::Join(isl_get_dim_names(res_set.get()), ","),
                                         cond_str,
                                         param_cond_str);

    VLOG(4) << "repr: " << set_repr;

    isl::set new_set(res_set.ctx(), set_repr);

    res_set = res_set.intersect(new_set);
  }
  return res_set;
}

注:通過檢索Bert 模型中的GLOG_v=10 的日誌,並沒有發現 ReplaceConstParamToInteger 生效的地方。

如下是一個 CINN 框架裏的單測,可以輔助幫助理解上面這個函數的作用效果,樣例中 j=0,其中會把 0 這個常量值先創建一個 _const_0,然後做了變換?

TEST(TransIdentityExtentToContextId, basic) {
  isl_ctx* ctx = isl_ctx_alloc();
  isl::set set(ctx, "{ s[i,j=0,k] : 0<=i<12 and 12<k<32 }");
  auto new_set = TransIdentityExtentToContextId(set);
  LOG(INFO) << new_set;

  ASSERT_EQ(utils::GetStreamCnt(new_set),
            "[_const_0] -> { s[i, j, k] : _const_0 <= 1 and 0 <= i <= 11 and 0 <= j <= _const_0 and 13 <= k <= 31 }");
}

3.3 CastSimplify

此Pass 僅會對 constant 的Expr進行處理,比如 IntImm、FloatImm、UIntImm,作用是將其持有的value值按照 ir::Cast.type() 進行數值類型轉換,然後包裹一個Expr重新返回。

void CastSimplify(Expr* e) {
  Mutator mutator;
  mutator.Visit(e, e);
}

struct Mutator : ir::IRMutator<> {
  using ir::IRMutator<>::Visit;

  void Visit(const ir::Cast* op, Expr* expr) {
    auto* node = expr->As<ir::Cast>();

    Visit(&node->v(), &node->v());    // <<--- 類似 AST 的 generic_visit,深度優先遞歸處理 node->v() 節點

    if (op->type() == op->v().type()) {
      *expr = op->v();              // Caset 1: 如果 value 類型已經與dst_type 一致,則直接返回 node->v() 以替換當前節點。
      return;
    }

#define __CAST_TO_TYPE(type__)                                          \
  if (auto* i = op->v().As<ir::IntImm>()) {                             \
    *expr = Expr(static_cast<type__>(i->value));                        \
  } else if (auto* f = op->v().As<ir::FloatImm>()) {                    \
    *expr = Expr(static_cast<type__>(NormCastValue<type__>(f->value))); \    // <<----- 注意:這裏對Float類型進行了特殊處理,因爲存在轉Float16的場景
  } else if (auto* u = op->v().As<ir::UIntImm>()) {                     \
    *expr = Expr(static_cast<type__>(u->value));                        \
  } else {                                                              \
    CINN_NOT_IMPLEMENTED                                                \
  }

    if (op->v().is_constant()) {      // <----- 注意:此pass僅支持 ir::Cast->v()爲常量類型的場景
      if (op->type() == type_of<int8_t>()) {
        __CAST_TO_TYPE(int8_t)
      } else if (op->type() == type_of<int16_t>()) {
        __CAST_TO_TYPE(int16_t)
      } else if (op->type() == type_of<int32_t>()) {
        __CAST_TO_TYPE(int32_t)
      } else if (op->type() == type_of<int64_t>()) {
        __CAST_TO_TYPE(int64_t)
      } else if (op->type() == type_of<uint8_t>()) {
        __CAST_TO_TYPE(uint8_t)
      } else if (op->type() == type_of<uint16_t>()) {
        __CAST_TO_TYPE(uint16_t)
      } else if (op->type() == type_of<uint32_t>()) {
        __CAST_TO_TYPE(uint32_t)
      } else if (op->type() == type_of<uint64_t>()) {
        __CAST_TO_TYPE(uint64_t)
      } else if (op->type() == type_of<float>()) {
        __CAST_TO_TYPE(float)
      } else if (op->type() == type_of<double>()) {
        __CAST_TO_TYPE(double)
      } else if (op->type() == type_of<bool>()) {
        __CAST_TO_TYPE(bool)
      } else if (op->type() == type_of<uint32_t>()) {
        __CAST_TO_TYPE(uint32_t)
      } else if (op->type() == type_of<uint64_t>()) {
        __CAST_TO_TYPE(uint64_t)
      } else if (op->type() == type_of<float16>()) {
        // Cannot simplify!!! pass
        __CAST_TO_TYPE(float16)
      } else {
        CINN_NOT_IMPLEMENTED
      }
    }
#undef __CAST_TO_TYPE
  }
};

在上面流程代碼中,我們可以看出對於 FloatImm 類型的處理額外借助了 NormCastValue 這個函數,原因是對於 Float32 到 Float16 的轉寫,要考慮上溢、下溢、NanInf 的場景:

template <typename CastType, typename T>
CastType NormCastValue(T value) {
  if (type_of<CastType>().is_uint() || type_of<T>().is_uint()) {
    // not support uint
    return static_cast<CastType>(value);
  }

  if (std::isinf(value)) {
    return std::numeric_limits<CastType>::infinity();
  } else if (std::isnan(value)) {
    return std::numeric_limits<CastType>::signaling_NaN();
  } else if (value >= static_cast<T>(std::numeric_limits<CastType>::max())) {
    return std::numeric_limits<CastType>::max();
  } else if (value <= static_cast<T>(std::numeric_limits<CastType>::lowest())) {
    return std::numeric_limits<CastType>::lowest();
  }
  return static_cast<CastType>(value);
}

3.4 Simplify

這個 Pass 包括的子邏輯比較多,單測文件 ir_simplify_test.cc 裏可以幫助理解效果:

void Simplify(Expr* expr) {
  optim::CastSimplify(expr);    // 先調用了 CastsSimplify,這個似乎會比較多餘?在遞歸調用時更會導致效率低下
  SimplifyRampMutator()(expr);
  SimplifyLoadMutator()(expr);
  SimplifyStoreMutator()(expr);
  SimplifyIfThenElseMutator()(expr);

  common::cas_intervals_t var_intervals;
  SimplifyButStoreLoadMutator mutator(var_intervals);   // 又額外來了一遍,這裏似乎也比較低效?
  mutator(expr);

  ReplaceFracWithDivMutator()(expr);  //  這裏將 ir::Frac 替換爲了 ir::Div,似乎也不是必要的,沒有看到哪裏構造了 ir::Frac
}

效果:

// case 1:
C = 1.  //shape = [100, 20]
B = C[i, 0] + 1 * 0 + 100 + 24.5

// 經過此 Pass 後:
B = C[i, 0] + 124.5

// case 2:
{
   serial for (i, 0, 100)
   {
     serial for (j, 0, 20)
     {
       B[i, j] = (X[i + 0, j + 0] + Y[i, j * 0] * 1.f + 0.f * X[i, j] + 25.f + 100.f - 0.f +
                 9.f * 10000.f * 1.f * 1.f * 0.f
)
    }
   }
}
// 經過此 Pass 後:
{
   serial for (i, 0, 100)
   {
     serial for (j, 0, 20)
     {
       B[i, j] = (125.000000f + (X[i, j] + y[i, 0]))
    }
   }
}

首先看 SimplifyRampMutator 的角色作用,從源碼上來看,只關心兩種節點:ir::Rampir::Add

  • 對於ir::Add節點,如果兩個操作數都是 ir::Ramp 類型,且其 lanes 屬性值是一樣的話,則會構建一個 ir::Ramp 節點來替換掉 ir::Add 節點
  • 對於ir::Ramp 節點,則遞歸對其 basestride 屬性調用 Simplify 函數。
struct SimplifyRampMutator : public ir::IRMutator<Expr*> {
  void operator()(Expr* x) { ir::IRMutator<ir::Expr*>::Visit(x, x); }

  void Visit(const Ramp* op, Expr* expr) override {
    auto* node = expr->As<ir::Ramp>();

    CHECK(common::IsPureMath(node->base)) << node->base << "is not a pure math!";
    CHECK(common::IsPureMath(node->stride)) << node->stride << "is not a pure math!";
    ;
    Simplify(&node->base);
    Simplify(&node->stride);
  }
  // ramp + ramp
  void Visit(const Add* op, Expr* expr) override {
    auto* node  = expr->As<ir::Add>();
    Expr a      = node->a();
    Expr b      = node->b();
    auto a_ramp = a.As<ir::Ramp>();
    auto b_ramp = b.As<ir::Ramp>();

    if (a_ramp && b_ramp && a_ramp->lanes == b_ramp->lanes) {
      Expr base_add   = common::AutoSimplify(a_ramp->base + b_ramp->base);     // 這裏會做CAS
      Expr stride_add = common::AutoSimplify(a_ramp->stride + b_ramp->stride);
      *expr           = ir::Ramp::Make(base_add, stride_add, a_ramp->lanes);
    }
  }
};

我們這裏瞅一眼 ir::Ramp節點是什麼樣子的:

//! A linear ramp node.
struct Ramp : public ExprNode<Ramp> {
  Expr base, stride;
  int lanes;

  static Expr Make(Expr base, Expr stride, int lanes);

  void Verify() const override;

  static const IrNodeTy _node_type_ = IrNodeTy::Ramp;
};

接下來看第二個 SimplifyLoadMutator 的角色,簡單理解就是對 X[i+0, j+0] 以及 Y[i, j*0] 進行優化,得到 X[i, j]Y[i, 0]

struct SimplifyLoadMutator : public ir::IRMutator<ir::Expr*> {
  void operator()(Expr* x) { ir::IRMutator<ir::Expr*>::Visit(x, x); }

  void Visit(const Load* expr, Expr* op) override {
    auto* node = op->As<Load>();
    for (auto& idx : node->indices) {
      if (common::IsPureMath(idx)) {
        PartialSimplify(&idx, var_intervals_);   // << 也是藉助了CAS了
      } else {
        SimplifyButStoreLoadMutator mutator(var_intervals_);  // 根據節點類型,分發調用 PartialSimplify 函數
        mutator(&idx);
      }
    }
  }

  void Visit(const For* op, Expr* expr) override {
    auto* min_i    = op->min.As<IntImm>();
    auto* extent_i = op->extent.As<IntImm>();
    if (min_i && extent_i && extent_i->value > min_i->value) {
      var_intervals_.emplace(op->loop_var->name, common::CasInterval{min_i->value, extent_i->value - 1});
    }

    auto* node = expr->As<For>();

    operator()(&node->body);
    operator()(&node->extent);

    if (min_i && extent_i) {
      var_intervals_.erase(op->loop_var->name);
    }
  }

  common::cas_intervals_t var_intervals_;
};

第三個 SimplifyStoreMutator 的代碼邏輯基本與 SimplifyLoadMutator 一致,這裏我們不再贅述。
第四個 SimplifyIfThenElseMutator ,這個也很好理解,對 condition 調用 CAS 邏輯:

struct SimplifyIfThenElseMutator : public ir::IRMutator<> {
  void operator()(Expr* x) { ir::IRMutator<>::Visit(x, x); }

  using ir::IRMutator<>::Visit;

  void Visit(const IfThenElse* op, Expr* expr) override {
    auto* node      = expr->As<ir::IfThenElse>();
    node->condition = common::AutoSimplify(node->condition);   // 核心點

    if (node->true_case.defined()) Visit(&node->true_case, &node->true_case);   // 訪問者模式分發
    if (node->false_case.defined()) Visit(&node->false_case, &node->false_case); // 訪問者模式分發
  }
};

第五個 SimplifyButStoreLoadMutator 本來在第二、三個子邏輯會局部觸發,這裏爲何單獨觸發了一遍?從函數實現了是對其他節點都遍歷一遍進行簡化處理,唯獨除了 StoreLoad 節點(因爲這兩個節點主要出現在 ir::For 節點中)
第六個 ReplaceFracWithDivMutator ,這個很有意思,是把所有的 ir::FracOp 替換爲 ir::Div ,這兩個不一樣麼?仔細看了下,在一些 CodeGen 模塊裏,如 codegen_llvm.cc 中,是沒有實現 ir::FracOp 裏的代碼生成邏輯的,只有 ir::Div 實現了。那爲什麼不直接把 ir::FracOp 節點刪除呢?

struct ReplaceFracWithDivMutator : public ir::IRMutator<> {
  void operator()(Expr* x) { ir::IRMutator<>::Visit(x, x); }

  void Visit(const FracOp* op, Expr* expr) override {
    auto* node = expr->As<ir::FracOp>();

    ir::IRMutator<>::Visit(&node->operand(0), &node->operand(0));
    ir::IRMutator<>::Visit(&node->operand(1), &node->operand(1));

    *expr = ir::Div::Make(node->operand(0), node->operand(1));
  }
};

llvm::Value *CodeGenLLVM::Visit(const ir::FracOp *op) { __IR_EMITTER_NOT_IMPLEMENTED(op); }

在CINN裏我檢索了 lang/pe 等模塊源碼,沒有看到在 IR 層面直接使用或構造 ir::Frac 節點的代碼,只有單測和不相關模塊:

3.5 MapExternCall 邏輯

從調用棧來看, 底層是 map_extern_call.cc 中具體的 MapExternCall 的實現

void MapExternCall(Expr *e, Target target) {
      Mutator m(target);
      m(e);
}

所有的工作都是交給基於 Ast 的 Mutator 來做的,原理:藉助「訪問者模式」僅識別和處理 ir::Call 對象:

    void Visit(const ir::Call *op, Expr *expr) override {
      auto *node = expr->As<ir::Call>();
      CHECK(node);
      OptimizeConstantPow(node);
      if (target.arch == Target::Arch::NVGPU) {
        DealWithNvGpuintrinsics(node, expr);
      } else {
        DealWithCpuintrinsics(node, expr);
      }
    }

我們比較關心 CUDA 上的變換,進一步看 DealWithNvGpuintrinsics 函數:

    void DealWithNvGpuintrinsics(ir::Call *node, Expr *expr) {
      auto arg_size = node->read_args.size();
      if (arg_size == 0UL) {
        // some node like __syncthreads hasn't arguments
        return;
      }
      const auto &dtype = node->read_args.front().type();
      const auto &name  = node->name;

      bool node_in_extern_fp32  = kExternFp32CallsGPU.count(name);
      bool node_in_extern_int32 = kExternInt32CallsGPU.count(name);
      if (!node_in_extern_fp32 && !node_in_extern_int32) {
        return;
      }

      std::string suffix;
      if (dtype.is_int() && node_in_extern_int32) {
        if (dtype.is_int(32)) {
          suffix = "_int32";
        } else if (dtype.is_int(64)) {
          suffix = "_int64";
        }
      } else if (dtype.is_float() && node_in_extern_fp32) {
        if (dtype.is_float(64)) {
          suffix = "_fp64";
        } else if (dtype.is_float(32)) {
          suffix = "_fp32";
        } else if (dtype.is_float(16)) {
          suffix = "_fp16";
        }
      }
      CHECK(!suffix.empty()) << name << " not support data type " << dtype;
      std::string extern_func = "cinn_nvgpu_" + name + suffix;    // <------ 主要是按照OpNode白名單+dtype拼接要替換的外部 API (其實也是在CINN層裏wrapper註冊了一層)

      *expr = lang::CallExtern(extern_func, node->read_args);     // 直接替換 ir::Call 對應的 Expr 對象
    }
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章