## DEV Community

Super Kai (Kazuya Ito)

Posted on • Updated on

# sort(), argsort() and msort() in PyTorch

sort() can get the zero or more sorted elements and their indices of a 0D or more D tensor in ascending(Default) or descending order as shown below:

*Memos:

• `sort()` can be used with torch and a tensor.
• Only the tensor of zero or more integers, floating-point numbers or boolean values can be used so the tensor of zero or more complex numbers cannot be used except the 0D tensor of a complex number.
• The 2nd argument(`int`) with `torch` or the 1st argument(`int`) with a tensor is `dim`(Optional-Defualt:`-1`) which is a dimension.
• The 3rd argument(`bool`) with `torch` or the 2nd argument(`bool`) with a tensor is `descending`(Optional-Default:`False`). *`False` is ascending order and `True` is descending order.
• The 4th argument(`bool`) with `torch` or the 3rd argument(`bool`) with a tensor is `stable`(Optional-Default:`False`). *Memos:
• `True` can definitely sort the multiple same values while `False` cannot but even `False` basically can sort the multiple same values.
• You must use `stable=`. *When setting `stable`, `dim` and `descending` also need `dim=` and `descending=` respectively.
``````import torch

my_tensor = torch.tensor([7, 1, -5, 7, 9, -3, 0, -3])

torch.sort(my_tensor)
my_tensor.sort()
torch.sort(my_tensor, dim=0, descending=False, stable=False)
my_tensor.sort(dim=0, descending=False, stable=False)
torch.sort(my_tensor, dim=-1, descending=False, stable=False)
my_tensor.sort(dim=-1, descending=False, stable=False)
torch.sort(my_tensor, dim=0, descending=False, stable=True)
my_tensor.sort(dim=0, descending=False, stable=True)
torch.sort(my_tensor, dim=-1, descending=False, stable=True)
my_tensor.sort(dim=-1, descending=False, stable=True)
# torch.return_types.sort(
# values=tensor([-5, -3, -3, 0, 1, 7, 7, 9]),
# indices=tensor([2, 5, 7, 6, 1, 0, 3, 4]))

torch.sort(my_tensor, dim=0, descending=True, stable=False)
my_tensor.sort(0, dim=0, descending=True, stable=False)
torch.sort(my_tensor, dim=-1, descending=True, stable=False)
my_tensor.sort(0, dim=-1, descending=True, stable=False)
# torch.return_types.sort(
# values=tensor([9, 7, 7, 1, 0, -3, -3, -5]),
# indices=tensor([4, 0, 3, 1, 6, 5, 7, 2]))

my_tensor = torch.tensor([[7, 1, -5, 7], [9, -3, 0, -3]])

torch.sort(my_tensor)
my_tensor.sort()
torch.sort(my_tensor, dim=-1, descending=False, stable=False)
my_tensor.sort(dim=-1, descending=False, stable=False)
torch.sort(my_tensor, dim=1, descending=False, stable=False)
my_tensor.sort(dim=1, descending=False, stable=False)
torch.sort(my_tensor, dim=-1, descending=False, stable=True)
my_tensor.sort(dim=-1, descending=False, stable=True)
torch.sort(my_tensor, dim=1, descending=False, stable=True)
my_tensor.sort(dim=1, descending=False, stable=True)
# torch.return_types.sort(
# values=tensor([[-5, 1, 7, 7], [-3, -3, 0, 9]]),
# indices=tensor([[2, 1, 0, 3], [1, 3, 2, 0]]))

torch.sort(my_tensor, dim=0, descending=False, stable=False)
my_tensor.sort(dim=0, descending=False, stable=False)
torch.sort(my_tensor, dim=-2, descending=False, stable=False)
my_tensor.sort(dim=-2, descending=False, stable=False)
torch.sort(my_tensor, dim=0, descending=False, stable=True)
my_tensor.sort(dim=0, descending=False, stable=True)
torch.sort(my_tensor, dim=-2, descending=False, stable=True)
my_tensor.sort(dim=-2, descending=False, stable=True)
# torch.return_types.sort(
# values=tensor([[7, -3, -5, -3], [9, 1, 0, 7]]),
# indices=tensor([[0, 1, 0, 1], [1, 0, 1, 0]]))

my_tensor = torch.tensor([[7., True, -5., 7.], [9., -3., False, -3.]])

