在使用huggingface transformers
时经常需要用到保存model,或者说是保存model的parameters。看了一篇medium上的blog,感觉很有用,解决了我很多的困扰,下面是blog里面的精髓:
Applying
named_parameters()
on an nn.Module object e.g. model or
model.layer2 or model.fc returns all the names and the respective parameters. These parameters are nn.Parameter (subclass of torch.Tensor) objects and therefore they have shape and requires_grad attributes.The
requires_grad
attribute of a nn.Parameter object (learnable parameter object) decides whether to train or freeze a particular parameter.Applying
named_children()
on any nn.Module object returns all it’s immediate(直系的) children (also nn.Module objects).
state_dict()
of any nn.Module object e.g. model or model.layer2 or model.fc is simply a python ordered dictionary object that maps each parameter to its parameter tensor (torch.Tensor object). The keys of this ordered dictionary are the names of the parameters, which can be used to access the respective parameter tensors.在state_dict()中,原本的parameters从较为复杂的
nn.Parameter
对象变成了当初只有值的torch.Tensor
对象。Saving a nn.Module object’s state_dict only saves the weights of the various parameters of that object and not the model architecture. Neither does it involve the requires_grad attribute of the weights. So before loading the state_dict, one must define the model first.
Entire model (nn.Module object) can also be saved which would include the model architecture as well as its weights. Since we are saving the nn.Module object, the requires_grad attribute is also saved this way. Also we don’t need to define the model architecture before loading the saved file since the saved file already has the model architecture saved in it.
Saving the state_dict can be used to only save the weights of the model. It doesn’t save the required_grad flag, whereas saving the entire model does save the model architecture, it’s weights and the requires_grad attributes of all its parameters.
Both state_dict as well as the entire model can be saved to make inferences.
实现
在huggingface transformers中,使用model.save_pretrained()
就可以实现保存模型,采用的方法是第一种——state_dict
的方式。