歡迎點擊「算法與編程之美」↑關注我們!
本文首發於微信公衆號:"算法與編程之美",歡迎關注,及時瞭解更多此係列文章。
問題描述
在成功調用官網打包好的tensorflowjs模型後,怎麼調用自己的模型呢?又需要做哪些處理呢?
解決方案
1)安裝好python和tensorflow
2)安裝tensorflowjs : pip install tensorflowjs
注:如果你的tensorflow版本是2.0的,在下載tfjs時可能會被更新爲1.15版本的。可以考慮新建個python環境。
3)準備已經訓練好的模型,並通過 model.save(“模型命名.h5”) 代碼將模型保存爲h5格式的文件。
下面是本文使用的mnist手寫數字集的模型代碼案例:
import tensorflow as tf mnist = tf.keras.datasets.mnist (x_train, y_train),(x_test, y_test) = mnist.load_data() x_train, x_test = x_train / 255.0, x_test / 255.0 model = tf.keras.models.Sequential([ tf.keras.layers.Flatten(input_shape=(28, 28)), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dropout(0.2), tf.keras.layers.Dense(10, activation='softmax') ]) model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) model.fit(x_train, y_train, epochs=5) model.evaluate(x_test, y_test) model.save('D:\\test/mnist.h5') |
4)通過tensorflowjs_converter命令將h5格式的模型文件轉換爲json格式的文件。
1.打開pycharm的Terminal指令框
2. 輸入轉換指令:
tensorflowjs_converter--input_format=keras D:\\test/mnist.h5 D:\\test
註釋:tensorflowjs_converter –模型格式 模型地址 保存地址
3.查看model.json是否生成
5)將模型放在服務器上,如果沒有可以在本地創建,步驟如下 :
1.打開pycharm的Terminal的指令框
2.輸入python3 -m http.server 8000
3.打開瀏覽器輸入 localhost:8000 輸出如下界面
如果出現localhost拒絕訪問,可能是你的系統沒開啓iis服務,只能手動開啓了。
未開啓的建議依次按以下步驟來:
1 .百度:如何安裝iss服務
2 .打開管理工具
3.進入管理工具界面,單擊“Internet Information Services (IIS)管理器”。
4.右鍵單擊“網站”,選擇“添加網站”。
5.在彈出的界面中輸入網站名稱、選擇物理路徑(model.json所在的文件地址)、IP地址輸入爲127.0.0.1、端口爲8000,然後點擊確定。
6.打開目錄展示功能:目錄瀏覽—打開功能—啓用
6) 在項目中安裝相應的庫
詳細過程請參考之前發佈的博客《微信小程序與tensorflow.js準備工作》在項目目錄下使用npm安裝對應包,安裝代碼如下:
npm install fetch-wechat npm install @tensorflow/tfjs-converter npm install @tensorflow/tfjs-core npm install @tensorflow/tfjs-layers npm install regenerator-runtime |
7) 效果
8) 代碼部分
代碼較爲簡單,說明以註釋方式放在代碼旁邊(只展示主體代碼部分,完成項目代碼下載鏈接:
https://pan.baidu.com/s/18VcMiNaiEjC_Y_Yz1gJ_1g)
const regeneratorRuntime = require('regenerator-runtime') const tf = require('@tensorflow/tfjs-core') const tfl = require('@tensorflow/tfjs-layers')
//index.js Page({ async onReady() { //加載相機 const camera = wx.createCameraContext(this) // 加載模型 const net = await this.loadModel() this.setData({result: 'Loading'}) let count = 0 //每隔10幀獲取一張相機捕捉到的圖片 const listener = camera.onCameraFrame((frame) => { count++ if (count === 10) { if (net) { //對圖片內容進行預測 this.predict(net, frame) } count = 0 } }) listener.start() }, //加載模型 async loadModel() { const net = await tfl.loadLayersModel('https://yuantao.store/model.json') net.summary() return net }, async predict(net, frame){ //圖像預處理,API說明和用發可到tensorflow.google.cn查看 const imgData = {data: new Uint8Array(frame.data), width: frame.width, height: frame.height} const x = tf.tidy(() => { const imgTensor = tf.browser.fromPixels(imgData, 4) //轉換爲Tensor,微信小程序相機獲取的圖片有4維 const d = Math.floor((frame.height - frame.width) / 2) const imgSlice = imgTensor.slice([d, 0, 0], [frame.width, -1, 3]) //截取正方形區域,並丟掉最後一個維度,只保留3個維度 const imgResize = tf.image.resizeBilinear(imgSlice, [28, 28]) return imgResize.mean(2)//對最後一個維度去均值,將三通道轉換爲單通道 }) // console.log(x) const y = await net.predict(x.expandDims(0)).argMax(1) //預測,並獲取預測值最大的下標,及預測結果 const res = y.dataSync()[0]//預測結果爲一個對象,我們只需要值部分 this.setData({result: res}) }
})
|
END
主 編 | 王文星
責 編 | 馬原濤
where2go 團隊
微信號:算法與編程之美
長按識別二維碼關注我們!
溫馨提示:點擊頁面右下角“寫留言”發表評論,期待您的參與!期待您的轉發!