- 概览
- 安装
- 教程
- 算法接口文档
- 简易高效的并行接口
- APIS
- FREQUENTLY ASKED QUESTIONS
- EVOKIT
- 其他
- parl.algorithms.paddle.policy_gradient
- parl.algorithms.paddle.dqn
- parl.algorithms.paddle.ddpg
- parl.algorithms.paddle.ddqn
- parl.algorithms.paddle.oac
- parl.algorithms.paddle.a2c
- parl.algorithms.paddle.qmix
- parl.algorithms.paddle.td3
- parl.algorithms.paddle.sac
- parl.algorithms.paddle.ppo
- parl.algorithms.paddle.maddpg
- parl.core.paddle.model
- parl.core.paddle.algorithm
- parl.remote.remote_decorator
- parl.core.paddle.agent
- parl.remote.client
自定义新算法
该教程的目标:
了解如何实现自己的算法。
概览
要构建新算法,您需要继承类 parl.Algorithm
,并实现两个基本函数: predict
和 learn
函数
__init__
Algorithms
更新Model
的参数,此构造函数需要继承parl.Model
和其中的一些函数 ,例如本示例中的self.model
。您还可以在此方法中设置一些超参数,例如learning_rate
,reward_decay
和action_dimension
,这些超参数可能在之后的步骤中使用。predict
这个函数定义如何去选择 actions。比如,你可以使用一个 policy model 去预测 action
learn
损失函数应该被定义在该函数中,该函数主要用于更新
self.model
的模型参数
示例: DQN
该示例演示了如何通过继承 parl.Algorithm
实现DQN算法
在 DQN(Algorithm)
类中,我们定义一下类函数:
__init__(self, model, gamma=None, lr=None)
我们在这个函数中定义 DQN的
self.model
和self.target_model
, 同时,我们在该函数中定义超参数gamma
以及lr
。 这些超参数在learn
函数中会被用到 。def __init__(self, model, gamma=None, lr=None): """ DQN algorithm Args: model (parl.Model): forward neural network representing the Q function. gamma (float): discounted factor for `accumulative` reward computation lr (float): learning rate. """ self.model = model self.target_model = copy.deepcopy(model) assert isinstance(gamma, float) assert isinstance(lr, float) self.gamma = gamma self.lr = lr self.mse_loss = paddle.nn.MSELoss(reduction='mean') self.optimizer = paddle.optimizer.Adam( learning_rate=lr, parameters=self.model.parameters())
predict(self, obs)
我们直接使用输入该函数的环境状态,并将该状态传输入
self.model
中,self.model
会输出预测的action value functiondef predict(self, obs): """ use self.model (Q function) to predict the action values """ return self.model.value(obs)
learn(self, obs, action, reward, next_obs, terminal)
learn
函数会根据当前的预测值和目标值输出当前的损失,并通过该损失进行反向传播更新self.model
中的参数def learn(self, obs, action, reward, next_obs, terminal): """ update the Q function (self.model) with DQN algorithm """ # Q pred_values = self.model.value(obs) action_dim = pred_values.shape[-1] action = paddle.squeeze(action, axis=-1) action_onehot = paddle.nn.functional.one_hot( action, num_classes=action_dim) pred_value = paddle.multiply(pred_values, action_onehot) pred_value = paddle.sum(pred_value, axis=1, keepdim=True) # target Q with paddle.no_grad(): max_v = self.target_model.value(next_obs).max(1, keepdim=True) target = reward + (1 - terminal) * self.gamma * max_v loss = self.mse_loss(pred_value, target) # optimize self.optimizer.clear_grad() loss.backward() self.optimizer.step() return loss
sync_target(self)
该函数同步
self.target_model
和self.model
中的参数def sync_target(self): self.model.sync_weights_to(self.target_model)
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。
绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论