DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on

gather() in PyTorch

Buy Me a Coffee

*My post explains take() and take_along_dim().

gather() can get the 0D or more D tensor of zero or more elements using the 0D or more D tensor of zero or more indices from the 0D or more D tensor of zero or more elements as shown below:

*Memos:

  • gather() can be called with torch or a tensor.
  • The 1st argument(input) with torch or using a tensor(Required-Type:tensor of int, float, complex or bool).
  • The 2nd argument with torch or the 3rd argument with a tensor is dim(Required-Type:int):
  • The 3rd argument with torch or the 2nd argument with a tensor is indices(Required-Type:tensor of int).
  • There is out argument with torch(Optional-Type:tensor): *Memos:
    • out= must be used.
    • My post explains out argument.
  • Basically, an input and indices tensor must be the same D and the returned tensor is the D but the combination of a 0D(input) and 1D(indices) tensor or a 1D(input) and 0D(indices) tensor is possible, then the returned tensor is the same D as indices.
import torch

my_tensor = torch.tensor([[10, 11, 12, 13],
                          [14, 15, 16, 17],
                          [18, 19, 20, 21]])
torch.gather(input=my_tensor, dim=0,
             index=torch.tensor([[0, 0, 0, 0], 
                                 [1, 1, 1, 1],
                                 [2, 2, 2, 2]]))
my_tensor.gather(dim=0, index=torch.tensor([[0, 0, 0, 0], 
                                            [1, 1, 1, 1],
                                            [2, 2, 2, 2]]))
torch.gather(input=my_tensor, dim=-2,
             index=torch.tensor([[0, 0, 0, 0], 
                                 [1, 1, 1, 1],
                                 [2, 2, 2, 2]]))
# tensor([[10, 11, 12, 13],
#         [14, 15, 16, 17],
#         [18, 19, 20, 21]])

torch.gather(input=my_tensor, dim=0,
             index=torch.tensor([[0, 2, 1, 0], 
                                 [1, 0, 2, 1],
                                 [2, 1, 0, 2]]))
# tensor([[10, 19, 16, 13],
#         [14, 11, 20, 17],
#         [18, 15, 12, 21]])

torch.gather(input=my_tensor, dim=0,
             index=torch.tensor([[0, 2, 1, 0], 
                                 [1, 0, 2, 1]]))
# tensor([[10, 19, 16, 13],
#         [14, 11, 20, 17]])

torch.gather(input=my_tensor, dim=0,
             index=torch.tensor([[0, 2], 
                                 [1, 0],
                                 [2, 1],
                                 [0, 2],
                                 [1, 0]]))
# tensor([[10, 19],
#         [14, 11],
#         [18, 15],
#         [10, 19],
#         [14, 11]])

torch.gather(input=my_tensor, dim=1,
             index=torch.tensor([[0, 1, 2, 3],
                                 [0, 1, 2, 3],
                                 [0, 1, 2, 3]]))
torch.gather(input=my_tensor, dim=-1,
             index=torch.tensor([[0, 1, 2, 3],
                                 [0, 1, 2, 3],
                                 [0, 1, 2, 3]]))
# tensor([[10, 11, 12, 13],
#         [14, 15, 16, 17],
#         [18, 19, 20, 21]])

torch.gather(input=my_tensor, dim=1,
             index=torch.tensor([[0, 1, 2, 3],
                                 [3, 0, 1, 2],
                                 [2, 3, 0, 1]]))
# tensor([[10, 11, 12, 13],
#         [17, 14, 15, 16],
#         [20, 21, 18, 19]])

torch.gather(input=my_tensor, dim=1,
             index=torch.tensor([[0, 1],
                                 [3, 0],
                                 [2, 3]]))
# tensor([[10, 11],
#         [17, 14],
#         [20, 21]])

torch.gather(input=my_tensor, dim=1,
             index=torch.tensor([[0, 1, 2, 3, 0, 1],
                                 [3, 0, 1, 2, 3, 0]]))
# tensor([[10, 11, 12, 13, 10, 11],
#         [17, 14, 15, 16, 17, 14]])

my_tensor = torch.tensor([[10., 11., 12., 13.],
                          [14., 15., 16., 17.],
                          [18., 19., 20., 21.]])
torch.gather(input=my_tensor, dim=0,
             index=torch.tensor([[0, 0, 0, 0], 
                                 [1, 1, 1, 1],
                                 [2, 2, 2, 2]]))
# tensor([[10., 11., 12., 13.],
#         [14., 15., 16., 17.],
#         [18., 19., 20., 21.]])

my_tensor = torch.tensor([[10.+0.j, 11.+0.j, 12.+0.j, 13.+0.j],
                          [14.+0.j, 15.+0.j, 16.+0.j, 17.+0.j],
                          [18.+0.j, 19.+0.j, 20.+0.j, 21.+0.j]])
torch.gather(input=my_tensor, dim=0,
             index=torch.tensor([[0, 0, 0, 0], 
                                 [1, 1, 1, 1],
                                 [2, 2, 2, 2]]))
# tensor([[10.+0.j, 11.+0.j, 12.+0.j, 13.+0.j],
#         [14.+0.j, 15.+0.j, 16.+0.j, 17.+0.j],
#         [18.+0.j, 19.+0.j, 20.+0.j, 21.+0.j]])

my_tensor = torch.tensor([[True, False, True, False],
                          [False, True, False, True],
                          [True, False, True, False]])
torch.gather(input=my_tensor, dim=0,
             index=torch.tensor([[0, 0, 0, 0], 
                                 [1, 1, 1, 1],
                                 [2, 2, 2, 2]]))
# tensor([[True, False, True, False],
#         [False, True, False, True],
#         [True, False, True, False]])

my_tensor = torch.tensor([10, 11, 12, 13])

torch.gather(input=my_tensor, dim=0, index=torch.tensor(2))
# tensor(12)

my_tensor = torch.tensor(2)

torch.gather(input=my_tensor, dim=0, index=torch.tensor([0, 0, 0, 0]))
# tensor([2, 2, 2, 2])
Enter fullscreen mode Exit fullscreen mode

Top comments (0)