DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on

diagflat() in PyTorch

Buy Me a Coffee

*Memos:

diagflat() can create the 2D tensor of zero or more elements on the diagonal and zero or more 0, 0., 0.+0.j or False elsewhere from the 0D or more D tensor of zero or more elements as shown below:

*Memos:

  • diagflat() can be used with torch or a tensor.
  • The 1st argument(input) with torch or using a tensor(Required-Type:tensor of int, float, complex or bool).
  • The 2nd argument with torch or the 1st argument with a tensor is offset(Optional-Default:0-Type:int).
import torch

my_tensor = torch.tensor([7, -4, 5])

torch.diagflat(input=my_tensor)
my_tensor.diagflat()
torch.diagflat(input=my_tensor, offset=0)
# tensor([[7, 0, 0],
#         [0, -4, 0],
#         [0, 0, 5]])

torch.diagflat(input=my_tensor, offset=1)
# tensor([[0, 7, 0, 0],
#         [0, 0, -4, 0],
#         [0, 0, 0, 5],
#         [0, 0, 0, 0]])

torch.diagflat(input=my_tensor, offset=-1)
# tensor([[0, 0, 0, 0],
#         [7, 0, 0, 0],
#         [0, -4, 0, 0],
#         [0, 0, 5, 0]])

torch.diagflat(input=my_tensor, offset=2)
# tensor([[0, 0, 7, 0, 0],
#         [0, 0, 0, -4, 0],
#         [0, 0, 0, 0, 5],
#         [0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0]])

torch.diagflat(input=my_tensor, offset=-2)
# tensor([[0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0],
#         [7, 0, 0, 0, 0],
#         [0, -4, 0, 0, 0],
#         [0, 0, 5, 0, 0]])

my_tensor = torch.tensor([7., -4., 5.])

torch.diagflat(input=my_tensor)
# tensor([[7., 0., 0.],
#         [0., -4., 0.],
#         [0., 0., 5.]])

my_tensor = torch.tensor([7.+0.j, -4.+0.j, 5.+0.j])

torch.diagflat(input=my_tensor)
# tensor([[7.+0.j, 0.+0.j, 0.+0.j],
#         [0.+0.j, -4.+0.j, 0.+0.j],
#         [0.+0.j, 0.+0.j, 5.+0.j]])

my_tensor = torch.tensor([True, True, True])

torch.diagflat(input=my_tensor)
# tensor([[True, False, False],
#         [False, True, False],
#         [False, False, True]])

my_tensor = torch.tensor([[7, -4, 5],
                          [-6, -3, 8],
                          [9, 1, -2]])
torch.diagflat(input=my_tensor)
torch.diagflat(input=my_tensor, offset=0)
# tensor([[7, 0, 0, 0, 0, 0, 0, 0, 0],
#         [0, -4, 0, 0, 0, 0, 0, 0, 0],
#         [0, 0, 5, 0, 0, 0, 0, 0, 0],
#         [0, 0, 0, -6, 0, 0, 0, 0, 0],
#         [0, 0, 0, 0, -3, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0, 8, 0, 0, 0],
#         [0, 0, 0, 0, 0, 0, 9, 0, 0],
#         [0, 0, 0, 0, 0, 0, 0, 1, 0],
#         [0, 0, 0, 0, 0, 0, 0, 0, -2]])

torch.diagflat(input=my_tensor, offset=1)
# tensor([[0, 7, 0, 0, 0, 0, 0, 0, 0, 0],
#         [0, 0, -4, 0, 0, 0, 0, 0, 0, 0],
#         [0, 0, 0, 5, 0, 0, 0, 0, 0, 0],
#         [0, 0, 0, 0, -6, 0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0, -3, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0, 0, 8, 0, 0, 0],
#         [0, 0, 0, 0, 0, 0, 0, 9, 0, 0],
#         [0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
#         [0, 0, 0, 0, 0, 0, 0, 0, 0, -2],
#         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])

torch.diagflat(input=my_tensor, offset=-1)
# tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
#         [7, 0, 0, 0, 0, 0, 0, 0, 0, 0],
#         [0,-4, 0, 0, 0, 0, 0, 0, 0, 0],
#         [0, 0, 5, 0, 0, 0, 0, 0, 0, 0],
#         [0, 0, 0, -6, 0, 0, 0, 0, 0, 0],
#         [0, 0, 0, 0, -3, 0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0, 8, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0, 0, 9, 0, 0, 0],
#         [0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
#         [0, 0, 0, 0, 0, 0, 0, 0, -2, 0]])

torch.diagflat(input=my_tensor, offset=2)
# tensor([[0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0],
#         [0, 0, 0, -4, 0, 0, 0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0, -6, 0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0, 0, -3, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0],
#         [0, 0, 0, 0, 0, 0, 0, 0, 9, 0, 0],
#         [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
#         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2],
#         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])

torch.diagflat(input=my_tensor, offset=-2)
# tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
#         [7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
#         [0, -4, 0, 0, 0, 0, 0, 0, 0, 0, 0],
#         [0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0],
#         [0, 0, 0, -6, 0, 0, 0, 0, 0, 0, 0],
#         [0, 0, 0, 0, -3, 0, 0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0, 0, 9, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
#         [0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0]])
Enter fullscreen mode Exit fullscreen mode

Top comments (0)