torch.sort(my_tensor)
my_tensor.sort()
torch.sort(my_tensor, dim=-1, descending=False, stable=False)
my_tensor.sort(dim=-1, descending=False, stable=False)
torch.sort(my_tensor, dim=1, descending=False, stable=False)
my_tensor.sort(dim=1, descending=False, stable=False)
torch.sort(my_tensor, dim=-1, descending=False, stable=True)
my_tensor.sort(dim=-1, descending=False, stable=True)
torch.sort(my_tensor, dim=1, descending=False, stable=True)
my_tensor.sort(dim=1, descending=False, stable=True)
# torch.return_types.sort(
# values=tensor([[-5., 1., 7., 7.], [-3., -3., 0., 9.]]),
# indices=tensor([[2, 1, 0, 3], [1, 3, 2, 0]]))

my_tensor = torch.tensor(7+5j)

torch.sort(my_tensor)
my_tensor.sort()
torch.sort(my_tensor, dim=0, descending=False, stable=False)
my_tensor.sort(dim=0, descending=False, stable=False)
torch.sort(my_tensor, dim=-1, descending=False, stable=False)
my_tensor.sort(dim=-1, descending=False, stable=False)
torch.sort(my_tensor, dim=0, descending=False, stable=True)
my_tensor.sort(dim=0, descending=False, stable=True)
torch.sort(my_tensor, dim=-1, descending=False, stable=True)
my_tensor.sort(dim=-1, descending=False, stable=True)
torch.sort(my_tensor, dim=0, descending=True, stable=False)
my_tensor.sort(dim=0, descending=True, stable=False)
torch.sort(my_tensor, dim=-1, descending=True, stable=False)
my_tensor.sort(dim=-1, descending=True, stable=False)
torch.sort(my_tensor, dim=0, descending=True, stable=True)
my_tensor.sort(dim=0, descending=True, stable=True)
torch.sort(my_tensor, dim=-1, descending=True, stable=True)
my_tensor.sort(dim=-1, descending=True, stable=True)
# torch.return_types.sort(
# values=tensor(7.+5.j),
# indices=tensor(0))
``````

argsort() can get the zero or more sorted elements' indices of a 0D or more D tensor in ascending(Default) or descending order as shown below:

*Memos:

• `argsort()` can be used with `torch` and a tensor.
• Only the tensor of zero or more integers, floating-point numbers or boolean values can be used so the tensor of zero or more complex numbers cannot be used except the 0D tensor of a complex number.
• The 2nd argument(`int`) with `torch` or the 1st argument(`int`) with a tensor is `dim`(Optional-Defualt:`-1`) which is a dimension.
• The 3rd argument(`bool`) with `torch` or the 2nd argument(`bool`) with a tensor is `descending`(Optional-Default:`False`). *`False` is ascending order and `True` is descending order.
• The 4th argument(`bool`) with `torch` or the 3rd argument(`bool`) with a tensor is `stable`(Optional-Default:`False`). *Memos:
• `True` can definitely sort the multiple same values while `False` cannot but even `False` basically can sort the multiple same values.
• You must use `stable=`. *When setting `stable`, `dim` and `descending` also need `dim=` and `descending=` respectively.
``````import torch

my_tensor = torch.tensor([7, 1, -5, 7, 9, -3, 0, -3])

torch.argsort(my_tensor)
my_tensor.argsort()
torch.argsort(my_tensor, dim=0, descending=False, stable=False)
my_tensor.argsort(dim=0, descending=False, stable=False)
torch.argsort(my_tensor, dim=-1, descending=False, stable=False)
my_tensor.argsort(dim=-1, descending=False, stable=False)
torch.argsort(my_tensor, dim=0, descending=False, stable=True)
my_tensor.argsort(dim=0, descending=False, stable=True)
torch.argsort(my_tensor, dim=-1, descending=False, stable=True)
my_tensor.argsort(dim=-1, descending=False, stable=True)
# tensor([2, 5, 7, 6, 1, 0, 3, 4])

torch.argsort(my_tensor, dim=0, descending=True, stable=False)
my_tensor.argsort(0, dim=0, descending=True, stable=False)
torch.argsort(my_tensor, dim=-1, descending=True, stable=False)
my_tensor.argsort(0, dim=-1, descending=True, stable=False)
# tensor([4, 0, 3, 1, 6, 5, 7, 2])

my_tensor = torch.tensor([[7, 1, -5, 7], [9, -3, 0, -3]])

