要在C++中調用PyTorch模型,需要使用PyTorch C++ API。以下是一個基本的C++代碼示例,展示了如何加載模型、輸入數據并運行模型。
#include <torch/script.h>
#include <torch/torch.h>
int main() {
// 加載模型
std::string model_path = "path_to_model.pt";
torch::jit::script::Module module = torch::jit::load(model_path);
// 創建輸入張量
std::vector<float> input_data = {1.0, 2.0, 3.0, 4.0};
torch::Tensor inputs = torch::from_blob(input_data.data(), {1, 4});
// 將輸入張量傳遞給模型
std::vector<torch::jit::IValue> inputs_list;
inputs_list.push_back(inputs);
torch::jit::IValue output = module.forward(inputs_list);
// 提取輸出張量
torch::Tensor result = output.toTensor();
std::cout << result << std::endl;
return 0;
}
請確保已正確安裝PyTorch C++ API,并將path_to_model.pt
替換為實際模型的路徑。在代碼中,我們首先使用torch::jit::load()
加載模型,然后創建輸入張量,將其傳遞給模型的forward
方法,并通過output.toTensor()
獲取輸出張量。
有關更多詳細信息和示例,請參考PyTorch官方文檔:https://pytorch.org/cppdocs/