RuntimeError:目前只有用户显式创建的张量(图叶)支持深度复制协议

发布于 2025-01-17 12:15:29 字数 1405 浏览 3 评论 0原文

这是关于DeepCopy的错误,我应该如何做。 错误: target_encoder = copy.deepcopy(self.online_encoder)

runtimeerror:仅用户(图形)明确创建的张量

class Model(nn.Module):
    def __init__(
            self,
            model, # byol
            projection_size=256,
            pred_size = 256,
            projection_hidden_size=4096,
            moving_average_decay=0.99,
            use_momentum=True,
    ):
        super(SSL, self).__init__()

        self.online_encoder = Pre_model(model) # 256

        self.use_momentum = use_momentum
        self.target_encoder = None
        self.target_ema_updater = EMA(moving_average_decay)


    def _get_target_encoder(self):
        target_encoder = copy.deepcopy(self.online_encoder)
        set_requires_grad(target_encoder, False)
        return target_encoder


    def forward(self, x):
        anchors = x['anchor'].cuda(non_blocking = True)
        neighbors = x['neighbor'].cuda(non_blocking = True)

        online_anchor_proj = self.online_encoder(anchors) 
        online_neighbor_proj = self.online_encoder(neighbors)


        with torch.no_grad():
            target_online = self._get_target_encoder() if self.use_momentum else self.online_encoder
            target_anchor_proj= target_online(anchors)
            target_neighbor_proj = target_online(neighbors)

Here is the error about deepcopy, how should I do it.
error:
target_encoder = copy.deepcopy(self.online_encoder)

RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment

class Model(nn.Module):
    def __init__(
            self,
            model, # byol
            projection_size=256,
            pred_size = 256,
            projection_hidden_size=4096,
            moving_average_decay=0.99,
            use_momentum=True,
    ):
        super(SSL, self).__init__()

        self.online_encoder = Pre_model(model) # 256

        self.use_momentum = use_momentum
        self.target_encoder = None
        self.target_ema_updater = EMA(moving_average_decay)


    def _get_target_encoder(self):
        target_encoder = copy.deepcopy(self.online_encoder)
        set_requires_grad(target_encoder, False)
        return target_encoder


    def forward(self, x):
        anchors = x['anchor'].cuda(non_blocking = True)
        neighbors = x['neighbor'].cuda(non_blocking = True)

        online_anchor_proj = self.online_encoder(anchors) 
        online_neighbor_proj = self.online_encoder(neighbors)


        with torch.no_grad():
            target_online = self._get_target_encoder() if self.use_momentum else self.online_encoder
            target_anchor_proj= target_online(anchors)
            target_neighbor_proj = target_online(neighbors)

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文