Everything You Need To Know About Saving Weights In PyTorch

在使用huggingface transformers时经常需要用到保存model,或者说是保存model的parameters。看了一篇medium上的blog,感觉很有用,解决了我很多的困扰,下面是blog里面的精髓:

  1. Applying named_parameters() on an nn.Module object e.g. model or
    model.layer2 or model.fc returns all the names and the respective parameters. These parameters are nn.Parameter (subclass of torch.Tensor) objects and therefore they have shape and requires_grad attributes.

  2. The requires_grad attribute of a nn.Parameter object (learnable parameter object) decides whether to train or freeze a particular parameter.

  3. Applying named_children() on any nn.Module object returns all it’s immediate(直系的) children (also nn.Module objects).

  1. state_dict() of any nn.Module object e.g. model or model.layer2 or model.fc is simply a python ordered dictionary object that maps each parameter to its parameter tensor (torch.Tensor object). The keys of this ordered dictionary are the names of the parameters, which can be used to access the respective parameter tensors.


  2. Saving a nn.Module object’s state_dict only saves the weights of the various parameters of that object and not the model architecture. Neither does it involve the requires_grad attribute of the weights. So before loading the state_dict, one must define the model first.

  3. Entire model (nn.Module object) can also be saved which would include the model architecture as well as its weights. Since we are saving the nn.Module object, the requires_grad attribute is also saved this way. Also we don’t need to define the model architecture before loading the saved file since the saved file already has the model architecture saved in it.

  4. Saving the state_dict can be used to only save the weights of the model. It doesn’t save the required_grad flag, whereas saving the entire model does save the model architecture, it’s weights and the requires_grad attributes of all its parameters.

  5. Both state_dict as well as the entire model can be saved to make inferences.


在huggingface transformers中,使用model.save_pretrained()就可以实现保存模型,采用的方法是第一种——state_dict的方式。


  1. Everything You Need To Know About Saving Weights In PyTorch

文章作者: CarlYoung
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 CarlYoung !
What is verbose? What is verbose?
z在Deep Learning的编程中,总是会遇到verbose这个概念,我一直理解这个单词就是控制程序打印信息的意思,但是具体是怎么控制打印信息,我一直没理解,查阅资料之后发现,这个参数在Keras中常见,stackoverflow关于它
What is Gradient Clipping What is Gradient Clipping
为什么需要gradient clipping?在DL的项目中常常会看到gradient clipping的身影,命令行传入参数grad_clip,然后再调用clip_grad_norm_()函数,如下: parser.add_argumen