中文字幕av专区_日韩电影在线播放_精品国产精品久久一区免费式_av在线免费观看网站

溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

如何利用pytorch自定義一個數據集

發布時間:2020-11-11 15:05:40 來源:億速云 閱讀:586 作者:Leah 欄目:開發技術

今天就跟大家聊聊有關如何利用pytorch自定義一個數據集,可能很多人都不太了解,為了讓大家更加了解,小編給大家總結了以下內容,希望大家根據這篇文章可以有所收獲。

自定義數據集

在訓練深度學習模型之前,樣本集的制作非常重要。在pytorch中,提供了一些接口和類,方便我們定義自己的數據集合,下面完整的試驗自定義樣本集的整個流程。

開發環境

  • Ubuntu 18.04
  • pytorch 1.0
  • pycharm
     

實驗目的

  1. 掌握pytorch中數據集相關的API接口和類
  2. 熟悉數據集制作的整個流程
     

實驗過程

1.收集圖像樣本

以簡單的貓狗二分類為例,可以在網上下載一些貓狗圖片。創建以下目錄:

  • data-------------根目錄
  • data/test-------測試集
  • data/train------訓練集
  • data/val--------驗證集
     

如何利用pytorch自定義一個數據集

在test/train/val之下在校分別創建2個文件夾,dog, cat

如何利用pytorch自定義一個數據集

cat, dog文件夾下分別存放2類圖像:

如何利用pytorch自定義一個數據集

標簽

種類標簽
cat0
dog1

之后寫一個簡單的python腳本,生成txt文件,用于指明每個圖像和標簽的對應關系。

格式: /cat/1.jpg 0 \n dog/1.jpg 1 \n .....

如圖:

如何利用pytorch自定義一個數據集

至此,樣本集的收集以及簡單歸類完成,下面將開始采用pytorch的數據集相關API和類。

2. 使用pytorch相關類,API對數據集進行封裝

2.1 pytorch中數據集相關的類,接口

pytorch中數據集相關的類位于torch.utils.data package中。

https://pytorch.org/docs/stable/data.html

如何利用pytorch自定義一個數據集

本次實驗,主要使用以下類:

torch.utils.data.Dataset
torch.utils.data.DataLoader

如何利用pytorch自定義一個數據集

Dataset類的使用: 所有的類都應該是此類的子類(也就是說應該繼承該類)。 所有的子類都要重寫(override) __len()__, __getitem()__ 這兩個方法。

方法作用
__len()__此方法應該提供數據集的大小(容量)
__getitem()__此方法應該提供支持下標索方式引訪問數據集

這里和Java抽象類很相似,在抽象類abstract class中,一般會定義一些抽象方法abstract method,抽象方法:只有方法名沒有方法的具體實現。如果一個子類繼承于該抽象類,要重寫(overrode)父類的抽象方法。

DataLoader類的使用:

如何利用pytorch自定義一個數據集

2.2 實現

使用到的python package

python package目的
numpy矩陣操作,對圖像進行轉置
skimage圖像處理,圖像I/O,圖像變換
matplotlib圖像的顯示,可視化
os一些文件查找操作
torchpytorch
torvisionpytorch

源碼

導入python包

import numpy as np
from skimage import io
from skimage import transform
import matplotlib.pyplot as plt
import os
import torch
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms
from torchvision.utils import make_grid

第一步:

定義一個子類,繼承Dataset類, 重寫 __len()__, __getitem()__ 方法。

細節:

1.數據集中一個一樣的表示:采用字典的形式sample = {'image': image, 'label': label}。

2.圖像的讀取:采用skimage.io進行讀取,讀取之后的結果為numpy.ndarray形式。

3.圖像變換:transform參數

# step1: 定義MyDataset類, 繼承Dataset, 重寫抽象方法:__len()__, __getitem()__
class MyDataset(Dataset):

 def __init__(self, root_dir, names_file, transform=None):
 self.root_dir = root_dir
 self.names_file = names_file
 self.transform = transform
 self.size = 0
 self.names_list = []

 if not os.path.isfile(self.names_file):
  print(self.names_file + 'does not exist!')
 file = open(self.names_file)
 for f in file:
  self.names_list.append(f)
  self.size += 1

 def __len__(self):
 return self.size

 def __getitem__(self, idx):
 image_path = self.root_dir + self.names_list[idx].split(' ')[0]
 if not os.path.isfile(image_path):
  print(image_path + 'does not exist!')
  return None
 image = io.imread(image_path) # use skitimage
 label = int(self.names_list[idx].split(' ')[1])

 sample = {'image': image, 'label': label}
 if self.transform:
  sample = self.transform(sample)

 return sample

第二步

實例化一個對象,并讀取和顯示數據集

train_dataset = MyDataset(root_dir='./data/train',
    names_file='./data/train/train.txt',
    transform=None)

plt.figure()
for (cnt,i) in enumerate(train_dataset):
 image = i['image']
 label = i['label']

 ax = plt.subplot(4, 4, cnt+1)
 ax.axis('off')
 ax.imshow(image)
 ax.set_title('label {}'.format(label))
 plt.pause(0.001)

 if cnt == 15:
 break

只顯示了部分數據,前部分全是cat

如何利用pytorch自定義一個數據集

第三步(可選 optional)

對數據集進行變換:一般收集到的圖像大小尺寸,亮度等存在差異,變換的目的就是使得數據歸一化。另一方面,可以通過變換進行數據增加data argument

關于pytorch中的變換transforms,請參考該系列之前的文章

