CINN 中子圖編譯緩存機制

採用 「問-答」形式記錄研讀 CINN 開源框架的筆記

Q:CINN中子圖編譯的入口是在哪裏?

  for (const auto& node_vec : clusters) {  // <------- 逐個遍歷每個子圖
    // Classify var node to inputs, outputs, and internals.
    GraphNodeSet cluster_set(node_vec.begin(), node_vec.end());

    GraphNodeSet cluster_inputs, cluster_outputs, cluster_internals;
    AnalyseClusterVariables(cluster_set,
                            deny_var_set,
                            &cluster_inputs,
                            &cluster_outputs,
                            &cluster_internals,
                            is_inference_stage,
                            all_skip_gc_vars);

    auto subgraph = CreateNewSubGraph(
        cluster_set, cluster_internals, cluster_inputs, cluster_outputs);

    if (graph->Has(kSkipGcVarNames)) {
      auto& sub_skip_gc_vars =
          subgraph->GetOrInit<std::unordered_set<std::string>>(kSkipGcVarNames);
      sub_skip_gc_vars = all_skip_gc_vars;
    }
    auto compilation_key = cinn_compiler->AddGraph(std::move(subgraph));  // <------ 添加子圖(可能包含-1動態shape)
    VLOG(4) << "Compilation Key:\n"
            << cinn_compiler->ReadableKey(compilation_key);

    // Replace the found cluster to a new cinn op node
    ReplaceSubGraphWithCinnOpNode(cluster_set,     // <------- 編譯並緩存每個子圖的結果
                                  cluster_inputs,
                                  cluster_outputs,
                                  cluster_internals,
                                  compilation_key,
                                  graph);

Q:AddGraph做的事情是什麼?

int64_t CinnCompiler::AddGraph(std::unique_ptr<Graph> graph) {
  int64_t graph_key = std::hash<Graph *>()((&(*graph)));
  graphs_[graph_key] = std::move(graph);  // <------ 編譯期原生靜態圖包含-1
  return graph_key;
}
// Add一個graph後,會同步替換原生的Graph爲一個 [cinn_launch] Op

Q:CINN中不同Program下的子圖編譯結果可以複用麼?hashkey是否耦合了var_name?

size_t CinnCacheKeyByStructure::HashGraph(const ir::Graph& graph) {
  // sort grad node by name and id.
  auto compare = [](ir::Node* n1, ir::Node* n2) {
    return (n1->Name() == n2->Name()) ? (n1->id() < n2->id())
                                      : (n1->Name() < n2->Name());
  };

  // graph.Nodes() return unordered_set, here using set to avoid the same graph
  // may return different result
  std::set<ir::Node*, bool (*)(ir::Node*, ir::Node*)> node_set(compare),
      output_set(compare);
  node_set.insert(graph.Nodes().begin(), graph.Nodes().end());

  std::string hash_str;
  for (ir::Node* n : node_set) {
    hash_str.append(n->Name());

    output_set.clear();
    output_set.insert(n->outputs.begin(), n->outputs.end());
    for (auto* out : output_set) {
      hash_str.append(out->Name()); // <------ 耦合了graph中的var_name
    }
  }

  VLOG(1) << "The hash graph:\n" << hash_str;

  size_t hash_val = std::hash<std::string>()(hash_str);
  VLOG(4) << "The graph's hash value by graph structure is: " << hash_val;
  return hash_val;
}  //

Bert中具體的一個hash_key樣例:cumsumcumsum_0.tmp_0cumsum_0.tmp_0elementwise_subelementwise_subtmp_0feedinput_idsfetchfill_any_likefull_like_0.tmp_0full_like_0.tmp_0cumsumelementwise_subinput_idsfill_any_liketmp_0fetch

size_t CinnCacheKey::Hash::operator()(const CinnCacheKey& key) const {
  std::ostringstream has_str;

  for (const auto& name_shape : key.input_shapes_) {  // <------- 輸入shape信息
    has_str << name_shape.first;
    has_str << std::hash<phi::DDim>()(name_shape.second);
  }

  has_str << key.graph_hash_val_;   // graph 結構信息
  has_str << key.arch_str_;        // target 信息
  return std::hash<std::string>()(has_str.str());
}

Q:主框架是何時觸發「編譯」的?

template <typename DeviceContext, typename T>
class CinnLaunchOpKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    const auto& compilation_key = ctx.template Attr<int64_t>(kCompilationKey); 
     // 根據輸入的Tensor shape信息來觸發,此時會消解掉一些動態shape爲-1的值
    const auto& cinn_compiled_object = CinnCompiler::GetInstance()->Compile(
        compilation_key, inputs_name2tensor, target, stream);
  }

