narrow() and narrow_copy() can extract a tensor from a 1D or more D tensor as shown below:
*Memos:
-
narrow()
andnarrow_copy()
can be used with torch or a tensor. - The tensor of zero or more integers, floating-point numbers, complex numbers or boolean values can be used.
- The 2nd argument(
int
) with torch or the 1st argument(int
) with a tensor isdim
(Required). - The 3rd argument(
int
) with torch or the 2nd argument(int
) with a tensor isstart
(Required). - The 4th argument(
int
) with torch or the 3rd argument(int
) with a tensor islength
(Required). -
narrow()
uses a shared tensor whilenarrow_copy()
copies a tensor taking more memory sonarrow()
is ligher and faster thannarrow_copy()
.
import torch
my_tensor = torch.tensor([6, -4, 5, 7, 1, 9, -6, 0, 3, -2, -5, 8])
torch.narrow(my_tensor, dim=0, start=3, length=5)
my_tensor.narrow(dim=0, start=3, length=5)
torch.narrow(my_tensor, dim=0, start=-9, length=5)
torch.narrow(my_tensor, dim=-1, start=3, length=5)
torch.narrow(my_tensor, dim=-1, start=-9, length=5)
torch.narrow_copy(my_tensor, dim=0, start=3, length=5)
my_tensor.narrow_copy(dim=0, start=3, length=5)
torch.narrow_copy(my_tensor, dim=0, start=-9, length=5)
torch.narrow_copy(my_tensor, dim=-1, start=3, length=5)
torch.narrow_copy(my_tensor, dim=-1, start=-9, length=5)
# tensor([7, 1, 9, -6, 0])
my_tensor = torch.tensor([[6, -4, 5, 7],
[1, 9, -6, 0],
[3, -2, -5, 8]])
torch.narrow(my_tensor, dim=0, start=1, length=2)
torch.narrow(my_tensor, dim=0, start=-2, length=2)
torch.narrow(my_tensor, dim=-2, start=1, length=2)
torch.narrow(my_tensor, dim=-2, start=-2, length=2)
torch.narrow_copy(my_tensor, dim=0, start=1, length=2)
torch.narrow_copy(my_tensor, dim=0, start=-2, length=2)
torch.narrow_copy(my_tensor, dim=-2, start=1, length=2)
torch.narrow_copy(my_tensor, dim=-2, start=-2, length=2)
# tensor([[1, 9, -6, 0],
# [3, -2, -5, 8]])
torch.narrow(my_tensor, dim=1, start=1, length=2)
torch.narrow(my_tensor, dim=1, start=-3, length=2)
torch.narrow(my_tensor, dim=-1, start=1, length=2)
torch.narrow(my_tensor, dim=-1, start=-3, length=2)
torch.narrow_copy(my_tensor, dim=1, start=1, length=2)
torch.narrow_copy(my_tensor, dim=1, start=-3, length=2)
torch.narrow_copy(my_tensor, dim=-1, start=1, length=2)
torch.narrow_copy(my_tensor, dim=-1, start=-3, length=2)
# tensor([[-4, 5],
# [9, -6],
# [-2, -5]])
my_tensor = torch.tensor([[[6, -4], [5, 7]],
[[1, 9], [-6, 0]],
[[3, -2], [-5, 8]]])
torch.narrow(my_tensor, dim=0, start=1, length=1)
torch.narrow(my_tensor, dim=0, start=-2, length=1)
torch.narrow_copy(my_tensor, dim=0, start=1, length=1)
torch.narrow_copy(my_tensor, dim=0, start=-2, length=1)
# tensor([[[1, 9],
# [-6, 0]]])
torch.narrow(my_tensor, dim=1, start=1, length=1)
torch.narrow(my_tensor, dim=1, start=-1, length=1)
torch.narrow_copy(my_tensor, dim=1, start=1, length=1)
torch.narrow_copy(my_tensor, dim=1, start=-1, length=1)
# tensor([[[5, 7]],
# [[-6, 0]],
# [[-5, 8]]])
torch.narrow(my_tensor, dim=2, start=1, length=1)
torch.narrow(my_tensor, dim=2, start=-1, length=1)
torch.narrow_copy(my_tensor, dim=2, start=1, length=1)
torch.narrow_copy(my_tensor, dim=2, start=-1, length=1)
# tensor([[[-4], [7]],
# [[9], [0]],
# [[-2], [8]]])
my_tensor = torch.tensor([[[6., -4.], [5., 7.]],
[[True, 9.], [-6+0j, False]],
[[3+0j, -2+0j], [-5+0j, 8+0j]]])
torch.narrow(my_tensor, dim=0, start=1, length=1)
torch.narrow_copy(my_tensor, dim=0, start=1, length=1)
# tensor([[[1.+0.j, 9.+0.j],
# [-6.+0.j, 0.+0.j]]])
Top comments (0)