ONNX Runtime 源碼閱讀:模型結點串行執行順序的確定

概要

ONNX模型中的結構是一個有向圖,包含了很多節點。每個節點執行一個特定的操作,最終就得到了推理結果。ONNX模型格式標準並沒有要求所有節點按照拓撲順序來存儲,進行模型解析的時候也基本不要求解析出來的節點一定要符合拓撲順序排列。有些模型很簡單,從輸入到輸出,可能只有一條通路;有些模型很複雜,不僅輸入和輸出節點間存在多條通路,還有可能存在多個輸入和輸出節點。ONNX Runtime 是如何確定模型中各個節點執行的先後順序的呢?怎麼確保某個節點被執行之前,其所有先導節點都已經被執行?這就是今天需要解決的疑惑。ONNX Runtime 執行模型的方式主要有兩種:串行和並行,好像有點廢話了。通過初始化的時候傳遞個InferenceSession的構造函數的結構體SessionOptions中的ExecutionMode成員來控制。今天主要研究串行執行時節點執行順序。

涉及文件

onnxruntime\onnxruntime\python\onnxruntime_pybind_state.cc
onnxruntime\onnxruntime\core\session\inference_session.cc
onnxruntime\onnxruntime\core\framework\sequential_executor.cc
onnxruntime\onnxruntime\core\framework\session_state_initializer.cc
onnxruntime\onnxruntime\core\graph\graph_viewer.cc
onnxruntime\onnxruntime\core\framework\session_state.cc
onnxruntime\onnxruntime\core\graph\graph.cc

正文

舉個栗子,有一個簡單的模型,如圖1所示:
圖1 一個簡單的模型
在這個簡單的模型裏面,一共有六個節點,從輸入到輸出有兩條通路。由於ONNX模型格式標準並沒有要求所有節點按照拓撲順序來存儲,因此模型再次加載到內存以後,節點的順序的排列完全是隨機的,有可能是1、3、2、4、6、5,也可能是其他的順序。因此,必須要先確定節點的拓撲結構並按照結構存儲起來,這樣才能在跑的時候知道那個是輸入,哪些節點必須先跑完。

代碼調用

在上一篇文章ONNX Runtime 源碼閱讀:模型推理過程概覽中我們說過,模型節點執行順序的確定是在InferenceSession實例化完畢後,在初始化階段完成的。

// onnxruntime\onnxruntime\python\onnxruntime_pybind_state.cc
  py::class_<InferenceSession>(m, "InferenceSession", R"pbdoc(This is the main class used to run a model.)pbdoc")
      .def(
          "load_model", [](InferenceSession* sess, std::vector<std::string>& provider_types) {
            OrtPybindThrowIfError(sess->Load());
            InitializeSession(sess, provider_types);
          },
          R"pbdoc(Load a model saved in ONNX format.)pbdoc")

從上面代碼中可以看到,初始化也分爲兩個階段:1)模型加載 2)InferenceSession實例初始化。
模型加載?模型不是在生成InferenceSession實例的時候已經加載到內存了麼?其實在InferenceSession實例化階段加載的模型知識編譯proto文件得到的類ModelProto的一個實例,直接使用還是不太方便,因此還需要對它進行進一步解析和封裝,OrtPybindThrowIfError(sess->Load());這句話主要做的就是這件事。
我們接着來看InitializeSession(sess, provider_types);:

// onnxruntime\onnxruntime\python\onnxruntime_pybind_state.cc
void InitializeSession(InferenceSession* sess, const std::vector<std::string>& provider_types) {
  if (provider_types.empty()) {
    // use default registration priority.
    RegisterExecutionProviders(sess, GetAllProviders());
  } else {
    RegisterExecutionProviders(sess, provider_types);
  }
  OrtPybindThrowIfError(sess->Initialize());
}

可以看到,InitializeSession(sess, provider_types)在註冊Provider後,最終調用到了onnxruntime\onnxruntime\core\session\inference_session.cc中類InferenceSessionInitiablize()方法。
Initiablize()方法體非常長,但是有兩行非常刺眼,session_initializer.CreatePlan; InitializeSubgraphSessions(graph, *session_state_),字面意思就是創建執行計劃,開個上帝視角執行順序這的是在這裏創建的。由於方法體很長,這就貼一部分重要的好了:

