從零開始深度學習0612——pytorch入門之保存與加載+批訓練

----------------------------- Save and reload -------------------------

##################################################################################################################################

 

 


# torch.manual_seed(1)    # reproducible

# fake data
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size())  # noisy y data (tensor), shape=(100, 1)

# The code below is deprecated in Pytorch 0.4. Now, autograd directly supports tensors
# x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False)


def save():
    # save net1
    net1 = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1)
    )
    optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
    loss_func = torch.nn.MSELoss()

    for t in range(100):
        prediction = net1(x)
        loss = loss_func(prediction, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # plot result
    plt.figure(1, figsize=(10, 3))
    plt.subplot(131)
    plt.title('Net1')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)

    # 2 ways to save the net
    torch.save(net1, './net.pkl')  # save entire net
    torch.save(net1.state_dict(), './net_params.pkl')   # save only the parameters


def restore_net():
    # restore entire net1 to net2
    net2 = torch.load('./net.pkl')
    prediction = net2(x)

    # plot result
    plt.subplot(132)
    plt.title('Net2')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)


def restore_params():
    # restore only the parameters in net1 to net3
    net3 = torch.nn.Sequential(
        torch.nn.Linear(1, 10),
        torch.nn.ReLU(),
        torch.nn.Linear(10, 1)
    )

    # copy net1's parameters into net3
    net3.load_state_dict(torch.load('./net_params.pkl'))
    prediction = net3(x)

    # plot result
    plt.subplot(133)
    plt.title('Net3')
    plt.scatter(x.data.numpy(), y.data.numpy())
    plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
    plt.show()

# save net1
save()

# restore entire net (may slow)
restore_net()

# restore only the net parameters
restore_params()

 

 

 

------------------------batch train---------------------------------------------

 

for step, (batch_x, batch_y) in enumerate(loader):

這個循環不能暴露在外面,一定要寫個方法調用


torch.manual_seed(1)    # reproducible



BATCH_SIZE = 5

# BATCH_SIZE = 8



x = torch.linspace(1, 10, 10)       # this is x data (torch tensor)

y = torch.linspace(10, 1, 10)       # this is y data (torch tensor)



torch_dataset = Data.TensorDataset(x, y)

loader = Data.DataLoader(

    dataset=torch_dataset,      # torch TensorDataset format

    batch_size=BATCH_SIZE,      # mini batch size

    shuffle=True,               # random shuffle for training

    num_workers=2,              # subprocesses for loading data

)





def show_batch():

    for epoch in range(3):   # train entire dataset 3 times

        for step, (batch_x, batch_y) in enumerate(loader):  # for each training step

            # train your data...

            print('Epoch: ', epoch, '| Step: ', step, '| batch x: ',

                  batch_x.numpy(), '| batch y: ', batch_y.numpy())





if __name__ == '__main__':

    show_batch()

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章