DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on

where() in PyTorch

Buy Me a Coffee

*Memos:

where() can get the 0D or more D tensor of zero or more elements from two of the 0D or more D tensors of the zero or more elements selected either from input or other tensor, depending on condition as shown below:

*Memos:

  • where() can be used with torch or a tensor.
  • The 1st argument with torch or a tensor is condition(Required-Type:tensor of bool).
  • The 2nd argument(input) with torch or using a tensor(Required-Type:tensor or scalar of int, float, complex or bool): *Memos:
    • torch must use input with a scalar without condition=, input= and other=.
    • A tensor cannot use input with a scalar.
  • The 3rd argument with torch or the 2nd argument with a tensor is other(Required-Type:tensor or scalar of int, float, complex or bool).
  • There is out argument with torch(Optional-Type:tensor): *Memos:
    • out= must be used.
    • My post explains out argument.
  • If condition is True, the element of input or a tensor is selected otherwise the element of other is selected.
import torch

tensor1 = torch.tensor([[5, 0, 4],
                        [0, 3, 1]])
tensor2 = torch.tensor([60, 70, 80])

torch.where(condition=tensor1 > 2, input=tensor1, other=tensor2)
tensor1.where(condition=tensor1 > 2, other=tensor2)
# tensor([[5, 70, 4],
#         [60, 3, 80]])

torch.where(condition=tensor1 > 2, input=tensor2, other=tensor1)
# tensor([[60, 0, 80],
#         [0, 70, 1]])

torch.where(tensor1 > 2, 10, tensor2)
# tensor([[10, 70, 10],
#         [60, 10, 80]])

torch.where(condition=tensor1 > 2, input=tensor1, other=10)
# tensor([[5, 10, 4],
#         [10, 3, 10]])

torch.where(tensor1 > 2, 10, 20)
# tensor([[10, 20, 10],
#         [20, 10, 20]])

tensor1 = torch.tensor([[5., 0., 4.],
                        [0., 3., 1.]])
tensor2 = torch.tensor([60., 70., 80.])
tensor3 = torch.tensor(True)
torch.where(condition=tensor3, input=tensor1, other=tensor2)
# tensor([[5., 0., 4.],
#         [0., 3., 1.]])

torch.where(tensor3, 5., other=tensor2)
# tensor([5., 5., 5.])

torch.where(condition=tensor3, input=tensor1, other=60.)
# tensor([[5., 0., 4.],
#         [0., 3., 1.]])

torch.where(tensor3, 5., other=60.)
# tensor(5.)

tensor1 = torch.tensor([[5.+0.j, 0.+0.j, 4.+0.j],
                        [0.+0.j, 3.+0.j, 1.+0.j]])
tensor2 = torch.tensor([60.+0.j, 70.+0.j, 80.+0.j])
tensor3 = torch.tensor(False)
torch.where(condition=tensor3, input=tensor1, other=tensor2)
# tensor([[60.+0.j, 70.+0.j, 80.+0.j],
#         [60.+0.j, 70.+0.j, 80.+0.j]])

torch.where(tensor3, 5.+0.j, other=tensor2)
# tensor([60.+0.j, 70.+0.j, 80.+0.j])

torch.where(condition=tensor3, input=tensor1, other=60.+0.j)
# tensor([[60.+0.j, 60.+0.j, 60.+0.j],
#         [60.+0.j, 60.+0.j, 60.+0.j]])

torch.where(tensor3, 5.+0.j, other=60.+0.j)
# tensor(60.+0.j)

tensor1 = torch.tensor([[True, False, True],
                        [False, True, False]])
tensor2 = torch.tensor([False, True, False])
tensor3 = torch.tensor(True)
torch.where(condition=tensor3, input=tensor1, other=tensor2)
# tensor([[True, False, True],
#         [False, True, False]])

torch.where(tensor3, True, other=tensor2)
# tensor([True, True, True])

torch.where(condition=tensor3, input=tensor1, other=False)
# tensor([[True, False, True],
#         [False, True, False]])

torch.where(tensor3, True, other=False)
# tensor(True)

tensor1 = torch.tensor([[[5, 0, 4], [0, 3, 1]],
                          [[0, 7, 0], [0, 6, 8]]])
tensor2 = torch.tensor([60, 70, 80])

torch.where(condition=tensor1 > 2, input=tensor1, other=tensor2)
# tensor([[[5, 70, 4],
#          [60, 3, 80]],
#         [[60, 7, 80],
#          [60, 6, 8]]])

torch.where(condition=tensor1 > 2, input=tensor2, other=tensor1)
# tensor([[[60, 0, 80],
#          [0, 70, 1]],
#         [[0, 70, 0],
#          [0, 70, 80]]])

torch.where(tensor1 > 2, 10, tensor2)
# tensor([[[10, 70, 10],
#          [60, 10, 80]],
#         [[60, 10, 80],
#          [60, 10, 10]]])

torch.where(condition=tensor1 > 2, input=tensor1, other=10)
# tensor([[[5, 10, 4],
#          [10, 3, 10]],
#         [[10, 7, 10],
#          [10, 6, 8]]])

torch.where(tensor1 > 2, 10, 20)
# tensor([[[10, 20, 10],
#          [20, 10, 20]],
#         [[20, 10, 20],
#          [20, 10, 10]]])
Enter fullscreen mode Exit fullscreen mode

Top comments (0)