Tensorflow源碼解析6 -- TensorFlow本地運行時

1 概述

TensorFlow後端分爲四層,運行時層、計算層、通信層、設備層。運行時作爲第一層,實現了session管理、graph管理等很多重要的邏輯,是十分關鍵的一層。根據任務分佈的不同,運行時又分爲本地運行時和分佈式運行時。本地運行時,所有任務運行於本地同一進程內。而分佈式運行時,則允許任務運行在不同機器上。

Tensorflow的運行,通過session搭建了前後端溝通的橋樑,前端幾乎所有操作都是通過session進行。session的生命週期由創建、運行、關閉、銷燬組成,前文已經詳細講述過。可以將session看做TensorFlow運行的載體。而TensorFlow運行的核心對象,則是計算圖Graph。它由計算算子和計算數據兩部分構成,可以完整描述整個計算內容。Graph的生命週期包括構建和傳遞、剪枝、分裂、執行等步驟,本文會詳細講解。理解TensorFlow的運行時,重點就是理解會話session和計算圖Graph。

本地運行時,client master和worker都在本地機器的同一進程內,均通過DirectSession類來描述。由於在同一進程內,三者間可以共享內存,通過DirectSession的相關函數實現調用。

client前端直接面向用戶,負責session的創建,計算圖Graph的構造。並通過session.run()將Graph序列化後傳遞給master。master收到後,先反序列化得到Graph,然後根據反向依賴關係,得到幾個最小依賴子圖,這一步稱爲剪枝。之後master根據可運行的設備情況,將子圖分裂到不同設備上,從而可以併發執行,這一步稱爲分裂。最後,由每個設備上的worker並行執行分裂後的子圖,得到計算結果後返回。

2 Graph構建和傳遞

session.run()開啓了後端Graph的構建和傳遞。在前文session生命週期的講解中,session.run()時會先調用_extend_graph()將要運行的Operation添加到Graph中,然後再啓動運行過程。extend_graph()會先將graph序列化,得到graph_def,然後調用後端的TF_ExtendGraph()方法。下面我們從c_api.cc中的TF_ExtendGraph()看起。

// 增加節點到graph中,proto爲序列化後的graph
void TF_ExtendGraph(TF_DeprecatedSession* s, const void* proto,
                    size_t proto_len, TF_Status* status) {
  GraphDef g;
  // 先將proto轉換爲GrapDef。graphDef是圖的序列化表示,反序列化在後面。
  if (!tensorflow::ParseProtoUnlimited(&g, proto, proto_len)) {
    status->status = InvalidArgument("Invalid GraphDef");
    return;
  }

  // 再調用session的extend方法。根據創建的不同session類型,多態調用不同方法。
  status->status = s->session->Extend(g);
}

後端系統根據生成的Session類型,多態的調用Extend方法。如果是本地session,則調用DirectSession的Extend()方法。下面看DirectSession的Extend()方法。

Status DirectSession::Extend(const GraphDef& graph) {
  // 保證線程安全,然後調用ExtendLocked()
  mutex_lock l(graph_def_lock_);
  return ExtendLocked(graph);
}

// 主要任務就是創建GraphExecutionState對象。
Status DirectSession::ExtendLocked(const GraphDef& graph) {
  bool already_initialized;

  if (already_initialized) {
    TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph.library()));

    // 創建GraphExecutionState
    std::unique_ptr<GraphExecutionState> state;
    TF_RETURN_IF_ERROR(execution_state_->Extend(graph, &state));
    execution_state_.swap(state);
  }
  return Status::OK();
}

最終創建了GraphExecutionState對象。它主要工作有

  1. 負責將GraphDef反序列化爲graph,從而構造出graph。在初始化方法InitBaseGraph()中
  2. 執行部分op編排工作,在初始化方法InitBaseGraph()中
