I have a tensor with shape torch.Size([3, 224, 225])
. when I do tensor.mean([1,2])
I get tensor([0.6893, 0.5840, 0.4741]). What does [1,2] mean here?
CodePudding user response:
Operations that aggregate along dimensions like min
,max
,mean
,sum
, etc. specify the dimension along which to aggregate. It is common to use these operations across every dimension (i.e. get the mean for the entire tensor) or a single dimension (i.e. torch.mean(dim = 2)
or torch.mean(2)
returns the mean of the 225 elements for each of 3 x 224 vectors.
Pytorch also allows these operations across a set of multiple dimensions, such as in your case. This means to take the mean of the 224 x 224 elements for each of the indices along the 0th (non-aggregated dimension). Likewise, if your original tensor shape was a.shape = torch.Size([3,224,10,225])
, a.mean([1,3]) would return a tensor of shape [3,10]
.
CodePudding user response:
The shape of your tensor is 3 across dimension 0, 224 across dimension 1 and 225 across dimension 2.
I would say that tensor.mean([1,2])
calculates the mean across dimension 1 as well as dimension 2. Thats why you are getting 3 values. Each plane spanned by dimension 1 and 2 of size 224x225 is reduced to a single value / scalar. Since there are 3 planes that are spanned by dimension 1 and 2 of size 224x225 you get 3 values back. Each value represents the mean of a whole plane with 224x225 values.