在PyTorch中,數據加載器可以通過torch.utils.data.DataLoader
來實現。數據加載器可以幫助用戶批量加載數據,并可以在訓練過程中對數據進行隨機排列、并行加載等操作。
下面是一個簡單的示例,演示如何使用數據加載器來加載一個簡單的數據集:
import torch
from torch.utils.data import Dataset, DataLoader
# 創建一個自定義的數據集類
class CustomDataset(Dataset):
def __init__(self):
self.data = torch.randn(100, 3) # 100個3維的隨機數據
self.targets = torch.randint(0, 2, (100,)) # 100個隨機目標標簽
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.targets[idx]
# 創建數據集實例
dataset = CustomDataset()
# 創建數據加載器實例
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)
# 遍歷數據加載器
for i, (data, target) in enumerate(data_loader):
print(f'Batch {i}:')
print('Data:', data)
print('Target:', target)
在上述示例中,首先定義了一個自定義的數據集類CustomDataset
,然后創建了一個數據集實例dataset
。接著利用DataLoader
類來創建一個數據加載器實例data_loader
,并指定了批量大小為32且開啟了數據隨機排列。最后通過對數據加載器進行遍歷,便可以逐批次地獲取數據和標簽。