TensorFlow學習筆記(一)

最近致力於深度學習,希望在移動領域能夠找出更多的應用點.其中TensorFlow作爲目前的一個熱點值得我們重點關注.

機器學習

機器學習是人工智能的一個分支,也是用來實現人工只能的一種方法。簡單來說,機器學習就是通過算法,使得機器能從大量歷史數據中學習規律,從而對新的樣本做智能識別或對未來做預測,與傳統的使用特定指令集手寫軟件不同,我們使用大量數據和算法來“訓練”機器,由此帶來機器學習如何完成任務.從1980年代末期以來,機器學習的發展大致經歷了兩次浪潮:淺層學習(Shallow Learning)和深度學習(Deep Learning)。

淺層學習

90年代,各種各樣的淺層機器學習模型相繼被提出,比如SVM、Boosting、最大熵方法等。這些模型在是理論分析或者工程應用領域都獲得了巨大的成功最成功的應用,比如搜索廣告系統的廣告點擊率CTR預估、網頁搜索排序、垃圾郵件過濾系統、基於內容的推薦系統等。

深度學習

深度學習是實現機器學習的一種技術,現在所說的深度學習很大層次上是指神經網絡。神經網絡是受人類大腦的啓發:神經元之間的相互連接。對比看來,人類大腦中的神經元與特定範圍內的任意神經元連接,而人工神經網絡中數據傳播要經歷不同的層,且傳播方向也不同.

現在來說說在神經網絡算法中,每個神經元的作用:每個神經元都會給其輸入指定一個權重:相對於執行的任務該神經元的正確和錯誤程度。最終的輸出由這些權重共同決定。

這裏寫圖片描述

現在來看看上面提到的停止標誌示例。一張停止標誌圖像的屬性,被一一細分,然後被神經元“檢查”:形狀、顏色、字符、標誌大小和是否運動。神經網絡的任務是判斷這是否是一個停止標誌。它將給出一個“概率向量”,這其實是基於權重做出的猜測結果。

爲什麼需要深度學習

淺層模型有一個重要特點,就是假設靠人工經驗來抽取樣本的特徵,在模型的運用不出差錯的前提下(如假設互聯網公司聘請的是機器學習的專家),特徵的好壞就成爲整個系統性能的瓶頸。要發現一個好的特徵,就要求開發人員對待解決的問題要有很深入的理解。

深度學習的實質,是通過構建具有很多隱層的機器學習模型和海量的訓練數據,來學習更有用的特徵,從而最終提升分類或預測的準確性。換句話說,深度學習的目的是爲了特徵學習。

區別於傳統的淺層學習,深度學習的不同在於:1. 強調了模型結構的深度,通常有5層、6層,甚至10多層的隱層節點;2. 明確突出了特徵學習的重要性,也就是說,同過逐層特徵變換,將樣本在原空間的特徵表示變換到一個新特徵空間,使分類或預測更加容易。

與人工規則構造特徵的方法相比,利用大數據來學習特徵,更能刻畫數據豐富的內在信息,目前來看,深度學習在搜索廣告CRT預估,自然語言處理,圖像識別,語音識別以及無人駕駛上有了廣泛的應用.

TensorFlow基礎

是Google開源的機器學習庫,基於DistDelief進行研發的第二代人工智能系統,用來幫助我們快速的實現DL和CNN等各種算法公式,.其名字本身描述了它自身的執行原理:Tensor(張量)意味着N維數組,Flow(流)意味着基於數據流圖的計算.數據流圖中的圖就是我們所說的有向圖,我們知道,在圖這種數據結構中包含兩種基本元素:節點和邊.這兩種元素在數據流圖中有自己各自的作用.節點用來表示要進行的數學操作,另外,任何一種操作都有輸入/輸出,因此它也可以表示數據的輸入的起點/輸出的終點.邊表示節點與節點之間的輸入/輸出關係,一種特殊類型的數據沿着這些邊傳遞.這種特殊類型的數據在TensorFlow被稱之爲tensor,即張量,所謂的張量通俗點說就是多維數組.當我們向這種圖中輸入張量後,節點所代表的操作就會被分配到計算設備完成計算.

