# Total Variation Regularization for Tensors in Python

Formula

Hi, I am trying to implement total variation function for tensor or in more accurate, multichannel images. I found that for above Total Variation (in picture), there is source code like this:

```def compute_total_variation_loss(img, weight):
tv_h = ((img[:,:,1:,:] - img[:,:,:-1,:]).pow(2)).sum()
tv_w = ((img[:,:,:,1:] - img[:,:,:,:-1]).pow(2)).sum()
return weight * (tv_h + tv_w)
```

Since, I am very beginner in python I didn’t understood how the indices are referred to i and j in image. I also want to add total variation for c (besides i and j) but I don’t know which index refers to c.

Or to be more concise, how to write following equation in python: enter image description here

This function assumes batched images. So `img` is a 4 dimensional tensor of dimensions `(B, C, H, W)` (`B` is the number of images in the batch, `C` the number of color channels, `H` the height and `W` the width).
So, `img[0, 1, 2, 3]` is the pixel `(2, 3)` of the second color (green in RGB) in the first image.
In Python (and Numpy and PyTorch), a slice of elements can be selected with the notation `i:j`, meaning that the elements `i, i + 1, i + 2, ..., j - 1` are selected. In your example, `:` means all elements, `1:` means all elements but the first and `:-1` means all elements but the last (negative indices retrieves the elements backward). Please refer to tutorials on “slicing in NumPy”.
So `img[:,:,1:,:] - img[:,:,:-1,:]` is equivalent to the (batch of) images minus themselves shifted by one pixel vertically, or, in your notation `X(i + 1, j, k) - X(i, j, k)`. Then the tensor is squared (`.pow(2)`) and summed (`.sum()`). Note that the sum is also over the batch in this case, so you receive the total variation of the batch, not of each images.