Reimplementing the PyTorch training loop in simple Python
Understand the internal abstractions of a training loop in PyTorch
To master some framework, you need to understand how it is built.
Let’s try to understand and reimplement the internal abstractions of a training loop in PyTorch.
I’ll start with a bare training loop which doesn’t use PyTorch’s dataloaders or optimizers.
Then I’ll reimplement Dataset
and Dataloader
in Python.
To update models, PyTorch relies on torch.nn.Parameters
and torch.optim
.
I’ll show how to reproduce them in simple Python.
You can follow along by running this notebook on Google Colab.
Bare Training Loop
The training loop shown below doesn’t use any abstractions over training data or model updates.
I manually index through the list of X
and Y
values to iterate over batches of training examples.
I update the model by going through the layers one by one and updating the weights and biases with their loss gradient.
for epoch in range(epochs):
num_batches = (n-1)//bs
for i in range(num_batches):
# TODO: training data abstraction
start_idx, end_idx = bs*i, bs*(i+1)
xb = x_train[start_idx:end_idx]
yb = y_train[start_idx:end_idx]
yb_pred = model(xb)
loss = loss_fn(yb_pred, yb)
loss.backward()
# TODO: model update abstraction
with torch.no_grad():
for layer in model.layers:
if hasattr(layer, "weight"):
layer.weight -= lr * layer.weight.grad
layer.weight.grad.zero_()
if hasattr(layer, "bias"):
layer.bias -= lr * layer.bias.grad
layer.bias.grad.zero_()
PyTorch Data Abstractions
Now I’ll rebuild the PyTorch abstractions over data.
First, the Dataset
is a joint list over X
and Y
values.
class Dataset():
def __init__(self, x, y):
self.x, self.y = x, y
def __len__(self):
return self.x.shape[0]
def __getitem__(self, i):
return self.x[i], self.y[i]
The Dataloader
builds over a Dataset
.
Its core functionality is to iterate over batches of X
and Y
items fetched from the Dataset
.
class DataLoader():
def __init__(self, dataset, batch_size, sampler, collate_fn):
self.ds, self.bs = dataset, batch_size
self.sampler, self.collate_fn = sampler, collate_fn
def __iter__(self):
for batch_idxs in self.sampler:
yield self.collate_fn([self.ds[idx] for idx in batch_idxs])
You can also customize how you want to sample the X
and Y
items (sampler
) to form batches (collate_fn
).
I’ll reimplement the random sampler that comes built-in with PyTorch.
class Sampler():
def __init__(self, data_size, batch_size, shuffle):
self.n, self.bs, self.shuffle = data_size, batch_size, shuffle
def __iter__(self):
sample_idx = torch.randperm(self.n) if self.shuffle else list(range(self.n))
for i in range(0, self.n, self.bs):
yield sample_idx[i : i+self.bs]
Now I can rewrite the training loop to use my Dataloader
.
for epoch in range(epochs):
# dataloader abstraction
for xb,yb in train_dl:
yb_pred = model(xb)
loss = loss_fn(yb_pred, yb)
loss.backward()
# TODO: optimizer abstraction
with torch.no_grad():
for layer in model.layers:
if hasattr(layer, "weight"):
layer.weight -= lr * layer.weight.grad
layer.weight.grad.zero_()
if hasattr(layer, "bias"):
layer.bias -= lr * layer.bias.grad
layer.bias.grad.zero_()
PyTorch Training Abstractions
Let’s reproduce the way you use PyTorch to update models at every training step.
nn.Parameters
Instead of manually looping over the layers of a model to get its parameters,
PyTorch maintains a key-value store for all the parameters as they are defined.
This can be achieved by using the __setattr__
method courtesy of the Python Data Model.
In the code snippet below, I use a _modules
dictionary to store all the layers as they are defined.
class OurModule():
def __init__(self, x_dim, y_dim, h_dim):
self._modules = {}
self.l1 = nn.Linear(x_dim, h_dim)
self.l2 = nn.Linear(h_dim, y_dim)
def __setattr__(self, k, v):
if not k.startswith('_'):
self._modules[k] = v
super().__setattr__(k, v)
def __repr__(self):
return f"{self._modules}"
def parameters(self):
for m in self._modules.values():
for p in m.parameters():
yield p
def __call__(self, x):
x = F.relu(self.l1(x))
return self.l2(x)
Now let’s update our training loop with the model.parameters()
functionality.
for epoch in range(epochs):
# dataloader abstraction
for xb,yb in train_dl:
yb_pred = model(xb)
loss = loss_fn(yb_pred, yb)
loss.backward()
# TODO: optimizer abstraction
with torch.no_grad():
for p in model.parameters():
p -= lr * p.grad
p.grad.zero_()
nn.Sequential
PyTorch lets you quickly construct a model by defining a list of layers.
Here’s how you can reproduce the nn.Sequential
functionality in plain Python.
class Sequential(nn.Module):
def __init__(self, *layers):
super().__init__()
self.layers = self.module_list(layers)
def module_list(self, layers):
for i, l in enumerate(layers):
self.add_module(name = f'{i}',
module = l)
def forward(self):
for l in self.layers:
x = l(x)
return x
Optimizer
In PyTorch, you can use different optimization functions to update the model. Here’s how you can reimplement a basic SGD optimizer from scratch.
class Optimizer():
def __init__(self, params, lr):
self.params, self.lr = list(params), lr
def step(self):
with torch.no_grad():
for p in self.params:
p -= lr * p.grad
def zero_grad(self):
for p in self.params:
p.grad.data.zero_()
Lets finally rewrite the entire training loop using the custom dataloader and optimizer.
for epoch in range(epochs):
# dataloader abstraction
for xb,yb in train_dl:
yb_pred = model(xb)
loss = loss_fn(yb_pred, yb)
loss.backward()
# optimizer abstraction
opt.step()
opt.zero_grad()
There you go. We have rebuilt the PyTorch training loop in plain Python.