// onnxruntime\onnxruntime\core\session\inference_session.cc # InferenceSession::Initialize()
onnxruntime::Graph& graph = model_->MainGraph();

    // Collect the kernel registries from execution provider instances;
    // There are 2 kinds of kernel registries with priority from high to low as below,
    // 1. Custom execution provider type specific kernel registries.
    // 2. common execution provider type specific kernel registries.
    // The 1st and 2nd ones are shared across sessions.
    // The 1st ones should have already been registered via session-level API into KernelRegistryManager.
    //
    // Register 2nd registries into KernelRegistryManager.
    ORT_RETURN_IF_ERROR_SESSIONID_(kernel_registry_manager_.RegisterKernels(execution_providers_));

    SessionStateInitializer session_initializer(session_options_.enable_mem_pattern, model_location_, graph,
                                                *session_state_, execution_providers_, kernel_registry_manager_);

    // create SessionState for subgraphs as it's needed by the transformers
    ORT_RETURN_IF_ERROR_SESSIONID_(CreateSubgraphSessionState(graph, *session_state_));

    // apply any transformations to the main graph and any subgraphs
    ORT_RETURN_IF_ERROR_SESSIONID_(TransformGraph(graph, *graph_transformation_mgr_,
                                                  execution_providers_, kernel_registry_manager_,
                                                  insert_cast_transformer_,
                                                  *session_state_));

    // now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs.
    ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve());

    if (!session_options_.optimized_model_filepath.empty()) {
      // Serialize optimized ONNX model.
      ORT_RETURN_IF_ERROR_SESSIONID_(Model::Save(*model_, session_options_.optimized_model_filepath));
      if (session_options_.graph_optimization_level >= TransformerLevel::Level3) {
        LOGS(*session_logger_, WARNING) << "Serializing Optimized ONNX model with Graph Optimization"
                                           " level greater than ORT_ENABLE_EXTENDED. The generated"
                                           " model may contain hardware and execution provider specific"
                                           " optimizations, and should only be used in the same environment"
                                           " the model was optimized for.";
      }
    }

    ORT_RETURN_IF_ERROR_SESSIONID_(session_initializer.CreatePlan(nullptr, nullptr, session_options_.execution_mode));

    // handle any subgraphs
    ORT_RETURN_IF_ERROR_SESSIONID_(InitializeSubgraphSessions(graph, *session_state_));
    is_inited_ = true;

但是,開上帝視角之前,我們是怎麼知道這一段就是我們心心念唸的代碼?一方面,我們從模型推理時的方法調用中發現執行的時候發現直接取到了一個已經按照拓撲順序存儲的結點序列,

// onnxruntime\onnxruntime\core\framework\sequential_executor.cc#SequentialExecutor::Execute()
 const SequentialExecutionPlan& seq_exec_plan = *session_state.GetExecutionPlan();

和這裏的CreatePlan可謂遙相呼應,更重要的是,這個序列是從SessionState的實例中取出來的,有出肯定有入,我們需要緊盯着這個序列什麼時候被放進去的。恰好,在SessionStateInitializer的實例中SessionState和模型中取出的主圖同時出現,讓人不得不將焦點聚集在這;另一方面,這裏的代碼命名非常好,可謂顧名思義。不禁讓人感嘆,寫的出代碼是一回事兒,讓人容易看懂又是另一回事兒了,畢竟,良好的代碼不僅要高效還要易讀。
代碼的開始,先從模型中取到主圖,然後將主圖和一個SessionState的實例session_state_和其他參數一起傳遞給了SessionStateInitializer的構造函數,該構造函數僅僅是做了些簡單的賦值操作,然後就執行到了SessionStateInitializer的方法CreatePlan()

// onnxruntime\onnxruntime\core\framework\session_state_initializer.cc#SessionStateInitializer::CreatePlan()
common::Status SessionStateInitializer::CreatePlan(
    const Node* parent_node,
    const ConstPointerContainer<std::vector<NodeArg*>>* outer_scope_node_args,
    ExecutionMode execution_mode) {
  session_state_.SetGraph(graph_);
  const GraphViewer* graph_viewer = session_state_.GetGraphViewer();

  // populate the SessionState OrtValueNameIdxMap
  const auto& ort_value_name_idx_map = session_state_.GetOrtValueNameIdxMap();

  // ignore any outer scope args we don't know about. this can happen if a node contains multiple subgraphs.
  std::vector<const NodeArg*> valid_outer_scope_node_args;
  if (outer_scope_node_args) {
    std::for_each(outer_scope_node_args->cbegin(), outer_scope_node_args->cend(),
                  [&ort_value_name_idx_map, &valid_outer_scope_node_args](const NodeArg* node_arg) {
                    int idx;
                    if (ort_value_name_idx_map.GetIdx(node_arg->Name(), idx).IsOK()) {
                      valid_outer_scope_node_args.push_back(node_arg);
                    };
                  });
  }

  std::unique_ptr<SequentialExecutionPlan> exec_plan;
  SequentialPlannerContext context(execution_mode);
  ORT_RETURN_IF_ERROR(SequentialPlanner::CreatePlan(parent_node, *graph_viewer, valid_outer_scope_node_args,
                                                    execution_providers_, kernel_registry_manager_,
                                                    ort_value_name_idx_map, context, exec_plan));
  session_state_.SetExecutionPlan(std::move(exec_plan));

  const auto* exec_plan_ptr = session_state_.GetExecutionPlan();
  ORT_ENFORCE(exec_plan_ptr, "Execution plan was not found in SessionState. CreatePlan must be called first.");
// omitting other code 
// ....
}

