在PyTorch中,你可以使用torch.save()
函數將模型保存為文件,使用torch.load()
函數加載保存的模型文件。以下是保存和加載模型的示例代碼:
import torch
import torch.nn as nn
# 定義模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
x = self.fc(x)
return x
model = Net()
# 保存模型
torch.save(model.state_dict(), 'model.pth')
# 加載模型
model.load_state_dict(torch.load('model.pth'))
在上述代碼中,model.state_dict()
函數用于獲取模型的參數狀態字典,然后使用torch.save()
函數將其保存為文件。加載模型時,使用torch.load()
函數加載保存的模型文件,然后使用model.load_state_dict()
函數將模型參數加載到模型中。
注意:保存模型時只保存了模型的參數,而不保存模型的結構。在加載模型時,需要首先創建相同的模型結構,然后再加載參數。