New to pytorch and tensors in general, I could use some guidance :) I'll do my best to write a correct question, but I may use terms incorrectly here and there. Feel free to correct all of this :)
Say I have a tensor of shape (n, 3, 3). Essentially, n matrices of 3x3. Each of these matrices contains either 0 or 1 for each cell.
What's the best (fastest, easiest?) way to do a bitwise OR for all of these matrices?
For example, if I have 3 matrices:
0 0 1
0 0 0
1 0 0
--
1 0 0
0 0 0
1 0 1
--
0 1 1
0 1 0
1 0 1
I want the final result to be
1 1 1
0 1 0
1 0 1
CodePudding user response:
Add all the tensors across the first dimension and check if the sum is above 0:
import torch
tensor = torch.tensor([[[0, 0, 1],
[0, 0, 0],
[1, 0, 0]],
[[1, 0, 0],
[0, 0, 0],
[1, 0, 1]],
[[0, 1, 1],
[0, 1, 0],
[1, 0, 1]]])
tensor2 = torch.sum(tensor, axis = 0) > 0
tensor2 = tensor2.to(torch.uint8)
CodePudding user response:
The easiest and fastest way to perform a bitwise OR operation on a tensor in PyTorch is to use the torch.bitwise_or() function. This function takes in two tensors as input and performs a bitwise OR operation element-wise. To apply the operation to all the matrices in your tensor, you can use a for loop to iterate through the first dimension of the tensor and use torch.bitwise_or() to perform the operation on each matrix individually.
import torch
# Create a tensor of shape (n, 3, 3)
n = 10
tensor = torch.randint(0, 2, (n, 3, 3))
# Initialize an empty tensor to store the result
result = torch.zeros((3, 3), dtype=torch.uint8)
# Iterate through the first dimension of the tensor
for i in range(n):
# Perform a bitwise OR operation on each matrix
result = torch.bitwise_or(result, tensor[i])
# Print the result
print(result)
Alternatively, you can use the reduce() function from the torch library as well, it will be more efficent
result = torch.reduce(tensor, 0, lambda x,y: torch.bitwise_or(x, y))
Both above methods will give you a single (3,3) matrix as the result of OR-ing all the submatrices in the original tensor.