In some cases, you might not want to use the default weights coming with the model when you use a pre-trained model. Or, you can create your own model but you want to use the same weights with some other model. In this article, we will explain how to update a model's weight in PyTorch.
Let’s explain it over the “Squeezenet1_0” model. Squeezenet is a DNN model which gives decent results with much fewer parameters comparing to the other high-performing DNNs and it is very small in size. That’s why, it is quite convenient to work with it if you are looking for quick results with high enough accuracy levels to have a sense of your data or test new tricks, etc.
We will start with our imports for this task and creating a model:
IF we set pretrained to False, PyTorch will initialize the weights from scratch “randomly” using one of the initialization functions (normal, kaiming_uniform_, constant) depending on the module and bias in the model.
Since the weights are assigned randomly, 'each time' we run our code we will have different weight values initialized.
IF we set pretrained to True, on the other hand, PyTorch will use the weights which have already been trained, our model will take it from there and continue fine-tuning by updating those weights with our dataset.
How to Display the Weights of a Specific Module?
You can easily see your model's architecture by typing your model. Check the end of the article to see squeezenet1_0's.
From this, you can see that Squeezenet has 2 main parts
If you would like to see the details of modules, you can check as below:
Generate New Weight Values Using Gaussian Distribution
We can use torch.distribution package to generate various probability distributions. Here, we will give an example of a Normal (Gaussian) distribution:
Assign New Weights to the Module
model_wu.features.weight, we can assign weights for the first Conv2d module in features.
Assume that we would like to assign values using a normal distribution with the mean 0 and standard deviation 0.05. Then, we can generate a sample with the same size as this module
We can assign the new values to the module now:
model_wu.features.weight = torch.nn.Parameter(sample_norm)
We can also see the modules by:
Not all the modules have weights. We can see the modules having weights as below and update whichever module we would like to update.
Also, if we would like to use specific initializations for specific modules, we can use the below code:
Below we can see the whole structure of the model: