*Memos:
- My post explains count_nonzero().
- My post explains argwhere() and nonzero().
where() can get the 0D or more D tensor of zero or more elements from two of the 0D or more D tensors of the zero or more elements selected either from input
or other
tensor, depending on condition
as shown below:
*Memos:
-
where()
can be used with torch or a tensor. - The 1st argument with
torch
or a tensor iscondition
(Required-Type:tensor
ofbool
). - The 2nd argument(
input
) withtorch
or using a tensor(Required-Type:tensor
orscalar
ofint
,float
,complex
orbool
): *Memos:-
torch
must useinput
with a scalar withoutcondition=
,input=
andother=
. - A tensor cannot use
input
with a scalar.
-
- The 3rd argument with
torch
or the 2nd argument with a tensor isother
(Required-Type:tensor
orscalar
ofint
,float
,complex
orbool
). - There is
out
argument withtorch
(Optional-Default:None
-Type:tensor
): *Memos:-
out=
must be used. -
My post explains
out
argument.
-
- If
condition
isTrue
, the element ofinput
or a tensor is selected otherwise the element ofother
is selected.
import torch
tensor1 = torch.tensor([[5, 0, 4],
[0, 3, 1]])
tensor2 = torch.tensor([60, 70, 80])
torch.where(condition=tensor1 > 2, input=tensor1, other=tensor2)
tensor1.where(condition=tensor1 > 2, other=tensor2)
# tensor([[5, 70, 4],
# [60, 3, 80]])
torch.where(condition=tensor1 > 2, input=tensor2, other=tensor1)
# tensor([[60, 0, 80],
# [0, 70, 1]])
torch.where(tensor1 > 2, 10, tensor2)
# tensor([[10, 70, 10],
# [60, 10, 80]])
torch.where(condition=tensor1 > 2, input=tensor1, other=10)
# tensor([[5, 10, 4],
# [10, 3, 10]])
torch.where(tensor1 > 2, 10, 20)
# tensor([[10, 20, 10],
# [20, 10, 20]])
tensor1 = torch.tensor([[5., 0., 4.],
[0., 3., 1.]])
tensor2 = torch.tensor([60., 70., 80.])
tensor3 = torch.tensor(True)
torch.where(condition=tensor3, input=tensor1, other=tensor2)
# tensor([[5., 0., 4.],
# [0., 3., 1.]])
torch.where(tensor3, 5., other=tensor2)
# tensor([5., 5., 5.])
torch.where(condition=tensor3, input=tensor1, other=60.)
# tensor([[5., 0., 4.],
# [0., 3., 1.]])
torch.where(tensor3, 5., other=60.)
# tensor(5.)
tensor1 = torch.tensor([[5.+0.j, 0.+0.j, 4.+0.j],
[0.+0.j, 3.+0.j, 1.+0.j]])
tensor2 = torch.tensor([60.+0.j, 70.+0.j, 80.+0.j])
tensor3 = torch.tensor(False)
torch.where(condition=tensor3, input=tensor1, other=tensor2)
# tensor([[60.+0.j, 70.+0.j, 80.+0.j],
# [60.+0.j, 70.+0.j, 80.+0.j]])
torch.where(tensor3, 5.+0.j, other=tensor2)
# tensor([60.+0.j, 70.+0.j, 80.+0.j])
torch.where(condition=tensor3, input=tensor1, other=60.+0.j)
# tensor([[60.+0.j, 60.+0.j, 60.+0.j],
# [60.+0.j, 60.+0.j, 60.+0.j]])
torch.where(tensor3, 5.+0.j, other=60.+0.j)
# tensor(60.+0.j)
tensor1 = torch.tensor([[True, False, True],
[False, True, False]])
tensor2 = torch.tensor([False, True, False])
tensor3 = torch.tensor(True)
torch.where(condition=tensor3, input=tensor1, other=tensor2)
# tensor([[True, False, True],
# [False, True, False]])
torch.where(tensor3, True, other=tensor2)
# tensor([True, True, True])
torch.where(condition=tensor3, input=tensor1, other=False)
# tensor([[True, False, True],
# [False, True, False]])
torch.where(tensor3, True, other=False)
# tensor(True)
tensor1 = torch.tensor([[[5, 0, 4], [0, 3, 1]],
[[0, 7, 0], [0, 6, 8]]])
tensor2 = torch.tensor([60, 70, 80])
torch.where(condition=tensor1 > 2, input=tensor1, other=tensor2)
# tensor([[[5, 70, 4],
# [60, 3, 80]],
# [[60, 7, 80],
# [60, 6, 8]]])
torch.where(condition=tensor1 > 2, input=tensor2, other=tensor1)
# tensor([[[60, 0, 80],
# [0, 70, 1]],
# [[0, 70, 0],
# [0, 70, 80]]])
torch.where(tensor1 > 2, 10, tensor2)
# tensor([[[10, 70, 10],
# [60, 10, 80]],
# [[60, 10, 80],
# [60, 10, 10]]])
torch.where(condition=tensor1 > 2, input=tensor1, other=10)
# tensor([[[5, 10, 4],
# [10, 3, 10]],
# [[10, 7, 10],
# [10, 6, 8]]])
torch.where(tensor1 > 2, 10, 20)
# tensor([[[10, 20, 10],
# [20, 10, 20]],
# [[20, 10, 20],
# [20, 10, 10]]])
Top comments (0)