Home > Software design >  What does [1,2] means in .mean([1,2]) for tensor?
What does [1,2] means in .mean([1,2]) for tensor?

Time:01-31

I have a tensor with shape torch.Size([3, 224, 225]). when I do tensor.mean([1,2]) I get tensor([0.6893, 0.5840, 0.4741]). What does [1,2] mean here?

CodePudding user response:

Operations that aggregate along dimensions like min,max,mean,sum, etc. specify the dimension along which to aggregate. It is common to use these operations across every dimension (i.e. get the mean for the entire tensor) or a single dimension (i.e. torch.mean(dim = 2) or torch.mean(2) returns the mean of the 225 elements for each of 3 x 224 vectors.

Pytorch also allows these operations across a set of multiple dimensions, such as in your case. This means to take the mean of the 224 x 224 elements for each of the indices along the 0th (non-aggregated dimension). Likewise, if your original tensor shape was a.shape = torch.Size([3,224,10,225]), a.mean([1,3]) would return a tensor of shape [3,10].

CodePudding user response:

The shape of your tensor is 3 across dimension 0, 224 across dimension 1 and 225 across dimension 2.

I would say that tensor.mean([1,2]) calculates the mean across dimension 1 as well as dimension 2. Thats why you are getting 3 values. Each plane spanned by dimension 1 and 2 of size 224x225 is reduced to a single value / scalar. Since there are 3 planes that are spanned by dimension 1 and 2 of size 224x225 you get 3 values back. Each value represents the mean of a whole plane with 224x225 values.

  • Related