您好,登錄后才能下訂單哦!
這篇文章主要講解了“pytorch加載模型遇到的問題怎么解決”,文中的講解內容簡單清晰,易于學習與理解,下面請大家跟著小編的思路慢慢深入,一起來研究和學習“pytorch加載模型遇到的問題怎么解決”吧!
pretrained_dict1 = torch.load(model_path2, map_location='cpu')['state_dict']#預訓練文件后綴是.tarpretrained_dict2 = torch.load(model_path3)#預訓練文件后綴是.pth#1.查看預訓練網絡參數for key ,value in pretrained_dict1.items():#pretrained_dict1,pretrained_dict2就是上面的東西count+=1print(key)print(count)#2.查看model的網絡參數for key ,value in model.state_dict.items():print(key,value)
1. 模型的鍵不匹配
以下兩代碼,解決了鍵不匹配問題,一個是刪除鍵的某一部分,一是添加鍵的某一部分。
例:
下面的錯誤是因為模型的model.state_dict().items()的鍵是conv1.weight,預訓練的鍵是module.conv1.weight,導致不匹配。所以下面的代碼是讓module. 去掉
1.刪除鍵的頭部 pretrained_dict = { k.replace('module.', ''): v for k, v in pretrained_dict2.items()}
當然有時候自己model的鍵需要改進,如下
2.補齊鍵的頭部 checkpoint={ 'module.'+k:v for k,v in pretrained_dict.items()}
2. 預訓練模型和自己的model長度不一樣
# 刪除pretrained_dict.items()中model所沒有的東西model_dict = model.state_dict()pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict} # 只保留預訓練模型中,自己建的model有的參數model_dict.update(pretrained_dict) # 將預訓練的值,更新到自己模型的dict中model.load_state_dict(model_dict) # model加載dict中的數據,更新網絡的初始值
for value1 ,value2 in zip(checkpoint.items(), model.state_dict().items()):print(value1,value2)
如下所示,model的參數和預訓練的參數是一樣的
(這里處理的只是針對本人的model加載的情況,要想正確加載,還需遵守上面3步)
def load_param(self, model_path):#這里的self就是modelmodel_dict = self.state_dict()pretrained_dict = torch.load(model_path)#這里model_path的后綴是.pth可直接讀取# pretrained_dict = {k.replace('module.', ''): v for k, v in# pretrained_dict.items()} # 因為pretrained_dict得到module.conv1.weight,但是自己建的model無module,只是conv1.weight,所以改寫下pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict} # 只保留預訓練模型中,自己建的model有的參數model_dict.update(pretrained_dict) # 將預訓練的值,更新到自己模型的dict中self.load_state_dict(model_dict) # model加載dict中的數據,更新網絡的初始值
感謝各位的閱讀,以上就是“pytorch加載模型遇到的問題怎么解決”的內容了,經過本文的學習后,相信大家對pytorch加載模型遇到的問題怎么解決這一問題有了更深刻的體會,具體使用情況還需要大家實踐驗證。這里是億速云,小編將為大家推送更多相關知識點的文章,歡迎關注!
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。