首页 资讯 热点 行情 地区 推荐 民宿 酒店 家居 度假 滚动
首页 >  热点 > 正文

全球视讯!第二十五篇—加载预训练权重

2023-04-22 12:12:31来源:哔哩哔哩


(资料图片仅供参考)

在Pytorch中,只有可学习参数的层(卷积层、线性层、BN层等)才有state_dict,model.state_dict()会以有序字典OrderedDict形式返回模型训练过程中学习的权重weight和偏置bias参数(参考第二十四篇—模型中的parameter和buffer),如下述代码所示:

上述代码定义的模型中,只有卷积层和BN层具有可学习参数,所以net.state_dict()只会保存这两层的参数,而激活函数层的参数则不会保存。BN层除了权重weight和偏置bias参数,还会保存训练阶段统计的均值(running_mean)、训练阶段统计的方差(running_val)、训练阶段的batch数目(num_batches_tracked),其中,weight和bias属于可学习参数,需要进行训练,而running_mean、running_val和num_batches_tracked三个参数则不需要训练,只是训练阶段的统计值。,如下图所示:

当我们对网络模型结构进行优化改进时,如果改进的部分不包含可学习的层,那么可以直接加载预训练权重。如:如果我们对上述代码的Conv模型进行改进,将激活函数层改为nn.Hardswish(),因为不包含可学习的参数,所以改进的模型的state_dict()没有改变,仍然可以直接加载Conv模型的权重文件,如下代码所示:

结果如下:

当我们改进的部分改变了可学习的参数时,如果直接加载预训练权重就会发生不匹配的错误,如下代码所示:

结果如下:

这时我们需要遍历预训练文件的每一层参数,将能够匹配成功的参数提取出来,再进行加载就可以了,如下代码所示:

结果如下:

标签: