LSTM 預測各國人均GPD

利用pandas裏面的數據和Pytorch中的lstm簡單的做一下預測,用1970年到2000年訓練,用2001年到2016年的數據進行預測。

from pandas_datareader import wb
import torch.nn
import torch
import torch.optim
import csv
from IPython.display import display
import pandas as pd 
import numpy
import matplotlib.pyplot as plt
class Net(torch.nn.Module):
    def __init__(self,input_size,hidden_size):
        super(Net,self).__init__()
        self.rnn = torch.nn.LSTM(input_size,hidden_size)
        self.fc = torch.nn.Linear(hidden_size,1)

    def forward(self, x):
        x = x[:,:,None]
        x, _ = self.rnn(x)
        x = self.fc(x)
        x = x[:,:,0]
        return x

countries=['BR', 'CA', 'CN', 'FR', 'DE', 'IN', 'IL','JP', 'SA', 'GB', 'US']
dat = wb.download(indicator='NY.GDP.PCAP.KD',country=countries,start=1970,end=2016)
df = dat.unstack().T
df.index = df.index.droplevel(0)
year_num, sample_num = df.shape

countries = df.columns

years = df.index
# with open('F:/tracker_programe/lstm_test/1970_2016.csv', 'w') as f:
#     writer = csv.writer(f)
#     for row in df:
#         for col in row:
#             writer.writerow(col)
# print(df)



net = Net(input_size=1,hidden_size=5)

print(net)

df_scaled = df / df.loc['2000']

years = df.index
train_seq_len = sum((years >= '1971')&(years <= '2000'))
test_seq_len = sum(years > '2000')

print('訓練集長度 = {},測試集長度={}'.format(train_seq_len,test_seq_len))

inputs=torch.tensor(df_scaled.iloc[:-1].values,dtype=torch.float32)
labels=torch.tensor(df_scaled.iloc[1:].values,dtype=torch.float32)

criterion = torch.nn.MSELoss()
optmizer = torch.optim.Adam(net.parameters())
for step in range(10001):
    if step:
        optmizer.zero_grad()
        train_loss.backward()
        optmizer.step()
    preds = net(inputs)
    train_preds = preds[:train_seq_len]
    train_labels = labels[:train_seq_len]
    train_loss = criterion(train_preds,train_labels)

    test_preds = preds[-test_seq_len]
    test_labels = labels[-test_seq_len]
    test_loss = criterion(test_preds,test_labels)

    if step % 500 ==0:
        print('第{}次迭代:loss(訓練集)={},loss(測試集)={}'.format(step,train_loss,test_loss))

preds=net(inputs)
df_pred_scaled=pd.DataFrame(preds.detach().numpy(),index=years[1:],columns=df.columns)
df_pred=df_pred_scaled*df.loc['2000']
display((df_pred.loc['2001':]-df['2001':])/df['2001':])

輸出

