pytorch(二):分類

import torch
import torch.nn.functional as f
from torch.autograd import Variable
import matplotlib.pyplot as plt


# 建造數據集
data = torch.ones((100, 2))
x0 = torch.normal(2*data, 1)
y0 = torch.zeros(100)  # y0是標籤  shape(100,),是一維
x1 = torch.normal(-2*data, 1)
y1 = torch.ones(100)  # y1也是標籤 shape(100,),是一維
x = torch.cat((x0, x1), 0).type(torch.FloatTensor)  # 參數0表示維度,在縱向方向將x0,x1合併,合併後shape(200, 2))
y = torch.cat((y0, y1), 0).type(torch.LongTensor)  # 標籤是0或1,類型爲整數,LongTensor = 64-bit integer,
x, y = Variable(x), Variable(y)  # 訓練神經網絡只能接受變量輸入,故要把x, y轉化爲變量
plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1],  # 這兩個參數分別代表x,y軸座標
            c=y.data.numpy(), s=100, cmap='RdYlGn')  # c爲color,y有兩種標籤,代表兩種顏色的點,'RdYlGn'紅色和綠色
plt.show()


# 建造神經網絡模型
class Net(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden = torch.nn.Linear(n_feature, n_hidden)
        self.out = torch.nn.Linear(n_hidden, n_output)
        
    def forward(self, x):
        x = f.relu(self.hidden(x))
        y = self.out(x)
        return y


# 定義神經網絡
net = Net(n_feature=2, n_hidden=10, n_output=2) 
# n_output=2,因爲它返回一個元素爲2的列表。[0, 1]表示學習到的內容爲標籤1,[1, 0]表示學習到的內容爲標籤0。
print(net)


# 訓練神經網絡模型並將訓練過程可視化
optimizer = torch.optim.SGD(net.parameters(), lr=0.02)
loss_func = torch.nn.CrossEntropyLoss()
plt.ion()
for i in range(100):
    out = net(x)
    loss = loss_func(out, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    # 繪圖
    if i % 2 == 0:
        plt.cla()
        # torch.max(a,1) 返回每一行中最大值的那個元素,且返回其索引(返回最大元素在這一行的列索引
        # f.softmax(out)是將out的內容以概率表示。
        # torch.max()返回的是兩個Variable,第一個Variable存的是最大值,第二個存的是其對應的位置索引index。這裏我們想要得到的是索引,所以後面用[1]。
        prediction = torch.max(f.softmax(out), 1)[1]
        pred_y = prediction.data.numpy().squeeze()
        target_y = y.data.numpy()
        plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=pred_y, s=100, cmap='RdYlGn')
        accuracy = sum(pred_y == target_y)/200
        plt.text(1.5, -4, 'accuracy=%.2f'%accuracy, fontdict={'size':10, 'color':'red'})
        plt.pause(0.1)
plt.ioff()
plt.show()

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章