DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on • Updated on

eye(), diag() and diagflat() in PyTorch

*My post explains diagonal() and diag_embed().

eye() can create a 2D tensor with zero or more 1., 1, 1.+0.j or True on the diagonal and zero or more 0., 0, 0.+0.j or False elsewhere as shown below:

*Memos:

  • eye() can be used with torch but not with a tensor.
  • The returned tensor has zero or more floating-point numbers(Default), integers, complex numbers or boolean values.
  • The 2nd argument(int) with torch is n(Required) which is the number of rows.
  • The 3rd argument(int) with torch is m(Optional-Default:n) which is the number of columns.
import torch

torch.eye(0)
# tensor([], size=(0, 0))

torch.eye(1)
# tensor([[1.]])

torch.eye(2)
# tensor([[1., 0.],
#         [0., 1.]])

torch.eye(3)
# tensor([[1., 0., 0.],
#         [0., 1., 0.],
#         [0., 0., 1.]])

torch.eye(4)
# tensor([[1., 0., 0., 0.],
#         [0., 1., 0., 0.],
#         [0., 0., 1., 0.],
#         [0., 0., 0., 1.]])

torch.eye(4, 0)
# tensor([], size=(4, 0))

torch.eye(4, 1) 
# tensor([[1.],
#         [0.],
#         [0.],
#         [0.]])

torch.eye(4, 2)
# tensor([[1., 0.],
#         [0., 1.],
#         [0., 0.],
#         [0., 0.]])

torch.eye(4, 3)
# tensor([[1., 0., 0.],
#         [0., 1., 0.],
#         [0., 0., 1.],
#         [0., 0., 0.]])

torch.eye(4, 4)
# tensor([[1., 0., 0., 0.],
#         [0., 1., 0., 0.],
#         [0., 0., 1., 0.],
#         [0., 0., 0., 1.]])

torch.eye(4, 5)
# tensor([[1., 0., 0., 0., 0.],
#         [0., 1., 0., 0., 0.],
#         [0., 0., 1., 0., 0.],
#         [0., 0., 0., 1., 0.]])

torch.eye(4, 6)
# tensor([[1., 0., 0., 0., 0., 0.],
#         [0., 1., 0., 0., 0., 0.],
#         [0., 0., 1., 0., 0., 0.],
#         [0., 0., 0., 1., 0., 0.]])

torch.eye(4, 6, dtype=torch.int64)
# tensor([[1, 0, 0, 0, 0, 0],
#         [0, 1, 0, 0, 0, 0],
#         [0, 0, 1, 0, 0, 0],
#         [0, 0, 0, 1, 0, 0]])

torch.eye(4, 6, dtype=torch.complex64)
# tensor([[1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
#         [0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
#         [0.+0.j, 0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j, 0.+0.j],
#         [0.+0.j, 0.+0.j, 0.+0.j, 1.+0.j, 0.+0.j, 0.+0.j]])

torch.eye(4, 6, dtype=torch.bool)
# tensor([[True, False, False, False, False, False],
#         [False, True, False, False, False, False],
#         [False, False, True, False, False, False],
#         [False, False, False, True, False, False]])
Enter fullscreen mode Exit fullscreen mode

diag() can create a 2D tensor with a 1D tensor on the diagonal and zero or more 0, 0., 0.+0.j or False elsewhere or extract a 1D tensor from a 2D tensor on the diagonal as shown below:

*Memos:

  • diag() can be used with torch or a tensor.
  • Only a 2D or 1D tensor can be used.
  • A 2D tensor creates a 1D tensor.
  • A 1D tensor creates a 2D tensor.
  • The tensor of zero or more integers, floating-point numbers, complex numbers or boolean values can be used.
  • The 2nd argument(int) with torch or the 1st argument(int) with a tensor is diagonal(Optional-Default:0).
import torch

my_tensor = torch.tensor([7, -4, 5])

torch.diag(my_tensor)
torch.diag(my_tensor, diagonal=0)
# tensor([[7, 0, 0],
#         [0, -4, 0],
#         [0, 0, 5]])

torch.diag(my_tensor, diagonal=1)
# tensor([[0, 7, 0, 0],
#         [0, 0, -4, 0],
#         [0, 0, 0, 5],
#         [0, 0, 0, 0]])

torch.diag(my_tensor, diagonal=-1)
# tensor([[0, 0, 0, 0],
#         [7, 0, 0, 0],
#         [0, -4, 0, 0],
#         [0, 0, 5, 0]])

torch.diag(my_tensor, diagonal=2)
# tensor([[0, 0, 7, 0, 0],
#         [0, 0, 0, -4, 0],
#         [0, 0, 0, 0, 5],
#         [0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0]])

torch.diag(my_tensor, diagonal=-2)
# tensor([[0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0],
#         [7, 0, 0, 0, 0],
#         [0, -4, 0, 0, 0],
#         [0, 0, 5, 0, 0]])

my_tensor = torch.tensor([7., -4., 5.])

torch.diag(my_tensor)
# tensor([[7., 0., 0.],
#         [0., -4., 0.],
#         [0., 0., 5.]])

my_tensor = torch.tensor([7+0j, -4+0j, 5+0j])

torch.diag(my_tensor)
# tensor([[7.+0.j, 0.+0.j, 0.+0.j],
#         [0.+0.j, -4.+0.j, 0.+0.j],
#         [0.+0.j, 0.+0.j, 5.+0.j]])

my_tensor = torch.tensor([True, True, True])

torch.diag(my_tensor)
# tensor([[True, False, False],
#         [False, True, False],
#         [False, False, True]])

my_tensor = torch.tensor([[7, -4, 5],
                          [-6, -3, 8],
                          [9, 1, -2]])
