## DEV Community is a community of 620,183 amazing developers

We're a place where coders share, stay up-to-date and grow their careers. # The Python Magic Behind PyTorch Amit Chaudhary Originally published at amitness.com Updated on ・5 min read

PyTorch has emerged as one of the go-to deep learning frameworks in recent years. This popularity can be attributed to its easy to use API and it being more "pythonic". PyTorch leverages numerous native features of Python to give us a consistent and clean API. In this article, I will explain those native features in detail. Learning these will help you better understand why you do things a certain way in PyTorch and make better use of what it has to offer.

## Magic Methods in Layers

Layers such as

``````
are some of the basic constructs in PyTorch that we use to build our models. You import the layer and apply them to tensors.

```python
import torch
import torch.nn as nn

x = torch.rand(1, 784)
layer = nn.Linear(784, 10)
output = layer(x)
``````

Here we are able to call layer on some tensor `x`, so it must be a function right? Is `nn.Linear()` returning a function? Let's verify it by checking the type.

``````>>> type(layer)
<class 'torch.nn.modules.linear.Linear'>
``````

Surprise! `nn.Linear` is actually a class and layer an object of that class.

"What! How could we call it then? Aren't only functions supposed to be callable?"

Nope, we can create callable objects as well. Python provides a native way to make objects created from classes callable by using magic functions.
Let's see a simple example of a class that doubles a number.

``````class Double(object):
def __call__(self, x):
return 2*x
``````

Here we add a magic method `__call__` in the class to double any number passed to it. Now, you can create an object out of this class and call it on some number.

``````>>> d = Double()
>>> d(2)
4
``````

Alternatively, the above code can be combined in the single line itself.

``````>>> Double()(2)
4
``````

This works because everything in Python is an object. See an example of a function below that doubles a number.

``````def double(x):
return 2*x

>>> double(2)
4
``````

Even functions invoke the`__call__` method behind the scenes.

``````>>> double.__call__(2)
4
``````

## Magic methods in Forward Pass

Let's see an example of a model that applies a single fully connected layer to MNIST images to get 10 outputs.

``````import torch.nn as nn

class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 10)

def forward(self, x):
return self.fc1(x)
``````

The following code should be familiar to you. We are computing output of this model on some tensor x.

``````x = torch.rand(10, 784)
model = Model()
output = model(x)
``````

We know calling the model directly on some tensor executes the `.forward()` function on it. How does that work?

It's the same reason in previous example. We're inheriting the class `nn.Module`. Internally, `nn.Module` has a `__call__()` magic method that calls the `.forward()`. So, when we override `.forward()` method later, it's executed.

``````# nn.Module
class Module(object):
def __call__(self, x):
# Simplified
# Actual implementation has validation and gradient tracking.
return self.forward(x)
``````

Thus, we were able to call the model directly on tensors.

``````output = model(x)
# model.__call__(x)
#   -> model.forward(x)
``````

## Magic methods in Dataset

In PyTorch, it is common to create a custom class inheriting from the `Dataset` class to prepare our training and test datasets. Have you ever wondered why we define methods with obscure names like `__len__` and `__getitem__` in it?

``````from torch.utils.data import Dataset

class Numbers(Dataset):
def __init__(self, x, y):
self.data = x
self.labels = y

def __len__(self):
return len(self.data)

def __getitem__(self, i):
return (self.data[i], self.labels[i])

>>> dataset = Numbers([1, 2, 3], [0, 1, 0])
>>> print(len(dataset))
3
>>> print(dataset)
(1, 0)
``````

These methods are builtin magic methods of Python. You know how we can get the length of iterables like list and tuples using `len` function.

``````>>> x = [10, 20, 30]
>>> len(x)
3
``````

Python allows defining a `__len__` on our custom class so that `len()` works on it. For example,

``````class Data(object):
def __len__(self):
return 10

>>> d = Data()
>>> len(d)
10
``````

Similarly, you know how we can access elements of list and tuples using index notation.

``````>>> x = [10, 20, 30]
>>> x
10
``````

Python allows a `__getitem__` magic method to allow such functionality for custom classes. For example,

``````class Data(object):
def __init__(self):
self.x = [10, 20, 30]

def __getitem__(self, i):
return x[i]

>>> d = Data()
>>> d
10
``````

With the above concept, now you can easily understand the builtin dataset like MNIST and what you can do with them.

``````from torchvision.datasets import MNIST

>>> print(len(trainset))
60000
>>> print(trainset)
(<PIL.Image.Image image mode=L size=28x28 at 0x7F06DC654128>, 0)
``````

Let's create a dataloader for a training dataset of MNIST digits.

``````from torchvision.datasets import MNIST
import torchvision.transforms as transforms

trainset = MNIST(root='mnist',
train=True,
transform=transforms.ToTensor())
``````

Now, let's try accessing first batch from the data loader directly without looping. If we try to access it via index, we get an exception.

``````>>> trainloader
TypeError: 'DataLoader' object does not support indexing
``````

You might have been used to doing it in this way.

``````images, labels =  next(iter(trainloader))
``````

Have you ever wondered why do we wrap trainloader by `iter()` and then call `next()`? Let's demystify this.

Consider a list `x` with 3 elements. In Python, we can create an iterator out of `x` using the `iter` function.

``````x = [1, 2, 3]
y = iter(x)
``````

Iterators are used because they allow lazy loading such that only one element is loaded in memory at a time.

``````>>> next(x)
1
>>> next(x)
2
>>> next(x)
3
>>> next(x)
StopIteration:
``````

We get each element and when we reach the end of the list, we get a `StopIteration` exception.

This pattern matches our usual machine learning workflow where we take small batches of data at a time in memory and do the forward and backward pass. So, `DataLoader` also incorporates this pattern in PyTorch.

To create iterators out of classes in Python, we need to define magic methods `__iter__` and `__next__`

``````class ExampleLoader(object):
def __init__(self, data):
self.data = iter(data)

def __iter__(self):
return self

def __next__(self):
return next(self.data)

>>> l = ExampleLoader([1, 2, 3])
>>> next(iter(l))
1
``````

Here, the `iter()` function calls the `__iter__` magic method of the class returning that same object. Then, the `next()` function calls the `__next__` magic method of the class to return next element present in our data.

In PyTorch, the implementation of DataLoader implements this pattern as follows:

``````class DataLoader(object):
def __iter__(self):
if self.num_workers == 0:
else:

def __next__(self):
# logic to return batch from whole data
...
``````

So, they decouple the iterator creation part and the actual data loading part.

• When you call `iter()` on the data loader, it checks if we are using single or multiple workers
• Based on that, it returns another iterator class for either single or multiple worker
``````>>> type(iter(trainloader))
``````
• That iterator class has a `__next__` method defined which returns the actual data of set `batch_size` when we call `next()` on it
``````images, labels =  next(iter(trainloader))

# equivalent to:
``````

Thus, we get images and labels for a single batch.

## Conclusion

Thus, we saw how PyTorch borrows several advanced concepts from native Python itself in its API design. I hope the article was helpful to demystify how these concepts work behind the scenes and will help you become a better PyTorch user.

## Connect

If you enjoyed this blog post, feel free to connect with me on Twitter where I share new blog posts every week.

## Discussion (0) 