讀TensorFlow 源碼筆記(2): tensorflow的控制流算子(control_flow_op)

讀TensorFlow 源碼筆記(2): tensorflow的控制流算子(control_flow_op)

在閱讀TensorFlow源碼時,遇到了很多複雜又晦澀的概念,今兒整理以下內容,分享給大家:

  • 介紹專門爲處理控制流而添加的五個TensorFlow原語運算符,

  • 演示如何將高級控制流結構編譯爲包含這五個原語的數據流圖

  • 解釋TensorFlow運行時如何執行這些數據流圖,包括在一組混合設備(如CPU、GPU和TPU)上的分佈式執行,並描述自動區分對控制流結構的作用。

控制流原語

在TensorFlow中,控制流的基本設計原則中引入一組非常小的簡單、原始的運算符,這些運算符廣泛應用在TensorFlow的複雜控制流中。我們希望這些原語具有靈活性和表現力,成爲高級領域特定語言(DSL)的良好編譯目標。它們應該很好地適應TensorFlow的數據流模型,並且應該能夠並行和分佈式執行以及自動區分。

有五個控制流原語運算符,如下所示。它們與Dennis和Arvind開發的數據流機器中引入的控制流原語非常相似。Switch和Merge的結合允許我們實現條件語句。所有五個原語一起允許我們實現while循環。
在這裏插入圖片描述

在TensorFlow中,每個操作都在一個執行幀(execution frame)中執行,控制流原語負責創建和管理這些執行幀(execution frame)。直觀地說,對於每個while循環,TensorFlow運行時設置一個執行幀(execution frame),並在執行幀(execution frame)內運行屬於while循環的所有操作(op)。執行幀(execution frame)可以嵌套。嵌套的while循環在嵌套的執行幀(execution frame)中運行。來自不同執行幀(execution frame)的操作(op)可以並行運行,只要它們之間沒有數據依賴關係。

Switch : switch算子根據控制輸入p的布爾張量(tensor of bool)將輸入張量d(input tensor)轉發到其輸出之一。當switch的兩個輸入pd都可用時,switch被啓用以執行。

Merge : merge運算符將其一個可用輸入轉發到其輸出。當merge的任何輸入可用時,將啓用該merge以執行。如果有多個可用輸入,則不指定輸出哪個可用輸入。

Enter(name) : Enter運算符將其輸入轉發到由給定名稱(name)唯一標識的執行幀(execution frame)。此Enter op用於將一個執行幀(execution frame)中的張量傳遞給子執行幀。同一子執行幀可以有多個Enter操作,每個Enter操作都使該子執行幀中的張量可用(異步)。當輸入可用時,將啓用Enter執行。當對該幀執行第一個Enter操作時,將在TensorFlow運行時實例化一個新的執行幀。

Exit : Exit運算符將值從執行幀轉發到其父執行幀。此Exit op用於將子執行幀中計算的張量返回到其父幀。父幀可以有多個退出操作,每個都異步地將張量傳遞迴父幀。退出在其輸入可用時啓用。

NextIteration : NextIteration運算符將其輸入轉發到當前執行幀中的下一個迭代。TensorFlow運行時在執行幀中跟蹤迭代。在執行幀中執行的任何操作都有一個唯一的迭代id,它允許我們在迭代計算中唯一地標識同一操作的不同調用。注意,在一個執行幀中可以有多個NextIteration操作。TensorFlow運行時在迭代N執行第一個NextIteration操作時啓動迭代N+1。隨着更多的Tensor通過執行NextIteration操作進入迭代,該迭代中的更多操作將準備好執行。當輸入可用時,將啓用NextIteration。

控制流結構的編譯

通過添加以上這五個控制流原語,條件(cond)和循環(while_)等高級編程結構現在就可以編譯成數據流圖,這些數據流圖可以由TensorFlow運行時執行。

條件(cond)運算符

下面是構建cond(pred,fn1,fn2)數據流圖的高級僞代碼。爲了簡單起見,此中忽略了實際實現中的許多重要問題。讀者可以在control_flow_ops.py中找到實現。

# Build the graph for the true branch
context_t = CondContext(pred, branch=1)
res_t = context_t.Call(fn1)
# Build the graph for the false branch
context_f = CondContext(pred, branch=0)
res_f = context_f.Call(fn2)
# Add the Merge nodes for the outputs
merges = [Merge([f, t]) for (f, t) in zip(res_f, res_t)]
return merges

