您好,登錄后才能下訂單哦!
本篇內容主要講解“PyTorch深度學習模型的保存和加載流程是什么”,感興趣的朋友不妨來看看。本文介紹的方法操作簡單快捷,實用性強。下面就讓小編來帶大家學習“PyTorch深度學習模型的保存和加載流程是什么”吧!
torch.save(module.state_dict(), path)
:使用module.state_dict()
函數獲取各層已經訓練好的參數和緩沖區,然后將參數和緩沖區保存到path
所指定的文件存放路徑(常用文件格式為.pt
、.pth
或.pkl
)。
torch.nn.Module.load_state_dict(state_dict)
:從state_dict
中加載參數和緩沖區到Module
及其子類中 。
torch.nn.Module.state_dict()
函數返回python
中的一個OrderedDict
類型字典對象,該對象將每一層與它的對應參數和緩沖區建立映射關系,字典的鍵值是參數或緩沖區的名稱。只有那些參數可以訓練的層才會被保存到OrderedDict
中,例如:卷積層、線性層等。
Python
中的字典類以“鍵:值
”方式存取數據,OrderedDict
是它的一個子類,實現了對字典對象中元素的排序(OrderedDict
根據放入元素的先后順序進行排序)。由于進行了排序,所以順序不同的兩個OrderedDict
字典對象會被當做是兩個不同的對象。
示例:
import torch import torch.nn as nn class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 2, 3) self.pool1 = nn.MaxPool2d(2, 2) def forward(self, x): x = self.conv1(x) x = self.pool1(x) return x # 初始化網絡 net = Net() net.conv1.weight[0].detach().fill_(1) net.conv1.weight[1].detach().fill_(2) net.conv1.bias.data.detach().zero_() # 獲取state_dict state_dict = net.state_dict() # 字典的遍歷默認是遍歷key,所以param_tensor實際上是鍵值 for param_tensor in state_dict: print(param_tensor,':\n',state_dict[param_tensor]) # 保存模型參數 torch.save(state_dict,"net_params.pth") # 通過加載state_dict獲取模型參數 net.load_state_dict(state_dict)
二、完整模型的保存和加載
torch.save(module, path)
:將訓練完的整個網絡模型module
保存到path
所指定的文件存放路徑(常用文件格式為.pt
或.pth
)。
torch.load(path)
:加載保存到path
中的整個神經網絡模型。
示例:
import torch import torch.nn as nn class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 2, 3) self.pool1 = nn.MaxPool2d(2, 2) def forward(self, x): x = self.conv1(x) x = self.pool1(x) return x # 初始化網絡 net = Net() net.conv1.weight[0].detach().fill_(1) net.conv1.weight[1].detach().fill_(2) net.conv1.bias.data.detach().zero_() # 保存整個網絡 torch.save(net,"net.pth") # 加載網絡 net = torch.load("net.pth")
到此,相信大家對“PyTorch深度學習模型的保存和加載流程是什么”有了更深的了解,不妨來實際操作一番吧!這里是億速云網站,更多相關內容可以進入相關頻道進行查詢,關注我們,繼續學習!
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。