中文字幕av专区_日韩电影在线播放_精品国产精品久久一区免费式_av在线免费观看网站

溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

pytorch加載模型遇到的問題怎么解決

發布時間:2022-03-18 16:59:47 來源:億速云 閱讀:506 作者:iii 欄目:大數據

這篇文章主要講解了“pytorch加載模型遇到的問題怎么解決”,文中的講解內容簡單清晰,易于學習與理解,下面請大家跟著小編的思路慢慢深入,一起來研究和學習“pytorch加載模型遇到的問題怎么解決”吧!

1. 查看網絡參數

pretrained_dict1 = torch.load(model_path2, map_location='cpu')['state_dict']#預訓練文件后綴是.tarpretrained_dict2 = torch.load(model_path3)#預訓練文件后綴是.pth#1.查看預訓練網絡參數for key ,value in pretrained_dict1.items():#pretrained_dict1,pretrained_dict2就是上面的東西count+=1print(key)print(count)#2.查看model的網絡參數for key ,value in model.state_dict.items():print(key,value)

2. 加載模型遇到的兩大問題

1. 模型的鍵不匹配

以下兩代碼,解決了鍵不匹配問題,一個是刪除鍵的某一部分,一是添加鍵的某一部分

例:
下面的錯誤是因為模型的model.state_dict().items()的鍵是conv1.weight,預訓練的鍵是module.conv1.weight,導致不匹配。所以下面的代碼是讓module. 去掉
pytorch加載模型遇到的問題怎么解決

1.刪除鍵的頭部
pretrained_dict = {
   
   
   k.replace('module.', ''): v for k, v in pretrained_dict2.items()}

當然有時候自己model的鍵需要改進,如下

2.補齊鍵的頭部
checkpoint={
   
   
   'module.'+k:v for k,v in pretrained_dict.items()}

2. 預訓練模型和自己的model長度不一樣

# 刪除pretrained_dict.items()中model所沒有的東西model_dict = model.state_dict()pretrained_dict = {
   
   
   k: v for k, v in pretrained_dict.items() if k in model_dict}  # 只保留預訓練模型中,自己建的model有的參數model_dict.update(pretrained_dict)  # 將預訓練的值,更新到自己模型的dict中model.load_state_dict(model_dict)  # model加載dict中的數據,更新網絡的初始值

3. 通過查看加載參數,看是否加載成功

for value1 ,value2 in zip(checkpoint.items(), model.state_dict().items()):print(value1,value2)

如下所示,model的參數和預訓練的參數是一樣的
pytorch加載模型遇到的問題怎么解決

4. 案例

(這里處理的只是針對本人的model加載的情況,要想正確加載,還需遵守上面3步)

    def load_param(self, model_path):#這里的self就是modelmodel_dict = self.state_dict()pretrained_dict = torch.load(model_path)#這里model_path的后綴是.pth可直接讀取# pretrained_dict = {k.replace('module.', ''): v for k, v in#                    pretrained_dict.items()}  # 因為pretrained_dict得到module.conv1.weight,但是自己建的model無module,只是conv1.weight,所以改寫下pretrained_dict = {
   
   
   k: v for k, v in pretrained_dict.items() if k in model_dict}  # 只保留預訓練模型中,自己建的model有的參數model_dict.update(pretrained_dict)  # 將預訓練的值,更新到自己模型的dict中self.load_state_dict(model_dict)  # model加載dict中的數據,更新網絡的初始值

感謝各位的閱讀,以上就是“pytorch加載模型遇到的問題怎么解決”的內容了,經過本文的學習后,相信大家對pytorch加載模型遇到的問題怎么解決這一問題有了更深刻的體會,具體使用情況還需要大家實踐驗證。這里是億速云,小編將為大家推送更多相關知識點的文章,歡迎關注!

向AI問一下細節

免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。

AI

屏东市| 麻城市| 建水县| 左权县| 阜宁县| 乌海市| 江北区| 黄石市| 乐亭县| 民乐县| 淮北市| 边坝县| 新昌县| 乐陵市| 同江市| 新兴县| 秀山| 平潭县| 宁强县| 通榆县| 南充市| 东乌| 寻甸| 灵石县| 调兵山市| 婺源县| 中方县| 资兴市| 邹平县| 济源市| 蒙城县| 新余市| 柯坪县| 玉树县| 遵义县| 峡江县| 和田市| 封开县| 浮山县| 湘阴县| 汉中市|