Home > other >  How to sum of torch tensors without looping?
How to sum of torch tensors without looping?

Time:05-04

I've got an array of bins' borders and I need to get a sum of values inside these bins. Now it looks as follows:

output = torch.zeros((16, 10)) #10 corresponds to the number of bins

for l in range(10):
   output[:,l] = data[:, bin_edges[l]:bin_edges[l 1]].sum(axis=-1)

Is it possible to avoid loops and improve the performance?

CodePudding user response:

Normally to optimize code by vectorization you would like to construct a single big tensor on which you compute the result in a single operation. But here your bins might have different lengths, so you can't construct a tensor from that.

Though, that's a usual case in time-series processing, so PyTorch has some utilities to overcome this issue, such as torch.nn.utils.rnn.pad_sequence.

Using that utility I was able to optimize the function a bit, but the difference depends on the data shape and the number and length of bins, and sometimes performance even decreases.

Please note that pad_sequence assumes that you want to make bins from the first dimension of your data, and you make bins from the last dim, so the optimization would be better if you can reorganize your data accordingly.

Code

Implementations

from itertools import pairwise
import random
import torch
from torch.nn.utils.rnn import pad_sequence


def bins_sum(x, edges):
    """ Your function (generalized a bit) """
    edges = [0, *edges, x.shape[-1]]
    bins = enumerate(pairwise(edges))
    num_bins = len(edges) - 1
    output = torch.zeros(*(x.shape[:-1]), num_bins)

    for bin_idx, (start, end) in bins:
        output[..., bin_idx] = x[..., start:end].sum(axis=-1)
    return output


def bins_sum_opti(x, edges):
    """ Trying to optimize using torch.nn.utils.rnn """
    x = x.movedim(-1, 0)
    edges = [0, *edges, x.shape[0]]
    xbins = [x[start:end] for start, end in pairwise(edges)]
    xbins_padded = pad_sequence(xbins)
    return xbins_padded.sum(dim=0).movedim(0, -1)


def get_data_bin_edges(data_shape, num_edges):
    data = torch.rand(*data_shape)
    bin_edges = sorted(random.sample(range(3, data_shape[-1] - 3), k=num_edges))
    return data, bin_edges

Results

Assert that both functions are equivalent:

data, bin_edges = get_data_bin_edges(data_shape=(10, 20), num_edges=7)

res1 = bins_sum(data, bin_edges)
res2 = bins_sum_opti(data, bin_edges)

assert torch.allclose(res1, res2)

Time for different shapes and edges:

>>> data, bin_edges = get_data_bin_edges(data_shape=(10, 20), num_edges=3)
>>> %timeit bins_sum(data, bin_edges)
>>> %timeit bins_sum_opti(data, bin_edges)
35.8 µs ± 531 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
27.6 µs ± 546 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
>>> data, bin_edges = get_data_bin_edges(data_shape=(10, 20), num_edges=7)
>>> %timeit bins_sum(data, bin_edges)
>>> %timeit bins_sum_opti(data, bin_edges)
67.4 µs ± 1.12 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
41.1 µs ± 107 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
>>> data, bin_edges = get_data_bin_edges(data_shape=(10, 20, 30), num_edges=3)
>>> %timeit bins_sum(data, bin_edges)
>>> %timeit bins_sum_opti(data, bin_edges)
43 µs ± 195 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
33 µs ± 314 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
>>> data, bin_edges = get_data_bin_edges(data_shape=(10, 20, 30), num_edges=7)
>>> %timeit bins_sum(data, bin_edges)
>>> %timeit bins_sum_opti(data, bin_edges)
90.5 µs ± 583 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
48.1 µs ± 134 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

CodePudding user response:

Lets do some exercise:

