在 PyTorch 中,可以使用 torchvision.datasets
模塊來加載常見的數據集,如 MNIST、CIFAR-10 等。這些數據集通常會被下載到本地,并返回一個 Dataset
對象,可以通過 DataLoader
對象來對數據集進行批量加載和隨機打亂。
以下是一個加載 MNIST 數據集的示例代碼:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 定義數據預處理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# 加載 MNIST 訓練集和測試集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# 創建 DataLoader 對象
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
# 遍歷數據集
for inputs, labels in train_loader:
# 在這里進行模型訓練
pass
上面的代碼首先定義了數據預處理的方法 transform
,然后使用 datasets.MNIST
加載了 MNIST 數據集的訓練集和測試集,并創建了對應的 DataLoader
對象 train_loader
和 test_loader
。最后,可以通過遍歷 train_loader
來逐批獲取訓練數據和標簽,并進行模型訓練。