在PyTorch中,可以使用torch.save()
函數來實現模型的持久化。torch.save()
函數可以將模型的權重、結構和其他參數保存到文件中,以便在以后加載和使用。以下是一個簡單的示例:
import torch
import torch.nn as nn
#定義一個簡單的神經網絡模型
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 1)
def forward(self, x):
return self.fc(x)
model = SimpleModel()
#保存模型
torch.save(model.state_dict(), 'model.pth')
#加載模型
model_load = SimpleModel()
model_load.load_state_dict(torch.load('model.pth'))
在上面的示例中,首先定義了一個簡單的神經網絡模型SimpleModel
,然后通過torch.save()
函數將模型的參數保存到文件model.pth
中。最后使用torch.load()
函數加載模型參數,并將其應用到新的模型中。通過這種方法,可以實現模型的持久化和加載。