DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on • Edited on

unflatten in PyTorch

Buy Me a Coffee

*Memos:

unflatten() can add zero or more dimensions to the 1D or more D tensor of zero or more elements, getting the 1D or more D tensor of zero or more elements as shown below:

*Memos:

  • unflatten() can be used with torch or a tensor.
  • The 1st argument(input) with torch or using a tensor(Required-Type:tensor of int, float, complex or bool).
  • The 2nd argument with torch and the 1st argument with a tensor is dim(Required-Type:int).
  • The 3rd argument with torch and the 2nd argument with a tensor is sizes(Required-Type:tuple of int or list of int). *-1 infers and adjust the size.
  • The difference between Unflatten() and unflatten() is:
    • Unflatten() has unflattened_size argument which is identical to sizes argument of unflatten().
    • Basically, Unflatten() is used to define a model while unflatten() is not used to define a model.
import torch

my_tensor = torch.tensor([7, 1, -8, 3, -6, 0])

torch.unflatten(input=my_tensor, dim=0, sizes=(6,))
my_tensor.unflatten(dim=0, sizes=(6,))
torch.unflatten(input=my_tensor, dim=0, sizes=(-1,))
torch.unflatten(input=my_tensor, dim=-1, sizes=(6,))
torch.unflatten(input=my_tensor, dim=-1, sizes=(-1,))
# tensor([7, 1, -8, 3, -6, 0])

torch.unflatten(input=my_tensor, dim=0, sizes=(1, 6))
torch.unflatten(input=my_tensor, dim=0, sizes=(-1, 6))
torch.unflatten(input=my_tensor, dim=0, sizes=(1, -1))
torch.unflatten(input=my_tensor, dim=-1, sizes=(1, 6))
torch.unflatten(input=my_tensor, dim=-1, sizes=(-1, 6))
torch.unflatten(input=my_tensor, dim=-1, sizes=(1, -1))
# tensor([[7, 1, -8, 3, -6, 0]])

torch.unflatten(input=my_tensor, dim=0, sizes=(2, 3))
torch.unflatten(input=my_tensor, dim=0, sizes=(2, -1))
torch.unflatten(input=my_tensor, dim=-1, sizes=(2, 3))
torch.unflatten(input=my_tensor, dim=-1, sizes=(2, -1))
# tensor([[7, 1, -8], [3, -6, 0]])

torch.unflatten(input=my_tensor, dim=0, sizes=(3, 2))
torch.unflatten(input=my_tensor, dim=0, sizes=(3, -1))
torch.unflatten(input=my_tensor, dim=-1, sizes=(3, 2))
torch.unflatten(input=my_tensor, dim=-1, sizes=(3, -1))
# tensor([[7, 1], [-8, 3], [-6, 0]])

torch.unflatten(input=my_tensor, dim=0, sizes=(6, 1))
torch.unflatten(input=my_tensor, dim=0, sizes=(6, -1))
torch.unflatten(input=my_tensor, dim=-1, sizes=(6, 1))
torch.unflatten(input=my_tensor, dim=-1, sizes=(6, -1))
# tensor([[7], [1], [-8], [3], [-6], [0]])

torch.unflatten(input=my_tensor, dim=0, sizes=(1, 2, 3))
torch.unflatten(input=my_tensor, dim=0, sizes=(-1, 2, 3))
torch.unflatten(input=my_tensor, dim=0, sizes=(1, -1, 3))
torch.unflatten(input=my_tensor, dim=0, sizes=(1, 2, -1))
torch.unflatten(input=my_tensor, dim=-1, sizes=(1, 2, 3))
torch.unflatten(input=my_tensor, dim=-1, sizes=(-1, 2, 3))
torch.unflatten(input=my_tensor, dim=-1, sizes=(1, -1, 3))
torch.unflatten(input=my_tensor, dim=-1, sizes=(1, 2, -1))
# tensor([[[7, 1, -8], [3, -6, 0]]])
etc.

my_tensor = torch.tensor([[7, 1, -8], [3, -6, 0]])

torch.unflatten(input=my_tensor, dim=0, sizes=(2,))
torch.unflatten(input=my_tensor, dim=0, sizes=(-1,))
torch.unflatten(input=my_tensor, dim=1, sizes=(3,))
torch.unflatten(input=my_tensor, dim=1, sizes=(-1,))
torch.unflatten(input=my_tensor, dim=-1, sizes=(3,))
torch.unflatten(input=my_tensor, dim=-1, sizes=(-1,))
torch.unflatten(input=my_tensor, dim=-2, sizes=(2,))
torch.unflatten(input=my_tensor, dim=-2, sizes=(-1,))
# tensor([[7, 1, -8], [3, -6, 0]])

