DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on • Edited on

Set `keepdim` argument in PyTorch

Buy Me a Coffee

*Memos:

You can set keepdim argument as shown below:

*Memos:

sum(). *My post explains sum():

import torch

my_tensor = torch.tensor([1, 2, 3, 4])

torch.sum(input=my_tensor)
torch.sum(input=my_tensor, dim=0)
# tensor(10)

torch.sum(input=my_tensor, dim=0, keepdim=True)
# tensor([10])
Enter fullscreen mode Exit fullscreen mode

prod(). *My post explains prod():

import torch

my_tensor = torch.tensor([1, 2, 3, 4])

torch.prod(input=my_tensor)
torch.prod(input=my_tensor, dim=0)
# tensor(24)

torch.prod(input=my_tensor, dim=0, keepdim=True)
# tensor([24])
Enter fullscreen mode Exit fullscreen mode

mean(). *My post explains mean():

import torch

my_tensor = torch.tensor([5., 4., 7., 7.])

torch.mean(input=my_tensor)
torch.mean(input=my_tensor, dim=0)
# tensor(5.7500)

torch.mean(input=my_tensor, dim=0, keepdim=True)
tensor([5.7500])
Enter fullscreen mode Exit fullscreen mode

median(). *My post explains median():

import torch

my_tensor = torch.tensor([5, 4, 7, 7])

torch.median(input=my_tensor, dim=0)
# torch.return_types.median(
# values=tensor(5),
# indices=tensor(0))

torch.median(input=my_tensor, dim=0, keepdim=True)
# torch.return_types.median(
# values=tensor([5]),
# indices=tensor([0]))
Enter fullscreen mode Exit fullscreen mode

min(). *My post explains min():

import torch

my_tensor = torch.tensor([5, 4, 7, 7])

torch.min(input=my_tensor, dim=0)
# torch.return_types.min(
# values=tensor(4),
# indices=tensor(1))

torch.min(input=my_tensor, dim=0, keepdim=True)
# torch.return_types.min(
# values=tensor([4]),
# indices=tensor([1]))
Enter fullscreen mode Exit fullscreen mode

max(). *My post explains max():

import torch

my_tensor = torch.tensor([5, 4, 7, 7])

torch.max(input=my_tensor, dim=0)
# torch.return_types.max(
# values=tensor(7),
# indices=tensor(2))

torch.max(input=my_tensor, dim=0, keepdim=True)
# torch.return_types.max(
# values=tensor([7]),
# indices=tensor([2]))
Enter fullscreen mode Exit fullscreen mode

argmin(). *My post explains argmin():

import torch

my_tensor = torch.tensor([5, 4, 7, 7])

torch.argmin(input=my_tensor)
torch.argmin(input=my_tensor, dim=0)
# tensor(1)

torch.argmin(input=my_tensor, keepdim=True)
torch.argmin(input=my_tensor, dim=0, keepdim=True)
# tensor([1])
Enter fullscreen mode Exit fullscreen mode

argmax(). *My post explains argmax():

import torch

my_tensor = torch.tensor([5, 4, 7, 7])

torch.argmax(input=my_tensor)
torch.argmax(input=my_tensor, dim=0)
# tensor(2)

torch.argmax(input=my_tensor, keepdim=True)
torch.argmax(input=my_tensor, dim=0, keepdim=True)
# tensor([2])
Enter fullscreen mode Exit fullscreen mode

all(). *My post explains all():

import torch

my_tensor = torch.tensor([True, False, True, False])

torch.all(input=my_tensor)
torch.all(input=my_tensor, dim=0)
# tensor(False)

torch.all(input=my_tensor, keepdim=True)
torch.all(input=my_tensor, dim=0, keepdim=True)
# tensor([False])
Enter fullscreen mode Exit fullscreen mode

any(). *My post explains any():

import torch

my_tensor = torch.tensor([True, False, True, False])

torch.any(input=my_tensor)
torch.any(input=my_tensor, dim=0)
# tensor(True)

torch.any(input=my_tensor, keepdim=True)
torch.any(input=my_tensor, dim=0, keepdim=True)
# tensor([True])
Enter fullscreen mode Exit fullscreen mode

Top comments (0)