如何无循环的火炬张量总和?

发布于 2025-01-25 17:47:18 字数 250 浏览 2 评论 0原文

我有一系列垃圾箱的边界,我需要在这些垃圾箱内获得一笔价值。 现在看起来如下:

output = torch.zeros((16, 10)) #10 corresponds to the number of bins

for l in range(10):
   output[:,l] = data[:, bin_edges[l]:bin_edges[l+1]].sum(axis=-1)

是否有可能避免循环并改善性能?

I've got an array of bins' borders and I need to get a sum of values inside these bins.
Now it looks as follows:

output = torch.zeros((16, 10)) #10 corresponds to the number of bins

for l in range(10):
   output[:,l] = data[:, bin_edges[l]:bin_edges[l+1]].sum(axis=-1)

Is it possible to avoid loops and improve the performance?

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

雪落纷纷 2025-02-01 17:47:18

通常,要通过矢量化优化代码,您希望构建一个单个大张量,您可以在该张量中计算单个操作中的结果。
但是在这里,您的垃圾箱的长度可能不同,因此您无法从中构建张量。

不过,这是时间序列处理中的通常情况,因此Pytorch有一些公用事业可以克服此问题,例如

使用该实用程序,我能够稍微优化该功能,但是差异取决于数据形状以及垃圾箱的数量和长度,有时性能甚至会降低。

请注意,pad_sequence假设您想从数据的第一个维度制作垃圾箱,并且从最后一个DIM制作垃圾箱,因此,如果您可以相应地重新组织数据,则优化会更好。

代码

实现

from itertools import pairwise
import random
import torch
from torch.nn.utils.rnn import pad_sequence


def bins_sum(x, edges):
    """ Your function (generalized a bit) """
    edges = [0, *edges, x.shape[-1]]
    bins = enumerate(pairwise(edges))
    num_bins = len(edges) - 1
    output = torch.zeros(*(x.shape[:-1]), num_bins)

    for bin_idx, (start, end) in bins:
        output[..., bin_idx] = x[..., start:end].sum(axis=-1)
    return output


def bins_sum_opti(x, edges):
    """ Trying to optimize using torch.nn.utils.rnn """
    x = x.movedim(-1, 0)
    edges = [0, *edges, x.shape[0]]
    xbins = [x[start:end] for start, end in pairwise(edges)]
    xbins_padded = pad_sequence(xbins)
    return xbins_padded.sum(dim=0).movedim(0, -1)


def get_data_bin_edges(data_shape, num_edges):
    data = torch.rand(*data_shape)
    bin_edges = sorted(random.sample(range(3, data_shape[-1] - 3), k=num_edges))
    return data, bin_edges

结果

断言,这两个功能都是等效的:

data, bin_edges = get_data_bin_edges(data_shape=(10, 20), num_edges=7)

res1 = bins_sum(data, bin_edges)
res2 = bins_sum_opti(data, bin_edges)

assert torch.allclose(res1, res2)

不同形状和边缘的时间:

>>> data, bin_edges = get_data_bin_edges(data_shape=(10, 20), num_edges=3)
>>> %timeit bins_sum(data, bin_edges)
>>> %timeit bins_sum_opti(data, bin_edges)
35.8 µs ± 531 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
27.6 µs ± 546 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
>>> data, bin_edges = get_data_bin_edges(data_shape=(10, 20), num_edges=7)
>>> %timeit bins_sum(data, bin_edges)
>>> %timeit bins_sum_opti(data, bin_edges)
67.4 µs ± 1.12 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
41.1 µs ± 107 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
>>> data, bin_edges = get_data_bin_edges(data_shape=(10, 20, 30), num_edges=3)
>>> %timeit bins_sum(data, bin_edges)
>>> %timeit bins_sum_opti(data, bin_edges)
43 µs ± 195 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
33 µs ± 314 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
>>> data, bin_edges = get_data_bin_edges(data_shape=(10, 20, 30), num_edges=7)
>>> %timeit bins_sum(data, bin_edges)
>>> %timeit bins_sum_opti(data, bin_edges)
90.5 µs ± 583 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
48.1 µs ± 134 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

Normally to optimize code by vectorization you would like to construct a single big tensor on which you compute the result in a single operation.
But here your bins might have different lengths, so you can't construct a tensor from that.

Though, that's a usual case in time-series processing, so PyTorch has some utilities to overcome this issue, such as torch.nn.utils.rnn.pad_sequence.

Using that utility I was able to optimize the function a bit, but the difference depends on the data shape and the number and length of bins, and sometimes performance even decreases.

Please note that pad_sequence assumes that you want to make bins from the first dimension of your data, and you make bins from the last dim, so the optimization would be better if you can reorganize your data accordingly.

Code

Implementations

from itertools import pairwise
import random
import torch
from torch.nn.utils.rnn import pad_sequence


def bins_sum(x, edges):
    """ Your function (generalized a bit) """
    edges = [0, *edges, x.shape[-1]]
    bins = enumerate(pairwise(edges))
    num_bins = len(edges) - 1
    output = torch.zeros(*(x.shape[:-1]), num_bins)

    for bin_idx, (start, end) in bins:
        output[..., bin_idx] = x[..., start:end].sum(axis=-1)
    return output


def bins_sum_opti(x, edges):
    """ Trying to optimize using torch.nn.utils.rnn """
    x = x.movedim(-1, 0)
    edges = [0, *edges, x.shape[0]]
    xbins = [x[start:end] for start, end in pairwise(edges)]
    xbins_padded = pad_sequence(xbins)
    return xbins_padded.sum(dim=0).movedim(0, -1)


def get_data_bin_edges(data_shape, num_edges):
    data = torch.rand(*data_shape)
    bin_edges = sorted(random.sample(range(3, data_shape[-1] - 3), k=num_edges))
    return data, bin_edges

Results

Assert that both functions are equivalent:

data, bin_edges = get_data_bin_edges(data_shape=(10, 20), num_edges=7)

res1 = bins_sum(data, bin_edges)
res2 = bins_sum_opti(data, bin_edges)

assert torch.allclose(res1, res2)

Time for different shapes and edges:

>>> data, bin_edges = get_data_bin_edges(data_shape=(10, 20), num_edges=3)
>>> %timeit bins_sum(data, bin_edges)
>>> %timeit bins_sum_opti(data, bin_edges)
35.8 µs ± 531 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
27.6 µs ± 546 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
>>> data, bin_edges = get_data_bin_edges(data_shape=(10, 20), num_edges=7)
>>> %timeit bins_sum(data, bin_edges)
>>> %timeit bins_sum_opti(data, bin_edges)
67.4 µs ± 1.12 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
41.1 µs ± 107 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
>>> data, bin_edges = get_data_bin_edges(data_shape=(10, 20, 30), num_edges=3)
>>> %timeit bins_sum(data, bin_edges)
>>> %timeit bins_sum_opti(data, bin_edges)
43 µs ± 195 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
33 µs ± 314 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
>>> data, bin_edges = get_data_bin_edges(data_shape=(10, 20, 30), num_edges=7)
>>> %timeit bins_sum(data, bin_edges)
>>> %timeit bins_sum_opti(data, bin_edges)
90.5 µs ± 583 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
48.1 µs ± 134 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文