Understanding PyTorch gather function

Understanding PyTorch gather function

PyTorch is one of the main libraries when it comes to Deep Learning in Python. Tensors, the key data structure of the framework, can be manipulated in multiple ways. One of the available methods is the gather function.

The following is the signature according to its documentation:

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

The important parameters are the first three:

  1. Input: the tensor to gather from.

  2. Dim: the dimension along which the gathering occurs.

  3. Index: the tensor specifying the gathering rules.

The idea is to create a new tensor composed by elements belonging to input collected on one of its dimensions according to index values.

Let's start with a simple example. Suppose we have a tensor T and we need to create a new tensor containing its first and third elements. We can use the gather function passing T as input, 0 as dimension (since T has just 1 dimension) and a tensor with values 0 and 2 as index.

T = torch.randn(5)
index = torch.tensor([0,2])
dim = 0
out = torch.gather(T,dim,index)

Out will be a tensor of shape 2 (as index) containing the first and the third element of T. Notice that we can also gather on T by calling the gather function directly on the tensor:

out = T.gather(dim,index)

An interesting use case of the gather function is when we need to select a specific element from each row of a matrix. This occurs, for instance, in Double Q-Learning, where the output of the target network is gathered according to the output of the base network.

Suppose we have a matrix M composed as follows:

M = torch.tensor([[1,2,3], [4,7,18], [19,9,23]])

We need to extract the following indexes from each row:

indexes = torch.tensor([1,1,2])

If we call the gather function on dimension 1, corresponding to rows, we get an error.

M.gather(1,indexes) 
# XXXX Runtime error:
# Index tensor must have the same number of dimensions as input tensor
# XXXX

This happens because M and indexes tensors don't have the same number of dimensions. Hence, PyTorch doesn't know how to compare the tensors. If we adjust the dimensions of indexes, everything works as expected:

M = torch.tensor([[1,2,3], [4,7,18], [19,9,23]])
indexes = torch.tensor([1,1,2]).view(-1,1)
dimension = 1
out = M.gather(dimension ,indexes) #tensor([[ 2],[ 7],[23]])

Out is a tensor containing the elements of each row of M at the specified indexes. Notice that it has shape of [3,1], the same of indexes.

What happens if we set 0 as dimension? The gathering will occur on the columns. Notice that we also need to switch the way in which the indexes tensor is seen. Its shape will be [1,3] and not [3,1] as before.

M = torch.tensor([[1,2,3], [4,7,18], [19,9,23]])
indexes = torch.tensor([1,1,2]).view(1,-1)
dimension = 0
out = M.gather(dimension ,indexes) #tensor([[ 4,  7, 23]])

The following image roughly shows the process undergoing above.

Gathering on columns

We can get even more fancier and build more complex tensors by gathering multiple times.

M = torch.tensor([[1,2,3], [4,7,18], [19,9,23]]) 
indexes = torch.tensor([[1,1],[1,2],[1,2]])
dimension = 1
out = M.gather(dimension, indexes)

The output vector is shown in the following table. It contains elements gathered from each row of M according to the specified indexes.

2

2

7

18

9

23


Thank you for joining me in this dive into PyTorch gather function. This library offers a wide range of functionality and understanding the concepts behind some of them can foster our productivity and enhance our skills.