返回介绍

快速入门:解决 CartPole 问题

发布于 2024-06-23 17:58:49 字数 6034 浏览 0 评论 0 收藏 0

本教程的目标

  • 简单了解PARL

  • 熟悉PARL构建智能体过程中需要用到的子模块。

本教程会使用 示例 中的代码来解释任何通过PARL构建智能体解决经典的Cartpole问题。

CartPole 介绍

CartPole又叫倒立摆。小车上放了一根杆,杆会因重力而倒下。为了不让杆倒下,我们要通过移动小车,来保持其是直立的。如下图所示。 在每一个时间步,模型的输入是一个4维的向量,表示当前小车和杆的状态,模型输出的信号用于控制小车往左或者右移动。当杆没有倒下的时候,每个时间步,环境会给1分的奖励;当杆倒下后,环境不会给任何的奖励,游戏结束。

https://www.wenjiangs.com/wp-content/uploads/2024/docimg5/performance.gif

Model

主要定义前向网络,这通常是一个策略网络(Policy Network)或者一个值函数网络(Value Function),输入是当前环境状态(State)。

首先,我们构建一个包含2个全连接层的前向网络。

import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import parl

class CartpoleModel(parl.Model):
    def __init__(self, obs_dim, act_dim):
        super(CartpoleModel, self).__init__()
        hid1_size = act_dim * 10
        self.fc1 = nn.Linear(obs_dim, hid1_size)
        self.fc2 = nn.Linear(hid1_size, act_dim)

    def forward(self, x):
        out = paddle.tanh(self.fc1(x))
        prob = F.softmax(self.fc2(out), axis=-1)
        return prob

定义前向网络的主要三个步骤:

  • 继承 parl.Model

  • 构造函数 __init__ 中声明要用到的中间层

  • forward 函数中搭建网络

在上面的代码中,我们先在构造函数中声明了两个全连接层,然后在 forward 函数中定义了网络的前向计算方式:输入一个状态,然后经过一层FC和tanh激活函数,一层FC和softmax激活函数,得到了每个action的概率分布预测。

Algorithm

定义了具体的算法来更新前向网络( Model )中的参数,也就是通过定义损失函数以及使用optimizer更新 Model 。一个 Algorithm 包含至少一个 Model 。在这个教程中,我们将使用经典的PolicyGradient算法来解决问题。由于PARL仓库已经实现了这个算法,我们只需要直接import来使用即可。

model = CartpoleModel(act_dim=2)
algorithm = parl.algorithms.PolicyGradient(model, lr=1e-3)

在实例化了 Model 之后,我们把它传给 Algorithm

Agent

负责算法与环境的交互,在交互过程中把生成的数据提供给 Algorithm 来更新模型( Model ),也就是数据和算法的交互一般定义在这里。

我们得要继承 parl.Agent 这个类来实现自己的 Agent ,下面先把 Agent 的代码抛出来,再按照函数解释。

class CartpoleAgent(parl.Agent):

    def __init__(self, algorithm):

        super(CartpoleAgent, self).__init__(algorithm)

    def sample(self, obs):

        obs = paddle.to_tensor(obs, dtype='float32')
        prob = self.alg.predict(obs)
        prob = prob.numpy()
        act = np.random.choice(len(prob), 1, p=prob)[0]

        return act

    def predict(self, obs):

        obs = paddle.to_tensor(obs, dtype='float32')
        prob = self.alg.predict(obs)
        act = int(prob.argmax())

        return act

    def learn(self, obs, act, reward):

        act = np.expand_dims(act, axis=-1)
        reward = np.expand_dims(reward, axis=-1)
        obs = paddle.to_tensor(obs, dtype='float32')
        act = paddle.to_tensor(act, dtype='int32')
        reward = paddle.to_tensor(reward, dtype='float32')

        loss = self.alg.learn(obs, act, reward)

        return float(loss)

一般情况下,用户必须实现以下几个函数:

  • __init__ :把前面定义好的algorithm传进来,作为agent的一个成员变量,用于后续的数据交互。需要注意的是,这里必须得要初始化父类:super(CartpoleAgent, self).__init__(algorithm)

  • predict :根据环境状态返回预测动作(action),一般用于评估和部署agent。

  • sample :根据环境状态返回动作(action),一般用于训练时候采样action进行探索。

开始训练

首先,我们来定一个智能体。逐步定义 model | algorithm | agent ,然后得到一个可以和环境进行交互的智能体。

model = CartpoleModel(obs_dim=obs_dim, act_dim=act_dim)
alg = parl.algorithms.PolicyGradient(model, lr=LEARNING_RATE)
agent = CartpoleAgent(alg)

然后我们用这个agent和环境进行交互,训练模型,1000个episode之后,agent就可以很好地解决Cartpole问题,拿到满分(200)。

def run_train_episode(env, agent):

    obs_list, action_list, reward_list = [], [], []
    obs = env.reset()

    while True:
        obs_list.append(obs)
        action = agent.sample(obs)
        action_list.append(action)

        obs, reward, done, info = env.step(action)
        reward_list.append(reward)

        if done:
            break

    return obs_list, action_list, reward_list

env = gym.make("CartPole-v0")

for i in range(1000):
      obs_list, action_list, reward_list = run_train_episode(env, agent)

      if i % 10 == 0:
          logger.info("Episode {}, Reward Sum {}.".format(i, sum(reward_list)))

      batch_obs = np.array(obs_list)
      batch_action = np.array(action_list)
      batch_reward = calc_reward_to_go(reward_list)

      agent.learn(batch_obs, batch_action, batch_reward)

      if (i + 1) % 100 == 0:
          _, _, reward_list = run_episode(env, agent, train_or_test='test')
          total_reward = np.sum(reward_list)

          logger.info('Test reward: {}'.format(total_reward))

总结

https://www.wenjiangs.com/wp-content/uploads/2024/docimg5/performance.gifhttps://www.wenjiangs.com/wp-content/uploads/2024/docimg5/quickstart.png

在这个教程,我们展示了如何一步步地构建强化学习智能体,用于解决经典的Cartpole问题。

完整的训练代码可以在这个 文件夹 中找到。赶紧执行下面的代码运行尝试下吧:)

# Install dependencies
pip install paddlepaddle

pip install gym
git clone https://github.com/PaddlePaddle/PARL.git
cd PARL
pip install .

# Train model
cd examples/QuickStart/
python train.py

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
    我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
    原文