- 概览
- 安装
- 教程
- 算法接口文档
- 简易高效的并行接口
- APIS
- FREQUENTLY ASKED QUESTIONS
- EVOKIT
- 其他
- parl.algorithms.paddle.policy_gradient
- parl.algorithms.paddle.dqn
- parl.algorithms.paddle.ddpg
- parl.algorithms.paddle.ddqn
- parl.algorithms.paddle.oac
- parl.algorithms.paddle.a2c
- parl.algorithms.paddle.qmix
- parl.algorithms.paddle.td3
- parl.algorithms.paddle.sac
- parl.algorithms.paddle.ppo
- parl.algorithms.paddle.maddpg
- parl.core.paddle.model
- parl.core.paddle.algorithm
- parl.remote.remote_decorator
- parl.core.paddle.agent
- parl.remote.client
子模块说明
在上一个教程中,我们快速地展示了如果通过PARL的三个基础模块:Model
, Algorithm
, Agent
来搭建智能体和环境进行交互的。在这个教程中,我们将详细介绍每个模块的具体定位,以及使用规范。
Model
定义:
Model
用来定义前向(Forward)网络,这通常是一个策略网络(Policy Network)或者一个值函数网络(Value Function),输入是当前环境状态(State)。⚠️注意事项:用户得要继承
parl.Model
这个类来构建自己的Model。- 需要实现的函数:
forward
: 根据在初始化函数中声明的计算层来搭建前向网络。
备注:在PARL中,实现强化学习常用的target network很方便的,直接调用
copy.deepcopy
即可。示例:
import paddle import paddle.nn as nn import parl import copy class CartpoleModel(parl.Model): def __init__(self, obs_dim, act_dim): super(CartpoleModel, self).__init__() hid1_size = act_dim * 10 self.fc1 = nn.Linear(obs_dim, hid1_size) self.fc2 = nn.Linear(hid1_size, act_dim) self.tanh = nn.Tanh() self.softmax = nn.Softmax() def forward(self, x): out = self.tanh(self.fc1(x)) prob = self.softmax(self.fc2(out)) return prob if __name__ == '__main__: model = CartpoleModel() target_model = copy.deepcopy(model)
Algorithm
定义:
Algorithm
定义了具体的算法来更新前向网络(Model),也就是通过定义损失函数来更新Model
。一个Algorithm
包含至少一个Model
。⚠️注意事项:一般不自己开发,推荐直接import 仓库中已经实现好的算法。
- 需要实现的函数:
learn
: 根据训练数据(观测量和输入的reward),定义损失函数,用于更新Model
中的参数。predict
: 根据当前的观测量,给出动作概率分布或者Q函数的预估值。
示例:
model = CartpoleModel(act_dim=2) algorithm = parl.algorithms.PolicyGradient(model, lr=1e-3)
Agent
定义:
Agent
负责算法与环境的交互,在交互过程中把生成的数据提供给Algorithm
来更新模型(Model
),数据的预处理流程也一般定义在这里。⚠️注意事项:需要继承
parl.Agent
来使用,要在构造函数中调用父类的构造函数。- 需要实现的函数:
learn
: 根据训练数据(观测量和输入的reward),定义损失函数,用于更新Model
中的参数。predict
: 根据环境状态返回预测动作(action),一般用于评估和部署agent。sample
:根据环境状态返回动作(action),一般用于训练时候采样action进行探索。
示例:
class CartpoleAgent(parl.Agent): def __init__(self, algorithm): super(CartpoleAgent, self).__init__(algorithm) def sample(self, obs): obs = paddle.to_tensor(obs, dtype='float32') prob = self.alg.predict(obs) prob = prob.numpy() act = np.random.choice(len(prob), 1, p=prob)[0] return act def predict(self, obs): obs = paddle.to_tensor(obs, dtype='float32') prob = self.alg.predict(obs) act = int(prob.argmax()) return act def learn(self, obs, act, reward): act = np.expand_dims(act, axis=-1) reward = np.expand_dims(reward, axis=-1) obs = paddle.to_tensor(obs, dtype='float32') act = paddle.to_tensor(act, dtype='int32') reward = paddle.to_tensor(reward, dtype='float32') loss = self.alg.learn(obs, act, reward) return float(loss) if __name__ == '__main__': model = CartpoleModel() alg = parl.algorithms.PolicyGradient(model, lr=1e-3) agent = CartpoleAgent(alg, obs_dim=4, act_dim=2)
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论