The PyTorch add_module() function
• deep-learning • PermalinkI have been building some bespoke PyTorch models, and have just been
stung by a bug; it turns out that using the add_module()
method is
sometimes critical to making a PyTorch model work. Without this
method, the program may just crash, but might also just about work but
give completely meaningless results.
Though there do not seem to be any hints about this in the
documentation, it seems that PyTorch determines the layers or others
Module
s used in a particular Module
by looking at the type of
object stored in each member of the Module
. And if that object is
not a Module
, PyTorch does not recognise it and will not
backpropagate through it.
Here is an example. We take the Quickstart from the PyTorch tutorials webpage. The neural network is defined in it as follows:
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
nn.ReLU()
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
Note that the layers, nn.Flatten()
and nn.Sequential(...)
, are
stored as members of the NeuralNetwork
object, by writing
self.flatten = ...
and so on.
Let us now change the code to store these layers in a list instead:
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
flatten = nn.Flatten()
linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
nn.ReLU()
)
self.layers = [flatten, linear_relu_stack]
def forward(self, x):
x = self.layers[0](x)
logits = self.layers[1](x)
return logits
The forward()
method is also modified to use the appropriate element
of the list of layers. But now PyTorch ignores the self.layers
member, as it is not a Module
, and the following code breaks quite
badly.
The simplest way to fix this, while keeping the layers as a list, is
to inform PyTorch about the existence of these layers using the
add_module()
method, as follows:
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
flatten = nn.Flatten()
linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
nn.ReLU()
)
self.layers = [flatten, linear_relu_stack]
for i, layer in enumerate(self.layers):
self.add_module(f"layer_{i}", layer)
def forward(self, x):
x = self.layers[0](x)
logits = self.layers[1](x)
return logits
The first parameter of the add_module()
method is a name that
PyTorch will use to refer to the layer when printing the neural
network model, while the second is the layer itself. The name can
also be used to refer to the layer as an attribute of the Module
object, so it is presumably important that the names are unique within
the Module
and potentially helpful if they are valid Python
identifiers (though if they are not, they can still be accessed using
getattr()
).
And with that addition, the code once again works.
(If you are wondering why we would store the layers in a list in the
first place, I had a use case where the network was constructed with
a variable number of layers passed to __init__()
as a list.)
Edits 27 April 2021
A colleague has just pointed out the nn.ModuleList
class to me. So
this problem could also be solved in a simpler way as follows:
class NeuralNetwork(nn.Module):
def __init__(self):
super(NeuralNetwork, self).__init__()
flatten = nn.Flatten()
linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10),
nn.ReLU()
)
self.layers = nn.ModuleList([flatten, linear_relu_stack])
def forward(self, x):
x = self.layers[0](x)
logits = self.layers[1](x)
return logits