Status GraphExecutionState::InitBaseGraph(const BuildGraphOptions& options) {
  const GraphDef* graph_def = &original_graph_def_;

  // graphDef反序列化得到graph
  std::unique_ptr<Graph> new_graph(new Graph(OpRegistry::Global()));
  GraphConstructorOptions opts;
  TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, *graph_def, new_graph.get()));

  // 恢復有狀態的節點
  RestoreStatefulNodes(new_graph.get());

  // 構造優化器的選項 optimization_options
  GraphOptimizationPassOptions optimization_options;
  optimization_options.session_options = session_options_;
  optimization_options.graph = &new_graph;
  optimization_options.flib_def = flib_def_.get();
  optimization_options.device_set = device_set_;

  TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
      OptimizationPassRegistry::PRE_PLACEMENT, optimization_options));

  // plaer執行op編排
  Placer placer(new_graph.get(), device_set_, session_options_);
  TF_RETURN_IF_ERROR(placer.Run());

  TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
      OptimizationPassRegistry::POST_PLACEMENT, optimization_options));

  // 報春狀態節點
  SaveStatefulNodes(new_graph.get());
  graph_ = new_graph.release();
  return Status::OK();
}

構造Graph:反序列化GraphDef爲Graph

由於client傳遞給master的是序列化後的計算圖,所以master需要先反序列化。通過ConvertGraphDefToGraph實現。代碼在graph_constructor.cc中,如下

Status ConvertGraphDefToGraph(const GraphConstructorOptions& opts,
                              const GraphDef& gdef, Graph* g) {
  ShapeRefiner refiner(gdef.versions().producer(), g->op_registry());
  return GraphConstructor::Construct(
      opts, gdef.node(), &gdef.versions(), &gdef.library(), g, &refiner,
      /*return_tensors=*/nullptr, /*return_nodes=*/nullptr,
      /*missing_unused_input_map_keys=*/nullptr);
}

編排OP

Operation編排的目的是,將op以最高效的方式,放在合適的硬件設備上,從而最大限度的發揮硬件能力。通過Placer的run()方法進行,算法很複雜,在placer.cc中,我也看得不大懂,就不展開了。

3 Graph剪枝

反序列化構建好Graph,並進行了Operation編排後,master就開始對Graph剪枝了。剪枝就是根據Graph的輸入輸出列表,反向遍歷全圖,找到幾個最小依賴的子圖,從而方便並行計算。

Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options,
                                       std::unique_ptr<ClientGraph>* out) {

  std::unique_ptr<Graph> ng;
  Status s = OptimizeGraph(options, &ng);
  if (!s.ok()) {
    // 1 複製一份原始的Graph
    ng.reset(new Graph(flib_def_.get()));
    CopyGraph(*graph_, ng.get());
  }

  // 2 剪枝,根據輸入輸出feed fetch,對graph進行增加節點或刪除節點等操作。通過RewriteGraphForExecution()方法
  subgraph::RewriteGraphMetadata rewrite_metadata;
  if (session_options_ == nullptr ||
      !session_options_->config.graph_options().place_pruned_graph()) {
    TF_RETURN_IF_ERROR(subgraph::RewriteGraphForExecution(
        ng.get(), options.feed_endpoints, options.fetch_endpoints,
        options.target_nodes, device_set_->client_device()->attributes(),
        options.use_function_convention, &rewrite_metadata));
  }

  // 3 處理優化選項optimization_options
  GraphOptimizationPassOptions optimization_options;
  optimization_options.session_options = session_options_;
  optimization_options.graph = &ng;
  optimization_options.flib_def = flib.get();
  optimization_options.device_set = device_set_;

  TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
      OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));

  // 4 複製一份ClientGraph
  std::unique_ptr<ClientGraph> dense_copy(
      new ClientGraph(std::move(flib), rewrite_metadata.feed_types,
                      rewrite_metadata.fetch_types));
  CopyGraph(*ng, &dense_copy->graph);

  *out = std::move(dense_copy);
  return Status::OK();
}

