# Reimplementing the PyTorch training loop in simple Python

Understand the internal abstractions of a training loop in PyTorch

To master some framework, you need to understand how it is built.
Let’s try to understand and reimplement the internal abstractions of a training loop in PyTorch.
I’ll start with a bare training loop which doesn’t use PyTorch’s dataloaders or optimizers.
Then I’ll reimplement `Dataset`

and `Dataloader`

in Python.
To update models, PyTorch relies on `torch.nn.Parameters`

and `torch.optim`

.
I’ll show how to reproduce them in simple Python.

You can follow along by running this notebook on Google Colab.

## Bare Training Loop

The training loop shown below doesn’t use any abstractions over training data or model updates.
I manually index through the list of `X`

and `Y`

values to iterate over batches of training examples.
I update the model by going through the layers one by one and updating the weights and biases with their loss gradient.

```
for epoch in range(epochs):
num_batches = (n-1)//bs
for i in range(num_batches):
# TODO: training data abstraction
start_idx, end_idx = bs*i, bs*(i+1)
xb = x_train[start_idx:end_idx]
yb = y_train[start_idx:end_idx]
yb_pred = model(xb)
loss = loss_fn(yb_pred, yb)
loss.backward()
# TODO: model update abstraction
with torch.no_grad():
for layer in model.layers:
if hasattr(layer, "weight"):
layer.weight -= lr * layer.weight.grad
layer.weight.grad.zero_()
if hasattr(layer, "bias"):
layer.bias -= lr * layer.bias.grad
layer.bias.grad.zero_()
```

## PyTorch Data Abstractions

Now I’ll rebuild the PyTorch abstractions over data.
First, the `Dataset`

is a joint list over `X`

and `Y`

values.

```
class Dataset():
def __init__(self, x, y):
self.x, self.y = x, y
def __len__(self):
return self.x.shape[0]
def __getitem__(self, i):
return self.x[i], self.y[i]
```

The `Dataloader`

builds over a `Dataset`

.
Its core functionality is to iterate over batches of `X`

and `Y`

items fetched from the `Dataset`

.

```
class DataLoader():
def __init__(self, dataset, batch_size, sampler, collate_fn):
self.ds, self.bs = dataset, batch_size
self.sampler, self.collate_fn = sampler, collate_fn
def __iter__(self):
for batch_idxs in self.sampler:
yield self.collate_fn([self.ds[idx] for idx in batch_idxs])
```

You can also customize how you want to sample the `X`

and `Y`

items (`sampler`

) to form batches (`collate_fn`

).
I’ll reimplement the random sampler that comes built-in with PyTorch.

```
class Sampler():
def __init__(self, data_size, batch_size, shuffle):
self.n, self.bs, self.shuffle = data_size, batch_size, shuffle
def __iter__(self):
sample_idx = torch.randperm(self.n) if self.shuffle else list(range(self.n))
for i in range(0, self.n, self.bs):
yield sample_idx[i : i+self.bs]
```

Now I can rewrite the training loop to use my `Dataloader`

.

```
for epoch in range(epochs):
# dataloader abstraction
for xb,yb in train_dl:
yb_pred = model(xb)
loss = loss_fn(yb_pred, yb)
loss.backward()
# TODO: optimizer abstraction
with torch.no_grad():
for layer in model.layers:
if hasattr(layer, "weight"):
layer.weight -= lr * layer.weight.grad
layer.weight.grad.zero_()
if hasattr(layer, "bias"):
layer.bias -= lr * layer.bias.grad
layer.bias.grad.zero_()
```

## PyTorch Training Abstractions

Let’s reproduce the way you use PyTorch to update models at every training step.

### nn.Parameters

Instead of manually looping over the layers of a model to get its parameters,
PyTorch maintains a key-value store for all the parameters as they are defined.
This can be achieved by using the `__setattr__`

method courtesy of the Python Data Model.
In the code snippet below, I use a `_modules`

dictionary to store all the layers as they are defined.

```
class OurModule():
def __init__(self, x_dim, y_dim, h_dim):
self._modules = {}
self.l1 = nn.Linear(x_dim, h_dim)
self.l2 = nn.Linear(h_dim, y_dim)
def __setattr__(self, k, v):
if not k.startswith('_'):
self._modules[k] = v
super().__setattr__(k, v)
def __repr__(self):
return f"{self._modules}"
def parameters(self):
for m in self._modules.values():
for p in m.parameters():
yield p
def __call__(self, x):
x = F.relu(self.l1(x))
return self.l2(x)
```

Now let’s update our training loop with the `model.parameters()`

functionality.

```
for epoch in range(epochs):
# dataloader abstraction
for xb,yb in train_dl:
yb_pred = model(xb)
loss = loss_fn(yb_pred, yb)
loss.backward()
# TODO: optimizer abstraction
with torch.no_grad():
for p in model.parameters():
p -= lr * p.grad
p.grad.zero_()
```

### nn.Sequential

PyTorch lets you quickly construct a model by defining a list of layers.
Here’s how you can reproduce the `nn.Sequential`

functionality in plain Python.

```
class Sequential(nn.Module):
def __init__(self, *layers):
super().__init__()
self.layers = self.module_list(layers)
def module_list(self, layers):
for i, l in enumerate(layers):
self.add_module(name = f'{i}',
module = l)
def forward(self):
for l in self.layers:
x = l(x)
return x
```

### Optimizer

In PyTorch, you can use different optimization functions to update the model. Here’s how you can reimplement a basic SGD optimizer from scratch.

```
class Optimizer():
def __init__(self, params, lr):
self.params, self.lr = list(params), lr
def step(self):
with torch.no_grad():
for p in self.params:
p -= lr * p.grad
def zero_grad(self):
for p in self.params:
p.grad.data.zero_()
```

Lets finally rewrite the entire training loop using the custom dataloader and optimizer.

```
for epoch in range(epochs):
# dataloader abstraction
for xb,yb in train_dl:
yb_pred = model(xb)
loss = loss_fn(yb_pred, yb)
loss.backward()
# optimizer abstraction
opt.step()
opt.zero_grad()
```

There you go. We have rebuilt the PyTorch training loop in plain Python.