利用pandas裏面的數據和Pytorch中的lstm簡單的做一下預測,用1970年到2000年訓練,用2001年到2016年的數據進行預測。
from pandas_datareader import wb
import torch.nn
import torch
import torch.optim
import csv
from IPython.display import display
import pandas as pd
import numpy
import matplotlib.pyplot as plt
class Net(torch.nn.Module):
def __init__(self,input_size,hidden_size):
super(Net,self).__init__()
self.rnn = torch.nn.LSTM(input_size,hidden_size)
self.fc = torch.nn.Linear(hidden_size,1)
def forward(self, x):
x = x[:,:,None]
x, _ = self.rnn(x)
x = self.fc(x)
x = x[:,:,0]
return x
countries=['BR', 'CA', 'CN', 'FR', 'DE', 'IN', 'IL','JP', 'SA', 'GB', 'US']
dat = wb.download(indicator='NY.GDP.PCAP.KD',country=countries,start=1970,end=2016)
df = dat.unstack().T
df.index = df.index.droplevel(0)
year_num, sample_num = df.shape
countries = df.columns
years = df.index
# with open('F:/tracker_programe/lstm_test/1970_2016.csv', 'w') as f:
# writer = csv.writer(f)
# for row in df:
# for col in row:
# writer.writerow(col)
# print(df)
net = Net(input_size=1,hidden_size=5)
print(net)
df_scaled = df / df.loc['2000']
years = df.index
train_seq_len = sum((years >= '1971')&(years <= '2000'))
test_seq_len = sum(years > '2000')
print('訓練集長度 = {},測試集長度={}'.format(train_seq_len,test_seq_len))
inputs=torch.tensor(df_scaled.iloc[:-1].values,dtype=torch.float32)
labels=torch.tensor(df_scaled.iloc[1:].values,dtype=torch.float32)
criterion = torch.nn.MSELoss()
optmizer = torch.optim.Adam(net.parameters())
for step in range(10001):
if step:
optmizer.zero_grad()
train_loss.backward()
optmizer.step()
preds = net(inputs)
train_preds = preds[:train_seq_len]
train_labels = labels[:train_seq_len]
train_loss = criterion(train_preds,train_labels)
test_preds = preds[-test_seq_len]
test_labels = labels[-test_seq_len]
test_loss = criterion(test_preds,test_labels)
if step % 500 ==0:
print('第{}次迭代:loss(訓練集)={},loss(測試集)={}'.format(step,train_loss,test_loss))
preds=net(inputs)
df_pred_scaled=pd.DataFrame(preds.detach().numpy(),index=years[1:],columns=df.columns)
df_pred=df_pred_scaled*df.loc['2000']
display((df_pred.loc['2001':]-df['2001':])/df['2001':])
輸出
Net(
(rnn): LSTM(1, 5)
(fc): Linear(in_features=5, out_features=1, bias=True)
)
訓練集長度 = 30,測試集長度=16
第0次迭代:loss(訓練集)=0.5748428106307983,loss(測試集)=0.9173784852027893
第500次迭代:loss(訓練集)=0.051534395664930344,loss(測試集)=0.017740309238433838
第1000次迭代:loss(訓練集)=0.012391290627419949,loss(測試集)=0.0012724784901365638
第1500次迭代:loss(訓練集)=0.004916149191558361,loss(測試集)=0.000621090061031282
第2000次迭代:loss(訓練集)=0.003049308666959405,loss(測試集)=0.0010105886030942202
第2500次迭代:loss(訓練集)=0.0023313837591558695,loss(測試集)=0.0013889336260035634
第3000次迭代:loss(訓練集)=0.001878467039205134,loss(測試集)=0.001620301860384643
第3500次迭代:loss(訓練集)=0.00158506422303617,loss(測試集)=0.0016249381005764008
第4000次迭代:loss(訓練集)=0.0013669944601133466,loss(測試集)=0.0014133700169622898
第4500次迭代:loss(訓練集)=0.001150963595137,loss(測試集)=0.0011018009390681982
第5000次迭代:loss(訓練集)=0.0009576826705597341,loss(測試集)=0.0008135449606925249
第5500次迭代:loss(訓練集)=0.0008130817441269755,loss(測試集)=0.0006490239757113159
第6000次迭代:loss(訓練集)=0.0007054363377392292,loss(測試集)=0.0005351279396563768
第6500次迭代:loss(訓練集)=0.0006444878526963294,loss(測試集)=0.0004590977623593062
第7000次迭代:loss(訓練集)=0.0006134260329417884,loss(測試集)=0.00043717006337828934
第7500次迭代:loss(訓練集)=0.0005939950933679938,loss(測試集)=0.0004407825763337314
第8000次迭代:loss(訓練集)=0.0005697330925613642,loss(測試集)=0.00044946340494789183
第8500次迭代:loss(訓練集)=0.0005507085588760674,loss(測試集)=0.0004614938225131482
第9000次迭代:loss(訓練集)=0.0005342827062122524,loss(測試集)=0.0004656326782424003
第9500次迭代:loss(訓練集)=0.0005212834221310914,loss(測試集)=0.00045894659706391394
第10000次迭代:loss(訓練集)=0.0005115617532283068,loss(測試集)=0.00044287487980909646
country Brazil Canada China France Germany India Israel \
year
2001 -0.002144 0.000075 -0.022060 -0.008991 -0.014243 -0.018933 0.030185
2002 -0.017341 -0.024670 -0.052552 -0.010788 -0.005069 -0.024296 0.018412
2003 -0.002197 -0.022479 -0.095633 -0.013584 -0.002956 -0.080484 0.017084
2004 -0.052607 -0.042801 -0.144129 -0.035649 -0.023291 -0.093914 -0.025987
2005 -0.038266 -0.045562 -0.207391 -0.026568 -0.020576 -0.124260 -0.014739
2006 -0.057623 -0.050707 -0.270292 -0.036556 -0.048844 -0.150937 -0.027125
2007 -0.085434 -0.051373 -0.326426 -0.041486 -0.051449 -0.198268 -0.039535
2008 -0.102728 -0.050965 -0.331391 -0.032274 -0.046527 -0.200981 -0.032936
2009 -0.077138 -0.016878 -0.359433 -0.003720 0.008244 -0.278661 -0.024666
2010 -0.154166 -0.062304 -0.407067 -0.041994 -0.076402 -0.331562 -0.069956
2011 -0.153396 -0.066169 -0.437540 -0.045873 -0.102181 -0.348864 -0.077647
2012 -0.159459 -0.056279 -0.440933 -0.030385 -0.069733 -0.359078 -0.062049
2013 -0.174038 -0.059620 -0.435616 -0.027806 -0.066858 -0.393259 -0.082605
2014 -0.177628 -0.074853 -0.435659 -0.035818 -0.091849 -0.434230 -0.090312
2015 -0.151927 -0.072834 -0.453737 -0.042662 -0.105560 -0.471287 -0.099816
2016 -0.140150 -0.078354 -0.481240 -0.046119 -0.113200 -0.487618 -0.114316
country Japan Saudi Arabia United Kingdom United States
year
2001 -0.005066 0.050080 -0.019758 0.001910
2002 0.000141 0.076382 -0.027843 -0.014702
2003 -0.013398 -0.024966 -0.044373 -0.029728
2004 -0.029923 -0.027837 -0.045746 -0.043851
2005 -0.030170 -0.016146 -0.060815 -0.042710
2006 -0.032287 -0.004100 -0.065461 -0.045268
2007 -0.041139 -0.014459 -0.076835 -0.048820
2008 -0.025376 -0.061920 -0.060479 -0.040589
2009 0.017596 0.006841 -0.027039 -0.017508
2010 -0.055529 -0.043322 -0.063360 -0.054677
2011 -0.027475 -0.083412 -0.061569 -0.047642
2012 -0.040973 -0.067815 -0.058723 -0.051977
2013 -0.042360 -0.050125 -0.057920 -0.046485
2014 -0.044990 -0.074047 -0.074974 -0.064566
2015 -0.058768 -0.097769 -0.080282 -0.076972
2016 -0.065484 -0.090002 -0.089329 -0.077763
China的數據明顯估計不對,這也很正常,2000年我國的經濟發展還不是很行~00年以後開始騰飛,預測不準很正常。