在Torch中加載和處理數據集通常通過使用torch.utils.data.Dataset
和torch.utils.data.DataLoader
類來實現。以下是一個簡單的示例代碼:
import torch
from torch.utils.data import Dataset, DataLoader
# 定義自定義數據集類
class CustomDataset(Dataset):
def __init__(self):
# 初始化數據集
self.data = torch.randn(100, 10)
self.labels = torch.randint(0, 2, (100,))
def __len__(self):
# 返回數據集大小
return len(self.data)
def __getitem__(self, idx):
# 獲取數據集中的一個樣本
return self.data[idx], self.labels[idx]
# 創建數據集實例
dataset = CustomDataset()
# 創建數據加載器
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
# 遍歷數據集
for data, labels in dataloader:
# 處理每個批次的數據
print(data.shape, labels.shape)
在上面的示例中,定義了一個自定義的數據集類CustomDataset
,其中實現了__init__
、__len__
和__getitem__
方法。然后創建了dataset
實例和dataloader
對象,并使用for
循環遍歷數據加載器,獲取每個批次的數據。