返回介绍

自定义新算法

发布于 2024-06-23 17:58:49 字数 3267 浏览 0 评论 0 收藏 0

该教程的目标:

  • 了解如何实现自己的算法。

概览

要构建新算法,您需要继承类 parl.Algorithm ,并实现两个基本函数: predictlearn

函数

  • __init__

    Algorithms 更新 Model 的参数,此构造函数需要继承 parl.Model 和其中的一些函数 ,例如本示例中的 self.model 。您还可以在此方法中设置一些超参数,例如 learning_ratereward_decayaction_dimension ,这些超参数可能在之后的步骤中使用。

  • predict

    这个函数定义如何去选择 actions。比如,你可以使用一个 policy model 去预测 action

  • learn

    损失函数应该被定义在该函数中,该函数主要用于更新 self.model 的模型参数

示例: DQN

该示例演示了如何通过继承 parl.Algorithm 实现DQN算法

DQN(Algorithm) 类中,我们定义一下类函数:

  • __init__(self, model, gamma=None, lr=None)

    我们在这个函数中定义 DQN的 self.modelself.target_model , 同时,我们在该函数中定义超参数 gamma 以及 lr 。 这些超参数在 learn 函数中会被用到 。

    def __init__(self, model, gamma=None, lr=None):
        """ DQN algorithm
    
        Args:
            model (parl.Model): forward neural network representing the Q function.
            gamma (float): discounted factor for `accumulative` reward computation
            lr (float): learning rate.
        """
        self.model = model
        self.target_model = copy.deepcopy(model)
    
        assert isinstance(gamma, float)
        assert isinstance(lr, float)
    
        self.gamma = gamma
        self.lr = lr
    
        self.mse_loss = paddle.nn.MSELoss(reduction='mean')
        self.optimizer = paddle.optimizer.Adam(
            learning_rate=lr, parameters=self.model.parameters())
    
  • predict(self, obs)

    我们直接使用输入该函数的环境状态,并将该状态传输入 self.model 中,self.model 会输出预测的action value function

    def predict(self, obs):
        """ use self.model (Q function) to predict the action values
        """
        return self.model.value(obs)
    
  • learn(self, obs, action, reward, next_obs, terminal)

    learn 函数会根据当前的预测值和目标值输出当前的损失,并通过该损失进行反向传播更新 self.model 中的参数

    def learn(self, obs, action, reward, next_obs, terminal):
        """ update the Q function (self.model) with DQN algorithm
        """
        # Q
        pred_values = self.model.value(obs)
        action_dim = pred_values.shape[-1]
        action = paddle.squeeze(action, axis=-1)
        action_onehot = paddle.nn.functional.one_hot(
            action, num_classes=action_dim)
        pred_value = paddle.multiply(pred_values, action_onehot)
        pred_value = paddle.sum(pred_value, axis=1, keepdim=True)
    
        # target Q
        with paddle.no_grad():
            max_v = self.target_model.value(next_obs).max(1, keepdim=True)
            target = reward + (1 - terminal) * self.gamma * max_v
    
        loss = self.mse_loss(pred_value, target)
    
        # optimize
        self.optimizer.clear_grad()
        loss.backward()
        self.optimizer.step()
    
        return loss
    
  • sync_target(self)

    该函数同步 self.target_modelself.model 中的参数

    def sync_target(self):
    
        self.model.sync_weights_to(self.target_model)
    

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
    我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
    原文