使用稳定的基线创建自定义健身房环境模型时的错误3算法
我使用pygame对俄罗斯方块进行了编码,现在我正在尝试创建一个能够使用稳定的基线3播放它的代理。因此,我为游戏创建了一个健身房环境,其中observation_space是该场的2D阵列空字段和一个数字占据的每个字段上的1:
from src.figure import Direction
from tetris import Tetris, GameStatus
from gym import Env
from gym.spaces import Discrete, MultiBinary
import pygame
class TetrisEnv(Env):
def __init__(self, field_size) -> None:
self.height, self.width = field_size
self.observation_space = MultiBinary([field_size[0], field_size[1]])
self.action_space = Discrete(6)
self.game = Tetris(self.width, self.height)
self.score = 0
self.episodes = 10000
self.clock = pygame.time.Clock()
def step(self, action):
self.episodes -= 1
match action:
case 0:
pass
case 1:
self.game.move(Direction.LEFT)
case 2:
self.game.move(Direction.RIGHT)
case 3:
self.game.drop_figure()
case 4:
self.game.rotate(Direction.LEFT)
case 5:
self.game.rotate(Direction.RIGHT)
status = self.game.move(Direction.DOWN)
reward = self.game.score - self.score + (1 if status != GameStatus.GAME_OVER else 0)
self.score = self.game.score
done = self.game.state == GameStatus.GAME_OVER or self.episodes <= 0
return self.game.field, reward, done, {}
def render(self):
self.game.draw()
self.clock.tick(4)
def reset(self):
self.game = Tetris(self.width, self.height)
self.score = 0
self.episodes = 10000
return self.game.field
当我尝试创建模型时,我会收到以下错误:
from stable_baselines3 import A2C
from stable_baselines3.common.vec_env import DummyVecEnv
env = TetrisEnv((20, 10))
env = DummyVecEnv([lambda: env])
model = A2C('CnnPolicy', env, verbose=1)
错误:
Traceback (most recent call last):
model = A2C('CnnPolicy', env, verbose=1)
File ".../lib/python3.10/site-packages/stable_baselines3/a2c/a2c.py", line 115, in __init__
self._setup_model()
File ".../lib/python3.10/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 111, in _setup_model
self.rollout_buffer = buffer_cls(
File ".../lib/python3.10/site-packages/stable_baselines3/common/buffers.py", line 342, in __init__
super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
File ".../lib/python3.10/site-packages/stable_baselines3/common/buffers.py", line 49, in __init__
self.obs_shape = get_obs_shape(observation_space)
File ".../lib/python3.10/site-packages/stable_baselines3/common/preprocessing.py", line 153, in get_obs_shape
return (int(observation_space.n),)
TypeError: int() argument must be a string, a bytes-like object or a real number, not 'list'
我在这里缺少什么?
I coded Tetris using pygame and now I am trying to create an agent that is able to play it using stable baseline 3. Therefore I created a gym environment for the game, where the observation_space is a 2D Array of the field containing a 0 on every empty field and a 1 on every field that is occupied by a figure:
from src.figure import Direction
from tetris import Tetris, GameStatus
from gym import Env
from gym.spaces import Discrete, MultiBinary
import pygame
class TetrisEnv(Env):
def __init__(self, field_size) -> None:
self.height, self.width = field_size
self.observation_space = MultiBinary([field_size[0], field_size[1]])
self.action_space = Discrete(6)
self.game = Tetris(self.width, self.height)
self.score = 0
self.episodes = 10000
self.clock = pygame.time.Clock()
def step(self, action):
self.episodes -= 1
match action:
case 0:
pass
case 1:
self.game.move(Direction.LEFT)
case 2:
self.game.move(Direction.RIGHT)
case 3:
self.game.drop_figure()
case 4:
self.game.rotate(Direction.LEFT)
case 5:
self.game.rotate(Direction.RIGHT)
status = self.game.move(Direction.DOWN)
reward = self.game.score - self.score + (1 if status != GameStatus.GAME_OVER else 0)
self.score = self.game.score
done = self.game.state == GameStatus.GAME_OVER or self.episodes <= 0
return self.game.field, reward, done, {}
def render(self):
self.game.draw()
self.clock.tick(4)
def reset(self):
self.game = Tetris(self.width, self.height)
self.score = 0
self.episodes = 10000
return self.game.field
When I now try to create a model I get the following error:
from stable_baselines3 import A2C
from stable_baselines3.common.vec_env import DummyVecEnv
env = TetrisEnv((20, 10))
env = DummyVecEnv([lambda: env])
model = A2C('CnnPolicy', env, verbose=1)
Error:
Traceback (most recent call last):
model = A2C('CnnPolicy', env, verbose=1)
File ".../lib/python3.10/site-packages/stable_baselines3/a2c/a2c.py", line 115, in __init__
self._setup_model()
File ".../lib/python3.10/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 111, in _setup_model
self.rollout_buffer = buffer_cls(
File ".../lib/python3.10/site-packages/stable_baselines3/common/buffers.py", line 342, in __init__
super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
File ".../lib/python3.10/site-packages/stable_baselines3/common/buffers.py", line 49, in __init__
self.obs_shape = get_obs_shape(observation_space)
File ".../lib/python3.10/site-packages/stable_baselines3/common/preprocessing.py", line 153, in get_obs_shape
return (int(observation_space.n),)
TypeError: int() argument must be a string, a bytes-like object or a real number, not 'list'
What am I missing here?
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论