Reimplementing the PyTorch training loop in simple Python

Understand the internal abstractions of a training loop in PyTorch

To master some framework, you need to understand how it is built. Let’s try to understand and reimplement the internal abstractions of a training loop in PyTorch. I’ll start with a bare training loop which doesn’t use PyTorch’s dataloaders or optimizers. Then I’ll reimplement Dataset and Dataloader in Python. To update models, PyTorch relies on torch.nn.Parameters and torch.optim. I’ll show how to reproduce them in simple Python.

You can follow along by running this notebook on Google Colab.

Bare Training Loop

The training loop shown below doesn’t use any abstractions over training data or model updates. I manually index through the list of X and Y values to iterate over batches of training examples. I update the model by going through the layers one by one and updating the weights and biases with their loss gradient.

for epoch in range(epochs):
    num_batches = (n-1)//bs
    for i in range(num_batches):

        # TODO: training data abstraction
        start_idx, end_idx = bs*i, bs*(i+1)
        xb = x_train[start_idx:end_idx]
        yb = y_train[start_idx:end_idx]

        yb_pred = model(xb)
        loss = loss_fn(yb_pred, yb)
        loss.backward()

        # TODO: model update abstraction
        with torch.no_grad():
            for layer in model.layers:
                if hasattr(layer, "weight"):
                    layer.weight -= lr * layer.weight.grad
                    layer.weight.grad.zero_()
                if hasattr(layer, "bias"):
                    layer.bias -= lr * layer.bias.grad
                    layer.bias.grad.zero_()       

PyTorch Data Abstractions

Now I’ll rebuild the PyTorch abstractions over data. First, the Dataset is a joint list over X and Y values.

class Dataset():
    
    def __init__(self, x, y):
        self.x, self.y = x, y
    
    def __len__(self):
        return self.x.shape[0]
    
    def __getitem__(self, i):
        return self.x[i], self.y[i]

The Dataloader builds over a Dataset. Its core functionality is to iterate over batches of X and Y items fetched from the Dataset.

class DataLoader():
    
    def __init__(self, dataset, batch_size, sampler, collate_fn):
        self.ds, self.bs = dataset, batch_size        
        self.sampler, self.collate_fn = sampler, collate_fn
    
    def __iter__(self):
        for batch_idxs in self.sampler:
            yield self.collate_fn([self.ds[idx] for idx in batch_idxs])

You can also customize how you want to sample the X and Y items (sampler) to form batches (collate_fn). I’ll reimplement the random sampler that comes built-in with PyTorch.

class Sampler():
    
    def __init__(self, data_size, batch_size, shuffle):
        self.n, self.bs, self.shuffle = data_size, batch_size, shuffle
        
    def __iter__(self):
        sample_idx = torch.randperm(self.n) if self.shuffle else list(range(self.n))
        for i in range(0, self.n, self.bs):
            yield sample_idx[i : i+self.bs]

Now I can rewrite the training loop to use my Dataloader.

for epoch in range(epochs):
    # dataloader abstraction
    for xb,yb in train_dl:
        yb_pred = model(xb)
        loss = loss_fn(yb_pred, yb)
        loss.backward()

        # TODO: optimizer abstraction
        with torch.no_grad():
            for layer in model.layers:
                if hasattr(layer, "weight"):
                    layer.weight -= lr * layer.weight.grad
                    layer.weight.grad.zero_()
                if hasattr(layer, "bias"):
                    layer.bias -= lr * layer.bias.grad
                    layer.bias.grad.zero_()

PyTorch Training Abstractions

Let’s reproduce the way you use PyTorch to update models at every training step.

nn.Parameters

Instead of manually looping over the layers of a model to get its parameters, PyTorch maintains a key-value store for all the parameters as they are defined. This can be achieved by using the __setattr__ method courtesy of the Python Data Model. In the code snippet below, I use a _modules dictionary to store all the layers as they are defined.

class OurModule():
    
    def __init__(self, x_dim, y_dim, h_dim):
        self._modules = {}
        self.l1 = nn.Linear(x_dim, h_dim)
        self.l2 = nn.Linear(h_dim, y_dim)
    
    def __setattr__(self, k, v):
        if not k.startswith('_'):
            self._modules[k] = v
        super().__setattr__(k, v)
    
    def __repr__(self):
        return f"{self._modules}"
    
    def parameters(self):
        for m in self._modules.values():
            for p in m.parameters():
                yield p
    
    def __call__(self, x):
        x = F.relu(self.l1(x))
        return self.l2(x)

Now let’s update our training loop with the model.parameters() functionality.

for epoch in range(epochs):
    # dataloader abstraction
    for xb,yb in train_dl:
        yb_pred = model(xb)
        loss = loss_fn(yb_pred, yb)
        loss.backward()

        # TODO: optimizer abstraction
        with torch.no_grad():
            for p in model.parameters():
                p -= lr * p.grad
                p.grad.zero_()

nn.Sequential

PyTorch lets you quickly construct a model by defining a list of layers. Here’s how you can reproduce the nn.Sequential functionality in plain Python.

class Sequential(nn.Module):
    
    def __init__(self, *layers):
        super().__init__()
        self.layers = self.module_list(layers)
    
    def module_list(self, layers):
        for i, l in enumerate(layers):
            self.add_module(name = f'{i}',
                            module = l)
    
    def forward(self):
        for l in self.layers:
            x = l(x)
        return x

Optimizer

In PyTorch, you can use different optimization functions to update the model. Here’s how you can reimplement a basic SGD optimizer from scratch.

class Optimizer():
    
    def __init__(self, params, lr):
        self.params, self.lr = list(params), lr
    
    def step(self):
        with torch.no_grad():
            for p in self.params:
                p -= lr * p.grad
    
    def zero_grad(self):
        for p in self.params:
            p.grad.data.zero_()

Lets finally rewrite the entire training loop using the custom dataloader and optimizer.

for epoch in range(epochs):
    # dataloader abstraction
    for xb,yb in train_dl:
        yb_pred = model(xb)
        loss = loss_fn(yb_pred, yb)
        loss.backward()

        # optimizer abstraction
        opt.step()
        opt.zero_grad()

There you go. We have rebuilt the PyTorch training loop in plain Python.

Masters in Artificial Intelligence

into building things, taking risks and aesthetics