*Memos:
- My post explains flatten() and ravel().
- My post explains Unflatten().
- My post explains unflatten().
Flatten() can remove zero or more dimensions by selecting dimensions from the 0D or more D tensor of zero or more elements, getting the 1D or more D tensor of zero or more elements as shown below:
*Memos:
- The 1st argument for initialization is
start_dim
(Optional-Default:1
-Type:int
). - The 2nd argument for initialization is
end_dim
(Optional-Default:-1
-Type:int
). - The 1st argument is
input
(Required-Type:tensor
ofint
,float
,complex
orbool
). -
Flatten()
can change a 0D tensor to a 1D tensor. -
Flatten()
does nothing for a 1D tensor. - The difference between Flatten() and flatten() is:
- The default value of
start_dim
forFlatten()
is1
while the default value ofstart_dim
forflatten()
is0
. - Basically,
Flatten()
is used to define a model whileflatten()
is not used to define a model.
- The default value of
import torch
from torch import nn
flatten = nn.Flatten()
flatten
# Flatten(start_dim=1, end_dim=-1)
flatten.start_dim
# 1
flatten.end_dim
# -1
my_tensor = torch.tensor(7)
flatten = nn.Flatten(start_dim=0, end_dim=0)
flatten = nn.Flatten(start_dim=0, end_dim=-1)
flatten = nn.Flatten(start_dim=-1, end_dim=0)
flatten = nn.Flatten(start_dim=-1, end_dim=-1)
flatten(input=my_tensor)
# tensor([7])
my_tensor = torch.tensor([7, 1, -8, 3, -6, 0])
flatten = nn.Flatten(start_dim=0, end_dim=0)
flatten = nn.Flatten(start_dim=0, end_dim=-1)
flatten = nn.Flatten(start_dim=-1, end_dim=0)
flatten = nn.Flatten(start_dim=-1, end_dim=-1)
flatten(input=my_tensor)
# tensor([7, 1, -8, 3, -6, 0])
my_tensor = torch.tensor([[7, 1, -8], [3, -6, 0]])
flatten = nn.Flatten(start_dim=0, end_dim=1)
flatten = nn.Flatten(start_dim=0, end_dim=-1)
flatten = nn.Flatten(start_dim=-2, end_dim=1)
flatten = nn.Flatten(start_dim=-2, end_dim=-1)
flatten(input=my_tensor)
# tensor([7, 1, -8, 3, -6, 0])
flatten = nn.Flatten()
flatten = nn.Flatten(start_dim=0, end_dim=0)
flatten = nn.Flatten(start_dim=-1, end_dim=-1)
flatten = nn.Flatten(start_dim=0, end_dim=-2)
flatten = nn.Flatten(start_dim=1, end_dim=1)
flatten = nn.Flatten(start_dim=1, end_dim=-1)
flatten = nn.Flatten(start_dim=-1, end_dim=1)
flatten = nn.Flatten(start_dim=-1, end_dim=-1)
flatten = nn.Flatten(start_dim=-2, end_dim=0)
flatten = nn.Flatten(start_dim=-2, end_dim=-2)
flatten(input=my_tensor)
# tensor([[7, 1, -8], [3, -6, 0]])
my_tensor = torch.tensor([[[7], [1], [-8]], [[3], [-6], [0]]])
flatten = nn.Flatten(start_dim=0, end_dim=2)
flatten = nn.Flatten(start_dim=0, end_dim=-1)
flatten = nn.Flatten(start_dim=-3, end_dim=2)
flatten = nn.Flatten(start_dim=-3, end_dim=-1)
flatten(input=my_tensor)
# tensor([7, 1, -8, 3, -6, 0])
flatten = nn.Flatten(start_dim=0, end_dim=0)
flatten = nn.Flatten(start_dim=0, end_dim=-3)
flatten = nn.Flatten(start_dim=1, end_dim=1)
flatten = nn.Flatten(start_dim=1, end_dim=-2)
flatten = nn.Flatten(start_dim=2, end_dim=2)
flatten = nn.Flatten(start_dim=2, end_dim=-1)
flatten = nn.Flatten(start_dim=-1, end_dim=2)
flatten = nn.Flatten(start_dim=-1, end_dim=-1)
flatten = nn.Flatten(start_dim=-2, end_dim=1)
flatten = nn.Flatten(start_dim=-2, end_dim=-2)
flatten = nn.Flatten(start_dim=-3, end_dim=0)
flatten = nn.Flatten(start_dim=-3, end_dim=-3)
flatten(input=my_tensor)
# tensor([[[7], [1], [-8]], [[3], [-6], [0]]])
flatten = nn.Flatten(start_dim=0, end_dim=1)
flatten = nn.Flatten(start_dim=0, end_dim=-2)
flatten = nn.Flatten(start_dim=-3, end_dim=1)
flatten = nn.Flatten(start_dim=-3, end_dim=-2)
flatten(input=my_tensor)
# tensor([[7], [1], [-8], [3], [-6], [0]])
flatten = nn.Flatten()
flatten = nn.Flatten(start_dim=1, end_dim=2)
flatten = nn.Flatten(start_dim=1, end_dim=-1)
flatten = nn.Flatten(start_dim=-2, end_dim=2)
flatten = nn.Flatten(start_dim=-2, end_dim=-1)
flatten(input=my_tensor)
# tensor([[7, 1, -8], [3, -6, 0]])
my_tensor = torch.tensor([[[7.], [1.], [-8.]], [[3.], [-6.], [0.]]])
flatten = nn.Flatten()
flatten(input=my_tensor)
# tensor([[7., 1., -8.], [3., -6., 0.]])
my_tensor = torch.tensor([[[7.+0.j], [1.+0.j], [-8.+0.j]],
[[3.+0.j], [-6.+0.j], [0.+0.j]]])
flatten = nn.Flatten()
flatten(input=my_tensor)
# tensor([[7.+0.j, 1.+0.j, -8.+0.j],
# [3.+0.j, -6.+0.j, 0.+0.j]])
my_tensor = torch.tensor([[[True], [False], [True]],
[[False], [True], [False]]])
flatten = nn.Flatten()
flatten(input=my_tensor)
# tensor([[True, False, True],
# [False, True, False]])
Top comments (0)