到現在,你應該對TensorFlow有了一個淺顯的人是,下面我們再來看一張圖,他會幫助你更好的認識TensorFlow工作過程.

這裏寫圖片描述

到現在,我們只是認識了TensorFlow,接下來我總結了TensorFlowd的四個特性,來說明我們爲什麼要使用TensorFlow:

  • 靈活性: 非嚴格的“神經網絡”庫。這意味這我們的計算只要能夠表示爲數據流圖,就能夠使用.
  • 可移植性:底層核心採用C++便宜,可以運行在臺式機、服務器、手機移動等設備上,提供對分佈式的支持,能夠快速構建深度學習集羣.
  • 多語言支持:前端支持Python,C/C++,Java以及Go,以及非官方支持的Scala,但是目前對Python接口支持最好.
  • 高效:提供對線程、隊列、異步操作支持,同時支持運行在CPU和GPU上,能夠充分發揮硬件潛力.

可以說正是由於以上四個特性,使得TensorFlow的使用逐漸流行開來.其中我認爲tensorflow最關鍵的一點是允許我們將計算的過程描述爲一張圖(Graph),我將其稱這張圖爲”計算圖”,能讓我們很容易的操作其中的網絡結構.

現在我們對Tensor的特性有了基本的瞭解,但是如何利用TensorFlow呢?爲了能夠容易的理解,我用下水管道結構圖來類比.

如果你是一名城市管道設計者,當你想要解決這個城市排水問題時,你會做點什麼呢?(這個例子來自早期看到一位作者的解釋)

這裏寫圖片描述

不出意外,你的腦海中隱約的浮現出管道圖.如果能到這一步,說明什麼呢?這意味着你已經開始觸及TensorFlow領域,看吧,其實TensorFlow的工作過程我們天生就懂.

TensorFlow中的計算圖就像此處的管道結構,我們考慮設計管道結構的過程就是在構建計算圖.

現在來看管道中的閥門.閥門可以用來控制水流動的強度和方向,這和神經網絡中的權重和偏移的作用一致.唯一的不同是,管道中閥門需要人爲調整,而神經網絡的”閥門”會根據數據自我調整/更新.

我們知道在管道中是水在流動,那麼計算圖流動的是什麼呢?
計算圖流動的就是我們上文說到的tensor,tensor本質上就是多維數組.(其實每一個tensor包含又一個靜態類型,一個rank和一個shape,關於這點我們就不做解釋了,有興趣的同學可以查閱相關paper)

和管道不同,在計算圖中,我們可以從任意一個節點處取出”液體”,也就是獲得當前tensor.

現在,我們來個稍微正規點的解釋:

TensorFlow使用Graph來描述計算任務,圖中的節點被稱之爲op.一個op可以接受0或多個tensor作爲輸入,也可產生0或多個tensor作爲輸出.任何一個Graph要想運行,都必須藉助上下文Session.通過Session啓動Graph,並將Graph中的op分發到CPU或GPU上,藉助Session提供執行這些op.op被執行後,將產生的tensor返回.藉助Session提供的feed和fetch操作,我們可以爲op賦值或者獲取數據.計算過程中,通過變量(Variable)來維護計算狀態.

爲了方便大家理解TensorFlow中相關的概念,這裏我列了一張表格:

類型 描述 用途
Session 會話 圖必須在稱之爲“會話”的上下文中執行。會話將圖的op分發到諸如CPU或者GPU上計算
Graph 描述計算過程 必須在Session中啓動
tensor 數據 數據類型之一,代表多維數組
op 操作 圖中的節點被稱之爲op,一個op獲得0或者多個Tensor,執行計算,產生0或者多個Tensor
Variable 變量 數據類型之一,運行過程中可以被改變,用於維護狀態
feed 賦值 爲op的tensor賦值
fetch 取值 從op的tensor中取值
Constant 常量 數據類型之一,不可變

TensorFlow環境

