使用gluon實現MLP比較簡單,本章不作過多理論說明。下面是相關代碼:
1、數據讀取
batch_size = 100
transformer = gn.data.vision.transforms.ToTensor()
train_data = gn.data.DataLoader(dataset=mnist_train, batch_size=batch_size, shuffle=True)
test_data = gn.data.DataLoader(dataset=mnist_test, batch_size=batch_size, shuffle=False)
2、模型定義
net=gn.nn.Sequential()
with net.name_scope():
net.add(gn.nn.Flatten())
net.add(gn.nn.Dense(256,activation='relu'))
net.add(gn.nn.Dense(10))
net.initialize() # 模型(參數)初始化
print(net)
3、定義準確率
# 定義準確率
def accuracy(output,label):
return nd.mean(output.argmax(axis=1)==label).asscalar()
def evaluate_accuracy(data_iter,net):# 定義測試集準確率
acc=0
for data,label in data_iter:
data,label=transform(data,label)
output=net(data)
acc+=accuracy(output,label)
return acc/len(data_iter)
4、損失函數和梯度下降
# softmax和交叉熵分開的話數值可能會不穩定
cross_loss=gn.loss.SoftmaxCrossEntropyLoss()
# 優化
train_step=gn.Trainer(net.collect_params(),'sgd',{"learning_rate":0.1})
5、訓練
epochs=20
for epoch in range(epochs):
train_loss=0
train_acc=0
for image,y in train_data:
image,y=transform(image,y) # 類型轉換,數據歸一化
with ag.record():
output=net(image)
loss=cross_loss(output,y)
loss.backward()
train_step.step(batch_size)
train_loss+=nd.mean(loss).asscalar()
train_acc+=accuracy(output,y)
test_acc=evaluate_accuracy(test_data,net)
print("Epoch %d, Loss:%f, Train acc:%f, Test acc:%f"
%(epoch,train_loss/len(train_data),train_acc/len(train_data),test_acc))
訓練結果:
6、預測
# 訓練完成後,可對樣本進行預測
image_10,label_10=mnist_test[:10] #拿到前10個數據
show_image(image_10)
print("真實樣本標籤:",label_10)
print("真實數字標籤對應的服飾名:",get_fashion_mnist_labels(label_10))
image_10,label_10=transform(image_10,label_10)
predict_label=net(image_10).argmax(axis=1)
print("預測樣本標籤:",predict_label.astype("int8"))
print("預測數字標籤對應的服飾名:",get_fashion_mnist_labels(predict_label.asnumpy()))
預測結果:
最後附上所有源碼:
import mxnet.autograd as ag
import mxnet.ndarray as nd
import mxnet.gluon as gn
def transform(data, label):
return data.astype("float32") / 255, label.astype("float32") # 樣本歸一化
mnist_train = gn.data.vision.FashionMNIST(train=True)
mnist_test = gn.data.vision.FashionMNIST(train=False)
data, label = mnist_train[0:9]
print(data.shape, label) # 查看數據維度
import matplotlib.pyplot as plt
def show_image(image): # 顯示圖像
n = image.shape[0]
_, figs = plt.subplots(1, n, figsize=(15, 15))
for i in range(n):
figs[i].imshow(image[i].reshape((28, 28)).asnumpy())
plt.show()
def get_fashion_mnist_labels(labels): # 顯示圖像標籤
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
#
# show_image(data)
# print(get_fashion_mnist_labels(label))
'''----數據讀取----'''
batch_size = 100
transformer = gn.data.vision.transforms.ToTensor()
train_data = gn.data.DataLoader(dataset=mnist_train, batch_size=batch_size, shuffle=True)
test_data = gn.data.DataLoader(dataset=mnist_test, batch_size=batch_size, shuffle=False)
'''---定義模型---'''
net=gn.nn.Sequential()
with net.name_scope():
net.add(gn.nn.Flatten())
net.add(gn.nn.Dense(256,activation='relu'))
net.add(gn.nn.Dense(10))
net.initialize() # 模型(參數)初始化
print(net)
# 定義準確率
def accuracy(output,label):
return nd.mean(output.argmax(axis=1)==label).asscalar()
def evaluate_accuracy(data_iter,net):# 定義測試集準確率
acc=0
for data,label in data_iter:
data,label=transform(data,label)
output=net(data)
acc+=accuracy(output,label)
return acc/len(data_iter)
# softmax和交叉熵分開的話數值可能會不穩定
cross_loss=gn.loss.SoftmaxCrossEntropyLoss()
# 優化
train_step=gn.Trainer(net.collect_params(),'sgd',{"learning_rate":0.1})
'''---訓練---'''
epochs=20
for epoch in range(epochs):
train_loss=0
train_acc=0
for image,y in train_data:
image,y=transform(image,y) # 類型轉換,數據歸一化
with ag.record():
output=net(image)
loss=cross_loss(output,y)
loss.backward()
train_step.step(batch_size)
train_loss+=nd.mean(loss).asscalar()
train_acc+=accuracy(output,y)
test_acc=evaluate_accuracy(test_data,net)
print("Epoch %d, Loss:%f, Train acc:%f, Test acc:%f"
%(epoch,train_loss/len(train_data),train_acc/len(train_data),test_acc))
'''----預測-------'''
# 訓練完成後,可對樣本進行預測
image_10,label_10=mnist_test[:10] #拿到前10個數據
show_image(image_10)
print("真實樣本標籤:",label_10)
print("真實數字標籤對應的服飾名:",get_fashion_mnist_labels(label_10))
image_10,label_10=transform(image_10,label_10)
predict_label=net(image_10).argmax(axis=1)
print("預測樣本標籤:",predict_label.astype("int8"))
print("預測數字標籤對應的服飾名:",get_fashion_mnist_labels(predict_label.asnumpy()))