Update Weights of a Neural Network Model in PyTorch

In some cases, you might not want to use the default weights coming with the model when you use a pre-trained model. Or, you can create your own model but you want to use the same weights with some other model. In this article, we will explain how to update a model's weight in PyTorch.

Photo by Jesper Aggergaard on Unsplash
Photo by Jesper Aggergaard on Unsplash

Let’s explain it over the “Squeezenet1_0” model. Squeezenet is a DNN model which gives decent results with much fewer parameters comparing to the other high-performing DNNs and it is very small in size. That’s why, it is quite convenient to work with it if you are looking for quick results with high enough accuracy levels to have a sense of your data or test new tricks, etc.

We will start with our imports for this task and creating a model:

IF we set pretrained to False, PyTorch will initialize the weights from scratch “randomly” using one of the initialization functions (normal, kaiming_uniform_, constant) depending on the module and bias in the model.
Since the weights are assigned randomly, 'each time' we run our code we will have different weight values initialized.
IF we set pretrained to True, on the other hand, PyTorch will use the weights which have already been trained, our model will take it from there and continue fine-tuning by updating those weights with our dataset.

You can easily see your model's architecture by typing your model. Check the end of the article to see squeezenet1_0's.

model_wu

From this, you can see that Squeezenet has 2 main parts
1) features
2) classifier

If you would like to see the details of modules, you can check as below:

We can use torch.distribution package to generate various probability distributions. Here, we will give an example of a Normal (Gaussian) distribution:

Using model_wu.features[0].weight, we can assign weights for the first Conv2d module in features.

Assume that we would like to assign values using a normal distribution with the mean 0 and standard deviation 0.05. Then, we can generate a sample with the same size as this module

We can assign the new values to the module now:

model_wu.features[0].weight = torch.nn.Parameter(sample_norm)

We can also see the modules by:

Not all the modules have weights. We can see the modules having weights as below and update whichever module we would like to update.

Also, if we would like to use specific initializations for specific modules, we can use the below code:

Below we can see the whole structure of the model:

Data science and Deep Learning enthusiast with Computer Vision interest

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store