TensorFlow目前支持三種平臺:Linux系列,Mac OS以及Window.並提供了多種安裝方式,目前常見的安裝方式有三種:pip,docker,Anacona,和源碼編譯安裝.TensorFlow支持CPU計算和GPU計算:

  • CPU 支持 :系統沒有 NVIDIA CUDA® GPU,我們只能安裝該版本。
  • GPU 支持: TensorFlow 程序通常在 GPU 比在 CPU 上運行快得多。如果系統具有 NVIDIA CUDA GPU 那麼可以安裝該版本

爲了方便,這裏只演示CPU.更多的資料你可以在TensorFlow官網官網上找到答案.

window上通過安裝

在開始之前首先確保我們安裝了python3.

pip3 install tensorflow
pip3 install tensorflow-gpu
pip3 install tensorlayer //上面二選一,後安裝tensorlayer,也可以不裝

一般來說,你會遇到以下兩個錯誤:

錯誤1:

   return importlib.import_module(mname)                                       
 File "C:\Users\liudongdong-iri\AppData\Local\Programs\Python\Python35\lib\impo
ib\__init__.py", line 126, in import_module                                    
   return _bootstrap._gcd_import(name[level:], package, level)                 
 File "<frozen importlib._bootstrap>", line 986, in _gcd_import                
 File "<frozen importlib._bootstrap>", line 969, in _find_and_load             
 File "<frozen importlib._bootstrap>", line 958, in _find_and_load_unlocked    
 File "<frozen importlib._bootstrap>", line 666, in _load_unlocked             
 File "<frozen importlib._bootstrap>", line 577, in module_from_spec           
 File "<frozen importlib._bootstrap_external>", line 906, in create_module     
 File "<frozen importlib._bootstrap>", line 222, in _call_with_frames_removed  
mportError: DLL load failed: 找不到指定的模塊。 

錯誤2:

File "<frozen importlib._bootstrap>", line 222, in _call_with_frames_removed
ImportError: DLL load failed: 找不到指定的模塊。

此時只需要下載https://www.microsoft.com/en-us/download/details.aspx?id=53587 安裝後重新使用pip命令安裝.

Mac OS X 通過pi安裝

在mac上通過pip來安裝,python 2.7和python 3.3+二者選一即可

pip install tensorflow

如果是python3使用以下命令:

pip3 install tensorflow

注意:如果在之前安裝過TensorFlow<0.71的版本,首先要使用pip uninstall卸載TensorFlow以及protobuf.

檢測安裝

在安裝成功後,先來運行個簡單Hello,TensorFlow程序一睹風采:

import tensorflow as tf

hello = tf.constant('Hello,TensorFlow')
sess = tf.Session()
print(sess.run(hello))

不出意外,我們將看到如下輸出:

Hello,TensorFlow

TensorFlow實踐

上面我們簡單的介紹了TensorFlow的工作原理以及相關的概念,接下來呢,我們從實踐的角度觸發,來進一步解釋相關概念,在這之前先來說明構建TensorFlow程序的基本過程.,通常分爲兩步:構建階段和一個執行階段。在構建階段,我們組織多個op,最終形成Graph。在執行階段,使用會話執行op.

先來看個簡單的示例,來大體有個瞭解:

import tensorflow as tf

# 定義‘符號’變量,也稱爲佔位符
a = tf.placeholder("int32")
b = tf.placeholder("int32")

# 構造一個op節點
y = tf.multiply(a, b)

# 建立會話
sess = tf.Session()

# 運行會話,輸入數據,並計算節點,同時打印結果
print(sess.run(y, feed_dict={a: 3, b: 3}))

# 任務完成, 關閉會話.
sess.close()

構建階段

構造階段的主要目的是爲了構建一張計算圖.構建圖的第一步是創建源op.源op不需要任何輸入,源op的輸出被傳遞給其他op作爲輸入.例如常量(Constant).Python 庫中, op 構造器的返回值代表被構造出的 op 的輸出, 這些返回值可以傳遞給其它 op 構造器作爲輸入..在TensorFlow中存在一個默認圖(defalut graph).op構造器可以爲其增加節點.很多時候我們會直接使用該圖,可以通過tf.Graph.as_default()來獲取.

Graph

