DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on

take() and take_along_dim() in PyTorch

Buy Me a Coffee

*My post explains gather().

take() can get the 0D or more D tensor of zero or more elements using the 0D or more D tensor of zero or more indices from the 0D or more D tensor of zero or more elements as shown below:

*Memos:

  • take() can be used with torch or a tensor.
  • The 1st argument(input) with torch or using a tensor(Required-Type:tensor of int, float, complex or bool).
  • The 2nd argument with torch or the 1st argument with a tensor is index(Required-Type:tensor of int). *It decides the size of a returned tensor.
import torch

my_tensor = torch.tensor([[9, 5, 0, 6, 2], [7, 1, 3, 4, 8]])

torch.take(input=my_tensor, index=torch.tensor(3))
my_tensor.take(index=torch.tensor(3))
torch.take(input=my_tensor, index=torch.tensor(-7))
# tensor(6)

torch.take(input=my_tensor, index=torch.tensor([3, 0, 7, 4]))
torch.take(input=my_tensor, index=torch.tensor([-7, -10, -3, -6]))
# tensor([6, 9, 3, 2])

torch.take(input=my_tensor, index=torch.tensor([[3, 0], [7, 4]]))
torch.take(input=my_tensor, index=torch.tensor([[-7, -10], [-3, -6]]))
# tensor([[6, 9], [3, 2]])

torch.take(input=my_tensor, index=torch.tensor([[[3, 0], [7, 4]],
                                                [[8, 2], [3, 5]]]))
torch.take(input=my_tensor, index=torch.tensor([[[-7, -10], [-3, -6]],
                                                [[-2, -8], [-7, -5]]]))
# tensor([[[6, 9], [3, 2]], [[4, 0], [6, 7]]])

my_tensor = torch.tensor([[9., 5., 0., 6., 2.], [7., 1., 3., 4., 8.]])

torch.take(input=my_tensor, index=torch.tensor([[[3, 0], [7, 4]],
                                                [[8, 2], [3, 5]]]))
# tensor([[[6., 9.], [3., 2.]], [[4., 0.], [6., 7.]]])

my_tensor = torch.tensor([[9.+0.j, 5.+0.j, 0.+0.j, 6.+0.j, 2.+0.j],
                          [7.+0.j, 1.+0.j, 3.+0.j, 4.+0.j, 8.+0.j]])
torch.take(input=my_tensor, index=torch.tensor([[[3, 0], [7, 4]],
                                                [[8, 2], [3, 5]]]))
# tensor([[[6.+0.j, 9.+0.j], [3.+0.j, 2.+0.j]],
#         [[4.+0.j, 0.+0.j], [6.+0.j, 7.+0.j]]])

my_tensor = torch.tensor([[True, False, True, False, True],
                          [False, True, False, True, False]])
torch.take(input=my_tensor, index=torch.tensor([[[3, 0], [7, 4]],
                                                [[8, 2], [3, 5]]]))
# tensor([[[False, True], [False, True]],
#         [[True, True], [False, False]]])
Enter fullscreen mode Exit fullscreen mode

take_along_dim() can get the 1D or more D tensor of zero or more elements using the 0D or more D tensor of zero or more indices from the 0D or more D tensor of zero or more elements as shown below:

*Memos:

  • take_along_dim() can be used with torch or a tensor.
  • The 1st argument(input) with torch or using a tensor(Required-Type:tensor of int, float, complex or bool).
  • The 2nd argument with torch or the 1st argument with a tensor is indices(Required-Type:tensor of int).
  • The 3rd argument with torch or the 2nd argument with a tensor is dim(Optional-Type:int): *Memos:
    • Not setting dim returns a 1D tensor.
    • If dim is set, both tensors must be the same D and the returned tensor is its D.
  • There is out argument with torch(Optional-Type:tensor): *Memos:
    • out= must be used.
    • My post explains out argument.
import torch

my_tensor = torch.tensor([[9, 5, 0, 6, 2], [7, 1, 3, 4, 8]])