按照我們之前的理論,我們繼續跟隨SequentialPlanner::CreatePlan()這個方法:

// onnxruntime\onnxruntime\core\framework\allocation_planner.cc#SequentialPlanner::CreatePlan()
Status SequentialPlanner::CreatePlan(const Node* parent_node, const onnxruntime::GraphViewer& graph_viewer,
                                     const std::vector<const NodeArg*>& outer_scope_node_args,
                                     const ExecutionProviders& providers, const KernelRegistryManager& kernel_registry,
                                     const OrtValueNameIdxMap& ort_value_name_idx_map,
                                     const ISequentialPlannerContext& context,
                                     std::unique_ptr<SequentialExecutionPlan>& plan) {
  // allocate/reset here so we know it's clean
  plan = onnxruntime::make_unique<SequentialExecutionPlan>();

  PlannerImpl planner(parent_node, graph_viewer, outer_scope_node_args, providers, kernel_registry,
                      ort_value_name_idx_map, context, *plan);

  return planner.CreatePlan();
}

這個方法生成一個PlannerImpl實例後,接着套娃:

// onnxruntime\onnxruntime\core\framework\allocation_planner.cc#PlannerImpl::CreatePlan()
Status PlannerImpl::CreatePlan() {
  auto& p_graph_nodes = graph_viewer_.GetNodesInTopologicalOrder();

  int num_ml_values = ort_value_name_idx_map_.MaxIdx() + 1;

  Initialize(p_graph_nodes.size(), static_cast<size_t>(num_ml_values));

  // Determine execution order: we use the default topological sort order for now. We can later
  // explore more efficient orderings (from a memory usage perspective).
  for (auto n : p_graph_nodes) {
    plan_.execution_plan.emplace_back(n);
  }
// omitting some code
// ......
}

看到auto& p_graph_nodes = graph_viewer_.GetNodesInTopologicalOrder();這句,有種守的雲開見月明的感覺。可惜,進去一看,裏面已經是一個進行了拓撲排序的列表。沒道理啊?怎麼可能在我們眼皮底下偷摸的把拓撲關係做了?難道我們上帝視角也出了問題?答案當然不是,只不過是因爲保存網絡節點拓撲關係的SessionState對象非常勤奮,在它獲取到模型結構圖的時候,就把節點按拓撲排序排了,根本不管你deadline是什麼時候。
我們回到上面SessionStateInitializer::CreatePlan()這個方法,方法體第一句session_state_.SetGraph(graph_);把模型結構圖給了SessionState,而SessionState馬上又把模型結構圖給了他的小弟GraphViewer,進入GraphViewer我們終於發現,尋他千百度的拓撲排序就在這裏。從字面上看,graph.ReverseDFSFrom()用的拓撲排序算法就是深度優先搜索算法。
進入SessionState.SetGraph()

// onnxruntime\onnxruntime\core\framework\session_state.cc
Status SessionState::SetGraph(const Graph& graph) {
  graph_viewer_ = onnxruntime::make_unique<onnxruntime::GraphViewer>(graph);
  auto& logger = Logger();
  // use graph_viewer_ to initialize ort_value_name_idx_map_
  LOGS(logger, INFO) << "SaveMLValueNameIndexMapping";
  int idx = 0;
// omitted some code
// ...
}

// onnxruntime\onnxruntime\core\graph\graph_viewer.cc
GraphViewer::GraphViewer(const Graph& graph) {
  graph_ = &graph;
  std::vector<const Node*> leaf_nodes;
  for (auto& node : graph_->Nodes()) {
    if (node.OutputNodesBegin() == node.OutputNodesEnd()) {
      // This is a leaf node (without any output node).
      leaf_nodes.push_back(&node);
    }
  }
  graph.ReverseDFSFrom(
      leaf_nodes,
      nullptr,
      [this](const Node* n) {
        nodes_in_topological_order_.push_back(n->Index());
      },
      NodeCompare());

  for (auto& node : graph_->Nodes()) {
    if (node.InputEdgesBegin() == node.InputEdgesEnd()) {
      root_nodes_.push_back(node.Index());
    }
  }
}

