DEV Community

Super Kai (Kazuya Ito)
Super Kai (Kazuya Ito)

Posted on

CelebA in PyTorch

Buy Me a Coffee

*My post explains CelebA.

CelebA() can use CelebA dataset as shown below:

*Memos:

  • The 1st argument is root(Required-Type:str or pathlib.Path). *An absolute or relative path is possible.
  • The 2nd argument is split(Optional-Default:"train"-Type:str). *"train"(162,770 images), "valid"(19,867 images), "test"(19,962 images) or "all"(202,599 images) can be set to it.
  • The 3rd argument is target_type(Optional-Default:"attr"-Type:str or list of str): *Memos:
    • "attr", "identity", "bbox" and/or "landmarks" can be set to it.
    • An empty list can also be set to it.
    • The multiple same values can be set to it.
    • If the order of values is different, the order of their elements is also different.
  • The 4th argument is transform(Optional-Default:None-Type:callable).
  • The 5th argument is target_transform(Optional-Default:None-Type:callable).
  • The 6th argument is download(Optional-Default:False-Type:bool): *Memos:
    • If it's True, the dataset is downloaded from the internet and extracted(unzipped) to root.
    • If it's True and the dataset is already downloaded, it's extracted.
    • If it's True and the dataset is already downloaded and extracted, nothing happens.
    • It should be False if the dataset is already downloaded and extracted because it's faster.
    • gdown is required to download the dataset.
    • You can manually download and extract the dataset(img_align_celeba.zip with identity_CelebA.txt, list_attr_celeba.txt, list_bbox_celeba.txt, list_eval_partition.txt and list_landmarks_align_celeba.txt) from here to data/celeba/.
from torchvision.datasets import CelebA

train_attr_data = CelebA(
    root="data"
)

train_attr_data = CelebA(
    root="data",
    split="train",
    target_type="attr",
    transform=None,
    target_transform=None,
    download=False
)

valid_identity_data = CelebA(
    root="data",
    split="valid",
    target_type="identity"
)

test_bbox_data = CelebA(
    root="data",
    split="test",
    target_type="bbox"
)

all_landmarks_data = CelebA(
    root="data",
    split="all",
    target_type="landmarks"
)

all_empty_data = CelebA(
    root="data",
    split="all",
    target_type=[]
)

all_all_data = CelebA(
    root="data",
    split="all",
    target_type=["attr", "identity", "bbox", "landmarks"]
)

len(train_attr_data), len(valid_identity_data), len(test_bbox_data)
# (162770, 19867, 19962)

len(all_landmarks_data), len(all_empty_data), len(all_all_data)
# (202599, 202599, 202599)

train_attr_data
# Dataset CelebA
#     Number of datapoints: 162770
#     Root location: data
#     Target type: ['attr']
#     Split: train

train_attr_data.root
# 'data'

train_attr_data.split
# 'train'

train_attr_data.target_type
# ['attr']

print(train_attr_data.transform)
# None

print(train_attr_data.target_transform)
# None

train_attr_data.download
# <bound method CelebA.download of Dataset CelebA
#     Number of datapoints: 162770
#     Root location: data
#     Target type: ['attr']
#     Split: train>

len(train_attr_data.attr), train_attr_data.attr
# (162770, tensor([[0, 1, 1, ..., 0, 0, 1],
#                  [0, 0, 0, ..., 0, 0, 1],
#                  [0, 0, 0, ..., 0, 0, 1],
#                  ...,
#                  [1, 0, 1, ..., 0, 1, 1],
#                  [0, 0, 0, ..., 0, 0, 1],
#                  [0, 1, 1, ..., 1, 0, 1]]))

len(train_attr_data.attr_names), train_attr_data.attr_names
# (41, ['5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 
#       'Bags_Under_Eyes', 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose',
#       'Black_Hair', 'Blond_Hair', 'Blurry', 'Brown_Hair',
#       ...
#       'Wearing_Necklace', 'Wearing_Necktie', 'Young', ''])