剪枝的關鍵在RewriteGraphForExecution()方法中,在subgraph.cc文件中。

Status RewriteGraphForExecution(
    Graph* g, const gtl::ArraySlice<string>& fed_outputs,
    const gtl::ArraySlice<string>& fetch_outputs,
    const gtl::ArraySlice<string>& target_node_names,
    const DeviceAttributes& device_info, bool use_function_convention,
    RewriteGraphMetadata* out_metadata) {

  std::unordered_set<string> endpoints;

  // 1 構建節點的name_index,從而快速索引節點。爲FeedInputs,FetchOutputs等步驟所使用
  NameIndex name_index;
  name_index.reserve(g->num_nodes());
  for (Node* n : g->nodes()) {
    name_index[n->name()] = n;
  }

  // 2 FeedInputs,添加輸入節點
  if (!fed_outputs.empty()) {
    FeedInputs(g, device_info, fed_outputs, use_function_convention, &name_index, &out_metadata->feed_types);
  }

  // 3 FetchOutputs,添加輸出節點
  std::vector<Node*> fetch_nodes;
  if (!fetch_outputs.empty()) {
    FetchOutputs(g, device_info, fetch_outputs, use_function_convention, &name_index, &fetch_nodes, &out_metadata->fetch_types);
  }

  // 4 剪枝,形成若干最小依賴子圖
  if (!fetch_nodes.empty() || !target_node_names.empty()) {
    PruneForTargets(g, name_index, fetch_nodes, target_node_names);
  }

  return Status::OK();
}

主要有4步

  1. 構建節點的name_index,從而快速索引節點。爲FeedInputs,FetchOutputs等步驟所使用
  2. FeedInputs,添加輸入節點。輸入節點的數據來源於session.run()時的feed列表。
  3. FetchOutputs,添加輸出節點。輸出節點在session.run()時通過fetches所給出
  4. 剪枝PruneForTargets,形成若干最小依賴子圖。這是剪枝算法最關鍵的一步。

PruneForTargets()從輸出節點反向搜索,按照BFS廣度優先算法,找到若干個最小依賴子圖。

static Status PruneForTargets(Graph* g, const subgraph::NameIndex& name_index,
                              const std::vector<Node*>& fetch_nodes,
                              const gtl::ArraySlice<string>& target_nodes) {
  string not_found;
  std::unordered_set<const Node*> targets;

  // 1 AddNodeToTargets添加節點到targets中,從輸出節點按照BFS反向遍歷。
  for (Node* n : fetch_nodes) {
    AddNodeToTargets(n->name(), name_index, &targets);
  }

  // 2 剪枝,得到多個最小依賴子圖子圖
  PruneForReverseReachability(g, targets);

  // 修正Source和Sink節點的依賴邊,將沒有輸出邊的節點連接到sink node上
  FixupSourceAndSinkEdges(g);

  return Status::OK();
}

主要有3步

  1. AddNodeToTargets,從輸出節點按照BFS反向遍歷圖的節點,添加到targets中。
  2. PruneForReverseReachability,剪枝,得到多個最小依賴子圖子圖
  3. FixupSourceAndSinkEdges,修正Source和Sink節點的依賴邊,將沒有輸出邊的節點連接到sink node上

PruneForReverseReachability()在algorithm.cc文件中,算法就不分析了,總體是按照BFS廣度優先算法搜索的。

bool PruneForReverseReachability(Graph* g,
                                 std::unordered_set<const Node*> visited) {
  // 按照BFS廣度優先算法,從輸出節點開始,反向搜索節點的依賴關係
  std::deque<const Node*> queue;
  for (const Node* n : visited) {
    queue.push_back(n);
  }
  while (!queue.empty()) {
    const Node* n = queue.front();
    queue.pop_front();
    for (const Node* in : n->in_nodes()) {
      if (visited.insert(in).second) {
        queue.push_back(in);
      }
    }
  }

  // 刪除不在"visited"列表中的節點,說明最小依賴子圖不依賴此節點
  std::vector<Node*> all_nodes;
  all_nodes.reserve(g->num_nodes());
  for (Node* n : g->nodes()) {
    all_nodes.push_back(n);
  }

  bool any_removed = false;
  for (Node* n : all_nodes) {
    if (visited.count(n) == 0 && !n->IsSource() && !n->IsSink()) {
      g->RemoveNode(n);
      any_removed = true;
    }
  }

  return any_removed;
}


