DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on • Updated on

reshape() and view() in PyTorch

Buy Me a Coffee

*Memos:

reshape() or view() can get the 0D or more D reshaped tensor of zero or more elements without losing data from the 0D or more D tensor of zero or more elements as shown below:

*Memos:

  • reshape() can be used with torch or a tensor while view() can be used with a tensor but not with torch.
  • For reshape(), the 1st argument(input) with torch or using a tensor(Required-Type:tensor of int, float, complex or bool).
  • For reshape(), the 2nd argument with torch(Type:tuple of int, list of int or size()) or the 1st or more arguments with a tensor(Type:int, tuple of int, list of int or size()) are shape(Required).
  • For view(), using a tensor(Required-Type:tensor of int, float, complex or bool).
  • For view(), the 1st or more arguments with a tensor are *shape(Required-Type:int, tuple of int, list of int or size()). *Don't use any keyword like *shape= or shape=.
  • For view(), the 1st argument with a tensor is dtype(Required-Type:dtype): *Memos:
    • If dtype is not given, it is inferred from a tensor.
    • My post explains dtype argument.
  • For reshape() and view(), setting -1 as the 1st number can adjust the size automatically so you don't need to set 12, 2, 3, 4 or 6 as the 1st number. *-1 is available only as the 1st number.
  • reshape() may create the copy of a tensor, taking more memory while view() doesn't create the copy of a tensor so view() can be ligher and faster than reshape().
import torch

my_tensor = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]])

torch.reshape(input=my_tensor, shape=(12,))
torch.reshape(input=my_tensor, shape=(-1,))
my_tensor.reshape(shape=(12,))
my_tensor.reshape(shape=(-1,))
my_tensor.reshape(12)
my_tensor.reshape(-1)
my_tensor.view((12,))
my_tensor.view((-1,))
my_tensor.view(12)
my_tensor.view(-1)
# tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])

torch.reshape(input=my_tensor, shape=(1, 12))
torch.reshape(input=my_tensor, shape=(-1, 12))
my_tensor.reshape(shape=(1, 12))
my_tensor.reshape(shape=(-1, 12))
my_tensor.reshape(1, 12)
my_tensor.reshape(-1, 12)
my_tensor.view((1, 12))
my_tensor.view((-1, 12))
my_tensor.view(1, 12)
my_tensor.view(-1, 12)
# tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]])

torch.reshape(input=my_tensor, shape=(2, 6))
torch.reshape(input=my_tensor, shape=(-1, 6))
my_tensor.reshape(shape=(2, 6))
my_tensor.reshape(shape=(-1, 6))
my_tensor.reshape(2, 6)
my_tensor.reshape(-1, 6)
my_tensor.view((2, 6))
my_tensor.view((-1, 6))
my_tensor.view(2, 6)
my_tensor.view(-1, 6)
# tensor([[0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11]])

torch.reshape(input=my_tensor, shape=(3, 4))
torch.reshape(input=my_tensor, shape=(-1, 4))
my_tensor.reshape(shape=(3, 4))
my_tensor.reshape(shape=(-1, 4))
my_tensor.reshape(3, 4)
my_tensor.reshape(-1, 4)
my_tensor.view((3, 4))
my_tensor.view((-1, 4))
my_tensor.view(3, 4)
my_tensor.view(-1, 4)
# tensor([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]])

torch.reshape(input=my_tensor, shape=(4, 3))
torch.reshape(input=my_tensor, shape=(-1, 3))
torch.reshape(input=my_tensor, shape=my_tensor.size())
my_tensor.reshape(shape=(4, 3))
my_tensor.reshape(shape=(-1, 3))
my_tensor.reshape(4, 3)
my_tensor.reshape(-1, 3)
my_tensor.reshape(shape=my_tensor.size())
my_tensor.view((4, 3))
my_tensor.view((-1, 3))
my_tensor.view(4, 3)
my_tensor.view(-1, 3)
my_tensor.view(my_tensor.size())
# tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]])