len(train_attr_data.identity), train_attr_data.identity
# (162770, tensor([[2880], [2937], [8692], ..., [7391], [8610], [2304]]))

len(train_attr_data.bbox), train_attr_data.bbox
# (162770, tensor([[95, 71, 226, 313],
#                  [72, 94, 221, 306],
#                  [216, 59, 91, 126],
#                  ...,
#                  [103, 103, 143, 198],
#                  [30, 59, 216, 280],
#                  [376, 4, 372, 515]]))

len(train_attr_data.landmarks_align), train_attr_data.landmarks_align
# (162770, tensor([[69, 109, 106, ..., 152, 108, 154],
#                  [69, 110, 107, ..., 151, 108, 153],
#                  [76, 112, 104, ..., 156, 98, 158],
#                  ...,
#                  [69, 113, 109, ..., 151, 110, 151],
#                  [68, 112, 109, ..., 150, 108, 151],
#                  [70, 111, 107, ..., 153, 102, 152]]))

train_attr_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([0, 1, 1, 0, 0, 0, 0, 0, 0, 0,
#          0, 1, 0, 0, 0, 0, 0, 0, 1, 1,
#          0, 1, 0, 0, 1, 0, 0, 1, 0, 0,
#          0, 1, 1, 0, 1, 0, 1, 0, 0, 1]))

train_attr_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([0, 0, 0, 1, 0, 0, 0, 1, 0, 0,
#          0, 1, 0, 0, 0, 0, 0, 0, 0, 1,
#          0, 1, 0, 0, 1, 0, 0, 0, 0, 0,
#          0, 1, 0, 0, 0, 0, 0, 0, 0, 1]))

train_attr_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
#          1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
#          1, 0, 0, 1, 1, 0, 0, 1, 0, 0,
#          0, 0, 0, 1, 0, 0, 0, 0, 0, 1]))

valid_identity_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor(2594))

valid_identity_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor(2795))

valid_identity_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor(947))

test_bbox_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([147, 82, 120, 166]))

test_bbox_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([106, 34, 140, 194]))

test_bbox_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([107, 78, 109, 151]))

all_landmarks_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([69, 109, 106, 113, 77, 142, 73, 152, 108, 154]))

all_landmarks_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([69, 110, 107, 112, 81, 135, 70, 151, 108, 153]))

all_landmarks_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([76, 112, 104, 106, 108, 128, 74, 156, 98, 158]))

all_empty_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>, None)

all_empty_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>, None)

all_empty_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>, None)

