import gym import torch import torch.nn as nn import torch.optim as optim import pygame import sys # 定義Actor網絡 class Actor(nn.Module): def __init__(self): super(Actor, self).__init__() self.fc = nn.Sequential( nn.Linear(4, 10), nn.ReLU(), nn.Linear(10, 2), nn.Softmax(dim=-1) ) def forward(self, x): return self.fc(x) # 定義Critic網絡 class Critic(nn.Module): def __init__(self): super(Critic, self).__init__() self.fc = nn.Sequential( nn.Linear(4, 10), nn.ReLU(), nn.Linear(10, 1) ) def forward(self, x): return self.fc(x) # 訓練模型 def train(actor, critic, actor_optimizer, critic_optimizer, state, action, reward, next_state, done): state = torch.tensor(state, dtype=torch.float) next_state = torch.tensor(next_state, dtype=torch.float) action = torch.tensor(action, dtype=torch.long) reward = torch.tensor(reward, dtype=torch.float) if done: next_value = 0 else: next_value = critic(next_state).detach() # Critic loss value = critic(state) expected_value = reward + 0.99 * next_value critic_loss = (value - expected_value).pow(2).mean() # Actor loss probs = actor(state) dist = torch.distributions.Categorical(probs) log_prob = dist.log_prob(action) advantage = (expected_value - value).detach() # TD error as advantage actor_loss = -log_prob * advantage # Update networks critic_optimizer.zero_grad() critic_loss.backward() critic_optimizer.step() actor_optimizer.zero_grad() actor_loss.backward() actor_optimizer.step() # 設置環境和模型 env = gym.make('CartPole-v1') actor = Actor() critic = Critic() actor_optimizer = optim.Adam(actor.parameters(), lr=0.001) critic_optimizer = optim.Adam(critic.parameters(), lr=0.01) pygame.init() screen = pygame.display.set_mode((600, 400)) clock = pygame.time.Clock() # 開始訓練 for episode in range(10000): state = env.reset() done = False state = state[0] step= 0 while not done: step += 1 state_tensor = torch.tensor(state, dtype=torch.float) probs = actor(state_tensor) dist = torch.distributions.Categorical(probs) action = dist.sample().item() next_state, reward, done, _ ,_= env.step(action) train(actor, critic, actor_optimizer, critic_optimizer, state, action, reward, next_state, done) state = next_state # Pygame visualization for event in pygame.event.get(): if event.type == pygame.QUIT: pygame.quit() sys.exit() # Drawing screen.fill((255, 255, 255)) cart_x = int(state[0] * 100 + 300) pygame.draw.rect(screen, (0, 0, 255), (cart_x, 300, 50, 30)) pygame.draw.line(screen, (255, 0, 0), (cart_x + 25, 300), (cart_x + 25 - int(50 * torch.sin(torch.tensor(state[2]))), 300 - int(50 * torch.cos(torch.tensor(state[2])))), 5) pygame.display.flip() clock.tick(200) print(f"第{episode}回合,玩{step}次掛了")