Q:CINN是如何消除動態shape的?

void CinnGraphSymbolization::RunOp(const CinnOpDesc& op_desc,
                                   const OpMapperContext& ctx) const {
  const auto& op_type = op_desc.Type();
  auto* kernel = ::cinn::frontend::OpMapperRegistry::Global()->Find(op_type);
  VLOG(4) << "Running Op " << op_type;
  kernel->Run(op_desc, ctx);  // 此處會由NetBuilder->build()分發到具體API上,調用infer_shape
}

Q:CINN內部是哪裏觸發緩存機制的?

const CinnCompiledObject &CinnCompiler::Compile(
    const Graph &graph,
    const std::map<std::string, const phi::DenseTensor *> &input_tensors,
    const Target &target,
    void *stream) {
  VLOG(4) << "-- The graph to be compiled is:\n" << VizGraph(graph);
  CinnCacheKeyByAddress cur_key_by_address(
      graph, input_tensors, target.arch_str());   // 優先通過graph.ptr + shape + target 來獲取?
  CinnCacheKeyByStructure cur_key_by_struct;      // 若未命中,則再以 graph info + shape + target 來獲取

  if (!cache_by_address_.count(cur_key_by_address)) {
    // generate the structure cache key
    cur_key_by_struct.SetKey(graph, input_tensors, target.arch_str());
    if (!cache_by_struct_.count(cur_key_by_struct)) {
      std::int64_t compiled_num = real_compiled_num_.fetch_add(1);
      auto compiled_res =
          CompileGraph(graph, input_tensors, target, compiled_num, stream); // 核心職責交給 CompileGraph
      std::unique_lock<std::mutex> guard(lock_);
      // double check cache_by_struct_
      if (!cache_by_struct_.count(cur_key_by_struct)) {
        cache_by_struct_[cur_key_by_struct] = compiled_num;
        index2cache_.emplace(compiled_num, std::move(compiled_res));
      }
      // double check cache_by_address_
      if (!cache_by_address_.count(cur_key_by_address)) {
        cache_by_address_[cur_key_by_address] =
            cache_by_struct_.at(cur_key_by_struct);
      }
    } else {
      std::unique_lock<std::mutex> guard(lock_);
      // double check cache_by_address_
      if (!cache_by_address_.count(cur_key_by_address)) {
        cache_by_address_[cur_key_by_address] =
            cache_by_struct_.at(cur_key_by_struct);
      }
    }
  }
  return *index2cache_.at(cache_by_address_.at(cur_key_by_address));
}

Q: CompileGraph裏的核心職責是什麼,是否還有緩存?

std::unique_ptr<CinnCompiledObject> CinnCompiler::CompileGraph(
    const ir::Graph &graph,
    const std::map<std::string, const phi::DenseTensor *> &input_tensors,
    const Target &target,
    std::int64_t compiled_num,
    void *stream) const {
  CinnGraphSymbolization symbol{compiled_num, graph, target, input_tensors};
  auto frontend_program = symbol();
  auto fetch_ids = symbol.GetFetchIds();
  VLOG(4) << "All fetch var ids in CINN: "
          << string::join_strings(fetch_ids, ',');

  auto cinn_graph = Optimize(&frontend_program, fetch_ids, target); // 同一個ir::Graph僅會做一次
  VLOG(4) << "-- The " << compiled_num << "-th compilation ("
          << target.arch_str() << "), and its related graph:\n"
          << cinn_graph->Visualize();

  auto scope = BuildScope(target, cinn_graph);
  auto graph_compiler =
      std::make_unique<GraphCompiler>(target, scope, cinn_graph); // GraphCompiler一次性工作,但會被compiled_obj持有
  GraphCompiler::CompileOptions options;
  options.with_instantiate_variables = false;
  if (!FLAGS_enable_pe_launch_cinn) {
    options.with_buffer_handle_instruction_inserted = true;
  }
  std::unique_ptr<AutoTuner> auto_tuner;
  if (FLAGS_enable_cinn_auto_tune) {
    VLOG(4) << "Compile with auto-tune";
    auto_tuner = std::make_unique<AutoTuner>(target, cinn_graph.get());
    auto_tuner->Initialize(AutoTuner::Config(), graph_compiler.get());
    ::cinn::auto_schedule::TuningOptions tuning_options;
    tuning_options.num_measure_trials = 0;
    auto tuning_result = auto_tuner->Tune(tuning_options);
    options.Apply(tuning_result);
  }
  auto compiled_res =
      graph_compiler->Build(options, std::move(fetch_ids), stream);
  auto compiled_obj = std::make_unique<CinnCompiledObject>();
  *compiled_obj = {std::move(graph_compiler),
                   std::move(auto_tuner),
                   std::move(compiled_res.runtime_program),
                   scope,
                   symbol.var_model_to_program_map()};  // <------對應於 paddle2cinn_varmap
  compiled_obj->cached_index = compiled_num;
  compiled_obj->launch_context =
      std::make_unique<operators::details::CinnLaunchContext>(graph,
                                                              *compiled_obj);
  CheckCompiledValid(graph, input_tensors, *compiled_obj);
  return compiled_obj;
}

