pytorch的平均计算图

发布于 2025-02-11 15:45:20 字数 1281 浏览 2 评论 0原文

我是在做适当的事情,因为我应对以下情况:在培训期间,我需要在Pytorch的远期功能中进行循环。此后,我的代码示例:

def forward(self, input_sinogram, sos):
    [variables declaration...]
    # stack = torch.zeros(batch_size, self.nb_elements * self.nb_elements, 
    for tx_id in range(0, self.nb_elements, self.decimation_factor):
      [variables declaration...]
      ima = self.netFeaturesExtractor(ima)
      stack[:, id_stack, :, :] = ima
      id_stack += 1

这样做,Pytorch计算为每次迭代构建计算图,并填充内存。如果迭代太多,它会占用太多的内存。因此,我尝试了以下实现:

def forward(self, input_sinogram, sos):
    [variables declaration...]
    # stack = torch.zeros(batch_size, self.nb_elements * self.nb_elements, 
    for tx_id in range(0, self.nb_elements, self.decimation_factor):
       if tx_id == self.id_no_frz:
          self.auto_grad(True)
       else:
          self.auto_grad(False)
       [variables declaration...]
       ima = self.netFeaturesExtractor(ima)
       stack[:, id_stack, :, :] = ima
       id_stack += 1

def auto_grad(self, freeze):
    """ Freeze network."""

    for param in self.netFeaturesExtractor.parameters():
      param.requires_grad = freeze

我的想法是仅针对某些迭代构建计算图。实际上,它行不通,该模型在几个时期后会收敛。

我想知道是否有可能平均每次迭代的计算图?这样,反向传播将更快,并考虑所有数据。这也将避免记忆问题。我在这个主题上没有找到任何东西。

感谢您的帮助!

I'am writting to due because I cope with the following situation: during training, I need a for loop in pytorch's forward function. Hereafter a sample of my code:

def forward(self, input_sinogram, sos):
    [variables declaration...]
    # stack = torch.zeros(batch_size, self.nb_elements * self.nb_elements, 
    for tx_id in range(0, self.nb_elements, self.decimation_factor):
      [variables declaration...]
      ima = self.netFeaturesExtractor(ima)
      stack[:, id_stack, :, :] = ima
      id_stack += 1

Doing that, pytorch computed build the computational graph for each iteration, and fill the memory. If there are too many iterations, it takes up too much memory. Thus I tried the following implementation:

def forward(self, input_sinogram, sos):
    [variables declaration...]
    # stack = torch.zeros(batch_size, self.nb_elements * self.nb_elements, 
    for tx_id in range(0, self.nb_elements, self.decimation_factor):
       if tx_id == self.id_no_frz:
          self.auto_grad(True)
       else:
          self.auto_grad(False)
       [variables declaration...]
       ima = self.netFeaturesExtractor(ima)
       stack[:, id_stack, :, :] = ima
       id_stack += 1

def auto_grad(self, freeze):
    """ Freeze network."""

    for param in self.netFeaturesExtractor.parameters():
      param.requires_grad = freeze

My idea was to build a computational graph only for some iterations. In practice it doesn't work, the model converges after a few epochs.

I a wonder if it is possible to average the computational graph of each iteration? That way, the backpropagation would be faster and take all the data into account. It would also avoid memory issues. I didn't find anything on this topic.

