DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on

unflatten() in PyTorch

*My post explains flatten() and ravel().

unflatten() can add zero or more dimensions to 1D or more D tensor as shown below:

*Memos:

  • unflatten() can be used with torch or a tensor.
  • The 1st argument(tensor of int, float, complex or bool) with torch or using a tensor(tensor of int, float, complex or bool) is input(Required).
  • The 2nd argument(int) with torch and the 1st argument(int) with a tensor is dim(Required) which is a dimension. *-1 infers and adjust the size.
  • The 3rd argument(tuple of int or list of int) with torch and the 1st argument(tuple of int or list of int) with a tensor is sizes(Required) which is a dimension.
import torch

my_tensor = torch.tensor([7, 1, -8, 3, -6, 0]) # 1D tensor
                                               # Size:[6]
torch.unflatten(input=my_tensor, dim=0, sizes=(6,))
my_tensor.unflatten(dim=0, sizes=(6,))
torch.unflatten(input=my_tensor, dim=0, sizes=(-1,))
torch.unflatten(input=my_tensor, dim=-1, sizes=(6,))
torch.unflatten(input=my_tensor, dim=-1, sizes=(-1,))
# tensor([7, 1, -8, 3, -6, 0])
# Size:[6]

torch.unflatten(input=my_tensor, dim=0, sizes=(1, 6))
torch.unflatten(input=my_tensor, dim=0, sizes=(-1, 6))
torch.unflatten(input=my_tensor, dim=0, sizes=(1, -1))
torch.unflatten(input=my_tensor, dim=-1, sizes=(1, 6))
torch.unflatten(input=my_tensor, dim=-1, sizes=(-1, 6))
torch.unflatten(input=my_tensor, dim=-1, sizes=(1, -1))
# tensor([[7, 1, -8, 3, -6, 0]])
# Size:[1, 6]

torch.unflatten(input=my_tensor, dim=0, sizes=(2, 3))
torch.unflatten(input=my_tensor, dim=0, sizes=(2, -1))
torch.unflatten(input=my_tensor, dim=-1, sizes=(2, 3))
torch.unflatten(input=my_tensor, dim=-1, sizes=(2, -1))
# tensor([[7, 1, -8], [3, -6, 0]])
# Size:[2, 3]

torch.unflatten(input=my_tensor, dim=0, sizes=(3, 2))
torch.unflatten(input=my_tensor, dim=0, sizes=(3, -1))
torch.unflatten(input=my_tensor, dim=-1, sizes=(3, 2))
torch.unflatten(input=my_tensor, dim=-1, sizes=(3, -1))
# tensor([[7, 1], [-8, 3], [-6, 0]])
# Size:[3, 2]

torch.unflatten(input=my_tensor, dim=0, sizes=(6, 1))
torch.unflatten(input=my_tensor, dim=0, sizes=(6, -1))
torch.unflatten(input=my_tensor, dim=-1, sizes=(6, 1))
torch.unflatten(input=my_tensor, dim=-1, sizes=(6, -1))
# tensor([[7], [1], [-8], [3], [-6], [0]])
# Size:[6, 1]

torch.unflatten(input=my_tensor, dim=0, sizes=(1, 2, 3))
torch.unflatten(input=my_tensor, dim=0, sizes=(-1, 2, 3))
torch.unflatten(input=my_tensor, dim=0, sizes=(1, -1, 3))
torch.unflatten(input=my_tensor, dim=0, sizes=(1, 2, -1))
torch.unflatten(input=my_tensor, dim=-1, sizes=(1, 2, 3))
torch.unflatten(input=my_tensor, dim=-1, sizes=(-1, 2, 3))
torch.unflatten(input=my_tensor, dim=-1, sizes=(1, -1, 3))
torch.unflatten(input=my_tensor, dim=-1, sizes=(1, 2, -1))
# tensor([[[7, 1, -8], [3, -6, 0]]])
# Size:[1, 2, 3]
etc.

my_tensor = torch.tensor([[7, 1, -8], [3, -6, 0]]) # 2D tensor
                                                   # Size:[2, 3]
torch.unflatten(input=my_tensor, dim=0, sizes=(2,))
torch.unflatten(input=my_tensor, dim=0, sizes=(-1,))
torch.unflatten(input=my_tensor, dim=-2, sizes=(2,))
torch.unflatten(input=my_tensor, dim=-2, sizes=(-1,))
# tensor([[7, 1, -8], [3, -6, 0]])
# Size:[2, 3]

torch.unflatten(input=my_tensor, dim=0, sizes=(1, 2))
torch.unflatten(input=my_tensor, dim=0, sizes=(-1, 2))
torch.unflatten(input=my_tensor, dim=-2, sizes=(1, 2))
torch.unflatten(input=my_tensor, dim=-2, sizes=(-1, 2))
# tensor([[[7, 1, -8], [3, -6, 0]]])
# Size:[1, 2, 3]