Q:GraphCompiler負責編譯鏈接的任務均交給了backends::Compiler,那麼此後端Compiler是否有編譯緩存呢?

A:host module 端看起來主要是函數聲明和調用邏輯,device module 主要是函數定義

如下是一個 CodeGen 生成的源碼,即將寫到一個 file 文件中傳遞給編譯引擎做編譯。如果是多個函數,則會放到同一個文件中編譯、鏈接。

從代碼來看,我理解對於一個 CINN 的 sub graph ,會對應一個GraphCompiler來編譯生成一個名稱範式爲:fn_xxx_yyy_zzz 的函數:

  • 描述 sub graph 裏所有 op 整體的計算邏輯
  • 可能經過算子 Decompose、優化等邏輯,生成多個子函數
  • 多個子函數放到一個 host 文件、一個 cuda 文件,統一編譯、鏈接成一個函數指針
  • 待確認項:所以lower_func層面是沒有緩存的?

上圖是在構建 engine_ = ExecutionEngine::Create(ExecutionOptions(), std::move(symbols));

附錄:TVM中編譯實現

Q:TVM裏類似 GraphCompiler 的角色是什麼?

A:大致複習了TVM的源碼,感覺是 TECompilerImpl ,繼承自TECompilerNode,提供瞭如下核心接口:

  // Lower the function.
  CachedFunc Lower(const CCacheKey& key) {
    return LowerInternal(key, global_var_supply_)->cached_func;
  }
  
// For now, build one module per function.
  PackedFunc JIT(const CCacheKey& key) final {
    CCacheValue value = LowerInternal(key, GlobalVarSupply(NameSupply("")));
    if (value->packed_func != nullptr) {
      return value->packed_func;
    }
    auto m = build(value->cached_func->funcs, key->target, Target(nullptr));   // <------ 此處 m 是一個 runtime::Module 對象
    value->packed_func = m.GetFunction(value->cached_func->prim_fn_var->name_hint);
    return value->packed_func;
  }

  CachedFunc LowerShapeFunc(const CCacheKey& key) final {
    return LowerShapeFuncInternal(key)->cached_func;
  }

值得注意的是,TECompilerImpl 中包含了兩個緩存相關的數據結構:

  /*! \brief internal compiler cache */
  std::unordered_map<CCacheKey, CCacheValue> cache_;
  
  /*! \brief internal compiler cache for shape funcs */
  std::unordered_map<CCacheKey, CCacheValue> shape_func_cache_;

Q: 上述 build () 方法是做什麼用的?與飛槳的 backend::Compiler 角色是一樣的麼?

A:我認爲是一樣的,而且其返回的 runtime::Module 對象似乎可以對標飛槳 CINN 中的 RuntimeProgram來理解?

// Build for heterogeneous execution when targets are specified as
// objects.  This wrapper around the internal API is maintained for
// backwards compatibility.
runtime::Module build(const Map<Target, IRModule>& input, const Target& target_host) {
  return TIRToRuntime(input, target_host);
}