對於cond的每個分支,我們爲條件語句創建一個新的控制流上下文,並在上下文中調用其圖形構造函數(fn1或fn2)。條件上下文允許捕獲任何外部張量(不是在上下文中創建的)並插入適當的Switch op來保護其進入分支(branch)。這確保了分支中的任何操作都只能在執行該分支時執行。由於TensorFlow的異步執行模型,這些外部張量可能在非常不同的時間變得可用,因此還需要爲每個外部張量使用一個switch操作以最大限度地提高並行性。

每個分支返回一個張量列表作爲結果(ref_t或res_f);然後添加一個合併節點列表,分別合併(merge)每個輸出的true和false值。同樣,輸出可以在非常不同的時間進行計算,因此我們對每個輸出使用一個合併(merge)操作,這允許能夠儘快啓用下游計算。

作爲一個例子,看看這個簡單的程序。
在這裏插入圖片描述

tf.cond(x < y, lambda : tf.add(x, z), lambda : tf.square(y))

在生成的數據流圖中,在true/false分支上插入開關(switch)操作以控制張量x、y 和z 的流。由於add的輸入來自switch ops的true輸出,因此僅當x < y 爲真時才執行add op。類似地,Square op僅在x < y 爲false時執行。最後的Merge op發出Add或Square的結果。如果有多個輸出,將有多個merge操作,每個輸出一個結果。

有多種方法可以使用Switch和Merge對cond進行編碼。這裏的編碼主要是因爲它使cond的自動區分變得更簡單。

while_循環運算符

下面是構建while_循環(pred,body,loop_vars)數據流圖的高級僞代碼:

while_context = WhileContext()
while_context.Enter()
# Add the Enter nodes for each loop variable.
enter_vars = [Enter(x, frame_name) for x in loop_vars]
# Add the Merge nodes. Note that input[1] will be updated later.
merge_vars = [Merge([x, x]) for x in enter_vars]
# Build the loop pred subgraph.
pred_result = pred(*merge_vars)
# Add the Switch nodes.
switch_vars = [Switch(x, pred_result) for x in merge_vars]
# Build the loop body subgraph.
body_result = body(*[x[1] for x in switch_vars])
# Add the NextIteration nodes.
next_vars = [NextIteration(x) for x in body_result]
# Form the cycles for the loop.
for m, v in zip(merge_vars, next_vars):
m.op._update_input(1, v)
# Add the Exit nodes.
exit_vars = [Exit(x[0]) for x in switch_vars]
while_context.Exit()
return exit_vars

整個while循環圖是在while循環的控制流上下文中創建的。這裏的基本思想很簡單。
從循環變量開始,我們爲每個變量添加一個Enter操作,然後添加一個Merge操作。然後使用結果(merge_vars)構建pred子圖,該子圖計算循環終止條件。

在添加開關(switch)操作之後,我們使用switch的true輸出爲while循環的主體構建子圖。循環體的結果需要進入下一個迭代,因此我們添加next iteration操作並將它們連接回Merge操作的第二個輸入。這形成了循環,允許我們在執行圖時多次重複運行同一個操作。

開關(switch)操作的false輸出是整個while循環的輸出,因此我們將exit操作添加到它們並返回exit操作的輸出。與cond類似,while循環上下文用於跟蹤predbody lambdas中使用的外部張量。這些外部張量被視爲循環常量,自動爲每個這樣的外部張量插入一個Enter op,使其在while循環上下文中可訪問。嵌套循環需要添加嵌套的Enter ops

爲一個簡單程序生成的圖.
在這裏插入圖片描述

tf.while_loop(lambda i : i< 10, lambda i : tf.add(i,1),[0])

對於這個例子,只有一個循環變量。如果有多個循環變量,我們將有多個Enter、Merge、Switch、NextIteration和Exit操作。這使得可以跨多個循環和循環內的多個迭代執行並行。

cond和while_循環的這種轉換支持條件句和循環的任意嵌套。例如,循環體可以調用另一個while_循環,該循環將遞歸地轉換爲嵌套子圖。轉換確保每個循環靜態地分配一個唯一的幀名(frame name)。

實現

