Hi, I am trying to implement total variation function for tensor or in more accurate, multichannel images. I found that for above Total Variation (in picture), there is source code like this:
def compute_total_variation_loss(img, weight):
tv_h = ((img[:,:,1:,:] - img[:,:,:-1,:]).pow(2)).sum()
tv_w = ((img[:,:,:,1:] - img[:,:,:,:-1]).pow(2)).sum()
return weight * (tv_h tv_w)
Since, I am very beginner in python I didn't understood how the indices are referred to i and j in image. I also want to add total variation for c (besides i and j) but I don't know which index refers to c.
Or to be more concise, how to write following equation in python: enter image description here
CodePudding user response:
This function assumes batched images. So img
is a 4 dimensional tensor of dimensions (B, C, H, W)
(B
is the number of images in the batch, C
the number of color channels, H
the height and W
the width).
So, img[0, 1, 2, 3]
is the pixel (2, 3)
of the second color (green in RGB) in the first image.
In Python (and Numpy and PyTorch), a slice of elements can be selected with the notation i:j
, meaning that the elements i, i 1, i 2, ..., j - 1
are selected. In your example, :
means all elements, 1:
means all elements but the first and :-1
means all elements but the last (negative indices retrieves the elements backward). Please refer to tutorials on "slicing in NumPy".
So img[:,:,1:,:] - img[:,:,:-1,:]
is equivalent to the (batch of) images minus themselves shifted by one pixel vertically, or, in your notation X(i 1, j, k) - X(i, j, k)
. Then the tensor is squared (.pow(2)
) and summed (.sum()
). Note that the sum is also over the batch in this case, so you receive the total variation of the batch, not of each images.