在PyTorch中,可以使用torch.save()函數來保存模型的參數。下面是保存模型參數的示例代碼:
import torch
# 定義模型
model = torch.nn.Linear(10, 1) # 示例模型
# 保存模型參數
torch.save(model.state_dict(), 'model.pth')
在上面的示例中,首先定義了一個模型(這里使用的是一個簡單的線性模型),然后使用model.state_dict()方法獲取模型的參數,并使用torch.save()函數將參數保存到文件’model.pth’中。
要加載模型參數,可以使用torch.load()函數:
import torch
# 定義模型
model = torch.nn.Linear(10, 1) # 示例模型
# 加載模型參數
model.load_state_dict(torch.load('model.pth'))
在上面的示例中,首先定義了一個模型(與保存模型參數時相同),然后使用torch.load()函數加載保存在’model.pth’文件中的參數,并使用model.load_state_dict()方法將參數加載到模型中。loadModel方法將參數加載到模型。