RuntimeError:目前只有用户显式创建的张量(图叶)支持深度复制协议
这是关于DeepCopy的错误,我应该如何做。 错误: target_encoder = copy.deepcopy(self.online_encoder)
runtimeerror:仅用户(图形)明确创建的张量
class Model(nn.Module):
def __init__(
self,
model, # byol
projection_size=256,
pred_size = 256,
projection_hidden_size=4096,
moving_average_decay=0.99,
use_momentum=True,
):
super(SSL, self).__init__()
self.online_encoder = Pre_model(model) # 256
self.use_momentum = use_momentum
self.target_encoder = None
self.target_ema_updater = EMA(moving_average_decay)
def _get_target_encoder(self):
target_encoder = copy.deepcopy(self.online_encoder)
set_requires_grad(target_encoder, False)
return target_encoder
def forward(self, x):
anchors = x['anchor'].cuda(non_blocking = True)
neighbors = x['neighbor'].cuda(non_blocking = True)
online_anchor_proj = self.online_encoder(anchors)
online_neighbor_proj = self.online_encoder(neighbors)
with torch.no_grad():
target_online = self._get_target_encoder() if self.use_momentum else self.online_encoder
target_anchor_proj= target_online(anchors)
target_neighbor_proj = target_online(neighbors)
Here is the error about deepcopy, how should I do it.
error:
target_encoder = copy.deepcopy(self.online_encoder)
RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment
class Model(nn.Module):
def __init__(
self,
model, # byol
projection_size=256,
pred_size = 256,
projection_hidden_size=4096,
moving_average_decay=0.99,
use_momentum=True,
):
super(SSL, self).__init__()
self.online_encoder = Pre_model(model) # 256
self.use_momentum = use_momentum
self.target_encoder = None
self.target_ema_updater = EMA(moving_average_decay)
def _get_target_encoder(self):
target_encoder = copy.deepcopy(self.online_encoder)
set_requires_grad(target_encoder, False)
return target_encoder
def forward(self, x):
anchors = x['anchor'].cuda(non_blocking = True)
neighbors = x['neighbor'].cuda(non_blocking = True)
online_anchor_proj = self.online_encoder(anchors)
online_neighbor_proj = self.online_encoder(neighbors)
with torch.no_grad():
target_online = self._get_target_encoder() if self.use_momentum else self.online_encoder
target_anchor_proj= target_online(anchors)
target_neighbor_proj = target_online(neighbors)
如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

绑定邮箱获取回复消息
由于您还没有绑定你的真实邮箱,如果其他用户或者作者回复了您的评论,将不能在第一时间通知您!
发布评论