all_all_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  (tensor([0, 1, 1, 0, 0, 0, 0, 0, 0, 0,
#           0, 1, 0, 0, 0, 0, 0, 0, 1, 1,
#           0, 1, 0, 0, 1, 0, 0, 1, 0, 0,
#           0, 1, 1, 0, 1, 0, 1, 0, 0, 1]),
#   tensor(2880),
#   tensor([95, 71, 226, 313]),
#   tensor([69, 109, 106, 113, 77, 142, 73, 152, 108, 154])))

all_all_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  (tensor([0, 0, 0, 1, 0, 0, 0, 1, 0, 0,
#           0, 1, 0, 0, 0, 0, 0, 0, 0, 1,
#           0, 1, 0, 0, 1, 0, 0, 0, 0, 0,
#           0, 1, 0, 0, 0, 0, 0, 0, 0, 1]),
#   tensor(2937),
#   tensor([72, 94, 221, 306]),
#   tensor([69, 110, 107, 112, 81, 135, 70, 151, 108, 153])))

all_all_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  (tensor([0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
#           1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
#           1, 0, 0, 1, 1, 0, 0, 1, 0, 0,
#           0, 0, 0, 1, 0, 0, 0, 0, 0, 1]),
#  tensor(8692),
#  tensor([216, 59, 91, 126]),
#  tensor([76, 112, 104, 106, 108, 128, 74, 156, 98, 158])))

import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.patches import Circle

def show_images(data, main_title=None):
    if "attr" in data.target_type and len(data.target_type) == 1 \
        or not data.target_type:
        plt.figure(figsize=(12, 6))
        plt.suptitle(t=main_title, y=1.0, fontsize=14)
        for i, (im, _) in enumerate(data, start=1):
            plt.subplot(2, 5, i)
            plt.imshow(X=im)
            if i == 10:
                break
        plt.tight_layout(h_pad=3.0)
        plt.show()
    elif "identity" in data.target_type and len(data.target_type) == 1:
        plt.figure(figsize=(12, 6))
        plt.suptitle(t=main_title, y=1.0, fontsize=14)
        for i, (im, lab) in enumerate(data, start=1):
            plt.subplot(2, 5, i)
            plt.title(label=lab.item())
            plt.imshow(X=im)
            if i == 10:
                break
        plt.tight_layout(h_pad=3.0)
        plt.show()
    elif "bbox" in data.target_type and len(data.target_type) == 1:
        fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(12, 6))
        fig.suptitle(t=main_title, y=1.0, fontsize=14)
        for (i, (im, (x, y, w, h))), axis \
            in zip(enumerate(data, start=1), axes.ravel()):
            axis.imshow(X=im)
            rect = Rectangle(xy=(x, y), width=w, height=h,
                             linewidth=3, edgecolor='r',
                             facecolor='none')
            axis.add_patch(p=rect)
            if i == 10:
                break
        fig.tight_layout(h_pad=3.0)
        plt.show()
    elif "landmarks" in data.target_type and len(data.target_type) == 1:
        plt.figure(figsize=(12, 6))
        plt.suptitle(t=main_title, y=1.0, fontsize=14)
        for i, (im, lm) in enumerate(data, start=1):
            px = []
            py = []
            for j, v in enumerate(lm):
                if j%2 == 0:
                    px.append(v)
                else:
                    py.append(v)
            plt.subplot(2, 5, i)
            plt.imshow(X=im)
            plt.scatter(x=px, y=py)
            if i == 10:
                break
        plt.tight_layout(h_pad=3.0)
        plt.show()
    elif len(data.target_type) == 4:
        fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(12, 6))
        fig.suptitle(t=main_title, y=1.0, fontsize=14)
        for (i, (im, (_, lab, (x, y, w, h), lm))), axis \
            in zip(enumerate(data, start=1), axes.ravel()):
            axis.set_title(label=lab.item())
            axis.imshow(X=im)
            rect = Rectangle(xy=(x, y), width=w, height=h,
                             linewidth=3, edgecolor='r',
                             facecolor='none', clip_on=True)
            axis.add_patch(p=rect)
            for j, (px, py) in enumerate(lm.split(2)):
                axis.add_patch(p=Circle(xy=(px, py)))
            # for j, v in enumerate(lm):
            #     if j%2 == 0:
            #         px.append(v)
            #     else:
            #         py.append(v)
            # axis.scatter(x=px, y=py)
            # axis.plot(px, py)
# `axis.scatter()` and `axis.plot()` of `plt.subplots()` don't work
# properly. They shrink images so use `axis.add_patch()` instead.
            if i == 10:
                break
        fig.tight_layout(h_pad=3.0)
        plt.show()

show_images(data=train_attr_data, main_title="train_attr_data")
show_images(data=valid_identity_data, main_title="valid_identity_data")
show_images(data=test_bbox_data, main_title="test_bbox_data")
show_images(data=all_landmarks_data, main_title="all_landmarks_data")
show_images(data=all_empty_data, main_title="all_empty_data")
show_images(data=all_all_data, main_title="all_all_data")
Enter fullscreen mode Exit fullscreen mode

Image description

Image description

Image description

Image description

Image description

Image description

Top comments (0)