要在多個設備上運行,TensorFlow會自動將操作(op)分配給設備集(device set)。根據設備的位置,TensorFlow自動將數據流圖劃分爲一組子圖(subgraph),每個設備分配一個子圖(subgraph)。當一條邊被分區(partition)破壞時,會自動插入一對發送(send)和接收(recv)節點,用於跨設備傳輸張量(tensor)。一對send和recv使用一個唯一的密鑰(key)進行通信,recv主動從send中拉去數據。例如,下面是將一個圖分區到兩個設備上的結果。TensorFlow對分區沒有任何限制:只要可以在設備上進行計算,就可以將節點分配給該設備。
在這裏插入圖片描述

子圖的執行由子圖分配給的設備的本地執行器(executor)管理。執行器(executor)從源節點(source node)開始並重復執行就緒節點(ready
nodes)。當一個節點的所有輸入都可用時,該節點(合併節點(merge node)除外)將準備就緒。注意,子圖中的所有recv節點都被視爲源節點(source node)。
在沒有控制流的情況下,圖的執行在概念上是非常簡單的:每個節點只執行一次,執行是在所有節點都執行時完成的。控制流引入了相當多的複雜性。一個節點現在可以執行任意次數,包括0。執行器需要能夠管理同一節點的多個實例(可能併發)的執行,並確定圖形執行的完成。

爲了跟蹤執行期間生成的張量,執行器中的張量表示爲一個元組d=(value,is_dead,tag),其中value是實際的張量,is_dead是一個布爾值,指示張量是否在條件的未指定分支上,tag是唯一標識張量(以及生成張量的節點的執行實例)的字符串。直觀地說,tag定義了一個執行上下文(executor context),在一個執行上下文中,節點最多執行一次。tag是send/recv對的通信密鑰(key)的一部分,用於區分同一send/recv對的多個調用。
執行器遵循以下執行規則(注意:節點的所有輸入必須具有相同的標記(tag)):

  • Switch(p, d) = (r1, r2 ) :

    r1 = (value(d), p || is_dead(d), tag(d))

    r2 = (value(d), !p || is_dead(d), tag(d))

  • Merge(d1 , d2 ) = r :

    r = if is_dead(d1 ) then d2 else d1

  • Enter(d, frame_name) = r :

    value® = value(d)

    is_dead® = is_dead(d)

    tag® = tag(d)/frame_name/0

  • Exit(d) = r :

    value® = value(d)

    is_dead® = is_dead(d)

    tag® = tag1 where tag(d) = tag1 /frame_name/n

  • NextIteration(d) = d1:

    value(d1 ) = value(d)

    is_dead(1 ) = is_dead(d)

    tag(1 ) = tag1 /frame_name/(n+1) where tag(d) = tag1 /frame_name/n

  • Op(d1 , …, dm ) = (r1 , …, rn ) :

    value(ri ) = Op.Compute(value(d1 ), …, value(dm)) if !is_dead(ri)

    is_dead(ri ) = any(is_dead(d1 ), … is_dead(dm )), for all i

    tag(ri ) = tag(d1 ), for all i

    最後一條規則適用於所有非控制流節點。請注意,實際計算僅在所有輸入都未停止時執行。如果存在死區輸入,我們將跳過計算並在下游傳播死區信號。這種死區傳播用於支持控制流的分佈式執行。

    分佈式條件執行

    對於分佈式執行,可以將cond分區到多個設備上,如下所示。
    在這裏插入圖片描述

    由於任何recv節點都是源節點,並且可以無條件地啓動,因此即使設備B上的recv位於cond的untaken分支上,它也是可以啓動的。爲了使untaken分支上的recv被激活,TensorFlow在從send到recv的設備之間傳播is_dead標誌。傳播可以在任意數量的設備上繼續。這個簡單的傳播方案處理嵌套條件的分佈式執行,並與while循環的分佈式執行進行良好的交互。

分佈式While循環

對於分佈式執行,一個while循環,特別是循環體,可以被劃分到多個設備上。如果純粹地應用跨設備邊添加send/recv節點的分區方案,設備上的本地執行器將沒有足夠的信息來正確運行while循環。
在這裏插入圖片描述

讓我們用一個簡單的例子來說明這些問題。在上面的例子中,Op在循環體中,並被分配給設備B。一個簡單的分區只需要使用一對send/recv節點就可以將邊緣從Switch斷開到Op。但是,這將不起作用,因爲設備B不知道recv和Op節點是while循環的一部分,並且將在一次迭代後終止執行。解決方案是要重寫數據流圖,在每個分區中添加一個控制循環狀態機(如下設備B的右下角所示)。標量張量0用作控制循環的輸入節點。
在這裏插入圖片描述