torch.reshape(input=my_tensor, shape=(6, 2))
torch.reshape(input=my_tensor, shape=(-1, 2))
my_tensor.reshape(shape=(6, 2))
my_tensor.reshape(shape=(-1, 2))
my_tensor.reshape(6, 2)
my_tensor.reshape(-1, 2)
my_tensor.view((6, 2))
my_tensor.view((-1, 2))
my_tensor.view(6, 2)
my_tensor.view(-1, 2)
# tensor([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11]])

torch.reshape(input=my_tensor, shape=(12, 1))
torch.reshape(input=my_tensor, shape=(-1, 1))
my_tensor.reshape(shape=(12, 1))
my_tensor.reshape(shape=(-1, 1))
my_tensor.reshape(12, 1)
my_tensor.reshape(-1, 1)
my_tensor.view((12, 1))
my_tensor.view((-1, 1))
my_tensor.view(12, 1)
my_tensor.view(-1, 1)
# tensor([[0], [1], [2], [3], [4], [5], [6], [7], [8], [9], [10], [11]])

torch.reshape(input=my_tensor, shape=(2, 2, 3))
torch.reshape(input=my_tensor, shape=(-1, 2, 3))
my_tensor.reshape(shape=(2, 2, 3))
my_tensor.reshape(shape=(-1, 2, 3))
my_tensor.reshape(2, 2, 3)
my_tensor.reshape(-1, 2, 3)
my_tensor.view((2, 2, 3))
my_tensor.view((-1, 2, 3))
my_tensor.view(2, 2, 3)
my_tensor.view(-1, 2, 3)
# tensor([[[0, 1, 2], [3, 4, 5]],
#         [[6, 7, 8], [9, 10, 11]]])
etc.

torch.reshape(input=my_tensor, shape=(1, 2, 2, 3))
torch.reshape(input=my_tensor, shape=(-1, 2, 2, 3))
my_tensor.reshape(shape=(1, 2, 2, 3))
my_tensor.reshape(shape=(-1, 2, 2, 3))
my_tensor.reshape(1, 2, 2, 3)
my_tensor.reshape(-1, 2, 2, 3)
my_tensor.view((1, 2, 2, 3))
my_tensor.view((-1, 2, 2, 3))
my_tensor.view(1, 2, 2, 3)
my_tensor.view(-1, 2, 2, 3)
# tensor([[[[0, 1, 2], [3, 4, 5]],
#          [[6, 7, 8], [9, 10, 11]]]])
etc.

my_tensor = torch.tensor([[0., 1., 2.], [3., 4., 5.],
                          [6., 7., 8.], [9., 10., 11.]])
torch.reshape(input=my_tensor, shape=(12,))
torch.reshape(input=my_tensor, shape=(-1,))
my_tensor.view((12,))
my_tensor.view((-1,))
# tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.])

my_tensor = torch.tensor([[0.+0.j, 1.+0.j, 2.+0.j],
                          [3.+0.j, 4.+0.j, 5.+0.j],
                          [6.+0.j, 7.+0.j, 8.+0.j],
                          [9.+0.j, 10.+0.j, 11.+0.j]])
torch.reshape(input=my_tensor, shape=(12,))
torch.reshape(input=my_tensor, shape=(-1,))
my_tensor.view((12,))
my_tensor.view((-1,))
# tensor([0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j, 4.+0.j, 5.+0.j,
#         6.+0.j, 7.+0.j, 8.+0.j, 9.+0.j, 10.+0.j, 11.+0.j])

my_tensor = torch.tensor([[True, False, True], [True, False, True],
                          [False, True, False], [False, True, False]])
torch.reshape(input=my_tensor, shape=(12,))
torch.reshape(input=my_tensor, shape=(-1,))
my_tensor.view((12,))
my_tensor.view((-1,))
# tensor([True, False, True, True, False, True,
#         False, True, False, False, True, False])
Enter fullscreen mode Exit fullscreen mode

Top comments (0)