Graph:要組裝的結構,由許多操作組成,其中的每個連接點代表一種操作

方法 用途
tf.Graph.as_graph_def() 返回一個圖的序列化的GraphDef,表示序列化的GraphDef可以導入到另外一個圖(使用import_graph_def())
tf.Graph.get_operations() 返回圖中的操作節點列表
tf.Operation.name 操作節點op的名稱
tf.Operation.type 操作節點op的類型
tf.Operation.inputs 操作節點的輸入與輸出
tf.Operation.run(session=None,feed_dict=None) 在會話中執行該操作
tf.add_to_collection(name,value) 基於默認的圖,其功能便爲Graph.add_to_collection()
tf.get_collection(key,scope=None) 基於默認的圖,其功能便爲Graph.get_collection()

op

op:接受(流入)零個或多個輸入(液體),返回(流出)零個或多個輸出

數據類型

數據類型:主要分爲tensor,variable,constant.

tensor:多維array或list

# 創建
tensor_name=tf.placeholder(type, shape, name)

variable:通常可以將一個統計模型中的參數表示爲一組變量。例如,你可以將一個神經網絡的權重當作一個tensor存儲在變量中。在訓練圖的重複運行過程中去更新這個tensor

# 創建變量
name_variable = tf.Variable(value, name)

# 初始化單個變量
init_op=variable.initializer()

# 初始化所有變量
init_op=tf.initialize_all_variables()

# 更新操作
update_op=tf.assign(variable to be updated, new_value)

簡單示例

這裏我們來創建一個包含三個op的圖,其中兩個constant op,一個matmul op.

import tensorflow as tf

# 創建作爲第一個常量op,該op會被加入到默認的圖中
# 1*2的矩陣,構造器的返回值代表該常量op的返回值
matrix_1 = tf.constant([[3., 3.]])

# 創建第二個常量op,該op會被加入到默認的圖中
# 2*1的矩陣
matrix_2 = tf.constant([[2.], [2.]])

# 創建第三個op,爲矩陣乘法op,接受matrix_1和matrix_2作爲輸入,product代表乘法矩陣結果
product = tf.matmul(matrix_1, matrix_2)

到現在我們已經創建好了包含三個op的圖.下面我們要通啓動該圖,執行運算.

執行階段

在實現上, TensorFlow 將圖形定義轉換成分佈式執行的操作, 以充分利用可用的計算資源(如 CPU 或 GPU). 一般你不需要顯式指定使用 CPU 還是 GPU, TensorFlow 能自動檢測. 如果檢測到 GPU, TensorFlow 會儘可能地利用找到的第一個 GPU 來執行操作. 但是今天我們暫時不關注該部分.

Session

首先我們需要創建一個Session對象.在不傳參數的情況下,該Session的構造器將啓動默認的圖.之後我們可以通過Session對象的run(op)來執行我們想要的操作.

方法 用途
tf.Session.run(fetches,feed-dict=Noe,options=Node,run_metadata=None) 運行fetches中的操作節點並求其
tf.Session.close() 關閉會話
tf.Session.graph 返回加載該會話的圖()
tf.Session.as_default() 設置該對象爲默認會話,並返回一個上下文管理器

簡單示例

# 創建會話
sess = tf.Session()

# 賦值操作
sess.run([output], feed_dict={input1:value1, input2:value1})

# 用創建的會話執行操作
sess.run(op)

# 取值操作

# 關閉會話
sess.close()

完整示例

最終結合構建階段和執行階段,完整代碼如下:

import tensorflow as tf

# 創建作爲第一個常量op,該op會被加入到默認的圖中
# 1*2的矩陣,構造器的返回值代表該常量op的返回值
matrix_1 = tf.constant([[3., 3.]])
# 創建第二個常量op,該op會被加入到默認的圖中
# 2*1的矩陣
matrix_2 = tf.constant([[2.], [2.]])
# 創建第三個op,爲矩陣乘法op,接受matrix_1和matrix_2作爲輸入,product代表乘法矩陣結果
product = tf.matmul(matrix_1, matrix_2)
# 獲取sess
sess = tf.Session()
# 來執行矩陣乘法op
result = sess.run(product)
# 輸出矩陣乘法結果
print("result:",result)

