Julian's musings

The PyTorch add_module() function

deep-learningPermalink

I have been building some bespoke PyTorch models, and have just been stung by a bug; it turns out that using the add_module() method is sometimes critical to making a PyTorch model work. Without this method, the program may just crash, but might also just about work but give completely meaningless results.

Though there do not seem to be any hints about this in the documentation, it seems that PyTorch determines the layers or others Modules used in a particular Module by looking at the type of object stored in each member of the Module. And if that object is not a Module, PyTorch does not recognise it and will not backpropagate through it.

Here is an example. We take the Quickstart from the PyTorch tutorials webpage. The neural network is defined in it as follows:

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

Note that the layers, nn.Flatten() and nn.Sequential(...), are stored as members of the NeuralNetwork object, by writing self.flatten = ... and so on.

Let us now change the code to store these layers in a list instead:

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        flatten = nn.Flatten()
        linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
            nn.ReLU()
        )
        self.layers = [flatten, linear_relu_stack]

    def forward(self, x):
        x = self.layers[0](x)
        logits = self.layers[1](x)
        return logits

The forward() method is also modified to use the appropriate element of the list of layers. But now PyTorch ignores the self.layers member, as it is not a Module, and the following code breaks quite badly.

The simplest way to fix this, while keeping the layers as a list, is to inform PyTorch about the existence of these layers using the add_module() method, as follows:

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        flatten = nn.Flatten()
        linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
            nn.ReLU()
        )
        self.layers = [flatten, linear_relu_stack]
        for i, layer in enumerate(self.layers):
            self.add_module(f"layer_{i}", layer)

    def forward(self, x):
        x = self.layers[0](x)
        logits = self.layers[1](x)
        return logits

The first parameter of the add_module() method is a name that PyTorch will use to refer to the layer when printing the neural network model, while the second is the layer itself. The name can also be used to refer to the layer as an attribute of the Module object, so it is presumably important that the names are unique within the Module and potentially helpful if they are valid Python identifiers (though if they are not, they can still be accessed using getattr()).

And with that addition, the code once again works.

(If you are wondering why we would store the layers in a list in the first place, I had a use case where the network was constructed with a variable number of layers passed to __init__() as a list.)

Edits 27 April 2021

A colleague has just pointed out the nn.ModuleList class to me. So this problem could also be solved in a simpler way as follows:

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        flatten = nn.Flatten()
        linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10),
            nn.ReLU()
        )
        self.layers = nn.ModuleList([flatten, linear_relu_stack])

    def forward(self, x):
        x = self.layers[0](x)
        logits = self.layers[1](x)
        return logits
comments powered by Disqus