DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on

expand() in PyTorch

Buy Me a Coffee

*Memos:

expand() can get the 0D or more D view tensor of zero or more expanded elements from the 0D or more D tensor of zero or more elements as shown below:

*Memos:

  • expand() can be used with a tensor but not with torch.
  • Using a tensor(Required-Type:tensor of int, float, complex or bool).
  • The 1st or more arguments with a tensor are size(Required-Type:int, tuple of int, list of int or size()): *Memos:
    • Its D must be more than or equal to the tensor's D
    • If at least one dimension is 0, an empty tensor is returned.
    • size= mustn't be used for the one or more dimensions without a tuple, list or size().
    • You can set -1 not to change the dimension. *-1 can be used only for existing demensions.
import torch

my_tensor = torch.tensor([7, 4, 2])

my_tensor.expand(size=(0, 3))
my_tensor.expand(0, 3)
my_tensor.expand(size=(0, -1))
my_tensor.expand(0, -1)
# tensor([], size=(0, 3), dtype=torch.int64)

my_tensor.expand(size=(3,))
my_tensor.expand(3)
my_tensor.expand(size=(-1,))
my_tensor.expand(-1)
my_tensor.expand(size=torch.tensor([5, 8, 1]).size())
# tensor([7, 4, 2])

my_tensor.expand(size=(1, 3))
my_tensor.expand(1, 3)
my_tensor.expand(size=(1, -1))
my_tensor.expand(1, -1)
my_tensor.expand(size=torch.tensor([[5, 8, 1]]).size())
# tensor([[7, 4, 2]])

my_tensor.expand(size=(2, 3))
my_tensor.expand(2, 3)
my_tensor.expand(size=(2, -1))
my_tensor.expand(2, -1)
my_tensor.expand(size=torch.tensor([[5, 8, 1], [9, 3, 0]]).size())
# tensor([[7, 4, 2], [7, 4, 2]])

my_tensor.expand(size=(3, 3))
my_tensor.expand(3, 3)
my_tensor.expand(size=(3, -1))
my_tensor.expand(3, -1)
# tensor([[7, 4, 2], [7, 4, 2], [7, 4, 2]])

my_tensor.expand(size=(4, 3))
my_tensor.expand(4, 3)
my_tensor.expand(size=(4, -1))
my_tensor.expand(4, -1)
# tensor([[7, 4, 2], [7, 4, 2], [7, 4, 2], [7, 4, 2]])
etc.

my_tensor.expand(size=(1, 2, 3))
my_tensor.expand(1, 2, 3)
my_tensor.expand(size=(1, 2, -1))
my_tensor.expand(1, 2, -1)
# tensor([[[7, 4, 2], [7, 4, 2]]])

my_tensor.expand(size=(1, 0, 3))
my_tensor.expand(1, 0, 3)
my_tensor.expand(size=(1, 0, -1))
my_tensor.expand(1, 0, -1)
# tensor([], size=(1, 0, 3), dtype=torch.int64)

my_tensor = torch.tensor([[7], [4], [2]])

my_tensor.expand(size=(3, 1))
my_tensor.expand(3, 1)
my_tensor.expand(size=(3, -1))
my_tensor.expand(3, -1)
# tensor([[7], [4], [2]])

my_tensor.expand(size=(3, 4))
my_tensor.expand(3, 4)
# tensor([[7, 7, 7, 7],
#         [4, 4, 4, 4],
#         [2, 2, 2, 2]])

my_tensor = torch.tensor([[7, 4, 2], [5, 1, 6]])

my_tensor.expand(size=(4, 2, 3))
my_tensor.expand(4, 2, 3)
my_tensor.expand(size=(4, 2, -1))
my_tensor.expand(4, 2, -1)
my_tensor.expand(size=(4, -1, 3))
my_tensor.expand(4, -1, 3)
my_tensor.expand(size=(4, -1, -1))
my_tensor.expand(4, -1, -1)
# tensor([[[7, 4, 2], [5, 1, 6]],
#         [[7, 4, 2], [5, 1, 6]],
#         [[7, 4, 2], [5, 1, 6]],
#         [[7, 4, 2], [5, 1, 6]]])

my_tensor = torch.tensor([[7., 4., 2.], [5., 1., 6.]])

my_tensor.expand(size=(4, 2, 3))
# tensor([[[7., 4., 2.], [5., 1., 6.]],
#         [[7., 4., 2.], [5., 1., 6.]],
#         [[7., 4., 2.], [5., 1., 6.]],
#         [[7., 4., 2.], [5., 1., 6.]]])

my_tensor = torch.tensor([[7.+0.j, 4.+0.j, 2.+0.j],
                          [5.+0.j, 1.+0.j, 6.+0.j]])
my_tensor.expand(size=(4, 2, 3))
# tensor([[[7.+0.j, 4.+0.j, 2.+0.j],
#          [5.+0.j, 1.+0.j, 6.+0.j]],
#         [[7.+0.j, 4.+0.j, 2.+0.j],
#          [5.+0.j, 1.+0.j, 6.+0.j]],
#         [[7.+0.j, 4.+0.j, 2.+0.j],
#          [5.+0.j, 1.+0.j, 6.+0.j]],
#         [[7.+0.j, 4.+0.j, 2.+0.j],
#          [5.+0.j, 1.+0.j, 6.+0.j]]])

my_tensor = torch.tensor([[True, False, True], [False, True, False]])

my_tensor.expand(size=(4, 2, 3))
# tensor([[[True, False, True], [False, True, False]],
#         [[True, False, True], [False, True, False]],
#         [[True, False, True], [False, True, False]],
#         [[True, False, True], [False, True, False]]])
Enter fullscreen mode Exit fullscreen mode

Top comments (0)