torch.take_along_dim(input=my_tensor, indices=torch.tensor(3))
my_tensor.take_along_dim(indices=torch.tensor(3))
torch.gather(input=my_tensor, indices=torch.tensor(3))
# tensor([6])

torch.take_along_dim(input=my_tensor, indices=torch.tensor([3, 0, 7, 4]))
torch.take_along_dim(input=my_tensor, indices=torch.tensor([[3, 0], [7, 4]]))
# tensor([6, 9, 3, 2])

torch.take_along_dim(input=my_tensor,
                     indices=torch.tensor([[[3, 0], [7, 4]],
                                           [[8, 2], [3, 5]]]))
# tensor([6, 9, 3, 2, 4, 0, 6, 7])

torch.take_along_dim(input=my_tensor,
                     indices=torch.tensor([[0], [1], [0], [1]]),
                     dim=0)
torch.take_along_dim(input=my_tensor,
                     indices=torch.tensor([[0], [1], [0], [1]]),
                     dim=-2)
# tensor([[9, 5, 0, 6, 2],
#         [7, 1, 3, 4, 8],
#         [9, 5, 0, 6, 2],
#         [7, 1, 3, 4, 8]])

torch.take_along_dim(input=my_tensor,
                     indices=torch.tensor([[0, 0, 0, 0, 0],
                                           [1, 1, 1, 1, 1], 
                                           [0, 1, 0, 1, 0], 
                                           [1, 0, 1, 0, 1]]),
                     dim=0)
torch.take_along_dim(input=my_tensor,
                     indices=torch.tensor([[0, 0, 0, 0, 0],
                                           [1, 1, 1, 1, 1], 
                                           [0, 1, 0, 1, 0], 
                                           [1, 0, 1, 0, 1]]),
                     dim=-2)
# tensor([[9, 5, 0, 6, 2],
#         [7, 1, 3, 4, 8],
#         [9, 1, 0, 4, 2],
#         [7, 5, 3, 6, 8]])

torch.take_along_dim(input=my_tensor,
                     indices=torch.tensor([[3, 0, 4]]),
                     dim=1)
torch.take_along_dim(input=my_tensor,
                     indices=torch.tensor([[3, 0, 4]]),
                     dim=-1)
# tensor([[6, 9, 2], [4, 7, 8]])

torch.take_along_dim(input=my_tensor,
                     indices=torch.tensor([[3, 0, 4], [4, 1, 2]]),
                     dim=1)
torch.take_along_dim(input=my_tensor,
                     indices=torch.tensor([[3, 0, 4], [4, 1, 2]]),
                     dim=-1)
# tensor([[6, 9, 2], [8, 1, 3]])

my_tensor = torch.tensor([[9., 5., 0., 6., 2.], [7., 1., 3., 4., 8.]])

torch.take_along_dim(input=my_tensor,
                     indices=torch.tensor([[3, 0, 4], [4, 1, 2]]),
                     dim=1)
# tensor([[6., 9., 2.], [8., 1., 3.]])

my_tensor = torch.tensor([[9.+0.j, 5.+0.j, 0.+0.j, 6.+0.j, 2.+0.j],
                          [7.+0.j, 1.+0.j, 3.+0.j, 4.+0.j, 8.+0.j]])
torch.take_along_dim(input=my_tensor,
                     indices=torch.tensor([[3, 0, 4], [4, 1, 2]]),
                     dim=1)
# tensor([[6.+0.j, 9.+0.j, 2.+0.j],
#         [8.+0.j, 1.+0.j, 3.+0.j]])

my_tensor = torch.tensor([[True, False, True, False, True],
                          [False, True, False, True, False]])
torch.take_along_dim(input=my_tensor,
                     indices=torch.tensor([[3, 0, 4], [4, 1, 2]]),
                     dim=1)
# tensor([[False, True, True],
#         [False, True, False]])
Enter fullscreen mode Exit fullscreen mode

Top comments (0)