Net(
  (rnn): LSTM(1, 5)
  (fc): Linear(in_features=5, out_features=1, bias=True)
)
訓練集長度 = 30,測試集長度=16
第0次迭代:loss(訓練集)=0.5748428106307983,loss(測試集)=0.9173784852027893
第500次迭代:loss(訓練集)=0.051534395664930344,loss(測試集)=0.017740309238433838
第1000次迭代:loss(訓練集)=0.012391290627419949,loss(測試集)=0.0012724784901365638
第1500次迭代:loss(訓練集)=0.004916149191558361,loss(測試集)=0.000621090061031282
第2000次迭代:loss(訓練集)=0.003049308666959405,loss(測試集)=0.0010105886030942202
第2500次迭代:loss(訓練集)=0.0023313837591558695,loss(測試集)=0.0013889336260035634
第3000次迭代:loss(訓練集)=0.001878467039205134,loss(測試集)=0.001620301860384643
第3500次迭代:loss(訓練集)=0.00158506422303617,loss(測試集)=0.0016249381005764008
第4000次迭代:loss(訓練集)=0.0013669944601133466,loss(測試集)=0.0014133700169622898
第4500次迭代:loss(訓練集)=0.001150963595137,loss(測試集)=0.0011018009390681982
第5000次迭代:loss(訓練集)=0.0009576826705597341,loss(測試集)=0.0008135449606925249
第5500次迭代:loss(訓練集)=0.0008130817441269755,loss(測試集)=0.0006490239757113159
第6000次迭代:loss(訓練集)=0.0007054363377392292,loss(測試集)=0.0005351279396563768
第6500次迭代:loss(訓練集)=0.0006444878526963294,loss(測試集)=0.0004590977623593062
第7000次迭代:loss(訓練集)=0.0006134260329417884,loss(測試集)=0.00043717006337828934
第7500次迭代:loss(訓練集)=0.0005939950933679938,loss(測試集)=0.0004407825763337314
第8000次迭代:loss(訓練集)=0.0005697330925613642,loss(測試集)=0.00044946340494789183
第8500次迭代:loss(訓練集)=0.0005507085588760674,loss(測試集)=0.0004614938225131482
第9000次迭代:loss(訓練集)=0.0005342827062122524,loss(測試集)=0.0004656326782424003
第9500次迭代:loss(訓練集)=0.0005212834221310914,loss(測試集)=0.00045894659706391394
第10000次迭代:loss(訓練集)=0.0005115617532283068,loss(測試集)=0.00044287487980909646
country    Brazil    Canada     China    France   Germany     India    Israel  \
year
2001    -0.002144  0.000075 -0.022060 -0.008991 -0.014243 -0.018933  0.030185
2002    -0.017341 -0.024670 -0.052552 -0.010788 -0.005069 -0.024296  0.018412
2003    -0.002197 -0.022479 -0.095633 -0.013584 -0.002956 -0.080484  0.017084
2004    -0.052607 -0.042801 -0.144129 -0.035649 -0.023291 -0.093914 -0.025987
2005    -0.038266 -0.045562 -0.207391 -0.026568 -0.020576 -0.124260 -0.014739
2006    -0.057623 -0.050707 -0.270292 -0.036556 -0.048844 -0.150937 -0.027125
2007    -0.085434 -0.051373 -0.326426 -0.041486 -0.051449 -0.198268 -0.039535
2008    -0.102728 -0.050965 -0.331391 -0.032274 -0.046527 -0.200981 -0.032936
2009    -0.077138 -0.016878 -0.359433 -0.003720  0.008244 -0.278661 -0.024666
2010    -0.154166 -0.062304 -0.407067 -0.041994 -0.076402 -0.331562 -0.069956
2011    -0.153396 -0.066169 -0.437540 -0.045873 -0.102181 -0.348864 -0.077647
2012    -0.159459 -0.056279 -0.440933 -0.030385 -0.069733 -0.359078 -0.062049
2013    -0.174038 -0.059620 -0.435616 -0.027806 -0.066858 -0.393259 -0.082605
2014    -0.177628 -0.074853 -0.435659 -0.035818 -0.091849 -0.434230 -0.090312
2015    -0.151927 -0.072834 -0.453737 -0.042662 -0.105560 -0.471287 -0.099816
2016    -0.140150 -0.078354 -0.481240 -0.046119 -0.113200 -0.487618 -0.114316

country     Japan  Saudi Arabia  United Kingdom  United States
year
2001    -0.005066      0.050080       -0.019758       0.001910
2002     0.000141      0.076382       -0.027843      -0.014702
2003    -0.013398     -0.024966       -0.044373      -0.029728
2004    -0.029923     -0.027837       -0.045746      -0.043851
2005    -0.030170     -0.016146       -0.060815      -0.042710
2006    -0.032287     -0.004100       -0.065461      -0.045268
2007    -0.041139     -0.014459       -0.076835      -0.048820
2008    -0.025376     -0.061920       -0.060479      -0.040589
2009     0.017596      0.006841       -0.027039      -0.017508
2010    -0.055529     -0.043322       -0.063360      -0.054677
2011    -0.027475     -0.083412       -0.061569      -0.047642
2012    -0.040973     -0.067815       -0.058723      -0.051977
2013    -0.042360     -0.050125       -0.057920      -0.046485
2014    -0.044990     -0.074047       -0.074974      -0.064566
2015    -0.058768     -0.097769       -0.080282      -0.076972
2016    -0.065484     -0.090002       -0.089329      -0.077763

China的數據明顯估計不對,這也很正常,2000年我國的經濟發展還不是很行~00年以後開始騰飛,預測不準很正常。

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章