如何简化图像标准化功能?

发布于 2025-01-17 06:22:20 字数 347 浏览 1 评论 0原文

我有一个函数可以计算数据集的平均值和标准差。 有没有更简单的方法来做到这一点?因为计算需要一段时间。

def get_mean_std(loader):
  sum = 0
  sum_sq_err = 0
  for data, _ in loader:
    sum += torch.mean(data, dim=[0,2,3]) 
    sum_sq_err += torch.mean(data**2, dim=[0,2,3])
  mean = sum/len(loader)
  std = (sum_sq_err/(len(loader)) - mean**2)**0.5
  return mean, std

I have a function to calculate the mean and standard deviation of my dataset.
Is there a simpler way to do this? As it takes a while to compute.

def get_mean_std(loader):
  sum = 0
  sum_sq_err = 0
  for data, _ in loader:
    sum += torch.mean(data, dim=[0,2,3]) 
    sum_sq_err += torch.mean(data**2, dim=[0,2,3])
  mean = sum/len(loader)
  std = (sum_sq_err/(len(loader)) - mean**2)**0.5
  return mean, std

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

日久见人心 2025-01-24 06:22:20

请注意,这种方法一般来说甚至是不正确的,因为集合的平均值通常不是某些子集的平均值的平均值(这是当所有子集具有相同长度时,但情况可能是也可能不是)这里)。

假设每个批次的大小相同,我要做的就是在循环中调用 torch.sum ,而不是已经将其累加到总和中,将其附加到列表中,然后减少通过 torch.sum + 之后的除法。请注意,torch.sum 实现了一种非常重要的算法,该算法通常比简单的迭代和更精确。

Note that this approach is not even correct in general, as the mean of a set is not the mean of the means of some subsets in general (it is when all the subsets have the same length, but that may or may not be the case here).

Provided that every batch is of the same size, what I would do is to call torch.sum in the loop, but rather than already accumulating it into a sum, appending it into a list, and then reducing it via torch.sum + a division afterwards. Note that torch.sum implements a highly non-trivial algorithm that is more precise in general than the naïve iterative sum.

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文