這些控制循環提供足夠的信息,允許設備上的執行器像以前一樣獨立運行,通過send/recv節點相互通信。注意虛線是控制邊。

更詳細地說,讓我們首先看看while循環只運行0次迭代的基本情況:

  • 在設備A上,執行從節點Enter、Merge、P和Switch開始。由於P爲false,連接Switch的send會將死區(dead signal)信號傳播到設備B,並且設備A上的Exit也會運行,從而啓用循環外節點的併發執行。連接到P的Send會將布爾張量False發送到設備B。還觸發執行Recv,等待設備B的返回值。
  • 在設備B上,執行從節點Enter和Merge開始。執行Merge將啓用兩個recv。Switch的Recv將收到False,因此Next將得到一個死張量。下一步是停止死亡(dead)的傳播。Op的Recv將得到一個死張量(dead tensor),這樣Op的Send將把一個死張量(dead tensor)發送回設備A。此時,設備B沒有未完成的ops,因此執行終止。
  • 回到設備A,Next的Recv得到一個死張量。下一次運行時,由於它停止了死區的傳播,設備A沒有未完成的操作,因此執行終止。

現在假設while循環運行一個或多個迭代:

  • 在設備A上,由於P在第一次迭代時爲true,因此會向設備B發送一個實張量。執行Recv,等待設備B的值。
  • 在設備B上,控制迴路狀態機(control-loop state machine)運行並啓用recv。Op的Recv從設備A得到一個實張量;Op被執行並且發送一個實張量回設備A。Switch的Recv得到布爾張量True。執行Next和Merge,進一步爲下一次迭代啓用recv。
  • 回到設備A,Recv得到一個真正的張量。接下來,Merge和P被執行。根據P的值,將執行基本情況或新的迭代。

注意,在執行過程中有很多並行性。例如,設備B在接收到P的值後可以開始下一個迭代或退出。參與設備可以並行運行多個迭代,並且兩個參與設備可以在同一循環的不同迭代上工作。

while循環的分佈式執行的開銷是,每個參與設備在每次迭代時都需要從產生P的設備接收布爾張量。考慮到執行過程中的並行性,應該在很大程度上隱藏開銷。

下面顯示了在跨多個設備分區while循環時數據流圖的外觀。控制循環被添加到每個分區,並控制while循環中的recv。重寫後的圖在語義上等價於原始圖。
在這裏插入圖片描述

對於嵌套的while循環,只將控制循環堆棧如下。注意,如果一個設備只有外環的節點,不會爲該設備上的任何內環添加控制環。
在這裏插入圖片描述

自動微分

TensorFlow支持自動微分。例如,用戶可以定義一個具有損失函數的神經網絡,TensorFlow將自動求導並構建反向傳播數據流圖。本節說明了TensorFlow如何在cond和while_循環存在時自動構建反向傳播圖。

反向傳播算法通過反向遍歷前向圖中的ops,通過調用ops的梯度函數逐步構造梯度圖。op的梯度函數定義了計算op的符號梯度的子圖。梯度函數可以使用op的輸入/輸出值,因此在前向計算中產生的一些張量將保留一段時間,直到在backprop中使用爲止。例如,下面顯示了一個正向操作及其梯度圖。G(Op)是Op的梯度子圖,x和y的值將保存在內存中,直到G(Op)被執行。

在這裏插入圖片描述
一旦構建了整個數據流圖,TensorFlow運行時將自動對該圖進行分區,並將執行分佈在多個設備上。因此,TensorFlow中的梯度計算也將分佈到多個設備上運行。
直觀地說,在我們的cond和while_循環的高層結構中,控制流操作符的反向傳播只是按照以下方式反向流動:Exit的梯度是Enter;Switch的梯度是Merge(對於cond)或NextIteration,然後是Merge(對於while_循環);Merge的梯度是Switch;關係的梯度是恆等的,梯度的Enter就是Exit。TensorFlow支持嵌套條件和while循環的反向傳播。

有條件的反向傳播

直觀地講,cond(p,fn1,fn2)的梯度是cond(p,g_fn1,g_fn2),其中g_fn1和g_fn2分別是fn1和fn2的梯度。下面顯示當cond未嵌套在while循環中時cond的基本反向傳播。假設Op在cond的真分支上。嵌套在while循環中的cond需要更多的工作來記住前向循環每次迭代的p值。稍後再看一下while循環的backprop。
在這裏插入圖片描述

