TensorFlow實操之--服裝圖像識別問題-基於Keras

問題描述

這裏我們還是以MNIST數據爲例,與上一篇的不同在於上一篇是關於手寫體數字識別,本篇是關於服裝的簡單識別。Fashion Mnist數據集由70,000張黑白圖片構成,每張圖片大小爲 28x28,由十類服飾圖片構成。另一個MNIST數據集是手寫數字,Fashion MNIST 與之相比更有挑戰性,適合用來驗證算法

解決思路

一般滴,識別思路是

  1. 獲取數據,包括訓練數據和測試數據。
  2. 模型建立
  3. 訓練
  4. 模型驗證

這篇文檔使用高級APItf.keras在TensorFlow中搭建和訓練模型。 下面簡單介紹一下Keras。

Keras簡介

Keras 是一個用 Python 編寫的高級神經網絡 API,它能夠以 TensorFlow, CNTK, 或者 Theano 作爲後端運行。Keras 的開發重點是支持快速的實驗。能夠以最小的時延把你的想法轉換爲實驗結果,是做好研究的關鍵。官方網址是https://keras.io/zh/。使用Keras主要下面幾個步驟

  1. 準備好數據,並做好預處理。
  2. 構建模型。Keras 的核心數據結構是 model,一種組織網絡層的方式。最簡單的模型是 Sequential 順序模型,它由多個網絡層線性堆疊。對於更復雜的結構,你應該使用 Keras 函數式 API,它允許構建任意的神經網絡圖。可以簡單地使用 .add() 來堆疊模型。
  3. 在完成了模型的構建後, 可以使用 .compile() 來配置學習過程
  4. 訓練,給模型喂數據
  5. 只需一行代碼就能評估模型性能:loss_and_metrics = model.evaluate(x_test, y_test, batch_size=128)
  6. 對新的數據生成預測:classes = model.predict(x_test, batch_size=128)
    Keras的基本架構和使用層次如下圖所示:
    在這裏插入圖片描述

神經網絡服裝識別

下面我們利用keras採用神經網絡的方法對服裝圖像進行識別。

引入必要的包

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import numpy as np

獲取數據

#獲取數據
fashion_mnist=keras.datasets.fashion_mnist
(train_images,train_labels),(test_images,test_labels)=fashion_mnist.load_data()
print(train_images.shape)
# print(train_images[1])
print(train_labels[1])

我們獲取了數據,通過shape可以查看大小,並可以看看內部的數值。

數據預處理

在訓練網絡之前,必須對數據進行預處理。當前數據的像素值在0-255之間,統一不同數據間的量綱,有助於我們接下來對數據進行分析和計算,即對數據進行歸一化處理。

#對數據進行歸一化處理
train_images=train_images/255
test_images=test_images/255

構建網絡模型

通過keras的Sequential創建模型,並添加神經網絡的層次

#構建網絡模型
model=keras.Sequential([keras.layers.Flatten(input_shape=(28,28)),
                       keras.layers.Dense(128,activation='relu'),
                       keras.layers.Dense(10,activation='softmax')
                        ])

配置網絡模型

keras的配置網絡也是一句話的事兒,可以看到設置了優化器,損失函數等。

model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])

評估模型

也是一句話

test_loss, test_acc = model.evaluate(test_images, test_labels, verbose=2)

預測

如有必要,我們對新的數據進行預測,這裏舉個例子,所有的測試數據進行預測。

predictions = model.predict(test_images)#對所有標籤與圖片進行預測

神經網絡圖像識別小結

訓練模型在整個測試數據集的表現情況如下,測試準確率接近88%。

60000/60000 [==============================] - 2s 35us/step - loss: 0.2384 - acc: 0.9115
Test accuracy: 0.8796

從對keras的神經網絡圖像識別可以看到

  • Keras使用起來簡直太簡單了,太友好了。把深度學習的細節完全封裝好,甚至把學習框架也封裝了。可以說是非常容易上手。但是這種簡單的,又讓人害怕和擔心,長期使用keras豈不是比掉包俠還調包蝦,所以對深度學習的內部原理還是要細細研究的。
  • 上面的方法如手寫數字識別的第一個方法一樣,沒用利用隱藏層,沒有加入正則化,沒有還用CNN,相信通過把模型建立的更合理更科學,效果會更好。我們後續再操作。
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章