返回介绍

parl.Model

发布于 2024-06-23 17:58:49 字数 3760 浏览 0 评论 0 收藏 0

class Model(name_scope=None, dtype='float32')[源代码]
alias: parl.Model alias: parl.core.paddle.agent.Model Model is a base class of PARL for the neural network.

A Model is usually a policy or Q-value function, which predicts an action or an estimate according to the environmental observation.

To use the PaddlePaddle2.0 backend model, user needs to call

super(Model, self).__init__() at the beginning of __init__ function.

Model supports duplicating a Model instance in a pythonic way: copied_model = copy.deepcopy(model)

Example:

import parl
import paddle.nn as nn

class Policy(parl.Model):
    def __init__(self):
        super(Policy, self).__init__()
        self.fc = nn.Linear(input_dim=100, output_dim=32)

    def policy(self, obs):
        out = self.fc(obs)
        return out

policy = Policy()
copied_policy = copy.deepcopy(policy)
变量:

model_id (str) – each model instance has its unique model_id.

Public Functions:
  • sync_weights_to: synchronize parameters of the current model

to another model. - get_weights: return a list containing all the parameters of the current model. - set_weights: copy parameters from set_weights() to the model. - forward: define the computations of a neural network. Should be overridden by all subclasses.

get_weights()[源代码]

Returns a Python dict containing parameters of current model.

返回:

a Python dict containing the parameters of current model.

set_weights(weights)[源代码]

Copy parameters from set_weights() to the model.

参数:

weights (dict) – a Python dict containing the parameters.

sync_weights_to(target_model, decay=0.0)[源代码]

Synchronize parameters of current model to another model.

target_model_weights = decay * target_model_weights
  • (1 - decay) * current_model_weights

参数:
  • target_model (parl.Model) – an instance of Model that has the same neural network architecture as the current model.

  • decay (float) – the rate of decline in copying parameters. 0 if no parameters decay when synchronizing the parameters.

Example:

import copy
# create a model that has the same neural network structures.
target_model = copy.deepcopy(model)

# after initilizing the parameters ...
model.sync_weights_to(target_mdodel)

备注

Before calling sync_weights_to, parameters of the model must have been initialized.

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
    我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
    原文