forward Merge被轉換成一個Switch,它使用與forward Switch相同的謂詞p。梯度gy分發到Switch的兩個分支上。forward Switch變爲了Merge。如果forward中只使用forward Switch的一個分支,需要添加一個零,如下所示,以確保始終有一個活的梯度流過backprop中的Merge。這個0 由一個Switch來控制,因此只有當p爲false時,它纔會被髮送到Merge中。
在這裏插入圖片描述

While循環的反向傳播

直觀地說,while_loop(pred,body)的梯度類似於一下的while loop形式:

def pred(i, _): return i < N
while_loop(pred, g_body, [0] + g_vars)

其中N是forward while循環運行的迭代次數,g_body是forward循環體的梯度,g_vars是循環變量的初始值。稍後將看到,g_vars包含forward while循環變量的初始梯度。while循環及其backprop while循環的圖形大致如下:
[外鏈圖片轉存失敗,源站可能有防盜鏈機制,建議將圖片保存下來直接上傳(img-JO5iBvYx-1584885390666)(./control_flow_op/while_backprop.jpg)]

backprop循環由N控制,N 爲前向循環將運行的迭代次數,其中假設循環條件pred是不可訓練的。G(Body)是運算主體的梯度。

Body可能再次包含while循環,因此此構造可以遞歸地發生以處理嵌套while循環。
到目前爲止,這一描述相當過於簡單化。例如,N在圖構造時是靜態未知的。更重要的是,G(Body)可能使用由forward循環體生成的值,我們希望保留這些值,以避免在backprop中重新計算它們。TensorFlow中的解決方式是重寫forward while循環的graph,以添加計算和(或)保存backprop中所需的值的邏輯。

爲了計算N,我們將以下子圖添加到forward while循環中。因此,N將由前向循環動態計算,併發送給backprop循環的循環次數計數器作爲變量的初始值。
在這裏插入圖片描述
爲了在backprop循環中重用前向的值,TensorFlow在backprop while循環的構造過程中自動檢測backprop中所需的前向值。對於每個這樣的前向值x,會自動引入一個堆棧,並在前向循環(forward while)中添加節點,以在每次迭代時將其值保存到堆棧中。backprop循環按相反的順序使用堆棧中的值。堆棧位於forward和backprop循環之外,由這兩個循環共享。
在這裏插入圖片描述

實際的圖形構造實際上比這更微妙和複雜。tensorflow在這裏還涉及很多細節,比如還有一下這些問題:

  • 爲了確保正確性,必須確保堆棧推送和pop按其各自循環的迭代順序排列。還要確保先在前向循環中向棧中push值之後,纔可能在堆棧中彈出到backprop,這需要使用控制邊(control edge)強制執行排序。
  • 爲了提高性能,TensorFlow將堆棧推送和彈出操作設爲異步操作,以便它們可以與實際計算並行運行。例如,op(甚至未來的迭代)可以與Push並行運行。
  • 如果op位於while循環中嵌套的cond中,那麼cond的謂詞必須正確地確保push和pop操作正確執行。
  • 如果值立即被backprop中的reduce op(例如Shape、Rank或Size)減少,TensorFlow會將reduce op移動到forward循環以減少內存使用。

對於循環變量,如前所述,反向傳播的梯度的Enter是Exit,以上就是它所做的一切。對於循環常數,TensorFlow還添加一個子圖來累積它們的梯度,如下所示。
在這裏插入圖片描述

假設x是向前的循環常數。在backprop中,每次迭代都會生成x的部分梯度。所以在backprop中添加小的累加子圖來將所有這些部分梯度相加。出口處的最終gx是所有部分梯度的總和。注意,累積是迫不及待地完成的,以並行迭代次數爲界。這與靜態展開不同,在靜態展開中,AddN的使用需要同時激活所有的局部梯度。
這種構造對嵌套條件和循環都有效。對於嵌套在while循環中的cond,TensorFlow引入一個堆棧來保存每次前向迭代時謂詞的值,並在backprop中使用堆棧中的值(以相反的順序)。對於嵌套循環,當遇到嵌套在循環體中的內部while循環時,將遞歸調用此構造。

一個重要的優化是內存交換。正如我們所看到的,對於backprop中需要的每個正向值v,將其在所有迭代v1,…,vN中的值保存在堆棧中,以便在backprop中重用它們。這可能是在內存有限的設備(如gpu)上進行培訓的限制。我們使用內存交換將堆棧中存儲的值從GPU異步移動到CPU,並在backprop中需要時將它們移回GPU內存。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章