Thank you for your help!

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(2

挥剑断情 2025-02-18 15:45:20

如果认为Pytorch中的“计算图”有些混乱。实际上,这是一个正式对一组操作员进行推断并通过它进行反向传播过程的概念。将其视为从输入和参数执行的操作流,以输出模型的不同结果。

该图并未保存在内存中的内存中,在特定位置。取而代之的是,它是在动态上构建的,在每个操作pytorch都会缓存一些张量并在这些张量节点上引入回调功能,以便以后对其进行渐变计算。该回调函数的形式为torch.tensor.grad_fn属性。该设计的目的是使其扩展,而不必在推断之前依靠预先计算的图。

即使您确实可以访问这种图形,我也不确定您的“图累积”过程对您来说是否清楚。请记住,在这一点上,该模型的节点尚未执行传播操作。 您会用这些节点做什么?

您只能对梯度或权重进行此类操作:

  • 梯度积累允许重量在几种推理上积累梯度更新这些权重并清除梯度缓存之前的迭代。

  • 重量平均这将进入模型nodembling 在多个培训中平均模型的重量值。

If think there is some confusion as to what a "computation graph" is in PyTorch. This in fact an concept to formalise the process of doing inferences on a set of operators and backpropagating through it. Think of it as the flow of operations that have been performed from the inputs and parameters to output the different results of your model.

This graph is not saved in memory in memory at a, specific location. Instead it is built on the fly, after every operation PyTorch caches some tensors and introduces callback functions on those tensor nodes in order to perform gradient computation on them later on. That callback function is in the form of the torch.Tensor.grad_fn attribute. The point of this design is to make it scale and not have to rely on a precomputed graph before doing inference.

Even if you did have access to this kind of graph I am not sure the "graph accumulation" process is clear to you. Keep in mind that at this point, no propagation operation has been performed on the nodes of that model. What would you be doing with those nodes?

You can only do such operations on gradients or weights themselves:

  • gradient accumulation is allowing the weights to accumulate gradients over several inference iterations before updating those weights and clearing the gradient caches.

  • weight average this comes into model ensembling whereby you are averaging weight values of your model across multiple trainings.

定格我的天空 2025-02-18 15:45:20

根据pytorch doc ://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#computational-groppational-groppational-groph“ rel =“ nofollow noreferrer”> this ,示例中导致内存问题的原因可能是示例中的内存问题,可能是向前通行证中的tensors tensors tensor ,而不是恕我直言,而是所谓的“计算图”。根据我的经验,您无法直接在Pytorch中访问计算图。在 @Ivan的建议中,梯度积累可能与您的情况最相关,其中使用此模式:

optimizer.zero_grad()
loss = net(batch1)
loss.backward()
loss = net(batch2)
loss.backward()
optimizer.step()

这在数学上等同于(例如,答案):

optimizer.zero_grad()
loss1 = net(batch1)
loss2 = net(batch2)
loss = loss1 + loss2
loss.backward()
optimizer.step()

但是更好,因为通过调用backward()在每次向前传递后,向前传球中的加速张量被释放,从而避免了内存问题。如果我的理解是正确的,那么这正是您“平均计算图”的含义,因为将损失累加(又称平均),每个损失附加了梯度。

但是,如果您要做的事情比仅增加损失更复杂的事情,那么解决方案可能会有所不同。经验法则是尽快致电向后(),例如,两到三个正向通行证后最多都会致电。在编写此答案时,我面临着类似的问题,可以简化计算以下损失的向后:

损失= log sigmoid(net(x1) + net(x2) + ... + net(x_n))

其中n可能非常大。如果我进行了n向前传递,添加输出,应用的logsigmoid,并且dod backward(),则会出现错误的错误。相反,请注意,链条规则

d/dp损失= d/dz(logsigmoid(z))(d/dp net(x1) + d/dp net(x2) + ... + d/dp net(x_n))

>

其中p 是网络参数,z = net(x1) + net(x2) + ... + net(x_n)。因此,我首先积累了如上所述的梯度,最后,将梯度乘以d/dz(logSigmoid(z))= 1 -sigmoid(z),给出正确的d /DP损失

最后,有 torch.utils.utils.checkpoint.checkpoint.checkpoint.checkpointPytorch,通过交易来计算内存来减轻内存问题。您可能会发现它有用。但是,在我的用例中,这没有那么有用。

According to pytorch doc this and this, what is causing memory issue in your example might be the cached tensors during forward pass, instead of, IMHO, the so-called "computation graph". From my experience, you don't have access to the computation graph directly in pytorch. Among @Ivan's suggestions, gradient accumulation could be most related to your case, wherein this pattern is used:

optimizer.zero_grad()
loss = net(batch1)
loss.backward()
loss = net(batch2)
loss.backward()
optimizer.step()

This is mathematically equivalent to (e.g.,this answer):

optimizer.zero_grad()
loss1 = net(batch1)
loss2 = net(batch2)
loss = loss1 + loss2
loss.backward()
optimizer.step()

but is better, in that by calling backward() after every forward pass, the cached tensors in the forward pass are freed, thus avoiding memory issue. If my understanding is correct, this is exactly what you mean by "averaging the computation graph", since adding up the losses also adding up (a.k.a. averaging) the gradient attached to each loss.

However, if you are doing something more sophisticated than merely adding up the losses, the solution might differ. The rule of thumb is to call backward() as soon as possible, e.g. at most after two or three forward passes. While writing this answer, I'm facing a similar problem, which can be simplified to computing the backward of the following loss:

loss = log sigmoid(net(x1) + net(x2) + ... + net(x_N))

where N is potentially very large. If I did N forward passes, added up the outputs, applied logsigmoid, and did backward(), there would be out-of-memory error. Instead, notice that by chain rule:

d/dp loss = d/dz(logsigmoid(z)) (d/dp net(x1) + d/dp net(x2) + ... + d/dp net(x_N))

where p is the network parameters and z = net(x1) + net(x2) + ... + net(x_N). Therefore, I first accumulate gradient as mentioned above, and in the end, multiply the gradient by d/dz(logsigmoid(z)) = 1 - sigmoid(z), giving the correct d/dp loss.

Finally, there's torch.utils.checkpoint.checkpoint in pytorch, which alleviates memory issue by trading compute for memory. You may find it useful. In my use case it's not that useful, though.

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文