Total Variation Regularization for Tensors in Python


Hi, I am trying to implement total variation function for tensor or in more accurate, multichannel images. I found that for above Total Variation (in picture), there is source code like this:

def compute_total_variation_loss(img, weight):      
    tv_h = ((img[:,:,1:,:] - img[:,:,:-1,:]).pow(2)).sum()
    tv_w = ((img[:,:,:,1:] - img[:,:,:,:-1]).pow(2)).sum()    
    return weight * (tv_h + tv_w)

Since, I am very beginner in python I didn’t understood how the indices are referred to i and j in image. I also want to add total variation for c (besides i and j) but I don’t know which index refers to c.

Or to be more concise, how to write following equation in python: enter image description here


This function assumes batched images. So img is a 4 dimensional tensor of dimensions (B, C, H, W) (B is the number of images in the batch, C the number of color channels, H the height and W the width).

So, img[0, 1, 2, 3] is the pixel (2, 3) of the second color (green in RGB) in the first image.

In Python (and Numpy and PyTorch), a slice of elements can be selected with the notation i:j, meaning that the elements i, i + 1, i + 2, ..., j - 1 are selected. In your example, : means all elements, 1: means all elements but the first and :-1 means all elements but the last (negative indices retrieves the elements backward). Please refer to tutorials on “slicing in NumPy”.

So img[:,:,1:,:] - img[:,:,:-1,:] is equivalent to the (batch of) images minus themselves shifted by one pixel vertically, or, in your notation X(i + 1, j, k) - X(i, j, k). Then the tensor is squared (.pow(2)) and summed (.sum()). Note that the sum is also over the batch in this case, so you receive the total variation of the batch, not of each images.