DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on • Updated on

stack() and cat() in PyTorch

*My post explains hstack(), vstack(), dstack() and column_stack().

stack() can concatenate a sequence of 2 or more tensors as shown below:

import torch

tensor1 = torch.tensor(2) # The size is [].
tensor2 = torch.tensor(7) # The size is [].
tensor3 = torch.tensor(4) # The size is [].
torch.stack((tensor1, tensor2, tensor3))
# tensor([2, 7, 4])
# The size is [3].

tensor1 = torch.tensor([2, 7, 4]) # The size is [3].
tensor2 = torch.tensor([8, 3, 2]) # The size is [3].
tensor3 = torch.tensor([5, 0, 8]) # The size is [3].
torch.stack((tensor1, tensor2, tensor3))
# tensor([[2, 7, 4], [8, 3, 2], [5, 0, 8]])
# The size is [3, 3].

tensor1 = torch.tensor([[2, 7, 4], [8, 3, 2]]) # The size is [2, 3].
tensor2 = torch.tensor([[5, 0, 8], [3, 6, 1]]) # The size is [2, 3].
tensor3 = torch.tensor([[9, 4, 7], [1, 0, 5]]) # The size is [2, 3].
torch.stack((tensor1, tensor2, tensor3))
# tensor([[[2, 7, 4], [8, 3, 2]],
#         [[5, 0, 8], [3, 6, 1]],
#         [[9, 4, 7], [1, 0, 5]]])
# The size is [3, 2, 3].

tensor1 = torch.tensor([[[2, 7, 4], [8, 3, 2]],
                        [[5, 0, 8], [3, 6, 1]]])
                       # The size is [2, 2, 3].
tensor2 = torch.tensor([[[9, 4, 7], [1, 0, 5]],
                        [[6, 7, 4], [2, 1, 9]]])
                       # The size is [2, 2, 3].
tensor3 = torch.tensor([[[1, 6, 3], [9, 6, 0]],
                        [[0, 8, 7], [3, 5, 2]]])
                       # The size is [2, 2, 3].
torch.stack((tensor1, tensor2, tensor3))
torch.stack((tensor1, tensor2, tensor3), 0)
# tensor([[[[2, 7, 4], [8, 3, 2]],
#          [[5, 0, 8], [3, 6, 1]]],
#         [[[9, 4, 7], [1, 0, 5]],
#          [[6, 7, 4], [2, 1, 9]]],
#         [[[1, 6, 3], [9, 6, 0]],
#          [[0, 8, 7], [3, 5, 2]]]])
# The size is [3, 2, 2, 3].

torch.stack((tensor1, tensor2, tensor3), 1)
torch.stack((tensor1, tensor2, tensor3), -3)
# tensor([[[[2, 7, 4], [8, 3, 2]],
#          [[9, 4, 7], [1, 0, 5]],
#          [[1, 6, 3], [9, 6, 0]]],
#         [[[5, 0, 8], [3, 6, 1]],
#          [[6, 7, 4], [2, 1, 9]],
#          [[0, 8, 7], [3, 5, 2]]]])
# The size is [2, 3, 2, 3].

torch.stack((tensor1, tensor2, tensor3), 2)
torch.stack((tensor1, tensor2, tensor3), -2)
# tensor([[[[2, 7, 4], [9, 4, 7], [1, 6, 3]],
#          [[8, 3, 2], [1, 0, 5], [9, 6, 0]]],
#         [[[5, 0, 8], [6, 7, 4], [0, 8, 7]],
#          [[3, 6, 1], [2, 1, 9], [3, 5, 2]]]])
# The size is [2, 2, 3, 3].

torch.stack((tensor1, tensor2, tensor3), 3)
torch.stack((tensor1, tensor2, tensor3), -1)
# tensor([[[[2, 9, 1], [7, 4, 6], [4, 7, 3]],
#          [[8, 1, 9], [3, 0, 6], [2, 5, 0]]],
#         [[[5, 6, 0], [0, 7, 8], [8, 4, 7]],
#          [[3, 2, 3], [6, 1, 5], [1, 9, 2]]]])
# The size is [2, 2, 3, 3].
Enter fullscreen mode Exit fullscreen mode