# 任務完畢,關閉Session
sess.close()

除了通過Session的close()的手動關閉外,也可以使用with代碼塊:

with tf.Session() as sess:
    result=sess.run(product)
    print("result:",result)

現在來運行該代碼,不出意外我們將獲得結果:

result: [[ 12.]]

其他

通過上面的過程,我們只需要瞭解了tensorflow的構建過程和執行過程,進一步瞭解了Graph,op以及Session各自所擔當的職責,但是仍然有很多點我們無法一一細聊.

爲了方便起見,這裏我將對一些概念再次進行解釋.

Tensor

在TensorFlow中,用tensor來表示其所使用的數據結構,簡單點理解tensor就是一個多維數組.任何一個物體,我們都可以用幾個特徵來描述它.每個特徵可以劃分成一個維度.比如:一小組圖像集表示爲一個四維浮點數數組, 這四個維度分別是 [batch, height,width, channels].

方法 用途
tf.Tensor.dtype tensor中數據類型
tf.Tensor.name tensor名稱
tf.Tensor.op 產生該tensor的op
tf.Tensor.graph 該tensor所在的

Variables

TensorFlow使用Variables來維護圖執行過程中的狀態信息.下面我們演示一個計數器:

import tensorflow as tf

# 創建一個變量,初始化爲0
state = tf.Variable(0, name="counter")

# 創建一個常量
one = tf.constant(1)
new_value = tf.add(state, one)
update = tf.assign(state, new_value)

# 變量初始化
init_op = tf.initialize_all_variables()

sess = tf.Session()
# 運行init_op
sess.run(init_op)
# 運行state,打印state初始值
print(sess.run(state))
for _ in range(10):
    sess.run(update)
    print(sess.run(state))

執行該程序,不出意外輸出結果:

0
1
2
3
4
5
6
7
8
9
10

Fetch

爲了獲取操作輸出的內容,可以在使用Session對象的run(op)時,傳入一些tensor,這些tensor用來取回我們想要的結果.

import tensorflow as tf

value_1 = tf.constant(3.0)
value_2 = tf.constant(2.0)
value_3 = tf.constant(5.0)

# 2.0+5.0
temp_value=tf.add(value_2,value_3)

# 3.0+(2.0+5.0)
result=tf.add(value_1,temp_value)

sess = tf.Session()
print(sess.run([temp_value,result]))

Feed

我們可以通過TensorFlow對象的placeholder()爲變量創建指定數據類型佔位符,在執行run(op)時通過feed_dict來爲變量賦值.

import tensorflow as  tf

input_1 = tf.placeholder(tf.float32)
input_2 = tf.placeholder(tf.float32)
output = tf.add(input_1, input_2)

with tf.Session() as sess:
    # 通過feed_dict來輸入,outpu表示輸出
    print(sess.run([output],feed_dict={input_1:[7.],input_2:[2.]}))

placeholder

TensorFlow提供一種佔位符操作,在執行時需要爲其提供數據.這有點類似我們編寫sql語句時使用?佔位符一樣,你可以理解爲預編譯.

方法 用途
tf.placeholder(dtype,shape=None,name=None) 爲一個tensor插入一個佔位符
input_value = tf.placeholder(tf.float32,shape=(1024,1024))

模型保存於恢復

在tensorflow中最簡單的保存與加載模型的方式是通過Saver對象.

方法 用途
tf.train.Saver.save(sess,save_path,global_step=None,latest_filename=None,meta_graph_suffix=’meta’,write_meta_graph=True) 保存變量
tf.train.Saver.restore(sess,save_path) 恢復變量
tf.train.Saver.last_checkpoints() 列出最近未刪除的checkpoint文件名
tf.train.Saver.set_last_checkpoints(last_checkpoints) 設置checkpoint文件名列表
tf.train.Saver.set_last_checkpoints_with_time(last_checkpoints_with_time) 設置checkpoint文件名列表和時間戳

保存模型

import tensorflow as tf