torch.diag(my_tensor)
torch.diag(my_tensor, diagonal=0)
# tensor([7, -3, -2])

torch.diag(my_tensor, diagonal=1)
# tensor([-4, 8])

torch.diag(my_tensor, diagonal=-1)
# tensor([-6, -1])

torch.diag(my_tensor, diagonal=2)
# tensor([5])

torch.diag(my_tensor, diagonal=-2)
tensor([9])
Enter fullscreen mode Exit fullscreen mode

diagflat() can create a 2D tensor with a 0D or more D tensor on the diagonal and zero or more 0, 0., 0.+0.j or False elsewhere as shown below:

*Memos:

  • diagflat() can be used with torch or a tensor.
  • The tensor of zero or more integers, floating-point numbers, complex numbers or boolean values can be used.
  • The 2nd argument(int) with torch or the 1st argument(int) with a tensor is offset(Optional-Default:0).
import torch

my_tensor = torch.tensor([7, -4, 5])

torch.diagflat(my_tensor)
torch.diagflat(my_tensor, offset=0)
# tensor([[7, 0, 0],
#         [0, -4, 0],
#         [0, 0, 5]])

torch.diagflat(my_tensor, offset=1)
# tensor([[0, 7, 0, 0],
#         [0, 0, -4, 0],
#         [0, 0, 0, 5],
#         [0, 0, 0, 0]])

torch.diagflat(my_tensor, offset=-1)
# tensor([[0, 0, 0, 0],
#         [7, 0, 0, 0],
#         [0, -4, 0, 0],
#         [0, 0, 5, 0]])

torch.diagflat(my_tensor, offset=2)
# tensor([[0, 0, 7, 0, 0],
#         [0, 0, 0, -4, 0],
#         [0, 0, 0, 0, 5],
#         [0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0]])

torch.diagflat(my_tensor, offset=-2)
# tensor([[0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0],
#         [7, 0, 0, 0, 0],
#         [0, -4, 0, 0, 0],
#         [0, 0, 5, 0, 0]])

my_tensor = torch.tensor([7., -4., 5.])

torch.diagflat(my_tensor)
# tensor([[7., 0., 0.],
#         [0., -4., 0.],
#         [0., 0., 5.]])

my_tensor = torch.tensor([7+0j, -4+0j, 5+0j])

torch.diagflat(my_tensor)
# tensor([[7.+0.j, 0.+0.j, 0.+0.j],
#         [0.+0.j, -4.+0.j, 0.+0.j],
#         [0.+0.j, 0.+0.j, 5.+0.j]])

my_tensor = torch.tensor([True, True, True])

torch.diagflat(my_tensor)
# tensor([[True, False, False],
#         [False, True, False],
#         [False, False, True]])

my_tensor = torch.tensor([[7, -4, 5],
                          [-6, -3, 8],
                          [9, 1, -2]])
torch.diagflat(my_tensor)
torch.diagflat(my_tensor, offset=0)
# tensor([[7, 0, 0, 0, 0, 0, 0, 0, 0],
#         [0, -4, 0, 0, 0, 0, 0, 0, 0],
#         [0, 0, 5, 0, 0, 0, 0, 0, 0],
#         [0, 0, 0, -6, 0, 0, 0, 0, 0],
#         [0, 0, 0, 0, -3, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0, 8, 0, 0, 0],
#         [0, 0, 0, 0, 0, 0, 9, 0, 0],
#         [0, 0, 0, 0, 0, 0, 0, 1, 0],
#         [0, 0, 0, 0, 0, 0, 0, 0, -2]])

torch.diagflat(my_tensor, offset=1)
# tensor([[0, 7, 0, 0, 0, 0, 0, 0, 0, 0],
#         [0, 0, -4, 0, 0, 0, 0, 0, 0, 0],
#         [0, 0, 0, 5, 0, 0, 0, 0, 0, 0],
#         [0, 0, 0, 0, -6, 0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0, -3, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0, 0, 8, 0, 0, 0],
#         [0, 0, 0, 0, 0, 0, 0, 9, 0, 0],
#         [0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
#         [0, 0, 0, 0, 0, 0, 0, 0, 0, -2],
#         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])

torch.diagflat(my_tensor, offset=-1)
# tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
#         [7, 0, 0, 0, 0, 0, 0, 0, 0, 0],
#         [0,-4, 0, 0, 0, 0, 0, 0, 0, 0],
#         [0, 0, 5, 0, 0, 0, 0, 0, 0, 0],
#         [0, 0, 0, -6, 0, 0, 0, 0, 0, 0],
#         [0, 0, 0, 0, -3, 0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0, 8, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0, 0, 9, 0, 0, 0],
#         [0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
#         [0, 0, 0, 0, 0, 0, 0, 0, -2, 0]])

torch.diagflat(my_tensor, offset=2)
# tensor([[0, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0],
#         [0, 0, 0, -4, 0, 0, 0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0, -6, 0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0, 0, -3, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0, 0, 0, 8, 0, 0, 0],
#         [0, 0, 0, 0, 0, 0, 0, 0, 9, 0, 0],
#         [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
#         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -2],
#         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])

torch.diagflat(my_tensor, offset=-2)
# tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
#         [7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
#         [0, -4, 0, 0, 0, 0, 0, 0, 0, 0, 0],
#         [0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0],
#         [0, 0, 0, -6, 0, 0, 0, 0, 0, 0, 0],
#         [0, 0, 0, 0, -3, 0, 0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0, 8, 0, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0, 0, 9, 0, 0, 0, 0],
#         [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
#         [0, 0, 0, 0, 0, 0, 0, 0, -2, 0, 0]])
Enter fullscreen mode Exit fullscreen mode

Top comments (0)