使用稳定的基线创建自定义健身房环境模型时的错误3算法

发布于 2025-02-09 20:56:14 字数 2930 浏览 3 评论 0原文

我使用pygame对俄罗斯方块进行了编码,现在我正在尝试创建一个能够使用稳定的基线3播放它的代理。因此,我为游戏创建了一个健身房环境,其中observation_space是该场的2D阵列空字段和一个数字占据的每个字段上的1:

from src.figure import Direction
from tetris import Tetris, GameStatus
from gym import Env
from gym.spaces import Discrete, MultiBinary
import pygame

class TetrisEnv(Env):
    def __init__(self, field_size) -> None:
        self.height, self.width = field_size
        self.observation_space = MultiBinary([field_size[0], field_size[1]])
        self.action_space = Discrete(6)
        self.game = Tetris(self.width, self.height)
        self.score = 0
        self.episodes = 10000
        self.clock = pygame.time.Clock()

    def step(self, action):
        self.episodes -= 1
        match action:
            case 0:
                pass
            case 1:
                self.game.move(Direction.LEFT)
            case 2:
                self.game.move(Direction.RIGHT)
            case 3:
                self.game.drop_figure()
            case 4:
                self.game.rotate(Direction.LEFT)
            case 5:
                self.game.rotate(Direction.RIGHT)
    
        status = self.game.move(Direction.DOWN)

        reward = self.game.score - self.score + (1 if status != GameStatus.GAME_OVER else 0)
        self.score = self.game.score
        done = self.game.state == GameStatus.GAME_OVER or self.episodes <= 0
        return self.game.field, reward, done, {}

    def render(self):
        self.game.draw()
        self.clock.tick(4)

    def reset(self):
        self.game = Tetris(self.width, self.height)
        self.score = 0
        self.episodes = 10000
        return self.game.field

当我尝试创建模型时,我会收到以下错误:

from stable_baselines3 import A2C
from stable_baselines3.common.vec_env import DummyVecEnv

env = TetrisEnv((20, 10))
env = DummyVecEnv([lambda: env])
model = A2C('CnnPolicy', env, verbose=1)

错误:

Traceback (most recent call last):
    model = A2C('CnnPolicy', env, verbose=1)
  File ".../lib/python3.10/site-packages/stable_baselines3/a2c/a2c.py", line 115, in __init__
    self._setup_model()
  File ".../lib/python3.10/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 111, in _setup_model
    self.rollout_buffer = buffer_cls(
  File ".../lib/python3.10/site-packages/stable_baselines3/common/buffers.py", line 342, in __init__
    super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
  File ".../lib/python3.10/site-packages/stable_baselines3/common/buffers.py", line 49, in __init__
    self.obs_shape = get_obs_shape(observation_space)
  File ".../lib/python3.10/site-packages/stable_baselines3/common/preprocessing.py", line 153, in get_obs_shape
    return (int(observation_space.n),)
TypeError: int() argument must be a string, a bytes-like object or a real number, not 'list'

我在这里缺少什么?

I coded Tetris using pygame and now I am trying to create an agent that is able to play it using stable baseline 3. Therefore I created a gym environment for the game, where the observation_space is a 2D Array of the field containing a 0 on every empty field and a 1 on every field that is occupied by a figure:

from src.figure import Direction
from tetris import Tetris, GameStatus
from gym import Env
from gym.spaces import Discrete, MultiBinary
import pygame

class TetrisEnv(Env):
    def __init__(self, field_size) -> None:
        self.height, self.width = field_size
        self.observation_space = MultiBinary([field_size[0], field_size[1]])
        self.action_space = Discrete(6)
        self.game = Tetris(self.width, self.height)
        self.score = 0
        self.episodes = 10000
        self.clock = pygame.time.Clock()

    def step(self, action):
        self.episodes -= 1
        match action:
            case 0:
                pass
            case 1:
                self.game.move(Direction.LEFT)
            case 2:
                self.game.move(Direction.RIGHT)
            case 3:
                self.game.drop_figure()
            case 4:
                self.game.rotate(Direction.LEFT)
            case 5:
                self.game.rotate(Direction.RIGHT)
    
        status = self.game.move(Direction.DOWN)

        reward = self.game.score - self.score + (1 if status != GameStatus.GAME_OVER else 0)
        self.score = self.game.score
        done = self.game.state == GameStatus.GAME_OVER or self.episodes <= 0
        return self.game.field, reward, done, {}

    def render(self):
        self.game.draw()
        self.clock.tick(4)

    def reset(self):
        self.game = Tetris(self.width, self.height)
        self.score = 0
        self.episodes = 10000
        return self.game.field

When I now try to create a model I get the following error:

from stable_baselines3 import A2C
from stable_baselines3.common.vec_env import DummyVecEnv

env = TetrisEnv((20, 10))
env = DummyVecEnv([lambda: env])
model = A2C('CnnPolicy', env, verbose=1)

Error:

Traceback (most recent call last):
    model = A2C('CnnPolicy', env, verbose=1)
  File ".../lib/python3.10/site-packages/stable_baselines3/a2c/a2c.py", line 115, in __init__
    self._setup_model()
  File ".../lib/python3.10/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 111, in _setup_model
    self.rollout_buffer = buffer_cls(
  File ".../lib/python3.10/site-packages/stable_baselines3/common/buffers.py", line 342, in __init__
    super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs)
  File ".../lib/python3.10/site-packages/stable_baselines3/common/buffers.py", line 49, in __init__
    self.obs_shape = get_obs_shape(observation_space)
  File ".../lib/python3.10/site-packages/stable_baselines3/common/preprocessing.py", line 153, in get_obs_shape
    return (int(observation_space.n),)
TypeError: int() argument must be a string, a bytes-like object or a real number, not 'list'

What am I missing here?

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文