您好,登錄后才能下訂單哦!
小編給大家分享一下Pytorch中的Dataset和DataLoader怎么用,相信大部分人都還不怎么了解,因此分享這篇文章給大家參考一下,希望大家閱讀完這篇文章后大有收獲,下面讓我們一起去了解一下吧!
確保安裝
scikit-image
numpy
一個例子:
# 導入需要的包
import torch
import torch.utils.data.dataset as Dataset
import numpy as np
# 編造數據
Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])
# 數據[1,2],對應的標簽是[0],數據[3,4],對應的標簽是[1]
#創建子類
class subDataset(Dataset.Dataset):
#初始化,定義數據內容和標簽
def __init__(self, Data, Label):
self.Data = Data
self.Label = Label
#返回數據集大小
def __len__(self):
return len(self.Data)
#得到數據內容和標簽
def __getitem__(self, index):
data = torch.Tensor(self.Data[index])
label = torch.IntTensor(self.Label[index])
return data, label
# 主函數
if __name__ == '__main__':
dataset = subDataset(Data, Label)
print(dataset)
print('dataset大小為:', dataset.__len__())
print(dataset.__getitem__(0))
print(dataset[0])
輸出的結果
我們有了對Dataset的一個整體的把握,再來分析里面的細節:
#創建子類
class subDataset(Dataset.Dataset):
創建子類時,繼承的時Dataset.Dataset,不是一個Dataset。因為Dataset是module模塊,不是class類,所以需要調用module里的class才行,因此是Dataset.Dataset!
len和getitem這兩個函數,前者給出數據集的大小**,后者是用于查找數據和標簽。是最重要的兩個函數,我們后續如果要對數據做一些操作基本上都是再這兩個函數的基礎上進行。
DataLoader(dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_works=0,
clollate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
multiprocessing_context=None)
功能:構建可迭代的數據裝載器;
dataset:Dataset類,決定數據從哪里讀取及如何讀取;數據集的路徑
batchsize:批大小;
num_works:是否多進程讀取數據;只對于CPU
shuffle:每個epoch是否打亂;
drop_last:當樣本數不能被batchsize整除時,是否舍棄最后一批數據;
Epoch:所有訓練樣本都已輸入到模型中,稱為一個Epoch;
Iteration:一批樣本輸入到模型中,稱之為一個Iteration;
Batchsize:批大小,決定一個Epoch中有多少個Iteration;
還是舉一個實例:
import torch
import torch.utils.data.dataset as Dataset
import torch.utils.data.dataloader as DataLoader
import numpy as np
Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])
#創建子類
class subDataset(Dataset.Dataset):
#初始化,定義數據內容和標簽
def __init__(self, Data, Label):
self.Data = Data
self.Label = Label
#返回數據集大小
def __len__(self):
return len(self.Data)
#得到數據內容和標簽
def __getitem__(self, index):
data = torch.Tensor(self.Data[index])
label = torch.IntTensor(self.Label[index])
return data, label
if __name__ == '__main__':
dataset = subDataset(Data, Label)
print(dataset)
print('dataset大小為:', dataset.__len__())
print(dataset.__getitem__(0))
print(dataset[0])
#創建DataLoader迭代器,相當于我們要先定義好前面說的Dataset,然后再用Dataloader來對數據進行一些操作,比如是否需要打亂,則shuffle=True,是否需要多個進程讀取數據num_workers=4,就是四個進程
dataloader = DataLoader.DataLoader(dataset,batch_size= 2, shuffle = False, num_workers= 4)
for i, item in enumerate(dataloader): #可以用enumerate來提取出里面的數據
print('i:', i)
data, label = item #數據是一個元組
print('data:', data)
print('label:', label)
總結下來時有兩種方法解決
1.如果在創建Dataset的類時,定義__getitem__方法的時候,將數據轉變為GPU類型。則需要將Dataloader里面的參數num_workers設置為0,因為這個參數是對于CPU而言的。如果數據改成了GPU,則只能單進程。如果是在Dataloader的部分,先多個子進程讀取,再轉變為GPU,則num_wokers不用修改。就是上述__getitem__部分的代碼,移到Dataloader部分。
2.不過一般來講,數據集和標簽不會像我們上述編輯的那么簡單。一般再kaggle上的標簽都是存在CSV這種文件中。需要pandas的配合。
以上是“Pytorch中的Dataset和DataLoader怎么用”這篇文章的所有內容,感謝各位的閱讀!相信大家都有了一定的了解,希望分享的內容對大家有所幫助,如果還想學習更多知識,歡迎關注億速云行業資訊頻道!
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。