在PyTorch中,可以通過繼承torch.utils.data.Dataset
類來創建自己的數據集。以下是一個簡單的示例代碼:
import torch
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
return sample
# 創建自己的數據集
data = [1, 2, 3, 4, 5]
custom_dataset = CustomDataset(data)
# 創建數據加載器
data_loader = torch.utils.data.DataLoader(custom_dataset, batch_size=2, shuffle=True)
# 遍歷數據加載器
for batch in data_loader:
print(batch)
在上面的示例中,首先定義了一個自定義的數據集CustomDataset
,該數據集繼承自torch.utils.data.Dataset
類,并實現了__init__
、__len__
和__getitem__
方法。然后創建了一個包含一些數據的實例data
,并使用它來實例化CustomDataset
類得到custom_dataset
。最后,使用DataLoader
將自定義數據集包裝成數據加載器,并遍歷數據加載器來獲取數據。