Hi, I am trying to implement total variation function for tensor or in more accurate, multichannel images. I found that for above Total Variation (in picture), there is source code like this:
def compute_total_variation_loss(img, weight): tv_h = ((img[:,:,1:,:] - img[:,:,:-1,:]).pow(2)).sum() tv_w = ((img[:,:,:,1:] - img[:,:,:,:-1]).pow(2)).sum() return weight * (tv_h + tv_w)
Since, I am very beginner in python I didn’t understood how the indices are referred to i and j in image. I also want to add total variation for c (besides i and j) but I don’t know which index refers to c.
Or to be more concise, how to write following equation in python: enter image description here
This function assumes batched images. So
img is a 4 dimensional tensor of dimensions
(B, C, H, W) (
B is the number of images in the batch,
C the number of color channels,
H the height and
W the width).
img[0, 1, 2, 3] is the pixel
(2, 3) of the second color (green in RGB) in the first image.
In Python (and Numpy and PyTorch), a slice of elements can be selected with the notation
i:j, meaning that the elements
i, i + 1, i + 2, ..., j - 1 are selected. In your example,
: means all elements,
1: means all elements but the first and
:-1 means all elements but the last (negative indices retrieves the elements backward). Please refer to tutorials on “slicing in NumPy”.
img[:,:,1:,:] - img[:,:,:-1,:] is equivalent to the (batch of) images minus themselves shifted by one pixel vertically, or, in your notation
X(i + 1, j, k) - X(i, j, k). Then the tensor is squared (
.pow(2)) and summed (
.sum()). Note that the sum is also over the batch in this case, so you receive the total variation of the batch, not of each images.