由于數據集中樣本采用字典dicts形式表示。 因此不能直接調用torchvision.transofrms中的方法。

本實驗只進行尺寸歸一化Resize, 數據類型變換ToTensor操作。

Resize

# # 變換Resize
class Resize(object):

 def __init__(self, output_size: tuple):
 self.output_size = output_size

 def __call__(self, sample):
 # 圖像
 image = sample['image']
 # 使用skitimage.transform對圖像進行縮放
 image_new = transform.resize(image, self.output_size)
 return {'image': image_new, 'label': sample['label']}

ToTensor

# # 變換ToTensor
class ToTensor(object):

 def __call__(self, sample):
 image = sample['image']
 image_new = np.transpose(image, (2, 0, 1))
 return {'image': torch.from_numpy(image_new),
  'label': sample['label']}

第四步: 對整個數據集應用變換

細節: transformers.Compose() 將不同的幾個組合起來。先進行Resize, 再進行ToTensor

# 對原始的訓練數據集進行變換
transformed_trainset = MyDataset(root_dir='./data/train',
    names_file='./data/train/train.txt',
    transform=transforms.Compose(
    [Resize((224,224)),
    ToTensor()]
    ))

第五步: 使用DataLoader進行包裝

為何要使用DataLoader?

① 深度學習的輸入是mini_batch形式

② 樣本加載時候可能需要隨機打亂順序,shuffle操作

③ 樣本加載需要采用多線程

pytorch提供的DataLoader封裝了上述的功能,這樣使用起來更方便。

# 使用DataLoader可以利用多線程,batch,shuffle等
trainset_dataloader = DataLoader(dataset=transformed_trainset,
     batch_size=4,
     shuffle=True,
     num_workers=4)

可視化:

def show_images_batch(sample_batched):
 images_batch, labels_batch = \
 sample_batched['image'], sample_batched['label']
 grid = make_grid(images_batch)
 plt.imshow(grid.numpy().transpose(1, 2, 0))


# sample_batch: Tensor , NxCxHxW
plt.figure()
for i_batch, sample_batch in enumerate(trainset_dataloader):
 show_images_batch(sample_batch)
 plt.axis('off')
 plt.ioff()
 plt.show()


plt.show()

通過DataLoader包裝之后,樣本以min_batch形式輸出,而且進行了隨機打亂順序。

如何利用pytorch自定義一個數據集

如何利用pytorch自定義一個數據集

如何利用pytorch自定義一個數據集

如何利用pytorch自定義一個數據集

至此,自定義數據集的完整流程已實現,test, val集只需要改路徑即可。

補充

更簡單的方法

上述繼承Dataset, 重寫 __len()__, __getitem() 是通用的方法,過程相對繁瑣。對于簡單的分類數據集,pytorch中提供了更簡便的方式——ImageFolder。

如果每種類別的樣本放在各自的文件夾中,則可以直接使用ImageFolder。

仍然以cat, dog 二分類數據集為例:

文件結構:

如何利用pytorch自定義一個數據集
如何利用pytorch自定義一個數據集
如何利用pytorch自定義一個數據集

Code

import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
import numpy as np


# https://pytorch.org/tutorials/beginner/data_loading_tutorial.html

# data_transform = transforms.Compose([
#  transforms.RandomResizedCrop(224),
#  transforms.RandomHorizontalFlip(),
#  transforms.ToTensor(),
#  transforms.Normalize(mean=[0.485, 0.456, 0.406],
#       std=[0.229, 0.224, 0.225])
# ])

data_transform = transforms.Compose([
 transforms.Resize((224,224)),
 transforms.RandomHorizontalFlip(),
 transforms.ToTensor(),

])

train_dataset = datasets.ImageFolder(root='./data/train',transform=data_transform)
train_dataloader = DataLoader(dataset=train_dataset,
        batch_size=4,
        shuffle=True,
        num_workers=4)


def show_batch_images(sample_batch):
 labels_batch = sample_batch[1]
 images_batch = sample_batch[0]

 for i in range(4):
  label_ = labels_batch[i].item()
  image_ = np.transpose(images_batch[i], (1, 2, 0))
  ax = plt.subplot(1, 4, i + 1)
  ax.imshow(image_)
  ax.set_title(str(label_))
  ax.axis('off')
  plt.pause(0.01)


plt.figure()
for i_batch, sample_batch in enumerate(train_dataloader):
 show_batch_images(sample_batch)

 plt.show()

由于 train 目錄下只有2個文件夾,分別為cat, dog, 因此ImageFolder安裝順序對cat使用標簽0, dog使用標簽1。

如何利用pytorch自定義一個數據集

如何利用pytorch自定義一個數據集

如何利用pytorch自定義一個數據集

看完上述內容,你們對如何利用pytorch自定義一個數據集有進一步的了解嗎?如果還想了解更多知識或者相關內容,請關注億速云行業資訊頻道,感謝大家的支持。

向AI問一下細節

免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。

AI

会泽县| 荣昌县| 连州市| 青阳县| 子长县| 大理市| 民权县| 马鞍山市| 涞源县| 屏边| 临高县| 潢川县| 印江| 靖安县| 武安市| 黄冈市| 克拉玛依市| 阜阳市| 买车| 全南县| 新泰市| 静宁县| 司法| 河北省| 正镶白旗| 弋阳县| 湖口县| 牟定县| 宁明县| 益阳市| 磐安县| 通化县| 浮山县| 特克斯县| 周至县| 沅陵县| 唐海县| 思南县| 吉安县| 罗甸县| 盐城市|