I am looking to make this calculation without using any for loops (vectorized) but cant really seem to find a good solution. Maybe someone can help?
edge_in = torch.ones(len(edge_embeds), len(edge_embeds[0]), len(edge_embeds[0][0]) 2*len(nodes_a_embeds[0]))
for i in range(0, len(nodes_a_embeds)): # A
for u in range(0, len(nodes_b_embeds)): # B
edge_in[i][u] = torch.cat([nodes_a_embeds[i], nodes_b_embeds[u], edge_embeds[i][u]], dim=0)
# OUT: edge_in: torch.Tensor with shape (|A|, |B|, 2*node_dim 2*edge_dim)
# IN: edge_embeds: torch.Tensor with shape (|A|, |B|, 2 x edge_dim)
# IN: nodes_a_embeds: torch.Tensor with shape (|A|, node_dim)
# IN: nodes_b_embeds: torch.Tensor with shape (|B|, node_dim)
CodePudding user response:
You can expand nodes_a_embed
and nodes_b_embeds
to the same shape as edge_embeds
and concatenate them directly:
nodes_a_embed = nodes_a_embeds[:, None].expand(-1, n_B, -1)
: [n_A, node_dim] => [n_A, n_B, node_dim]nodes_b_embed = nodes_b_embeds[None].expand(n_A, -1, -1)
: [n_B, node_dim] => [n_A, n_B, node_dim]
Verification:
import torch
n_A = 100
n_B = 200
node_dim = 32
edge_dim = 32
edge_in = torch.randn(n_A, n_B, 2*node_dim 2*edge_dim)
edge_embeds = torch.randn(n_A, n_B, 2*edge_dim)
nodes_a_embeds = torch.randn(n_A, node_dim)
nodes_b_embeds = torch.randn(n_B, node_dim)
edge_in = torch.ones(len(edge_embeds), len(edge_embeds[0]), len(edge_embeds[0][0]) 2*len(nodes_a_embeds[0]))
for i in range(0, len(nodes_a_embeds)): # A
for u in range(0, len(nodes_b_embeds)): # B
edge_in[i][u] = torch.cat([nodes_a_embeds[i], nodes_b_embeds[u], edge_embeds[i][u]], dim=0)
# vectorized version
edge_in_vectorized = torch.cat([
nodes_a_embeds[:, None].expand(-1, n_B, -1),
nodes_b_embeds[None].expand(n_A, -1, -1),
edge_embeds], dim=-1)
print((edge_in_vectorized == edge_in).all()) # True