Home > database >  Vectorizing torch tensor instead of using for loop
Vectorizing torch tensor instead of using for loop

Time:07-04

I am looking to make this calculation without using any for loops (vectorized) but cant really seem to find a good solution. Maybe someone can help?

    edge_in = torch.ones(len(edge_embeds), len(edge_embeds[0]), len(edge_embeds[0][0])   2*len(nodes_a_embeds[0]))

    for i in range(0, len(nodes_a_embeds)): # A
      for u in range(0, len(nodes_b_embeds)): # B
        edge_in[i][u] = torch.cat([nodes_a_embeds[i], nodes_b_embeds[u], edge_embeds[i][u]], dim=0)

    # OUT: edge_in: torch.Tensor with shape (|A|, |B|, 2*node_dim   2*edge_dim)

    # IN: edge_embeds: torch.Tensor with shape (|A|, |B|, 2 x edge_dim) 
    # IN: nodes_a_embeds: torch.Tensor with shape (|A|, node_dim)
    # IN: nodes_b_embeds: torch.Tensor with shape (|B|, node_dim)

CodePudding user response:

You can expand nodes_a_embed and nodes_b_embeds to the same shape as edge_embeds and concatenate them directly:

  • nodes_a_embed = nodes_a_embeds[:, None].expand(-1, n_B, -1): [n_A, node_dim] => [n_A, n_B, node_dim]
  • nodes_b_embed = nodes_b_embeds[None].expand(n_A, -1, -1): [n_B, node_dim] => [n_A, n_B, node_dim]

Verification:

import torch

n_A = 100
n_B = 200
node_dim = 32
edge_dim = 32

edge_in = torch.randn(n_A, n_B, 2*node_dim   2*edge_dim)

edge_embeds = torch.randn(n_A, n_B, 2*edge_dim) 
nodes_a_embeds = torch.randn(n_A, node_dim)
nodes_b_embeds = torch.randn(n_B, node_dim)

edge_in = torch.ones(len(edge_embeds), len(edge_embeds[0]), len(edge_embeds[0][0])   2*len(nodes_a_embeds[0]))

for i in range(0, len(nodes_a_embeds)): # A
    for u in range(0, len(nodes_b_embeds)): # B
        edge_in[i][u] = torch.cat([nodes_a_embeds[i], nodes_b_embeds[u], edge_embeds[i][u]], dim=0)

# vectorized version
edge_in_vectorized = torch.cat([
                nodes_a_embeds[:, None].expand(-1, n_B, -1),
                nodes_b_embeds[None].expand(n_A, -1, -1),
                edge_embeds], dim=-1)

print((edge_in_vectorized == edge_in).all())    # True
  • Related