torch.argsort(my_tensor)
my_tensor.argsort()
torch.argsort(my_tensor, dim=-1, descending=False, stable=False)
my_tensor.argsort(dim=-1, descending=False, stable=False)
torch.argsort(my_tensor, dim=1, descending=False, stable=False)
my_tensor.argsort(dim=1, descending=False, stable=False)
torch.argsort(my_tensor, dim=-1, descending=False, stable=True)
my_tensor.argsort(dim=-1, descending=False, stable=True)
torch.argsort(my_tensor, dim=1, descending=False, stable=True)
my_tensor.argsort(dim=1, descending=False, stable=True)
# tensor([[2, 1, 0, 3], [1, 3, 2, 0]])

torch.argsort(my_tensor, dim=0, descending=False, stable=False)
my_tensor.argsort(dim=0, descending=False, stable=False)
torch.argsort(my_tensor, dim=-2, descending=False, stable=False)
my_tensor.argsort(dim=-2, descending=False, stable=False)
torch.argsort(my_tensor, dim=0, descending=False, stable=True)
my_tensor.argsort(dim=0, descending=False, stable=True)
torch.argsort(my_tensor, dim=-2, descending=False, stable=True)
my_tensor.argsort(dim=-2, descending=False, stable=True)
# tensor([[0, 1, 0, 1], [1, 0, 1, 0]])

my_tensor = torch.tensor([[7., True, -5., 7.], [9., -3., False, -3.]])

torch.argsort(my_tensor)
my_tensor.argsort()
torch.argsort(my_tensor, dim=-1, descending=False, stable=False)
my_tensor.argsort(dim=-1, descending=False, stable=False)
torch.argsort(my_tensor, dim=1, descending=False, stable=False)
my_tensor.argsort(dim=1, descending=False, stable=False)
torch.argsort(my_tensor, dim=-1, descending=False, stable=True)
my_tensor.argsort(dim=-1, descending=False, stable=True)
torch.argsort(my_tensor, dim=1, descending=False, stable=True)
my_tensor.argsort(dim=1, descending=False, stable=True)
# tensor([[2, 1, 0, 3], [1, 3, 2, 0]])

my_tensor = torch.tensor(7+5j)

torch.argsort(my_tensor)
my_tensor.argsort()
torch.argsort(my_tensor, dim=0, descending=False, stable=False)
my_tensor.argsort(dim=0, descending=False, stable=False)
torch.argsort(my_tensor, dim=-1, descending=False, stable=False)
my_tensor.argsort(dim=-1, descending=False, stable=False)
torch.argsort(my_tensor, dim=0, descending=False, stable=True)
my_tensor.argsort(dim=0, descending=False, stable=True)
torch.argsort(my_tensor, dim=-1, descending=False, stable=True)
my_tensor.argsort(dim=-1, descending=False, stable=True)
torch.argsort(my_tensor, dim=0, descending=True, stable=False)
my_tensor.argsort(dim=0, descending=True, stable=False)
torch.argsort(my_tensor, dim=-1, descending=True, stable=False)
my_tensor.argsort(dim=-1, descending=True, stable=False)
torch.argsort(my_tensor, dim=0, descending=True, stable=True)
my_tensor.argsort(dim=0, descending=True, stable=True)
torch.argsort(my_tensor, dim=-1, descending=True, stable=True)
my_tensor.argsort(dim=-1, descending=True, stable=True)
# tensor(0)
``````

msort() can get the zero or more sorted elements of a 0D or more D tensor in ascending as shown below:

*Memos:

• `msort()` can be used with `torch` and a tensor.
• Only the tensor of zero or more integers, floating-point numbers or boolean values can be used so the tensor of zero or more complex numbers cannot be used except the 0D tensor of a complex number.
• `msort()` doesn't have `dim`, `descending` and `stable` argument.
• `msort()` is equivalent to `sort()[0]` with `dim=0`.
``````import torch

my_tensor = torch.tensor([7, 1, -5, 7, 9, -3, 0, -3])

torch.msort(my_tensor)
my_tensor.msort()
# tensor([-5, -3, -3, 0, 1, 7, 7, 9])

my_tensor = torch.tensor([[7, 1, -5, 7], [9, -3, 0, -3]])

torch.msort(my_tensor)
my_tensor.msort()
# tensor([[7, -3, -5, -3], [9, 1, 0, 7]])

my_tensor = torch.tensor([[7., True, -5., 7.], [9., -3., False, -3.]])

torch.msort(my_tensor)
my_tensor.msort()
# tensor([[7., -3., -5., -3.], [9., 1., 0., 7.]])

my_tensor = torch.tensor(7+5j)

torch.msort(my_tensor)
my_tensor.msort()
# tensor(7.+5.j)
``````