要加載訓練好的PyTorch模型,可以使用torch.load()函數來加載模型的參數和狀態字典。以下是一個加載并使用訓練好的模型的示例代碼:
import torch
import torchvision.models as models
# 實例化模型
model = models.resnet18()
# 加載訓練好的模型參數
model.load_state_dict(torch.load('path_to_saved_model.pth'))
# 設置模型為評估模式
model.eval()
# 使用模型進行推理
inputs = torch.randn(1, 3, 224, 224)
outputs = model(inputs)
# 打印預測結果
print(outputs)
在上述代碼中,首先使用torchvision.models模塊實例化了一個ResNet-18模型。然后使用load_state_dict()函數加載了訓練好的模型參數,需要提供模型參數保存的文件路徑。接著調用eval()方法將模型設置為評估模式,這將關閉模型中的一些訓練特定的操作,如Dropout。最后,將輸入數據傳遞給模型進行推理,并打印預測結果。
需要注意的是,加載模型時,要確保模型的結構與訓練時的結構完全一致,否則加載的模型參數可能會出錯。