*Memos:

  • stack() can concatenate 0D or more D tensors.
  • The size of tensors must be the same.
  • The 2nd argument is a dimension.
  • If at least one tensor contains at least one floating-point number, the result is the tensor of floating-point numbers.
  • stack() can be called only from torch but not from a tensor.

cat() concatenate a sequence of seq 2 or more tensors as shown below:

import torch

tensor1 = torch.tensor([2, 7, 4]) # The size is [3].
tensor2 = torch.tensor([8, 3, 2]) # The size is [3].
tensor3 = torch.tensor([5, 0, 8]) # The size is [3].
torch.cat((tensor1, tensor2, tensor3))
# tensor([2, 7, 4, 8, 3, 2, 5, 0, 8])
# The size is [9].

tensor1 = torch.tensor([[2, 7, 4], [8, 3, 2]]) # The size is [2, 3].
tensor2 = torch.tensor([[5, 0, 8], [3, 6, 1]]) # The size is [2, 3].
tensor3 = torch.tensor([[9, 4, 7], [1, 0, 5]]) # The size is [2, 3].
torch.cat((tensor1, tensor2, tensor3))
# tensor([[2, 7, 4],
#         [8, 3, 2],
#         [5, 0, 8],
#         [3, 6, 1],
#         [9, 4, 7],
#         [1, 0, 5]])
# The size is [6, 3].

tensor1 = torch.tensor([[[2, 7, 4], [8, 3, 2]],
                        [[5, 0, 8], [3, 6, 1]]])
                       # The size is [2, 2, 3].
tensor2 = torch.tensor([[[9, 4, 7], [1, 0, 5]],
                        [[6, 7, 4], [2, 1, 9]]])
                       # The size is [2, 2, 3].
tensor3 = torch.tensor([[[1, 6, 3], [9, 6, 0]],
                        [[0, 8, 7], [3, 5, 2]]])
                       # The size is [2, 2, 3].
torch.cat((tensor1, tensor2, tensor3))
torch.cat((tensor1, tensor2, tensor3), 0)
torch.cat((tensor1, tensor2, tensor3), -3)
# tensor([[[2, 7, 4], [8, 3, 2]],
#         [[5, 0, 8], [3, 6, 1]],
#         [[9, 4, 7], [1, 0, 5]],
#         [[6, 7, 4], [2, 1, 9]],
#         [[1, 6, 3], [9, 6, 0]],
#         [[0, 8, 7], [3, 5, 2]]])
# The size is [6, 2, 3].

torch.cat((tensor1, tensor2, tensor3), 1)
torch.cat((tensor1, tensor2, tensor3), -2)
# tensor([[[2, 7, 4], 
#          [8, 3, 2],
#          [9, 4, 7],
#          [1, 0, 5],
#          [1, 6, 3],
#          [9, 6, 0]],
#         [[5, 0, 8],
#          [3, 6, 1],
#          [6, 7, 4],
#          [2, 1, 9],
#          [0, 8, 7],
#          [3, 5, 2]]])
# The size is [2, 6, 3].

torch.cat((tensor1, tensor2, tensor3), 2)
torch.cat((tensor1, tensor2, tensor3), -1)
# tensor([[[2, 7, 4, 9, 4, 7, 1, 6, 3],
#          [8, 3, 2, 1, 0, 5, 9, 6, 0]],
#         [[5, 0, 8, 6, 7, 4, 0, 8, 7],
#          [3, 6, 1, 2, 1, 9, 3, 5, 2]]])
# The size is [2, 2, 9].
Enter fullscreen mode Exit fullscreen mode

*Memos:

  • cat() can concatenate 1D or more D tensors.
  • With cat(), concatenation is possible with some tensors of different size but it is not possible with some tensors of different size.
  • The 2nd argument is a dimension.
  • If at least one tensor contains at least one floating-point number, the result is the tensor of floating-point numbers.
  • cat() can be called only from torch but not from a tensor.
  • cat() and concat() are the same because *concat() is the alias of cat().

Top comments (0)