概要
mnist 數據集鏈接:http://yann.lecun.com/exdb/mnist/
fashion_mnist:https://github.com/zalandoresearch/fashion-mnist
mnist 已經被用爛了,也太簡單了。所以現在準備採用fashion_mnist。
兩者的讀取方式完全一致。這裏以fashion mnist作爲例子。
FashionMNIST 是一個替代 MNIST 手寫數字集 的圖像數據集。 它是由 Zalando(一家德國的時尚科技公司)旗下的研究部門提供。其涵蓋了來自 10 種類別的共 7 萬個不同商品的正面圖片。
FashionMNIST 的大小、格式和訓練集/測試集劃分與原始的 MNIST 完全一致。60000/10000 的訓練測試數據劃分,28x28 的灰度圖片。你可以直接用它來測試你的機器學習和深度學習算法性能,且不需要改動任何的代碼。
說白了就是手寫數字沒有衣服鞋子之類的更復雜。
數據格式和mnist完全一致:
標註編號描述
0:T-shirt/top(T恤)
1:Trouser(褲子)
2:Pullover(套衫)
3:Dress(裙子)
4:Coat(外套)
5:Sandal(涼鞋)
6:Shirt(汗衫)
7:Sneaker(運動鞋)
8:Bag(包)
9:Ankle boot(踝靴)
代碼
# -*- coding: utf-8 -*-
from sklearn import neighbors
from read_data import DataUtils
import datetime
import numpy as np
import struct
import matplotlib.pyplot as plt
def read_image(file_name):
'''
:param file_name: 文件路徑
:return: 訓練或者測試數據
如下是訓練的圖片的二進制格式
[offset] [type] [value] [description]
0000 32 bit integer 0x00000803(2051) magic number
0004 32 bit integer 60000 number of images
0008 32 bit integer 28 number of rows
0012 32 bit integer 28 number of columns
0016 unsigned byte ?? pixel
0017 unsigned byte ?? pixel
........
xxxx unsigned byte ?? pixel
'''
file_handle=open(file_name,"rb") #以二進制打開文檔
file_content=file_handle.read() #讀取到緩衝區中
head = struct.unpack_from('>IIII', file_content, 0) # 取前4個整數,返回一個元組
offset = struct.calcsize('>IIII')
imgNum = head[1] #圖片數
width = head[2] #寬度
height = head[3] #高度
bits = imgNum * width * height # data一共有60000*28*28個像素值
bitsString = '>' + str(bits) + 'B' # fmt格式:'>47040000B'
imgs = struct.unpack_from(bitsString, file_content, offset) # 取data數據,返回一個元組
imgs_array=np.array(imgs).reshape((imgNum,width*height)) #最後將讀取的數據reshape成 【圖片數,圖片像素】二維數組
return imgs_array
def out_image(img):
'''
:param img: 圖片像素組
:return:
'''
plt.figure()
plt.imshow(img)
plt.show()
def read_label(file_name):
'''
:param file_name:
:return:
標籤的格式如下:
[offset] [type] [value] [description]
0000 32 bit integer 0x00000801(2049) magic number (MSB first)
0004 32 bit integer 60000 number of items
0008 unsigned byte ?? label
0009 unsigned byte ?? label
........
xxxx unsigned byte ?? label
The labels values are 0 to 9.
'''
file_handle = open(file_name, "rb") # 以二進制打開文檔
file_content = file_handle.read() # 讀取到緩衝區中
head = struct.unpack_from('>II', file_content, 0) # 取前2個整數,返回一個元組
offset = struct.calcsize('>II')
labelNum = head[1] # label數
bitsString = '>' + str(labelNum) + 'B' # fmt格式:'>47040000B'
label = struct.unpack_from(bitsString, file_content, offset) # 取data數據,返回一個元組
return np.array(label)
def get_data():
# 文件獲取
train_image = "./mnist/train-images-idx3-ubyte"
test_image = "./mnist/t10k-images-idx3-ubyte"
train_label = "./mnist/train-labels-idx1-ubyte"
test_label = "./mnist/t10k-labels-idx1-ubyte"
# 讀取數據
train_x = read_image(train_image)
test_x = read_image(test_image)
train_y = read_label(train_label)
test_y = read_label(test_label)
print(train_y[0:10])
print(test_y[0:10])
out_image(np.array(test_x[0]).reshape(28, 28))
return train_x,train_y,test_x,test_y
if __name__ == "__main__":
get_data()
結果展示
C:\ProgramData\Anaconda3\python.exe E:/hw/hw0.py
[9 0 0 3 0 2 7 2 5 5]
[9 2 1 1 6 1 4 6 5 7]
Process finished with exit code 0
從結果上看,9——裸靴,圖片看正確。解析ok
後面就是採用各種機器學習算法進行分類。