在PyTorch中讀取圖片有多種方法,常用的方法是使用torchvision
庫中的ImageFolder
和DataLoader
類。首先,需要將圖片數據集組織成以下格式:一個文件夾包含所有的類別文件夾,每個類別文件夾包含該類別的圖片。
import torch
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
# 定義數據轉換
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
# 讀取圖片數據集
dataset = ImageFolder(root='path_to_dataset', transform=transform)
# 創建數據加載器
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
# 遍歷數據加載器
for images, labels in dataloader:
# 進行模型訓練或其他操作
pass
在上面的代碼中,首先定義了一個數據轉換transform
,然后使用ImageFolder
類加載圖片數據集,最后創建了一個數據加載器dataloader
用于批量加載數據。通過遍歷數據加載器,可以獲得每個batch的圖片數據和對應的標簽。