Home > Enterprise >  What is the most efficient way to make a method that is able to process single and multi dimensional
What is the most efficient way to make a method that is able to process single and multi dimensional

Time:09-10

I was using pytorch and realized that for a linear layer you could pass not only 1d tensors but multidmensional tensors as long as the last dimensions matched. Multi dimensional inputs in pytorch Linear method?

I tried looping over each item, but is that what pytorch does? I'm having trouble thinking how you would program looping with more dimensions, maybe recursion but that seems messy. What is the most efficient way to implement this?

CodePudding user response:

I tried looping over each item, but is that what pytorch does?

The short answer is yes, loops are used, but it's more complicated than you probably think. If input is a 2D tensor (a matrix), then the output of a linear operation is computed as input @ weight.T bias using an external BLAS library's GEMM operation. Otherwise it uses torch.matmul(input, weight.T) bias which using broadcast semantics to compute a batch-ed version of the operation. Broadcasting is a semantic, not an implementation, so how the broadcasting is performed is going to be backend-dependent. Ultimately some form of looping combined with parallel processing will be used for most of these implementation.

To go a little deeper, lets take a look at the PyTorch implementation of the linear layer. This quickly leads down some rabbit holes since PyTorch uses different backend libraries for performing linear algebra operations efficiently on the hardware available (libraries like oneAPI, Intel MKL, or MAGMA) but perhaps understanding some of the details can help.

Starting at the C entrypoint to nn.functional.linear:

Tensor linear(const Tensor& input, const Tensor& weight, const Tensor& bias) {
  if (input.is_mkldnn()) {
    return at::mkldnn_linear(input, weight, bias);
  }

  if (input.dim() == 2 && bias.defined()) {
    // Fused op is marginally faster.
    return at::addmm(bias, input, weight.t());
  }
  auto output = at::matmul(input, weight.t());
  if (bias.defined()) {
    output.add_(bias);
  }
  return output;
}

There are three cases here.

  1. input.is_mkldnn(). This condition occurs if the input tensor is in the MKL-DNN format (Tensor.to_mkldnn) and will make PyTorch use the at::mkldnn_linear function, which in turn makes calls to ideep, which in turn makes calls to the oneDNN library (previous known as Intel MKL-DNN), which ultimately selects a specific general matrix-matrix multiplication (GEMM) routine dependent on platform and data types. The simplest implementation is the reference implementation, and from that we can see that they use a parallel-for loop (note the anonymous function they use uses a quadruple nested for-loop). In practice the reference implementation probably isn't used, instead, you would probably be calling the x86 optimized version compiled with the Xbyak JIT assembler to produce highly optimized code. I'm not going to pretend to follow all the details of the optimized code, but efficient generalized matrix-matrix multiply (GEMM) is a heavily studied topic that I only have a passing knowledge of.

  2. input.dim() == 2 && bias.defined(). This condition means that input is a 2D tensor (shape [B,M]) and bias is defined. In this case pytorch uses the addmm function. This efficiently computes the output as input @ weight.T bias where @ is matrix multiplication. There are multiple implementations of addmm registered in PyTorch depending on what types of tensors are being used. The dense-CPU specific version is here which eventually makes calls to an external BLAS library's GEMM subroutine. The backend used is likely Intel MKL but you can check using print(torch.__config__.parallel_info()). Whichever BLAS implementation is being used, its certainly highly optimized implementation of matrix multiplication similar to the oneDNN implementation, probably using multi-threading and optimized compilation.

  3. If neither of the previous two conditions are met then PyTorch uses the torch.matmul function, which performs a broadcasted version of input @ weight.T where input is shape [..., M]. The result of this operation is a tensor of shape [..., N]. Similar to addmm, there are multiple implementations of this function depending on the tensor types but an external library will ultimately be used that uses parallelization and optimized matrix-multiplication subroutines. After the broadcasted matrix-multiplication a broadcasted add_ operation is used to add the bias term (if bias is defined).

  • Related