Home > Software design >  How can I apply a linear transformation on sparse matrix in PyTorch?
How can I apply a linear transformation on sparse matrix in PyTorch?

Time:08-28

In PyTorch, we have nn.linear that applies a linear transformation to the incoming data:

y = WA b

In this formula, W and b are our learnable parameters and A is my input data matrix. The matrix 'A' for my case is too large for RAM to complete loading, so I use it sparsely. Is it possible to perform such an operation on sparse matrices using PyTorch?

CodePudding user response:

This is possible with PyTorch using sparse matrix multiply. In your case, I think you want something like:

>> i = [[0, 1, 1],
     [2, 0, 2]]
>> v =  [3, 4, 5]
>> A = torch.sparse_coo_tensor(i, v, (2, 3))
>> A.to_dense()
tensor([[0, 0, 3],
    [4, 0, 5]])
# compute W@A by computing ((A.T)@(W.T)).T because...
# at time of writing, the sparse matrix must be first in the matmul
>> (A.t() @ W.t()).t()
  • Related