torch.unflatten(input=my_tensor, dim=0, sizes=(2, 1))
torch.unflatten(input=my_tensor, dim=0, sizes=(2, -1))
torch.unflatten(input=my_tensor, dim=-2, sizes=(2, 1))
torch.unflatten(input=my_tensor, dim=-2, sizes=(2, -1))
# tensor([[[7, 1, -8]], [[3, -6, 0]]])
# Size:[2, 1, 3]

torch.unflatten(input=my_tensor, dim=0, sizes=(1, 1, 2))
torch.unflatten(input=my_tensor, dim=0, sizes=(-1, 1, 2))
torch.unflatten(input=my_tensor, dim=0, sizes=(1, -1, 2))
torch.unflatten(input=my_tensor, dim=0, sizes=(1, 1, -1))
torch.unflatten(input=my_tensor, dim=-2, sizes=(1, 1, 2))
torch.unflatten(input=my_tensor, dim=-2, sizes=(-1, 1, 2))
torch.unflatten(input=my_tensor, dim=-2, sizes=(1, -1, 2))
torch.unflatten(input=my_tensor, dim=-2, sizes=(1, 1, -1))
# tensor([[[[7, 1, -8], [3, -6, 0]]]])
# Size:[1, 1, 2, 3]

torch.unflatten(input=my_tensor, dim=1, sizes=(3,))
torch.unflatten(input=my_tensor, dim=1, sizes=(-1,))
torch.unflatten(input=my_tensor, dim=-1, sizes=(3,))
torch.unflatten(input=my_tensor, dim=-1, sizes=(-1,))
# tensor([[7, 1, -8], [3, -6, 0]])
# Size:[2, 3]

torch.unflatten(input=my_tensor, dim=1, sizes=(3, 1))
torch.unflatten(input=my_tensor, dim=1, sizes=(3, -1))
torch.unflatten(input=my_tensor, dim=-1, sizes=(3, 1))
torch.unflatten(input=my_tensor, dim=-1, sizes=(3, -1))
# tensor([[[7], [1], [-8]], [[3], [-6], [0]]])
# Size:[2, 3, 1]

torch.unflatten(input=my_tensor, dim=1, sizes=(1, 3))
torch.unflatten(input=my_tensor, dim=1, sizes=(-1, 3))
torch.unflatten(input=my_tensor, dim=-1, sizes=(1, 3))
torch.unflatten(input=my_tensor, dim=-1, sizes=(-1, 3))
# tensor([[[7, 1, -8]], [[3, -6, 0]]])
# Size:[2, 1, 3]

torch.unflatten(input=my_tensor, dim=1, sizes=(1, 1, 3))
torch.unflatten(input=my_tensor, dim=1, sizes=(-1, 1, 3))
torch.unflatten(input=my_tensor, dim=1, sizes=(1, -1, 3))
torch.unflatten(input=my_tensor, dim=1, sizes=(1, 1, -1))
torch.unflatten(input=my_tensor, dim=-1, sizes=(1, 1, 3))
torch.unflatten(input=my_tensor, dim=-1, sizes=(-1, 1, 3))
torch.unflatten(input=my_tensor, dim=-1, sizes=(1, -1, 3))
torch.unflatten(input=my_tensor, dim=-1, sizes=(1, 1, -1))
# tensor([[[[7, 1, -8]]], [[[3, -6, 0]]]])
# Size:[2, 1, 1, 3]

my_tensor = torch.tensor([[7., 1., -8.], [3., -6., 0.]]) # 2D tensor
                                                         # Size:[2, 3]
torch.unflatten(input=my_tensor, dim=0, sizes=(2,))
# tensor([[7., 1., -8.], [3., -6., 0.]])

my_tensor = torch.tensor([[7.+0.j, 1.+0.j, -8.+0.j],  # 2D tensor
                          [3.+0.j, -6.+0.j, 0.+0.j]]) # Size:[2, 3]
torch.unflatten(input=my_tensor, dim=0, sizes=(2,))
# tensor([[7.+0.j, 1.+0.j, -8.+0.j],
#         [3.+0.j, -6.+0.j, 0.+0.j]])

my_tensor = torch.tensor([[True, False, True],   # 2D tensor
                          [False, True, False]]) # Size:[2, 3]
torch.unflatten(input=my_tensor, dim=0, sizes=(2,))
# tensor([[True, False, True],
#         [False, True, False]])
# Size:[2, 3]
etc.
Enter fullscreen mode Exit fullscreen mode

Top comments (0)