actor critic 玩carpole遊戲

 

import gym
import torch
import torch.nn as nn
import torch.optim as optim
import pygame
import sys

# 定義Actor網絡
class Actor(nn.Module):
    def __init__(self):
        super(Actor, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(4, 10),
            nn.ReLU(),
            nn.Linear(10, 2),
            nn.Softmax(dim=-1)
        )

    def forward(self, x):
        return self.fc(x)

# 定義Critic網絡
class Critic(nn.Module):
    def __init__(self):
        super(Critic, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(4, 10),
            nn.ReLU(),
            nn.Linear(10, 1)
        )

    def forward(self, x):
        return self.fc(x)

# 訓練模型
def train(actor, critic, actor_optimizer, critic_optimizer, state, action, reward, next_state, done):
    state = torch.tensor(state, dtype=torch.float)
    next_state = torch.tensor(next_state, dtype=torch.float)
    action = torch.tensor(action, dtype=torch.long)
    reward = torch.tensor(reward, dtype=torch.float)
    if done:
        next_value = 0
    else:
        next_value = critic(next_state).detach()
    
    # Critic loss
    value = critic(state)
    expected_value = reward + 0.99 * next_value
    critic_loss = (value - expected_value).pow(2).mean()
    
    # Actor loss
    probs = actor(state)
    dist = torch.distributions.Categorical(probs)
    log_prob = dist.log_prob(action)
    advantage = (expected_value - value).detach()  # TD error as advantage
    actor_loss = -log_prob * advantage
    
    # Update networks
    critic_optimizer.zero_grad()
    critic_loss.backward()
    critic_optimizer.step()
    
    actor_optimizer.zero_grad()
    actor_loss.backward()
    actor_optimizer.step()

# 設置環境和模型
env = gym.make('CartPole-v1')
actor = Actor()
critic = Critic()
actor_optimizer = optim.Adam(actor.parameters(), lr=0.001)
critic_optimizer = optim.Adam(critic.parameters(), lr=0.01)

pygame.init()
screen = pygame.display.set_mode((600, 400))
clock = pygame.time.Clock()

# 開始訓練
for episode in range(10000):
    state = env.reset()
    done = False
    state = state[0]
    step= 0
    while not done:
        step += 1
        state_tensor = torch.tensor(state, dtype=torch.float)
        probs = actor(state_tensor)
        dist = torch.distributions.Categorical(probs)
        action = dist.sample().item()
        next_state, reward, done, _ ,_= env.step(action)
        
        train(actor, critic, actor_optimizer, critic_optimizer, state, action, reward, next_state, done)
        state = next_state
        
        # Pygame visualization
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                sys.exit()

        # Drawing
        
        screen.fill((255, 255, 255))
        cart_x = int(state[0] * 100 + 300)
        pygame.draw.rect(screen, (0, 0, 255), (cart_x, 300, 50, 30))
        pygame.draw.line(screen, (255, 0, 0), (cart_x + 25, 300), (cart_x + 25 - int(50 * torch.sin(torch.tensor(state[2]))), 300 - int(50 * torch.cos(torch.tensor(state[2])))), 5)
        pygame.display.flip()
        clock.tick(200)

    print(f"第{episode}回合,玩{step}次掛了")

 

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章