要自定義一個Dataset類,可以繼承自torch.utils.data.Dataset,并實現其中的__len__和__getitem__方法來定義數據集的長度和獲取數據的方式。
下面是一個簡單的例子:
import torch
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
sample = self.data[idx]
return sample
# 創建一個數據集實例
data = [1, 2, 3, 4, 5]
custom_dataset = CustomDataset(data)
# 獲取數據集的長度
print(len(custom_dataset))
# 獲取數據集中第一個樣本
print(custom_dataset[0])
在上面的例子中,我們定義了一個CustomDataset類,它接受一個數據列表作為輸入,并實現了__len__方法和__getitem__方法。通過實例化CustomDataset類,我們可以獲取數據集的長度并獲取數據集中的樣本。