4 Graph分裂

剪枝完成後,master即得到了最小依賴子圖ClientGraph。然後根據本地機器的硬件設備,以及op所指定的運行設備等關係,將圖分裂爲多個Partition Graph,傳遞到相關設備的worker上,從而進行並行運算。這就是Graph的分裂。

Graph分裂的算法在graph_partition.cc的Partition()方法中。算法比較複雜,我們就不分析了。圖分裂有兩種

  1. splitbydevice按設備分裂,也就是將Graph分裂到本地各CPU GPU上。本地運行時只使用按設備分裂。

    static string SplitByDevice(const Node* node) {
      return node->assigned_device_name();
    }
  2. splitByWorker 按worker分裂, 也就是將Graph分裂到各分佈式任務上,常用於分佈式運行時。分佈式運行時,圖會經歷兩次分裂。先splitByWorker分裂到各分佈式任務上,一般是各分佈式機器。然後splitbydevice二次分裂到分佈式機器的CPU GPU等設備上。

    static string SplitByWorker(const Node* node) {
      string task;
      string device;
      DeviceNameUtils::SplitDeviceName(node->assigned_device_name(), &task, &device);
      return task;
    }


5 Graph執行

Graph經過master剪枝和分裂後,就可以在本地的各CPU GPU設備上執行了。這個過程的管理者叫worker。一般一個worker對應一個分裂後的子圖partitionGraph。每個worker啓動一個執行器Executor,入度爲0的節點數據依賴已經ready了,故可以並行執行。等所有Executor執行完畢後,通知執行完畢。

各CPU GPU設備間可能需要數據通信,通過創建send/recv節點來解決。數據發送方創建send節點,將數據放在send節點內,不阻塞。數據接收方創建recv節點,從recv節點中取出數據,recv節點中如果沒有數據則阻塞。這又是一個典型的生產者-消費者關係。

Graph執行的代碼邏輯在direct_session.cc文件的DirectSession::Run()方法中。代碼邏輯很長,我們抽取其中的關鍵部分。

