DEV Community

Cover image for The Python Magic Behind PyTorch
Amit Chaudhary
Amit Chaudhary

Posted on • Edited on • Originally published at amitness.com

The Python Magic Behind PyTorch

PyTorch has emerged as one of the go-to deep learning frameworks in recent years. This popularity can be attributed to its easy to use API and it being more "pythonic".

Python and PyTorch combined Logo

PyTorch leverages numerous native features of Python to give us a consistent and clean API. In this article, I will explain those native features in detail. Learning these will help you better understand why you do things a certain way in PyTorch and make better use of what it has to offer.

Magic Methods in Layers

Layers such as


 are some of the basic constructs in PyTorch that we use to build our models. You import the layer and apply them to tensors.



```python
import torch
import torch.nn as nn

x = torch.rand(1, 784)
layer = nn.Linear(784, 10)
output = layer(x)
Enter fullscreen mode Exit fullscreen mode

Here we are able to call layer on some tensor x, so it must be a function right? Is nn.Linear() returning a function? Let's verify it by checking the type.

>>> type(layer)
<class 'torch.nn.modules.linear.Linear'>
Enter fullscreen mode Exit fullscreen mode

Surprise! nn.Linear is actually a class and layer an object of that class.

"What! How could we call it then? Aren't only functions supposed to be callable?"

Nope, we can create callable objects as well. Python provides a native way to make objects created from classes callable by using magic functions.
Let's see a simple example of a class that doubles a number.

class Double(object):
    def __call__(self, x):
        return 2*x
Enter fullscreen mode Exit fullscreen mode

Here we add a magic method __call__ in the class to double any number passed to it. Now, you can create an object out of this class and call it on some number.

>>> d = Double()
>>> d(2)
4
Enter fullscreen mode Exit fullscreen mode

Alternatively, the above code can be combined in the single line itself.

>>> Double()(2)
4
Enter fullscreen mode Exit fullscreen mode

This works because everything in Python is an object. See an example of a function below that doubles a number.

def double(x):
    return 2*x

>>> double(2)
4
Enter fullscreen mode Exit fullscreen mode

Even functions invoke the__call__ method behind the scenes.

>>> double.__call__(2)
4
Enter fullscreen mode Exit fullscreen mode

Magic methods in Forward Pass

Let's see an example of a model that applies a single fully connected layer to MNIST images to get 10 outputs.

import torch.nn as nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 10)        

    def forward(self, x):
        return self.fc1(x)
Enter fullscreen mode Exit fullscreen mode

The following code should be familiar to you. We are computing output of this model on some tensor x.

x = torch.rand(10, 784)
model = Model()
output = model(x)
Enter fullscreen mode Exit fullscreen mode

We know calling the model directly on some tensor executes the .forward() function on it. How does that work?

It's the same reason in previous example. We're inheriting the class nn.Module. Internally, nn.Module has a __call__() magic method that calls the .forward(). So, when we override .forward() method later, it's executed.

# nn.Module
class Module(object):
    def __call__(self, x):
        # Simplified
        # Actual implementation has validation and gradient tracking.
        return self.forward(x)
Enter fullscreen mode Exit fullscreen mode

Thus, we were able to call the model directly on tensors.

output = model(x)
# model.__call__(x) 
#   -> model.forward(x)
Enter fullscreen mode Exit fullscreen mode

Magic methods in Dataset

In PyTorch, it is common to create a custom class inheriting from the Dataset class to prepare our training and test datasets. Have you ever wondered why we define methods with obscure names like __len__ and __getitem__ in it?

from torch.utils.data import Dataset

class Numbers(Dataset):
    def __init__(self, x, y):
        self.data = x
        self.labels = y

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i):
        return (self.data[i], self.labels[i])

>>> dataset = Numbers([1, 2, 3], [0, 1, 0])
>>> print(len(dataset))
3
>>> print(dataset[0])
(1, 0)
Enter fullscreen mode Exit fullscreen mode

These methods are builtin magic methods of Python. You know how we can get the length of iterables like list and tuples using len function.

>>> x = [10, 20, 30]
>>> len(x)
3
Enter fullscreen mode Exit fullscreen mode

Python allows defining a __len__ on our custom class so that len() works on it. For example,

class Data(object):
    def __len__(self):
        return 10

>>> d = Data()
>>> len(d)
10
Enter fullscreen mode Exit fullscreen mode

Similarly, you know how we can access elements of list and tuples using index notation.

>>> x = [10, 20, 30]
>>> x[0]
10
Enter fullscreen mode Exit fullscreen mode

Python allows a __getitem__ magic method to allow such functionality for custom classes. For example,

class Data(object):
    def __init__(self):
        self.x = [10, 20, 30]

    def __getitem__(self, i):
        return x[i]

>>> d = Data()
>>> d[0]
10
Enter fullscreen mode Exit fullscreen mode

With the above concept, now you can easily understand the builtin dataset like MNIST and what you can do with them.

from torchvision.datasets import MNIST

>>> trainset = MNIST(root='mnist', download=True, train=True)
>>> print(len(trainset))
60000
>>> print(trainset[0])
(<PIL.Image.Image image mode=L size=28x28 at 0x7F06DC654128>, 0)
Enter fullscreen mode Exit fullscreen mode

Magic methods in DataLoader

Let's create a dataloader for a training dataset of MNIST digits.

from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

trainset = MNIST(root='mnist', 
                 download=True, 
                 train=True, 
                 transform=transforms.ToTensor())
trainloader = DataLoader(trainset, batch_size=32, shuffle=True)
Enter fullscreen mode Exit fullscreen mode

Now, let's try accessing first batch from the data loader directly without looping. If we try to access it via index, we get an exception.

>>> trainloader[0]
TypeError: 'DataLoader' object does not support indexing
Enter fullscreen mode Exit fullscreen mode

You might have been used to doing it in this way.

images, labels =  next(iter(trainloader))
Enter fullscreen mode Exit fullscreen mode

Have you ever wondered why do we wrap trainloader by iter() and then call next()? Let's demystify this.

Consider a list x with 3 elements. In Python, we can create an iterator out of x using the iter function.

x = [1, 2, 3]
y = iter(x)
Enter fullscreen mode Exit fullscreen mode

Iterators are used because they allow lazy loading such that only one element is loaded in memory at a time.

>>> next(x)
1
>>> next(x)
2
>>> next(x)
3
>>> next(x)
StopIteration:
Enter fullscreen mode Exit fullscreen mode

We get each element and when we reach the end of the list, we get a StopIteration exception.

This pattern matches our usual machine learning workflow where we take small batches of data at a time in memory and do the forward and backward pass. So, DataLoader also incorporates this pattern in PyTorch.

To create iterators out of classes in Python, we need to define magic methods __iter__ and __next__

class ExampleLoader(object):
    def __init__(self, data):
        self.data = iter(data)

    def __iter__(self):
        return self

    def __next__(self):
        return next(self.data)        

>>> l = ExampleLoader([1, 2, 3])
>>> next(iter(l))
1
Enter fullscreen mode Exit fullscreen mode

Here, the iter() function calls the __iter__ magic method of the class returning that same object. Then, the next() function calls the __next__ magic method of the class to return next element present in our data.

In PyTorch, the implementation of DataLoader implements this pattern as follows:

class DataLoader(object):
    def __iter__(self):
        if self.num_workers == 0:
            return _SingleProcessDataLoaderIter(self)
        else:
            return _MultiProcessingDataLoaderIter(self)

class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
    def __next__(self):
        # logic to return batch from whole data
        ...
Enter fullscreen mode Exit fullscreen mode

So, they decouple the iterator creation part and the actual data loading part.

  • When you call iter() on the data loader, it checks if we are using single or multiple workers
  • Based on that, it returns another iterator class for either single or multiple worker
>>> type(iter(trainloader))
torch.utils.data.dataloader._SingleProcessDataLoaderIter
Enter fullscreen mode Exit fullscreen mode
  • That iterator class has a __next__ method defined which returns the actual data of set batch_size when we call next() on it
images, labels =  next(iter(trainloader))

# equivalent to:
images, labels = trainloader.__iter__().__next__()
Enter fullscreen mode Exit fullscreen mode

Thus, we get images and labels for a single batch.

Conclusion

Thus, we saw how PyTorch borrows several advanced concepts from native Python itself in its API design. I hope the article was helpful to demystify how these concepts work behind the scenes and will help you become a better PyTorch user.

Connect

If you enjoyed this blog post, feel free to connect with me on Twitter where I share new blog posts every week.

Top comments (0)