torch.unflatten(input=my_tensor, dim=0, sizes=(1, 2))
torch.unflatten(input=my_tensor, dim=0, sizes=(-1, 2))
torch.unflatten(input=my_tensor, dim=-2, sizes=(1, 2))
torch.unflatten(input=my_tensor, dim=-2, sizes=(-1, 2))
# tensor([[[7, 1, -8], [3, -6, 0]]])

torch.unflatten(input=my_tensor, dim=0, sizes=(2, 1))
torch.unflatten(input=my_tensor, dim=0, sizes=(2, -1))
torch.unflatten(input=my_tensor, dim=1, sizes=(1, 3))
torch.unflatten(input=my_tensor, dim=1, sizes=(-1, 3))
torch.unflatten(input=my_tensor, dim=-1, sizes=(1, 3))
torch.unflatten(input=my_tensor, dim=-1, sizes=(-1, 3))
torch.unflatten(input=my_tensor, dim=-2, sizes=(2, 1))
torch.unflatten(input=my_tensor, dim=-2, sizes=(2, -1))
# tensor([[[7, 1, -8]], [[3, -6, 0]]])

torch.unflatten(input=my_tensor, dim=1, sizes=(3, 1))
torch.unflatten(input=my_tensor, dim=1, sizes=(3, -1))
torch.unflatten(input=my_tensor, dim=-1, sizes=(3, 1))
torch.unflatten(input=my_tensor, dim=-1, sizes=(3, -1))
# tensor([[[7], [1], [-8]], [[3], [-6], [0]]])

torch.unflatten(input=my_tensor, dim=0, sizes=(1, 1, 2))
torch.unflatten(input=my_tensor, dim=0, sizes=(-1, 1, 2))
torch.unflatten(input=my_tensor, dim=0, sizes=(1, -1, 2))
torch.unflatten(input=my_tensor, dim=0, sizes=(1, 1, -1))
torch.unflatten(input=my_tensor, dim=-2, sizes=(1, 1, 2))
torch.unflatten(input=my_tensor, dim=-2, sizes=(-1, 1, 2))
torch.unflatten(input=my_tensor, dim=-2, sizes=(1, -1, 2))
torch.unflatten(input=my_tensor, dim=-2, sizes=(1, 1, -1))
# tensor([[[[7, 1, -8], [3, -6, 0]]]])

torch.unflatten(input=my_tensor, dim=1, sizes=(1, 1, 3))
torch.unflatten(input=my_tensor, dim=1, sizes=(-1, 1, 3))
torch.unflatten(input=my_tensor, dim=1, sizes=(1, -1, 3))
torch.unflatten(input=my_tensor, dim=1, sizes=(1, 1, -1))
torch.unflatten(input=my_tensor, dim=-1, sizes=(1, 1, 3))
torch.unflatten(input=my_tensor, dim=-1, sizes=(-1, 1, 3))
torch.unflatten(input=my_tensor, dim=-1, sizes=(1, -1, 3))
torch.unflatten(input=my_tensor, dim=-1, sizes=(1, 1, -1))
# tensor([[[[7, 1, -8]]], [[[3, -6, 0]]]])

my_tensor = torch.tensor([[7., 1., -8.], [3., -6., 0.]])

torch.unflatten(input=my_tensor, dim=0, sizes=(2,))
# tensor([[7., 1., -8.], [3., -6., 0.]])

my_tensor = torch.tensor([[7.+0.j, 1.+0.j, -8.+0.j],
                          [3.+0.j, -6.+0.j, 0.+0.j]])
torch.unflatten(input=my_tensor, dim=0, sizes=(2,))
# tensor([[7.+0.j, 1.+0.j, -8.+0.j],
#         [3.+0.j, -6.+0.j, 0.+0.j]])

my_tensor = torch.tensor([[True, False, True], [False, True, False]])

torch.unflatten(input=my_tensor, dim=0, sizes=(2,))
# tensor([[True, False, True], [False, True, False]])
Enter fullscreen mode Exit fullscreen mode

Top comments (0)