算法

下面讓我們來看看具體的算法實現的吧:

// onnxruntime\onnxruntime\core\graph\graph.cc#Graph::ReverseDFSFrom()
void Graph::ReverseDFSFrom(const std::vector<const Node*>& from,
                           const std::function<void(const Node*)>& enter,
                           const std::function<void(const Node*)>& leave,
                           const std::function<bool(const Node*, const Node*)>& comp) const {
  using WorkEntry = std::pair<const Node*, bool>;  // bool represents leave or not
  std::vector<WorkEntry> stack(from.size());
  for (size_t i = 0; i < from.size(); i++) {
    stack[i] = WorkEntry(from[i], false);
  }

  std::vector<bool> visited(MaxNodeIndex(), false);
  while (!stack.empty()) {
    const WorkEntry last_entry = stack.back();
    stack.pop_back();
    const Node& n = *last_entry.first;
    if (last_entry.second) {
      // leave node
      leave(&n);
      continue;
    }

    if (visited[n.Index()]) continue;

    visited[n.Index()] = true;

    if (enter) enter(&n);

    if (leave) stack.emplace_back(&n, true);

    if (comp) {
      std::vector<const Node*> sorted_nodes;
      for (auto iter = n.InputNodesBegin(); iter != n.InputNodesEnd(); ++iter) {
        sorted_nodes.push_back(&(*iter));
      }
      std::sort(sorted_nodes.begin(), sorted_nodes.end(), comp);
      for (const auto* in : sorted_nodes) {
        const NodeIndex idx = in->Index();
        if (!visited[idx]) {
          stack.emplace_back(in, false);
        }
      }
    } else {
      for (auto iter = n.InputNodesBegin(); iter != n.InputNodesEnd(); ++iter) {
        const NodeIndex idx = (*iter).Index();
        if (!visited[idx]) {
          stack.emplace_back(GetNode(idx), false);
        }
      }
    }
  }
}

算法中通過一個站存儲節點,每個節點有一個標誌位表示該節點是否可以被取走放入拓撲隊列,我們可以稱之爲可入隊列標誌,另外再用一個列表表示某個節點是否已經被訪問過,我們可以稱之爲已訪問標誌。
與一般DFS略有區別的地方,就是它不需要先找到根節點,給定任意一個節點,它最終都能得到一個合理的拓撲列表。它是怎麼實現的呢?很簡單,直接在存儲節點的棧上進行操作:

  1. 開始的時候節點隨機入棧,可如隊列標誌和已訪問標誌都清除;
  2. 棧頂元素出棧,如果:
    • 可入隊標誌位被設置,則該元素進入拓撲隊列,重新開始第二步;
    • 如果該節點已訪問標誌位被設置,說明該節點已經進入拓撲隊列,重新開始第二步;
    • 可入隊標誌位未被設置,設置該節點的已訪問標誌位和可入棧標誌位,重新入棧;並找到該節點所有輸入節點,按一定規則排序後,清空輸入節點的可入棧標誌位,依次入棧。
  3. 重複第二步直到棧中所有元素都已經彈出並放入拓撲隊列中。
    例如我們最開頭的一個簡單模型,假設入棧後其排列爲:1,4,2,6,5,3。其算法過程如圖2圖3所示,其中,黃色表示可入隊標誌被設置,粉紅色表示已訪問標誌被設置,淡藍色表示拓撲隊列裏的內容:
    圖2 DFS1

圖2 DFS2
最終,我們得到了一個拓撲隊列中內容爲:1,2,3,4 ,5 ,6。這個隊列確保了每個節點被執行的時候,它的輸入節點肯定已經被執行。例如,當節點5執行的時候,他的輸入節點3和4已經被執行了。

子圖

如果模型中還有子圖,子圖的處理過程也和主圖類似,這裏就不多說了。

總結

InferenceSession就好似一個統帥,SessionState替他保存推理需要的信息,IExecutor幫他進行推理工作。

算了,就不強行總結了。


本文首發於個人微信公衆號TensorBoy。如果你覺得內容還不錯,歡迎分享並關注我的微信公衆號TensorBoy,掃描下方二維碼獲取更多精彩原創內容!
公衆號二維碼

發佈了26 篇原創文章 · 獲贊 1 · 訪問量 1萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章