Status DirectSession::Run(const RunOptions& run_options,
                          const NamedTensorList& inputs,
                          const std::vector<string>& output_names,
                          const std::vector<string>& target_nodes,
                          std::vector<Tensor>* outputs,
                          RunMetadata* run_metadata) {

  // 1 將輸入tensor的name取出,組成一個列表,方便之後快速索引輸入tensor
  std::vector<string> input_tensor_names;
  input_tensor_names.reserve(inputs.size());
  for (const auto& it : inputs) {
    input_tensor_names.push_back(it.first);
  }

  // 2 傳遞輸入數據給executor,通過FunctionCallFrame方式。
  // 2.1 創建FunctionCallFrame,用來輸入數據給executor,並從executor中取出數據。
  FunctionCallFrame call_frame(executors_and_keys->input_types,
                               executors_and_keys->output_types);
  // 2.2 構造輸入數據feed_args
  gtl::InlinedVector<Tensor, 4> feed_args(inputs.size());
  for (const auto& it : inputs) {
    if (it.second.dtype() == DT_RESOURCE) {
      Tensor tensor_from_handle;
      ResourceHandleToInputTensor(it.second, &tensor_from_handle);
      feed_args[executors_and_keys->input_name_to_index[it.first]] = tensor_from_handle;
    } else {
      feed_args[executors_and_keys->input_name_to_index[it.first]] = it.second;
    }
  }

  // 2.3 將feed_args輸入數據設置到Arg節點上
  const Status s = call_frame.SetArgs(feed_args);


  // 3 開始執行executor
  // 3.1 創建run_state, 和IntraProcessRendezvous
  RunState run_state(args.step_id, &devices_);
  run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());
  CancellationManager step_cancellation_manager;
  args.call_frame = &call_frame;

  // 3.2 創建ExecutorBarrier,它是一個執行完成的計數器。同時註冊執行完成的監聽事件executors_done.Notify()
  const size_t num_executors = executors_and_keys->items.size();
  ExecutorBarrier* barrier = new ExecutorBarrier(
      num_executors, run_state.rendez, [&run_state](const Status& ret) {
        {
          mutex_lock l(run_state.mu_);
          run_state.status.Update(ret);
        }
        // 所有線程池計算完畢後,會觸發Notify,發送消息。
        run_state.executors_done.Notify();
      });

  args.rendezvous = run_state.rendez;
  args.cancellation_manager = &step_cancellation_manager;
  args.session_state = &session_state_;
  args.tensor_store = &run_state.tensor_store;
  args.step_container = &run_state.step_container;
  args.sync_on_finish = sync_on_finish_;

  // 3.3 創建executor的運行器Runner
  Executor::Args::Runner default_runner = [this,
                                           pool](Executor::Args::Closure c) {
    SchedClosure(pool, std::move(c));
  };

  // 3.4 依次啓動所有executor,開始運行
  for (const auto& item : executors_and_keys->items) {
    item.executor->RunAsync(args, barrier->Get());
  }

  // 3.5 阻塞,收到所有executor執行完畢的通知
  WaitForNotification(&run_state, &step_cancellation_manager, operation_timeout_in_ms_);

  // 4 接收執行器執行完畢的輸出值
  if (outputs) {
    // 4.1 從RetVal節點中得到輸出值sorted_outputs
    std::vector<Tensor> sorted_outputs;
    const Status s = call_frame.ConsumeRetvals(&sorted_outputs);

    // 4.2 處理原始輸出sorted_outputs,保存到最終的輸出outputs中
    outputs->clear();
    outputs->reserve(sorted_outputs.size());
    for (int i = 0; i < output_names.size(); ++i) {
      const string& output_name = output_names[i];
      if (first_indices.empty() || first_indices[i] == i) {
        outputs->emplace_back(
            std::move(sorted_outputs[executors_and_keys->output_name_to_index[output_name]]));
      } else {
        outputs->push_back((*outputs)[first_indices[i]]);
      }
    }
  }

  // 5 保存輸出的tensor
  run_state.tensor_store.SaveTensors(output_names, &session_state_));

  return Status::OK();
}

主要步驟如下

  1. 將輸入tensor的name取出,組成一個列表,方便之後快速索引輸入tensor
  2. 傳遞輸入數據給executor,通過FunctionCallFrame方式。本地運行時因爲在同一個進程中,我們採用FunctionCallFrame函數調用的方式來實現數據傳遞。將輸入數據傳遞給Arg節點,從RetVal節點中取出數據。
  3. 開始執行executor,並註冊監聽器。所有executor執行完畢後,會觸發executors_done.Notify()事件。然後當前線程wait阻塞,等待收到執行完畢的消息。
  4. 收到執行完畢的消息後,從RetVal節點中取出輸出值,經過簡單處理後,就可以最終輸出了
  5. 保存輸出的tensor,方便以後使用。

6 總結

本文主要講解了TensorFlow的本地運行時,牢牢抓住session和graph兩個對象即可。Session的生命週期前文講解過,本文主要講解了Graph的生命週期,包括構建與傳遞,剪枝,分裂和執行。Graph是TensorFlow的核心對象,很多問題都是圍繞它來進行的,理解它有一定難度,但十分關鍵。文章中可能有一些理解不正確的地方,希望小夥伴們不吝賜教。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章