def save_model():
    v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
    v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
    init_op = tf.global_variables_initializer()
    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(init_op)
        saver_path = saver.save(sess, "./model.ckpt")
        print("model saved in file: ", saver_path)

加載模型

用同一個Saver對象來恢復變量,注意,當你從文件恢復變量是,不需要對它進行初始化,否則會報錯。

import tensorflow as tf

def load_model():
    v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
    v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
    saver = tf.train.Saver()
    with tf.Session() as sess:
        saver.restore(sess,"./model.ckpt")
        print("mode restored")

移植TensorFlow到移動設備

由於TensorFlow核心代碼使用C++編寫的,因此我們可以很容易的將其移植到移動設備中,一般需要經過以下幾步:

  1. PC訓練模型,並將其保存爲pb格式,然後導入該模型文件到Android項目的assets目錄中
  2. 導入TensorFlow的jar包以及so文件到Android項目中,jar包向我們暴露了操作接口,具體的執行引擎算法責備封裝在so文件當中
  3. 定義相關變量,存儲數據,並通過jar包提供的接口加載模型,執行運算即可.

下面我們用個簡單的示例來演示整個移植的過程.

1. 訓練模型

import tensorflow as tf

sess = tf.Session()
matrix_1 = tf.constant([3., 3.], name='input')
add = tf.add(matrix_1, matrix_1, name='output')
sess.run(add)

output_graph_def = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, output_node_names=['output'])
# 保存模型到目錄下的model文件夾中
with tf.gfile.FastGFile('./model/tensorflow_matrix_graph.pb',mode='wb') as f:
    f.write(output_graph_def.SerializeToString())

sess.close()

唯一注意的一點是務必要保成pb格式的文件:

  1. 不能使用 tf.train.write_graph()保存模型,該種方式只是保存了模型的結構,並不保存訓練完畢的參數值
  2. 不能使用 tf.train.saver()保存模型,該種方式只是保存了網絡中的參數值,並不保存模型的結構。

我們需要的是既保存模型的結構,又保存模型中每個參數的值,所以上述的兩種方式都不行:因此我們用一下方式保存:

# 可以把整個sesion當作常量都保存下來,通過output_node_names參數來指定輸出
graph_util.convert_variables_to_constants
# 指定保存文件的路徑以及讀寫方式
tf.gfile.FastGFile('model/test.pb', mode='wb')
# 將固化的模型寫入到文件
f.write(output_graph_def.SerializeToString())

2.編譯所需要的jar和so文件

這裏以Mac OS X平臺爲例,你可以可以在linux平臺編譯,目前不支持window平臺編譯.

  1. 首先克隆 TensorFlow 倉庫到本地:
$ git clone --recurse-submodules https://github.com/tensorflow/tensorflow

--recurse-submodules 參數是必須得, 用於獲取 TesorFlow 依賴的 protobuf 庫.

  1. 安裝Bazel

Bazel是Google開源的一款自動化構建工作,TensorFlow整個工程就是通過它進行構建.其安裝過程也非常簡單如果你和我一樣使用macOS構建,那麼我們可以通過包管理器Homebrew來安裝Bazel

brew install bazel

安裝之後可以通過bazel version來查看其版本,比如當前我這裏是:

Build label: 0.4.5-homebrew
Build target: bazel-out/local-opt/bin/src/main/java/com/google/devtools/build/lib/bazel/BazelServer_deploy.jar
Build time: Thu Mar 16 13:37:54 2017 (1489671474)
Build timestamp: 1489671474
Build timestamp as int: 1489671474

在需要升級的時候可以通過brew upgrade bazel.如果你要在其他平臺安裝,查閱官網:https://bazel.build/versions/master/docs/install-os-x.html

接下來,我們需要修改TensorFlow項目的WORKSPACE文件:

這裏寫圖片描述

修改前:

