GAN的基本結構
GAN的主要結構包括一個生成器G(Generator)和一個判別器D(Discriminator)
GAN 充分利用“對抗過程”訓練兩個神經網絡,這兩個網絡會互相博弈直至達到一種理想的平衡狀態,我們這個例子中的警察和罪犯就相當於這兩個神經網絡。其中一個神經網絡叫做生成器網絡 G(Z),它會使用輸入隨機噪聲數據,生成和已有數據集非常接近的數據,它學習的是數據分佈;另一個神經網絡叫鑑別器網絡 D(X),它會以生成的數據作爲輸入,嘗試鑑別出哪些是生成的數據,哪些是真實數據。鑑別器的核心是實現二元分類,輸出的結果是輸入數據來自真實數據集(和合成數據或虛假數據相對)的概率。
整個過程的目標函數從正式意義上可以寫爲:
前面所說的 GAN 最終能達到一種理想的平衡狀態,是指生成器應該能模擬真實的數據,鑑別器輸出的概率應該爲 0.5, 即生成的數據和真實數據一致。也就是說,它不確定來自生成器的新數據是真實還是虛假,二者的概率相等(這樣熵最大)。
這裏,使用GAN生成正弦信號,下面給出代碼:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
# torch.manual_seed(1) # reproducible
# np.random.seed(1)
# Hyper Parameters
BATCH_SIZE = 64
LR_G = 0.0001 # learning rate for generator
LR_D = 0.0001 # learning rate for discriminator
N_IDEAS = 8 # think of this as number of ideas for generating an art work(Generator)
ART_COMPONENTS = 15 # it could be total point G can drew in the canvas
PAINT_POINTS = np.vstack([np.linspace(-1, 1, ART_COMPONENTS) for _ in range(BATCH_SIZE)])
def artist_works(): # painting from the famous artist (real target)
# a = np.random.uniform(1, 2, size=BATCH_SIZE)[:, np.newaxis]
r = 0.02 * np.random.randn(1, ART_COMPONENTS)
paintings = np.sin(PAINT_POINTS * np.pi) + r
paintings = torch.from_numpy(paintings).float()
return paintings
# G = nn.Sequential( # Generator
# nn.Linear(N_IDEAS, 128), # random ideas (could from normal distribution)
# nn.ReLU(),
# nn.Linear(128, ART_COMPONENTS), # making a painting from these random ideas
# )
#
# D = nn.Sequential( # Discriminator
# nn.Linear(ART_COMPONENTS, 128), # receive art work either from the famous artist or a newbie like G
# nn.ReLU(),
# nn.Linear(128, 1),
# nn.Sigmoid(), # tell the probability that the art work is made by artist
# )
class Ge(nn.Module):
def __init__(self):
super(Ge,self).__init__()
self.fc1=nn.Linear(N_IDEAS,128)
self.fc2=nn.Linear(128,ART_COMPONENTS)
def forward(self, x):
x=F.relu(self.fc1(x))
x=self.fc2(x)
return x
class De(nn.Module):
def __init__(self):
super(De,self).__init__()
self.fc1=nn.Linear(ART_COMPONENTS,128)
self.fc2=nn.Linear(128,1)
def forward(self,x):
x=F.relu(self.fc1(x))
x=F.sigmoid(self.fc2(x))
return x
G=Ge()
D=De()
opt_D = torch.optim.Adam(D.parameters(), lr=LR_D)
opt_G = torch.optim.Adam(G.parameters(), lr=LR_G)
plt.ion() # something about continuous plotting
D_loss_history = []
G_loss_history = []
for step in range(10000):
artist_paintings = artist_works() # real painting from artist
G_ideas = torch.randn(BATCH_SIZE, N_IDEAS) # random ideas
G_paintings = G(G_ideas) # fake painting from G (random ideas)
prob_artist0 = D(artist_paintings) # D try to increase this prob
prob_artist1 = D(G_paintings) # D try to reduce this prob
D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))
G_loss = torch.mean(torch.log(1. - prob_artist1))
D_loss_history.append(D_loss)
G_loss_history.append(G_loss)
opt_D.zero_grad()
D_loss.backward(retain_graph=True) # reusing computational graph
opt_D.step()
opt_G.zero_grad()
G_loss.backward()
opt_G.step()
print("4444d",PAINT_POINTS[0])
if step % 1000 == 0: # plotting
plt.cla()
plt.plot(PAINT_POINTS[0], G_paintings.data.numpy()[0], c='r', lw=3, label='Generated painting', )
plt.plot(PAINT_POINTS[0], np.sin(PAINT_POINTS[0] * np.pi), c='b', lw=3, label='upper bound')
plt.text(-1, 0.75, 'D accuracy=%.2f (0.5 for D to converge)' % prob_artist0.data.numpy().mean(),
fontdict={'size': 13})
plt.text(-1, 0.5, 'D score= %.2f (-1.38 for G to converge)' % -D_loss.data.numpy(), fontdict={'size': 13})
plt.ylim((-1, 1));
plt.legend(loc='upper right', fontsize=10);
plt.draw();
plt.pause(0.01)
# plt.ioff()
# plt.show()
上面代碼中,def artist_works()函數這裏主要產生給定的正弦信號:
def artist_works(): # painting from the famous artist (real target)
# a = np.random.uniform(1, 2, size=BATCH_SIZE)[:, np.newaxis]
r = 0.02 * np.random.randn(1, ART_COMPONENTS)
paintings = np.sin(PAINT_POINTS * np.pi) + r
paintings = torch.from_numpy(paintings).float()
return paintings
下面這段代碼主要是構建生成器與判別器網絡,這裏的網絡是在pytorch下完成的。
class Ge(nn.Module):
def __init__(self):
super(Ge,self).__init__()
self.fc1=nn.Linear(N_IDEAS,128)
self.fc2=nn.Linear(128,ART_COMPONENTS)
def forward(self, x):
x=F.relu(self.fc1(x))
x=self.fc2(x)
return x
class De(nn.Module):
def __init__(self):
super(De,self).__init__()
self.fc1=nn.Linear(ART_COMPONENTS,128)
self.fc2=nn.Linear(128,1)
def forward(self,x):
x=F.relu(self.fc1(x))
x=F.sigmoid(self.fc2(x))
return x
下面這段代碼爲生成器和判別器的損失函數:
D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))
G_loss = torch.mean(torch.log(1. - prob_artist1))
實現效果,第一幅圖爲剛開始隨機數輸入產生的曲線,第二幅圖爲鑑別器輸出的概率爲 0.5,可以看出效果很好:
有了上面GAN的經驗,接下來介紹生成對抗模仿學習:
在這裏,整個工程有兩個文件組成,一個env_OppositeV4.py構建環境,一個GAIL_OppositeV4.py運行程序。
首先介紹env_OppositeV4.py代碼構建環境,先看一個構建的環境效果圖:
圖中紅色的部分爲起點,綠色部分爲終點,下面給出env_OppositeV4.py代碼:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import random
import cv2
class EnvOppositeV4(object):
def __init__(self, size):
self.map_size = size
self.raw_occupancy = np.zeros((self.map_size, self.map_size))
for i in range(self.map_size):
self.raw_occupancy[0][i] = 1
self.raw_occupancy[self.map_size - 1][i] = 1
self.raw_occupancy[i][0] = 1
self.raw_occupancy[i][self.map_size - 1] = 1
self.raw_occupancy[i][int((self.map_size - 1) / 2)] = 1
self.raw_occupancy[1][int((self.map_size - 1) / 2)] = 0
self.raw_occupancy[self.map_size - 2][int((self.map_size - 1) / 2)] = 0
self.occupancy = self.raw_occupancy.copy()
self.agt1_pos = [int((self.map_size - 1) / 2), 1]
self.goal1_pos = [int((self.map_size - 1) / 2), self.map_size - 2]
self.occupancy[self.agt1_pos[0]][self.agt1_pos[1]] = 1
def reset(self):
self.occupancy = self.raw_occupancy.copy()
self.agt1_pos = [int((self.map_size - 1) / 2), 1]
self.goal1_pos = [int((self.map_size - 1) / 2), self.map_size - 2]
self.occupancy[self.agt1_pos[0]][self.agt1_pos[1]] = 1
def get_state(self):
state = np.zeros((1, 2))
state[0, 0] = self.agt1_pos[0] / self.map_size
state[0, 1] = self.agt1_pos[1] / self.map_size
return state
def step(self, action_list):
reward = 0
# agent1 move
if action_list[0] == 0: # move up
if self.occupancy[self.agt1_pos[0] - 1][self.agt1_pos[1]] != 1: # if can move
self.agt1_pos[0] = self.agt1_pos[0] - 1
self.occupancy[self.agt1_pos[0] + 1][self.agt1_pos[1]] = 0
self.occupancy[self.agt1_pos[0]][self.agt1_pos[1]] = 1
elif action_list[0] == 1: # move down
if self.occupancy[self.agt1_pos[0] + 1][self.agt1_pos[1]] != 1: # if can move
self.agt1_pos[0] = self.agt1_pos[0] + 1
self.occupancy[self.agt1_pos[0] - 1][self.agt1_pos[1]] = 0
self.occupancy[self.agt1_pos[0]][self.agt1_pos[1]] = 1
elif action_list[0] == 2: # move left
if self.occupancy[self.agt1_pos[0]][self.agt1_pos[1] - 1] != 1: # if can move
self.agt1_pos[1] = self.agt1_pos[1] - 1
self.occupancy[self.agt1_pos[0]][self.agt1_pos[1] + 1] = 0
self.occupancy[self.agt1_pos[0]][self.agt1_pos[1]] = 1
elif action_list[0] == 3: # move right
if self.occupancy[self.agt1_pos[0]][self.agt1_pos[1] + 1] != 1: # if can move
self.agt1_pos[1] = self.agt1_pos[1] + 1
self.occupancy[self.agt1_pos[0]][self.agt1_pos[1] - 1] = 0
self.occupancy[self.agt1_pos[0]][self.agt1_pos[1]] = 1
if self.agt1_pos == self.goal1_pos:
reward = reward + 5
done = False
if reward == 5:
done = True
return reward, done
def get_global_obs(self):
obs = np.zeros((self.map_size, self.map_size, 3))
for i in range(self.map_size):
for j in range(self.map_size):
if self.occupancy[i][j] == 0:
obs[i, j, 0] = 1.0
obs[i, j, 1] = 1.0
obs[i, j, 2] = 1.0
obs[self.agt1_pos[0], self.agt1_pos[1], 0] = 1.0
obs[self.agt1_pos[0], self.agt1_pos[1], 1] = 0.0
obs[self.agt1_pos[0], self.agt1_pos[1], 2] = 0.0
return obs
def render(self):
obs = self.get_global_obs()
enlarge = 30
new_obs = np.ones((self.map_size*enlarge, self.map_size*enlarge, 3))
for i in range(self.map_size):
for j in range(self.map_size):
if obs[i][j][0] == 0.0 and obs[i][j][1] == 0.0 and obs[i][j][2] == 0.0:
cv2.rectangle(new_obs, (j * enlarge, i * enlarge), (j * enlarge + enlarge, i * enlarge + enlarge), (0, 0, 0), -1)
if obs[i][j][0] == 1.0 and obs[i][j][1] == 0.0 and obs[i][j][2] == 0.0:
cv2.rectangle(new_obs, (j * enlarge, i * enlarge), (j * enlarge + enlarge, i * enlarge + enlarge), (0, 0, 255), -1)
if obs[i][j][0] == 0.0 and obs[i][j][1] == 1.0 and obs[i][j][2] == 0.0:
cv2.rectangle(new_obs, (j * enlarge, i * enlarge), (j * enlarge + enlarge, i * enlarge + enlarge), (0, 255, 0), -1)
if obs[i][j][0] == 0.0 and obs[i][j][1] == 0.0 and obs[i][j][2] == 1.0:
cv2.rectangle(new_obs, (j * enlarge, i * enlarge), (j * enlarge + enlarge, i * enlarge + enlarge), (255, 0, 0), -1)
cv2.imshow('image', new_obs)
cv2.waitKey(100)
上面代碼中,這個部分生成如下圖,其實就是生成環境的矩形框,1的部分到時候賦予黑顏色,0的部分賦予白色,就構建出了上面的圖,這裏也計算了agent的目標位置與起始位置。
def __init__(self, size):
self.map_size = size
self.raw_occupancy = np.zeros((self.map_size, self.map_size))
for i in range(self.map_size):
self.raw_occupancy[0][i] = 1
self.raw_occupancy[self.map_size - 1][i] = 1
self.raw_occupancy[i][0] = 1
self.raw_occupancy[i][self.map_size - 1] = 1
self.raw_occupancy[i][int((self.map_size - 1) / 2)] = 1
self.raw_occupancy[1][int((self.map_size - 1) / 2)] = 0
self.raw_occupancy[self.map_size - 2][int((self.map_size - 1) / 2)] = 0
self.occupancy = self.raw_occupancy.copy()
self.agt1_pos = [int((self.map_size - 1) / 2), 1]
self.goal1_pos = [int((self.map_size - 1) / 2), self.map_size - 2]
self.occupancy[self.agt1_pos[0]][self.agt1_pos[1]] = 1
通過下面代碼把數字爲1的地方賦予黑色,把0的地方賦予白色,結果如下圖。
def get_global_obs(self):
obs = np.zeros((self.map_size, self.map_size, 3))
for i in range(self.map_size):
for j in range(self.map_size):
if self.occupancy[i][j] == 0:
obs[i, j, 0] = 1.0
obs[i, j, 1] = 1.0
obs[i, j, 2] = 1.0
obs[self.agt1_pos[0], self.agt1_pos[1], 0] = 1.0
obs[self.agt1_pos[0], self.agt1_pos[1], 1] = 0.0
obs[self.agt1_pos[0], self.agt1_pos[1], 2] = 0.0
return obs
通過下面的代碼把框圖放大。
def render(self):
obs = self.get_global_obs()
enlarge = 30
new_obs = np.ones((self.map_size*enlarge, self.map_size*enlarge, 3))
for i in range(self.map_size):
for j in range(self.map_size):
if obs[i][j][0] == 0.0 and obs[i][j][1] == 0.0 and obs[i][j][2] == 0.0:
cv2.rectangle(new_obs, (j * enlarge, i * enlarge), (j * enlarge + enlarge, i * enlarge + enlarge), (0, 0, 0), -1)
if obs[i][j][0] == 1.0 and obs[i][j][1] == 0.0 and obs[i][j][2] == 0.0:
cv2.rectangle(new_obs, (j * enlarge, i * enlarge), (j * enlarge + enlarge, i * enlarge + enlarge), (0, 0, 255), -1)
if obs[i][j][0] == 0.0 and obs[i][j][1] == 1.0 and obs[i][j][2] == 0.0:
cv2.rectangle(new_obs, (j * enlarge, i * enlarge), (j * enlarge + enlarge, i * enlarge + enlarge), (0, 255, 0), -1)
if obs[i][j][0] == 0.0 and obs[i][j][1] == 0.0 and obs[i][j][2] == 1.0:
cv2.rectangle(new_obs, (j * enlarge, i * enlarge), (j * enlarge + enlarge, i * enlarge + enlarge), (255, 0, 0), -1)
cv2.imshow('image',new_obs)
cv2.waitKey(100)
下面這段代碼主要是描述agent的動作與reward。
def step(self, action_list):
reward = 0
# agent1 move
if action_list[0] == 0: # move up
if self.occupancy[self.agt1_pos[0] - 1][self.agt1_pos[1]] != 1: # if can move
self.agt1_pos[0] = self.agt1_pos[0] - 1
self.occupancy[self.agt1_pos[0] + 1][self.agt1_pos[1]] = 0
self.occupancy[self.agt1_pos[0]][self.agt1_pos[1]] = 1
elif action_list[0] == 1: # move down
if self.occupancy[self.agt1_pos[0] + 1][self.agt1_pos[1]] != 1: # if can move
self.agt1_pos[0] = self.agt1_pos[0] + 1
self.occupancy[self.agt1_pos[0] - 1][self.agt1_pos[1]] = 0
self.occupancy[self.agt1_pos[0]][self.agt1_pos[1]] = 1
elif action_list[0] == 2: # move left
if self.occupancy[self.agt1_pos[0]][self.agt1_pos[1] - 1] != 1: # if can move
self.agt1_pos[1] = self.agt1_pos[1] - 1
self.occupancy[self.agt1_pos[0]][self.agt1_pos[1] + 1] = 0
self.occupancy[self.agt1_pos[0]][self.agt1_pos[1]] = 1
elif action_list[0] == 3: # move right
if self.occupancy[self.agt1_pos[0]][self.agt1_pos[1] + 1] != 1: # if can move
self.agt1_pos[1] = self.agt1_pos[1] + 1
self.occupancy[self.agt1_pos[0]][self.agt1_pos[1] - 1] = 0
self.occupancy[self.agt1_pos[0]][self.agt1_pos[1]] = 1
if self.agt1_pos == self.goal1_pos:
reward = reward + 5
done = False
if reward == 5:
done = True
return reward, done
到這裏,agent運行環境已經介紹完成。
下面給出GAIL_OppositeV4.py代碼:
from torch.distributions.categorical import Categorical
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
from env_OppositeV4 import EnvOppositeV4
import numpy as np
import csv
from collections import deque
import os
class Actor(nn.Module):
def __init__(self, N_action):
super(Actor, self).__init__()
self.N_action = N_action
self.fc1 = nn.Linear(2, 64)
self.fc2 = nn.Linear(64, 32)
self.fc3 = nn.Linear(32, self.N_action)
def get_action(self, h):
h = F.relu(self.fc1(h))
h = F.relu(self.fc2(h))
h = F.softmax(self.fc3(h), dim=1)
m = Categorical(h.squeeze(0))
a = m.sample()
log_prob = m.log_prob(a)
return a.item(), h, log_prob
class Discriminator(nn.Module):
def __init__(self, s_dim, N_action):
super(Discriminator, self).__init__()
self.s_dim = s_dim
self.N_action = N_action
self.fc1 = nn.Linear(self.s_dim + self.N_action, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 1)
def forward(self, state, action):
state_action = torch.cat([state, action], 1)
x = torch.relu(self.fc1(state_action))
x = torch.relu(self.fc2(x))
x = torch.sigmoid(self.fc3(x))
return x
class GAIL(object):
def __init__(self, s_dim, N_action):
self.s_dim = s_dim
self.N_action = N_action
self.actor1 = Actor(self.N_action)
self.disc1 = Discriminator(self.s_dim, self.N_action)
self.d1_optimizer = torch.optim.Adam(self.disc1.parameters(), lr=1e-3)
self.a1_optimizer = torch.optim.Adam(self.actor1.parameters(), lr=1e-3)
self.loss_fn = torch.nn.MSELoss()
self.adv_loss_fn = torch.nn.BCELoss()
self.gamma = 0.9
def get_action(self, obs1):
action1, pi_a1, log_prob1 = self.actor1.get_action(torch.from_numpy(obs1).float())
return action1, pi_a1, log_prob1
def int_to_tensor(self, action):
temp = torch.zeros(1, self.N_action)
temp[0, action] = 1
return temp
def train_D(self, s1_list, a1_list, e_s1_list, e_a1_list):
p_s1 = torch.from_numpy(s1_list[0]).float()
p_a1 = self.int_to_tensor(a1_list[0])
for i in range(1, len(s1_list)):
temp_p_s1 = torch.from_numpy(s1_list[i]).float()
p_s1 = torch.cat([p_s1, temp_p_s1], dim=0)
temp_p_a1 = self.int_to_tensor(a1_list[i])
p_a1 = torch.cat([p_a1, temp_p_a1], dim=0)
e_s1 = torch.from_numpy(e_s1_list[0]).float()
e_a1 = self.int_to_tensor(e_a1_list[0])
for i in range(1, len(e_s1_list)):
temp_e_s1 = torch.from_numpy(e_s1_list[i]).float()
e_s1 = torch.cat([e_s1, temp_e_s1], dim=0)
temp_e_a1 = self.int_to_tensor(e_a1_list[i])
e_a1 = torch.cat([e_a1, temp_e_a1], dim=0)
p1_label = torch.zeros(len(s1_list), 1)
e1_label = torch.ones(len(e_s1_list), 1)
e1_pred = self.disc1(e_s1, e_a1)
# print('e1_pred', e1_pred)
loss = self.adv_loss_fn(e1_pred, e1_label)
p1_pred = self.disc1(p_s1, p_a1)
# print('p1_pred', p1_pred)
loss = loss + self.adv_loss_fn(p1_pred, p1_label)
self.d1_optimizer.zero_grad()
loss.backward()
self.d1_optimizer.step()
def train_G(self, s1_list, a1_list, log_pi_a1_list, r1_list, e_s1_list, e_a1_list):
T = len(s1_list)
p_s1 = torch.from_numpy(s1_list[0]).float()
p_a1 = self.int_to_tensor(a1_list[0])
for i in range(1, len(s1_list)):
temp_p_s1 = torch.from_numpy(s1_list[i]).float()
p_s1 = torch.cat([p_s1, temp_p_s1], dim=0)
temp_p_a1 = self.int_to_tensor(a1_list[i])
p_a1 = torch.cat([p_a1, temp_p_a1], dim=0)
e_s1 = torch.from_numpy(e_s1_list[0]).float()
e_a1 = self.int_to_tensor(e_a1_list[0])
for i in range(1, len(e_s1_list)):
temp_e_s1 = torch.from_numpy(e_s1_list[i]).float()
e_s1 = torch.cat([e_s1, temp_e_s1], dim=0)
temp_e_a1 = self.int_to_tensor(e_a1_list[i])
e_a1 = torch.cat([e_a1, temp_e_a1], dim=0)
p1_pred = self.disc1(p_s1, p_a1)
fake_reward = p1_pred.mean()
a1_loss = torch.FloatTensor([0.0])
for t in range(T):
a1_loss = a1_loss + fake_reward * log_pi_a1_list[t]
a1_loss = -a1_loss / T
# print(a1_loss)
self.a1_optimizer.zero_grad()
a1_loss.backward()
self.a1_optimizer.step()
class REINFORCE(object):
def __init__(self, N_action):
self.N_action = N_action
self.actor1 = Actor(self.N_action)
def get_action(self, obs):
action1, pi_a1, log_prob1 = self.actor1.get_action(torch.from_numpy(obs).float())
return action1, pi_a1, log_prob1
def train(self, a1_list, pi_a1_list, r_list):
a1_optimizer = torch.optim.Adam(self.actor1.parameters(), lr=1e-3)
T = len(r_list)
G_list = torch.zeros(1, T)
G_list[0, T - 1] = torch.FloatTensor([r_list[T - 1]])
for k in range(T - 2, -1, -1):
G_list[0, k] = r_list[k] + 0.95 * G_list[0, k + 1]
a1_loss = torch.FloatTensor([0.0])
for t in range(T):
a1_loss = a1_loss + G_list[0, t] * torch.log(pi_a1_list[t][0, a1_list[t]])
a1_loss = -a1_loss / T
a1_optimizer.zero_grad()
a1_loss.backward()
a1_optimizer.step()
def save_model(self):
torch.save(self.actor1, 'V4_actor.pkl')
def load_model(self):
self.actor1 = torch.load('V4_actor.pkl')
if __name__ == '__main__':
torch.set_num_threads(1)
env = EnvOppositeV4(9)
max_epi_iter = 100
max_MC_iter = 100
# train expert policy by REINFORCE algorithm
agent = REINFORCE(N_action=5)
if os.path.exists('./V4_actor.pkl'):
agent.load_model()
else:
print('無保存模型,將從頭開始訓練!')
for epi_iter in range(max_epi_iter):
env.reset()
a1_list = []
pi_a1_list = []
r_list = []
acc_r = 0
for MC_iter in range(max_MC_iter):
env.render()
state = env.get_state()
action1, pi_a1, log_prob1 = agent.get_action(state)
a1_list.append(action1)
pi_a1_list.append(pi_a1)
reward, done = env.step([action1, 0])
acc_r = acc_r + reward
r_list.append(reward)
if done:
break
print('Train expert, Episode', epi_iter, 'average reward', acc_r / MC_iter)
if done:
agent.train(a1_list, pi_a1_list, r_list)
# record expert policy
agent.save_model()
exp_s_list = []
exp_a_list = []
env.reset()
for MC_iter in range(max_MC_iter):
env.render()
state = env.get_state()
action1, pi_a1, log_prob1 = agent.get_action(state)
exp_s_list.append(state)
exp_a_list.append(action1)
reward, done = env.step([action1, 0])
print('step', MC_iter, 'agent 1 at', exp_s_list[MC_iter], 'agent 1 action', exp_a_list[MC_iter], 'reward', reward, 'done', done)
if done:
break
# generative adversarial imitation learning from [exp_s_list, exp_a_list]
agent = GAIL(s_dim=2, N_action=5)
for epi_iter in range(max_epi_iter):
env.reset()
s1_list = []
a1_list = []
r1_list = []
log_pi_a1_list = []
acc_r = 0
for MC_iter in range(max_MC_iter):
# env.render()
state = env.get_state()
action1, pi_a1, log_prob1 = agent.get_action(state)
s1_list.append(state)
a1_list.append(action1)
log_pi_a1_list.append(log_prob1)
reward, done = env.step([action1, 0])
acc_r = acc_r + reward
r1_list.append(reward)
if done:
break
print('Imitate by GAIL, Episode', epi_iter, 'average reward', acc_r/MC_iter)
# train Discriminator
agent.train_D(s1_list, a1_list, exp_s_list, exp_a_list)
# train Generator
agent.train_G(s1_list, a1_list, log_pi_a1_list, r1_list, exp_s_list, exp_a_list)
# learnt policy
print('expert trajectory')
for i in range(len(exp_a_list)):
print('step', i, 'agent 1 at', exp_s_list[i], 'agent 1 action', exp_a_list[i])
print('learnt trajectory')
env.reset()
for MC_iter in range(max_MC_iter):
# env.render()
state = env.get_state()
action1, pi_a1, log_prob1 = agent.get_action(state)
exp_s_list.append(state)
exp_a_list.append(action1)
reward, done = env.step([action1, 0])
print('step', MC_iter, 'agent 1 at', exp_s_list[MC_iter], 'agent 1 action', exp_a_list[MC_iter])
if done:
break
運行結果爲:
expert trajectory
step 0 agent 1 at [[0.44444444 0.11111111]] agent 1 action 1
step 1 agent 1 at [[0.55555556 0.11111111]] agent 1 action 4
step 2 agent 1 at [[0.55555556 0.11111111]] agent 1 action 3
step 3 agent 1 at [[0.55555556 0.22222222]] agent 1 action 1
step 4 agent 1 at [[0.66666667 0.22222222]] agent 1 action 0
step 5 agent 1 at [[0.55555556 0.22222222]] agent 1 action 0
step 6 agent 1 at [[0.44444444 0.22222222]] agent 1 action 3
step 7 agent 1 at [[0.44444444 0.33333333]] agent 1 action 4
step 8 agent 1 at [[0.44444444 0.33333333]] agent 1 action 3
step 9 agent 1 at [[0.44444444 0.33333333]] agent 1 action 0
step 10 agent 1 at [[0.33333333 0.33333333]] agent 1 action 4
step 11 agent 1 at [[0.33333333 0.33333333]] agent 1 action 1
step 12 agent 1 at [[0.44444444 0.33333333]] agent 1 action 0
step 13 agent 1 at [[0.33333333 0.33333333]] agent 1 action 0
step 14 agent 1 at [[0.22222222 0.33333333]] agent 1 action 1
step 15 agent 1 at [[0.33333333 0.33333333]] agent 1 action 1
step 16 agent 1 at [[0.44444444 0.33333333]] agent 1 action 1
step 17 agent 1 at [[0.55555556 0.33333333]] agent 1 action 2
step 18 agent 1 at [[0.55555556 0.22222222]] agent 1 action 3
step 19 agent 1 at [[0.55555556 0.33333333]] agent 1 action 3
step 20 agent 1 at [[0.55555556 0.33333333]] agent 1 action 3
step 21 agent 1 at [[0.55555556 0.33333333]] agent 1 action 3
step 22 agent 1 at [[0.55555556 0.33333333]] agent 1 action 3
step 23 agent 1 at [[0.55555556 0.33333333]] agent 1 action 0
step 24 agent 1 at [[0.44444444 0.33333333]] agent 1 action 4
step 25 agent 1 at [[0.44444444 0.33333333]] agent 1 action 0
step 26 agent 1 at [[0.33333333 0.33333333]] agent 1 action 0
step 27 agent 1 at [[0.22222222 0.33333333]] agent 1 action 1
step 28 agent 1 at [[0.33333333 0.33333333]] agent 1 action 3
step 29 agent 1 at [[0.33333333 0.33333333]] agent 1 action 1
step 30 agent 1 at [[0.44444444 0.33333333]] agent 1 action 3
step 31 agent 1 at [[0.44444444 0.33333333]] agent 1 action 3
step 32 agent 1 at [[0.44444444 0.33333333]] agent 1 action 0
step 33 agent 1 at [[0.33333333 0.33333333]] agent 1 action 0
step 34 agent 1 at [[0.22222222 0.33333333]] agent 1 action 2
step 35 agent 1 at [[0.22222222 0.22222222]] agent 1 action 3
step 36 agent 1 at [[0.22222222 0.33333333]] agent 1 action 3
step 37 agent 1 at [[0.22222222 0.33333333]] agent 1 action 1
step 38 agent 1 at [[0.33333333 0.33333333]] agent 1 action 3
step 39 agent 1 at [[0.33333333 0.33333333]] agent 1 action 0
step 40 agent 1 at [[0.22222222 0.33333333]] agent 1 action 3
step 41 agent 1 at [[0.22222222 0.33333333]] agent 1 action 3
step 42 agent 1 at [[0.22222222 0.33333333]] agent 1 action 1
step 43 agent 1 at [[0.33333333 0.33333333]] agent 1 action 3
step 44 agent 1 at [[0.33333333 0.33333333]] agent 1 action 1
step 45 agent 1 at [[0.44444444 0.33333333]] agent 1 action 3
step 46 agent 1 at [[0.44444444 0.33333333]] agent 1 action 1
step 47 agent 1 at [[0.55555556 0.33333333]] agent 1 action 3
step 48 agent 1 at [[0.55555556 0.33333333]] agent 1 action 1
step 49 agent 1 at [[0.66666667 0.33333333]] agent 1 action 3
step 50 agent 1 at [[0.66666667 0.33333333]] agent 1 action 3
step 51 agent 1 at [[0.66666667 0.33333333]] agent 1 action 0
step 52 agent 1 at [[0.55555556 0.33333333]] agent 1 action 0
step 53 agent 1 at [[0.44444444 0.33333333]] agent 1 action 3
step 54 agent 1 at [[0.44444444 0.33333333]] agent 1 action 1
step 55 agent 1 at [[0.55555556 0.33333333]] agent 1 action 1
step 56 agent 1 at [[0.66666667 0.33333333]] agent 1 action 3
step 57 agent 1 at [[0.66666667 0.33333333]] agent 1 action 4
step 58 agent 1 at [[0.66666667 0.33333333]] agent 1 action 1
step 59 agent 1 at [[0.77777778 0.33333333]] agent 1 action 1
step 60 agent 1 at [[0.77777778 0.33333333]] agent 1 action 4
step 61 agent 1 at [[0.77777778 0.33333333]] agent 1 action 0
step 62 agent 1 at [[0.66666667 0.33333333]] agent 1 action 3
step 63 agent 1 at [[0.66666667 0.33333333]] agent 1 action 3
step 64 agent 1 at [[0.66666667 0.33333333]] agent 1 action 3
step 65 agent 1 at [[0.66666667 0.33333333]] agent 1 action 1
step 66 agent 1 at [[0.77777778 0.33333333]] agent 1 action 0
step 67 agent 1 at [[0.66666667 0.33333333]] agent 1 action 1
step 68 agent 1 at [[0.77777778 0.33333333]] agent 1 action 3
step 69 agent 1 at [[0.77777778 0.44444444]] agent 1 action 3
step 70 agent 1 at [[0.77777778 0.55555556]] agent 1 action 0
step 71 agent 1 at [[0.66666667 0.55555556]] agent 1 action 0
step 72 agent 1 at [[0.55555556 0.55555556]] agent 1 action 0
step 73 agent 1 at [[0.44444444 0.55555556]] agent 1 action 0
step 74 agent 1 at [[0.33333333 0.55555556]] agent 1 action 1
step 75 agent 1 at [[0.44444444 0.55555556]] agent 1 action 4
step 76 agent 1 at [[0.44444444 0.55555556]] agent 1 action 0
step 77 agent 1 at [[0.33333333 0.55555556]] agent 1 action 1
step 78 agent 1 at [[0.44444444 0.55555556]] agent 1 action 3
step 79 agent 1 at [[0.44444444 0.66666667]] agent 1 action 0
step 80 agent 1 at [[0.33333333 0.66666667]] agent 1 action 3
step 81 agent 1 at [[0.33333333 0.77777778]] agent 1 action 1
learnt trajectory
step 0 agent 1 at [[0.44444444 0.11111111]] agent 1 action 1
step 1 agent 1 at [[0.55555556 0.11111111]] agent 1 action 4
step 2 agent 1 at [[0.55555556 0.11111111]] agent 1 action 3
step 3 agent 1 at [[0.55555556 0.22222222]] agent 1 action 1
step 4 agent 1 at [[0.66666667 0.22222222]] agent 1 action 0
step 5 agent 1 at [[0.55555556 0.22222222]] agent 1 action 0
step 6 agent 1 at [[0.44444444 0.22222222]] agent 1 action 3
step 7 agent 1 at [[0.44444444 0.33333333]] agent 1 action 4
step 8 agent 1 at [[0.44444444 0.33333333]] agent 1 action 3
step 9 agent 1 at [[0.44444444 0.33333333]] agent 1 action 0
step 10 agent 1 at [[0.33333333 0.33333333]] agent 1 action 4
step 11 agent 1 at [[0.33333333 0.33333333]] agent 1 action 1
step 12 agent 1 at [[0.44444444 0.33333333]] agent 1 action 0
step 13 agent 1 at [[0.33333333 0.33333333]] agent 1 action 0
step 14 agent 1 at [[0.22222222 0.33333333]] agent 1 action 1
step 15 agent 1 at [[0.33333333 0.33333333]] agent 1 action 1
step 16 agent 1 at [[0.44444444 0.33333333]] agent 1 action 1
step 17 agent 1 at [[0.55555556 0.33333333]] agent 1 action 2
step 18 agent 1 at [[0.55555556 0.22222222]] agent 1 action 3
step 19 agent 1 at [[0.55555556 0.33333333]] agent 1 action 3
step 20 agent 1 at [[0.55555556 0.33333333]] agent 1 action 3
step 21 agent 1 at [[0.55555556 0.33333333]] agent 1 action 3
step 22 agent 1 at [[0.55555556 0.33333333]] agent 1 action 3
step 23 agent 1 at [[0.55555556 0.33333333]] agent 1 action 0
step 24 agent 1 at [[0.44444444 0.33333333]] agent 1 action 4
step 25 agent 1 at [[0.44444444 0.33333333]] agent 1 action 0
step 26 agent 1 at [[0.33333333 0.33333333]] agent 1 action 0
step 27 agent 1 at [[0.22222222 0.33333333]] agent 1 action 1
step 28 agent 1 at [[0.33333333 0.33333333]] agent 1 action 3
step 29 agent 1 at [[0.33333333 0.33333333]] agent 1 action 1
step 30 agent 1 at [[0.44444444 0.33333333]] agent 1 action 3
step 31 agent 1 at [[0.44444444 0.33333333]] agent 1 action 3
step 32 agent 1 at [[0.44444444 0.33333333]] agent 1 action 0
step 33 agent 1 at [[0.33333333 0.33333333]] agent 1 action 0
step 34 agent 1 at [[0.22222222 0.33333333]] agent 1 action 2
step 35 agent 1 at [[0.22222222 0.22222222]] agent 1 action 3
step 36 agent 1 at [[0.22222222 0.33333333]] agent 1 action 3
step 37 agent 1 at [[0.22222222 0.33333333]] agent 1 action 1
step 38 agent 1 at [[0.33333333 0.33333333]] agent 1 action 3
step 39 agent 1 at [[0.33333333 0.33333333]] agent 1 action 0
step 40 agent 1 at [[0.22222222 0.33333333]] agent 1 action 3
step 41 agent 1 at [[0.22222222 0.33333333]] agent 1 action 3
step 42 agent 1 at [[0.22222222 0.33333333]] agent 1 action 1
step 43 agent 1 at [[0.33333333 0.33333333]] agent 1 action 3
step 44 agent 1 at [[0.33333333 0.33333333]] agent 1 action 1
step 45 agent 1 at [[0.44444444 0.33333333]] agent 1 action 3
step 46 agent 1 at [[0.44444444 0.33333333]] agent 1 action 1
step 47 agent 1 at [[0.55555556 0.33333333]] agent 1 action 3
step 48 agent 1 at [[0.55555556 0.33333333]] agent 1 action 1
step 49 agent 1 at [[0.66666667 0.33333333]] agent 1 action 3
step 50 agent 1 at [[0.66666667 0.33333333]] agent 1 action 3
step 51 agent 1 at [[0.66666667 0.33333333]] agent 1 action 0
step 52 agent 1 at [[0.55555556 0.33333333]] agent 1 action 0
step 53 agent 1 at [[0.44444444 0.33333333]] agent 1 action 3
step 54 agent 1 at [[0.44444444 0.33333333]] agent 1 action 1
step 55 agent 1 at [[0.55555556 0.33333333]] agent 1 action 1
step 56 agent 1 at [[0.66666667 0.33333333]] agent 1 action 3
step 57 agent 1 at [[0.66666667 0.33333333]] agent 1 action 4
step 58 agent 1 at [[0.66666667 0.33333333]] agent 1 action 1
step 59 agent 1 at [[0.77777778 0.33333333]] agent 1 action 1
step 60 agent 1 at [[0.77777778 0.33333333]] agent 1 action 4
可以看出learnt trajectory與expert trajectory軌跡一樣。
好了,現在來介紹裏面的細節部分:
對於我們這個自己構建的環境,我們沒有專家軌跡怎麼辦呢?那就自己來製作專家軌跡。
這裏,使用下面代碼進行樣本收集:
for epi_iter in range(max_epi_iter):
env.reset()
a1_list = []
pi_a1_list = []
r_list = []
acc_r = 0
for MC_iter in range(max_MC_iter):
env.render()
state = env.get_state()
action1, pi_a1, log_prob1 = agent.get_action(state)
a1_list.append(action1)
pi_a1_list.append(pi_a1)
reward, done = env.step([action1, 0])
acc_r = acc_r + reward
r_list.append(reward)
下面這段代碼爲只有agent到達綠色的目標點採用來訓練網絡更新參數。
if done:
agent.train(a1_list, pi_a1_list, r_list)