runtime::Module TIRToRuntime(const Map<Target, IRModule>& inputs_arg,
                             const Target& target_host_arg) {   // <------- 實現
  std::vector<runtime::Module> device_modules;
  Map<Target, IRModule> inputs = inputs_arg;
  Target target_host = target_host_arg;

  // Fetch previous defined target host in targets
  CheckAndUpdateHostConsistency(&inputs, &target_host);

  if (!target_host.defined()) {
    for (const auto& it : inputs) {
      if (it.first->GetTargetDeviceType() == kDLCPU ||
          it.first->GetTargetDeviceType() == kDLMicroDev) {
        target_host = it.first;
        break;
      }
    }
  }

  if (!target_host.defined()) {
    target_host = DefaultTargetHost(target_host);
  }

  // Update target host for all targets
  CheckAndUpdateHostConsistency(&inputs, &target_host);

  // Take the attrs from the first module so the eventual modules have them.
  // Ideally this would just be one unified module all the way through;
  IRModule first_module = (*inputs.begin()).second;
  IRModule mhost_all = IRModule(Map<GlobalVar, BaseFunc>(), {}, {}, {}, first_module->attrs);

  ICHECK(mhost_all.defined()) << "The host module must be defined";

  for (const auto& it : inputs) {
    if (it.second.defined()) {
      const Target& target = it.first;
      const IRModule& ir_module = it.second;
      auto pair = SplitMixedModule(ir_module, target, target_host);
      auto& host_mod = pair.first;
      auto& device_mod = pair.second;

      ICHECK(host_mod.defined()) << "The split host module must be defined";

      ICHECK(mhost_all.defined()) << "The host module must be defined";

      // We don't want library modules going back into host codegen
      // unless they're supposed to. Here if we overrode the target host
      // to allow lowering previously we check that it's meant to be placed
      // back into the host Module.
      bool overrides_host_target =
          target->GetTargetDeviceType() == target_host->GetTargetDeviceType();
      bool non_host_target_kind = target->kind != target_host->kind;
      if (overrides_host_target && non_host_target_kind) {
        device_modules.push_back(codegen::Build(host_mod, it.first));
      } else {
        mhost_all->Update(host_mod);
      }

      if (device_mod->functions.size() != 0) {
        device_modules.push_back(codegen::Build(device_mod, it.first));
      }
    }
  }

  runtime::Module mhost = codegen::Build(mhost_all, target_host);   // <----- 編譯?
  for (const auto& it : device_modules) {
    if (it.operator->()) {
      mhost.Import(it);
    }
  }

  return mhost;
}

runtime::Module Build(IRModule mod, Target target) {
  if (transform::PassContext::Current()
          ->GetConfig<Bool>("tir.disable_assert", Bool(false))
          .value()) {
    mod = tir::transform::SkipAssert()(mod);
  }

  auto target_attr_map = tvm::TargetKind::GetAttrMap<FTVMTIRToRuntime>("TIRToRuntime");
  if (target_attr_map.count(target->kind)) {
    return target_attr_map[target->kind](mod, target);
  }

  // the build function.
  std::string build_f_name = "target.build." + target->kind->name;
  const PackedFunc* bf = runtime::Registry::Get(build_f_name);
  ICHECK(bf != nullptr) << build_f_name << " is not enabled";
  return (*bf)(mod, target);
}
TVM_REGISTER_GLOBAL("target.build.cuda").set_body_typed(BuildCUDA);

runtime::Module BuildCUDA(IRModule mod, Target target) {
  using tvm::runtime::Registry;
  bool output_ssa = false;
  CodeGenCUDA cg;
  cg.Init(output_ssa);

  for (auto kv : mod->functions) {
    ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenCUDA: Can only take PrimFunc";
    auto f = Downcast<PrimFunc>(kv.second);
    auto calling_conv = f->GetAttr<Integer>(tvm::attr::kCallingConv);
    ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch)
        << "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch";
    cg.AddFunction(f);
  }

  std::string code = cg.Finish();

  if (const auto* f = Registry::Get("tvm_callback_cuda_postproc")) {
    code = (*f)(code).operator std::string();
  }
  std::string fmt = "ptx";
  std::string ptx;
  const auto* f_enter = Registry::Get("target.TargetEnterScope");
  (*f_enter)(target);
  if (const auto* f = Registry::Get("tvm_callback_cuda_compile")) {
    ptx = (*f)(code).operator std::string();
    // Dirty matching to check PTX vs cubin.
    // TODO(tqchen) more reliable checks
    if (ptx[0] != '/') fmt = "cubin";
  } else {
    ptx = NVRTCCompile(code, cg.need_include_path());
  }
  const auto* f_exit = Registry::Get("target.TargetExitScope");
  (*f_exit)(target);
  return CUDAModuleCreate(ptx, fmt, ExtractFuncInfo(mod), code);
}

Q:TVM中是從哪裏調用執行的?

A:看到了一個 GraphExecutor 的數據結構。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章