要調用訓練好的模型,需要按照以下步驟進行:
import torch
import torch.nn as nn
class YourModel(nn.Module):
def __init__(self):
super(YourModel, self).__init__()
# 定義模型的結構
def forward(self, x):
# 定義模型的前向傳播邏輯
return x
model = YourModel()
model.load_state_dict(torch.load('path/to/your/trained/model.pth'))
確保將’path/to/your/trained/model.pth’替換為實際訓練好的模型參數文件的路徑。
model.eval()
現在,模型已經加載并準備好進行推理了。你可以使用模型進行預測,例如:
input_data = torch.randn(1, 3, 224, 224) # 模擬輸入數據
output = model(input_data)
請注意,為了正確預測,輸入數據的尺寸和模型的輸入尺寸應該匹配。根據你的具體模型和任務,你可能需要進行適當的數據預處理。
希望能幫助到你!