# Uncomment and update the paths in these entries to build the Android demo.
#android_sdk_repository(
#    name = "androidsdk",
#    api_level = 23,
#    # Ensure that you have the build_tools_version below installed in the 
#    # SDK manager as it updates periodically.
#    build_tools_version = "25.0.2",
#    # Replace with path to Android SDK on your system
#    path = "<PATH_TO_SDK>",
#)
#
# Android NDK r12b is recommended (higher may cause issues with Bazel)
#android_ndk_repository(
#    name="androidndk",
#    path="<PATH_TO_NDK>",
#    # This needs to be 14 or higher to compile TensorFlow. 
#    # Note that the NDK version is not the API level.
#    api_level=14)

根據本機情況來設置正確的SDK和NDK路徑.

# Uncomment and update the paths in these entries to build the Android demo.
android_sdk_repository(
    name = "androidsdk",
    api_level = 23,
    # Ensure that you have the build_tools_version below installed in the 
    # SDK manager as it updates periodically.
    build_tools_version = "25.0.2",
    # 修改爲自己系統SDK路徑
    path = "/Users/liudongdong/Library/Android/sdk/",
)
#
# Android NDK r12b is recommended (higher may cause issues with Bazel)
android_ndk_repository(
    name="androidndk",
    # 修改爲自己系統NDK路徑
    path="/Users/liudongdong/Library/Android/ndk/",
    # This needs to be 14 or higher to compile TensorFlow. 
    # Note that the NDK version is not the API level.
    api_level=14)

先編譯so文件

bazel build -c opt //tensorflow/contrib/android:libtensorflow_inference.so --crosstool_top=//external:android/crosstool --host_crosstool_top=@bazel_tools//tools/cpp:toolchain --cpu=armeabi-v7a

生成的so文件位於tensorflow目錄下:

bazel-bin/tensorflow/contrib/android/libtensorflow_inference.so

再編譯jar文件

bazel build //tensorflow/contrib/android:android_tensorflow_inference_java

生成的jar文件位於:

bazel-bin/tensorflow/contrib/android/libandroid_tensorflow_inference_java.jar

需要注意一點:在編譯出so文件後要及時地將其拷貝出來,因爲在編譯jar文件的時候會將上面編譯出的so文件刪除.

3.移植模型到Android設備

  1. 將pb模型文件放入assets目錄

  2. 添加jar包到項目的libs目錄下,添加so文件到jniLibs目錄下:

這裏寫圖片描述

  1. 定義變量,然後初始化tensorflow,調用相關api.同樣,我們以剛纔生成的tensorflow_matrix_graph.pb爲例:
public class NumberActivity extends AppCompatActivity {
    // 定義模型文件路徑
    private static final String MODE_FILE = "file:///android_asset/tensorflow_matrix_graph.pb";

    private static final int HEIGHT=1;
    private static final int WIDTH =2;

    // 輸入
    private static final String inputName = "input";
    private float[] inputs = new float[HEIGHT * WIDTH];

    // 輸出
    private static final String outputName = "output";
    private float[] outputs = new float[HEIGHT * WIDTH];

    //tensorflow接口
    TensorFlowInferenceInterface mTensorFlowInferenceInterface;

    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);
        setContentView(R.layout.activity_number);
        findViewById(R.id.btn_matrix).setOnClickListener(new View.OnClickListener() {
            @Override
            public void onClick(View v) {
                float[] result = getResult();
                Toast.makeText(NumberActivity.this, result[0] + "  " + result[1], Toast.LENGTH_SHORT).show();
            }
        });
        mTensorFlowInferenceInterface = new TensorFlowInferenceInterface(getAssets(), MODE_FILE);
    }


    private float[] getResult() {
        inputs[0]=4;
        inputs[1]=3;

        Trace.beginSection("feed");
        // 輸入數據
        mTensorFlowInferenceInterface.feed(inputName, inputs, WIDTH, HEIGHT);
        Trace.endSection();

        Trace.beginSection("run");
        String[] outputNames = new String[]{outputName};
        // 執行數據
        mTensorFlowInferenceInterface.run(outputNames);
        Trace.endSection();

        Trace.beginSection("fetch");
        // 取出數據
        mTensorFlowInferenceInterface.fetch(outputName,outputs);
        Trace.endSection();

        return outputs;
    }
}

總結

此文更多是各前輩的總結,如有不妥之處,多多指教.

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章