微信小程序|調用tensorflow自定義模型

歡迎點擊「算法與編程之美」↑關注我們!

本文首發於微信公衆號:"算法與編程之美",歡迎關注,及時瞭解更多此係列文章。

問題描述

在成功調用官網打包好的tensorflowjs模型後,怎麼調用自己的模型呢?又需要做哪些處理呢?

解決方案

1)安裝好pythontensorflow

2)安裝tensorflowjs : pip install tensorflowjs

注:如果你的tensorflow版本是2.0的,在下載tfjs時可能會被更新爲1.15版本的。可以考慮新建個python環境。

3)準備已經訓練好的模型,並通過 model.save(“模型命名.h5”) 代碼將模型保存爲h5格式的文件。

下面是本文使用的mnist手寫數字集的模型代碼案例:

import tensorflow as tf

mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()

x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([

  tf.keras.layers.Flatten(input_shape=(28, 28)),  

  tf.keras.layers.Dense(128, activation='relu'), 

  tf.keras.layers.Dropout(0.2),  tf.keras.layers.Dense(10, activation='softmax')

])

model.compile(optimizer='adam',

                              loss='sparse_categorical_crossentropy',  

                              metrics=['accuracy'])

model.fit(x_train, y_train, epochs=5)

model.evaluate(x_test, y_test)

model.save('D:\\test/mnist.h5')

4)通過tensorflowjs_converter命令將h5格式的模型文件轉換爲json格式的文件。

1.打開pycharm的Terminal指令框

2. 輸入轉換指令:

tensorflowjs_converter--input_format=keras D:\\test/mnist.h5  D:\\test

註釋:tensorflowjs_converter –模型格式 模型地址 保存地址

3.查看model.json是否生成

5)將模型放在服務器上,如果沒有可以在本地創建,步驟如下 :

1.打開pycharm的Terminal的指令框

2.輸入python3 -m http.server 8000

3.打開瀏覽器輸入 localhost:8000  輸出如下界面

如果出現localhost拒絕訪問,可能是你的系統沒開啓iis服務,只能手動開啓了。

未開啓的建議依次按以下步驟來:

1 .百度:如何安裝iss服務

2 .打開管理工具

3.進入管理工具界面,單擊“Internet Information Services (IIS)管理器”。

4.右鍵單擊“網站”,選擇“添加網站”。

5.在彈出的界面中輸入網站名稱、選擇物理路徑(model.json所在的文件地址)、IP地址輸入爲127.0.0.1、端口爲8000,然後點擊確定。

6.打開目錄展示功能:目錄瀏覽打開功能啓用

6) 在項目中安裝相應的庫

詳細過程請參考之前發佈的博客《微信小程序與tensorflow.js準備工作》在項目目錄下使用npm安裝對應包,安裝代碼如下:

npm install fetch-wechat

npm install @tensorflow/tfjs-converter

npm install @tensorflow/tfjs-core

npm install @tensorflow/tfjs-layers

npm install regenerator-runtime

7效果

8代碼部分

代碼較爲簡單,說明以註釋方式放在代碼旁邊(只展示主體代碼部分,完成項目代碼下載鏈接:

https://pan.baidu.com/s/18VcMiNaiEjC_Y_Yz1gJ_1g 

const regeneratorRuntime = require('regenerator-runtime')

const tf = require('@tensorflow/tfjs-core')

const tfl = require('@tensorflow/tfjs-layers')

 

//index.js

Page({

  async onReady() {

    //加載相機

    const camera = wx.createCameraContext(this)

    // 加載模型

    const net = await this.loadModel()

    this.setData({result: 'Loading'})

    let count = 0

    //每隔10幀獲取一張相機捕捉到的圖片

    const listener = camera.onCameraFrame((frame) => {

      count++

      if (count === 10) {

        if (net) {

          //對圖片內容進行預測

          this.predict(net, frame)

        }

        count = 0

      }

    })

    listener.start()

  },

  //加載模型

  async loadModel() {

    const net = await tfl.loadLayersModel('https://yuantao.store/model.json')

    net.summary()

    return net

  },

  async predict(net, frame){

    //圖像預處理,API說明和用發可到tensorflow.google.cn查看

    const imgData = {data: new Uint8Array(frame.data), width: frame.width, height: frame.height}

    const x = tf.tidy(() => {

      const imgTensor = tf.browser.fromPixels(imgData, 4)

      //轉換爲Tensor,微信小程序相機獲取的圖片有4

      const d = Math.floor((frame.height - frame.width) / 2)

      const imgSlice = imgTensor.slice([d, 0, 0], [frame.width, -1, 3])

      //截取正方形區域,並丟掉最後一個維度,只保留3個維度

      const imgResize = tf.image.resizeBilinear(imgSlice, [28, 28])

      return imgResize.mean(2)//對最後一個維度去均值,將三通道轉換爲單通道

    })

    // console.log(x)

    const y = await net.predict(x.expandDims(0)).argMax(1)

    //預測,並獲取預測值最大的下標,及預測結果

    const res = y.dataSync()[0]//預測結果爲一個對象,我們只需要值部分

    this.setData({result: res})

  }

  

})

 

 

 

END


主       編   |   王文星

責       編   |   馬原濤

 where2go 團隊


   

微信號:算法與編程之美          

長按識別二維碼關注我們!

溫馨提示:點擊頁面右下角“寫留言”發表評論,期待您的參與!期待您的轉發!

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章