>>> import torch
>>> v = torch.rand((16, 10))
>>> v.shape
torch.Size([16, 10])
>>> v
tensor([[5.2187e-01, 9.3389e-01, 7.7927e-01, 5.3492e-01, 6.0553e-02, 6.1225e-01,
         4.5279e-01, 9.4369e-01, 7.9510e-01, 1.6277e-01],
        [4.9732e-01, 6.9812e-01, 7.6995e-01, 1.8056e-01, 7.0234e-01, 7.4067e-01,
         3.1072e-01, 1.0065e-01, 9.9970e-01, 6.4164e-01],
        [3.9077e-01, 2.8244e-01, 8.3243e-01, 5.0246e-01, 4.5566e-01, 5.0806e-01,
         6.5971e-01, 8.1989e-01, 5.1700e-01, 4.2624e-01],
        [3.3785e-01, 3.8474e-01, 4.8132e-02, 9.3515e-01, 1.3121e-01, 4.8328e-01,
         8.2264e-01, 3.3031e-01, 4.1098e-01, 2.8584e-01],
        [3.7890e-01, 5.0190e-02, 4.5153e-02, 6.7827e-02, 9.7752e-01, 5.7650e-01,
         2.5389e-01, 9.2771e-02, 9.6251e-01, 9.6351e-01],
        [8.3147e-01, 5.9653e-01, 5.6390e-01, 8.1196e-01, 8.5893e-01, 8.1458e-02,
         4.9451e-01, 7.1894e-01, 9.5709e-01, 1.4245e-01],
        [7.4089e-01, 1.9195e-01, 6.6763e-01, 7.7315e-01, 7.1233e-01, 5.4214e-01,
         1.1453e-01, 4.0559e-01, 9.7582e-01, 9.0780e-01],
        [9.9809e-02, 9.4316e-01, 5.4488e-01, 6.0931e-01, 7.9540e-01, 3.0222e-01,
         9.5903e-01, 9.6852e-01, 5.6619e-01, 2.8973e-01],
        [2.3243e-01, 9.3186e-01, 5.1603e-01, 9.0458e-02, 3.8374e-01, 1.5755e-01,
         3.9596e-01, 5.0016e-01, 3.0587e-01, 9.1070e-01],
        [8.7504e-01, 2.7716e-01, 8.2512e-01, 1.9982e-01, 4.2982e-01, 3.0412e-01,
         1.0716e-01, 8.8398e-02, 4.3348e-01, 7.3884e-01],
        [5.0195e-01, 4.5738e-01, 3.0292e-01, 6.6867e-01, 5.8694e-01, 9.7289e-01,
         2.9898e-01, 9.7225e-01, 1.2917e-01, 8.0031e-01],
        [5.2466e-01, 6.0844e-01, 2.6522e-01, 8.5786e-01, 3.6592e-01, 3.6974e-01,
         9.5623e-01, 7.7282e-02, 5.8547e-01, 8.7895e-01],
        [6.6620e-01, 1.4502e-01, 2.9290e-01, 2.6731e-01, 1.2170e-01, 6.2980e-01,
         9.3782e-01, 7.9795e-01, 8.7459e-01, 9.8554e-02],
        [7.2040e-01, 5.6500e-01, 4.6514e-01, 2.6318e-01, 3.1107e-01, 4.1578e-01,
         1.4852e-01, 4.3629e-01, 8.1342e-04, 6.8361e-01],
        [3.3129e-01, 5.7071e-02, 5.1649e-01, 2.4868e-02, 2.5514e-01, 6.2073e-02,
         6.2700e-01, 6.9716e-01, 3.7102e-01, 6.3859e-01],
        [3.1294e-01, 6.3655e-01, 9.8621e-01, 4.6491e-01, 8.2948e-01, 6.1694e-02,
         2.8140e-01, 5.6612e-01, 2.0409e-01, 8.5010e-01]])

If we do a simple sum of the tensor, it sums up all values in the matrix.

>>> v.sum()
tensor(81.0404)

If we do a sum of the 0th dimension, we get the sum of each "column":

>>> v.sum(0)
tensor([7.9638, 7.7595, 8.4214, 7.2524, 7.9778, 6.8202, 7.8209, 8.5160, 9.0889,
        9.4196])

>>> v.sum(0).shape
torch.Size([10])

If we do a sum on the 1st dimension, we get the sum of each "row":

>>> v.sum(1)
tensor([5.7971, 5.6417, 5.3947, 4.1701, 4.3688, 6.0572, 6.0318, 6.0783, 4.4247,
        4.2790, 5.6915, 5.4898, 4.8318, 4.0098, 3.5807, 5.1935])

>>> v.sum(1).shape
torch.Size([16])

Most probably the sum that you want is one of the above.


Most probably not something you expect, but now for some more fun. There are other "dimensions" that you can sum here.

You can also sum in the -1st and -2nd dimension, that yields the same values as v.sum(1).

>>> v.sum(-1)
tensor([5.7971, 5.6417, 5.3947, 4.1701, 4.3688, 6.0572, 6.0318, 6.0783, 4.4247,
        4.2790, 5.6915, 5.4898, 4.8318, 4.0098, 3.5807, 5.1935])

>>> v.sum(-2)
tensor([7.9638, 7.7595, 8.4214, 7.2524, 7.9778, 6.8202, 7.8209, 8.5160, 9.0889,
        9.4196])

Q: But why -1 and -2?

This is an undocumented feature

One thing that is not mentioned explicitly in the documentation is: you can sum across the last array-dimension by using -1 (or the second-to last dimension, with -2, etc.)

c.f. https://stackoverflow.com/a/57040945/610569


BTW, if you not only wants the simple sum of the bins but something like https://pytorch.org/docs/stable/generated/torch.cumsum.html

  • Related