DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on

narrow() and narrow_copy() in PyTorch

narrow() and narrow_copy() can extract a tensor from a 1D or more D tensor as shown below:

*Memos:

  • narrow() and narrow_copy() can be used with torch or a tensor.
  • The tensor of zero or more integers, floating-point numbers, complex numbers or boolean values can be used.
  • The 2nd argument(int) with torch or the 1st argument(int) with a tensor is dim(Required).
  • The 3rd argument(int) with torch or the 2nd argument(int) with a tensor is start(Required).
  • The 4th argument(int) with torch or the 3rd argument(int) with a tensor is length(Required).
  • narrow() uses a shared tensor while narrow_copy() copies a tensor taking more memory so narrow() is ligher and faster than narrow_copy().
import torch

my_tensor = torch.tensor([6, -4, 5, 7, 1, 9, -6, 0, 3, -2, -5, 8])

torch.narrow(my_tensor, dim=0, start=3, length=5)
my_tensor.narrow(dim=0, start=3, length=5)
torch.narrow(my_tensor, dim=0, start=-9, length=5)
torch.narrow(my_tensor, dim=-1, start=3, length=5)
torch.narrow(my_tensor, dim=-1, start=-9, length=5)
torch.narrow_copy(my_tensor, dim=0, start=3, length=5)
my_tensor.narrow_copy(dim=0, start=3, length=5)
torch.narrow_copy(my_tensor, dim=0, start=-9, length=5)
torch.narrow_copy(my_tensor, dim=-1, start=3, length=5)
torch.narrow_copy(my_tensor, dim=-1, start=-9, length=5)
# tensor([7, 1, 9, -6, 0])

my_tensor = torch.tensor([[6, -4, 5, 7],
                          [1, 9, -6, 0],
                          [3, -2, -5, 8]])
torch.narrow(my_tensor, dim=0, start=1, length=2)
torch.narrow(my_tensor, dim=0, start=-2, length=2)
torch.narrow(my_tensor, dim=-2, start=1, length=2)
torch.narrow(my_tensor, dim=-2, start=-2, length=2)
torch.narrow_copy(my_tensor, dim=0, start=1, length=2)
torch.narrow_copy(my_tensor, dim=0, start=-2, length=2)
torch.narrow_copy(my_tensor, dim=-2, start=1, length=2)
torch.narrow_copy(my_tensor, dim=-2, start=-2, length=2)
# tensor([[1, 9, -6, 0],
#         [3, -2, -5, 8]])

torch.narrow(my_tensor, dim=1, start=1, length=2)
torch.narrow(my_tensor, dim=1, start=-3, length=2)
torch.narrow(my_tensor, dim=-1, start=1, length=2)
torch.narrow(my_tensor, dim=-1, start=-3, length=2)
torch.narrow_copy(my_tensor, dim=1, start=1, length=2)
torch.narrow_copy(my_tensor, dim=1, start=-3, length=2)
torch.narrow_copy(my_tensor, dim=-1, start=1, length=2)
torch.narrow_copy(my_tensor, dim=-1, start=-3, length=2)
# tensor([[-4, 5],
#         [9, -6],
#         [-2, -5]])

my_tensor = torch.tensor([[[6, -4], [5, 7]],
                          [[1, 9], [-6, 0]],
                          [[3, -2], [-5, 8]]])
torch.narrow(my_tensor, dim=0, start=1, length=1)
torch.narrow(my_tensor, dim=0, start=-2, length=1)
torch.narrow_copy(my_tensor, dim=0, start=1, length=1)
torch.narrow_copy(my_tensor, dim=0, start=-2, length=1)
# tensor([[[1, 9],
#          [-6, 0]]])

torch.narrow(my_tensor, dim=1, start=1, length=1)
torch.narrow(my_tensor, dim=1, start=-1, length=1)
torch.narrow_copy(my_tensor, dim=1, start=1, length=1)
torch.narrow_copy(my_tensor, dim=1, start=-1, length=1)
# tensor([[[5, 7]],
#        [[-6, 0]],
#        [[-5, 8]]])

torch.narrow(my_tensor, dim=2, start=1, length=1)
torch.narrow(my_tensor, dim=2, start=-1, length=1)
torch.narrow_copy(my_tensor, dim=2, start=1, length=1)
torch.narrow_copy(my_tensor, dim=2, start=-1, length=1)
# tensor([[[-4], [7]],
#         [[9], [0]],
#         [[-2], [8]]])

my_tensor = torch.tensor([[[6., -4.], [5., 7.]],
                          [[True, 9.], [-6+0j, False]],
                          [[3+0j, -2+0j], [-5+0j, 8+0j]]])
torch.narrow(my_tensor, dim=0, start=1, length=1)
torch.narrow_copy(my_tensor, dim=0, start=1, length=1)
# tensor([[[1.+0.j, 9.+0.j],
#          [-6.+0.j, 0.+0.j]]])
Enter fullscreen mode Exit fullscreen mode

Top comments (0)