問題描述
通過 Tensorflow Slim API 實現的 MobileFaceNet 模型,在轉換爲 tflite 過程中,出現不支持 switch 的問題。
分析
通過 tensorboard 搜索 graph 中 control_flow_ops.switch ,都出現在 BatchNorm/cond/switch
節點中。顯然,BN 使用了 cond 封裝下的 Switch,所以想要知道誰調用了Switch,關鍵就是找到 BatchNorm
、Cond
、Switch
這三個 namescope,因此在讀源碼過程中,抽絲剝繭,縷了一下 BatchNorm 使用 switch 的內容。
問題從 slim.batch_norm 開始:
BatchNorm
在slim.batch_norm 方法中 variable_scope 定義了 BatchNorm
,第一個scope找到,在with上下文語法中,繼續查找Cond
。
264| with variable_scope.variable_scope(
265| scope, 'BatchNorm', [inputs], reuse=reuse) as sc:
結合Graph 中的內容,發現 switch
包含在Cond
之內,而 Switch
之後,就緊跟着 FusedBatchNorm 方法。
cond
因此存在 Cond
及 Switch
的調用,在以下兩條語句中。
# (tensorflow/contrib/layers/python/layers/layers.py)
# slim.batch_norm()
371| nn.fused_batch_norm(
…… |
381| utils.smart_cond()
進入 utils.smart_cond()
# (tensorflow/contrib/layers/python/layers/utils.py)
# utils.smart_cond()
def smart_cond(pred, fn1, fn2, name=None):
214| return static_cond(pred_value, fn1, fn2)
217| return control_flow_ops.cond(pred, fn1, fn2, name)
switch
在段落最後,出現了 control_flow_ops.cond() ,而 control_flow_ops.cond() 就是調用 switch 的地方。
# (tensorflow/python/ops/control_flow_ops.py)
# cond()
1927| def cond(pred, true_fn=None, false_fn=None, strict=False,
1928| name=None, fn1=None, fn2=None):
…… |
2020| with ops.name_scope(name, "cond", [pred]):
…… |
2026| # Add the Switch to the graph.
2027| if isinstance(pred, bool):
2028| raise TypeError("pred must not be a Python bool")
2029| p_2, p_1 = switch(pred, pred)
2030| pivot_1 = array_ops.identity(p_1, name="switch_t")
結論
至此, cond
和 switch_t
都出現了,所有 BatchNorm 所以依賴的 Switch 的調用邏輯爲:
slim.batch_norm (tensorflow/contrib/layers/python/layers/layers.py)
|-> utils.smart_cond() (tensorflow/contrib/layers/python/layers/utils.py)
|-> control_flow_ops.cond() (tensorflow/python/ops/control_flow_ops.py)
下一步是看怎麼樣替代掉 slim.batch_norm 爲其他不包括Switch 的 BN實現(希望很渺茫),或者替換掉 control_flow_ops.cond (難度比較大) 。