您好,登錄后才能下訂單哦!
這篇文章將為大家詳細講解有關如何計算pytorch標準化Normalize所需要數據集的均值和方差,小編覺得挺實用的,因此分享給大家做個參考,希望大家閱讀完這篇文章后可以有所收獲。
pytorch做標準化利用transforms.Normalize(mean_vals, std_vals),其中常用數據集的均值方差有:
if 'coco' in args.dataset: mean_vals = [0.471, 0.448, 0.408] std_vals = [0.234, 0.239, 0.242] elif 'imagenet' in args.dataset: mean_vals = [0.485, 0.456, 0.406] std_vals = [0.229, 0.224, 0.225]
計算自己數據集圖像像素的均值方差:
import numpy as np import cv2 import random # calculate means and std train_txt_path = './train_val_list.txt' CNum = 10000 # 挑選多少圖片進行計算 img_h, img_w = 32, 32 imgs = np.zeros([img_w, img_h, 3, 1]) means, stdevs = [], [] with open(train_txt_path, 'r') as f: lines = f.readlines() random.shuffle(lines) # shuffle , 隨機挑選圖片 for i in tqdm_notebook(range(CNum)): img_path = os.path.join('./train', lines[i].rstrip().split()[0]) img = cv2.imread(img_path) img = cv2.resize(img, (img_h, img_w)) img = img[:, :, :, np.newaxis] imgs = np.concatenate((imgs, img), axis=3) # print(i) imgs = imgs.astype(np.float32)/255. for i in tqdm_notebook(range(3)): pixels = imgs[:,:,i,:].ravel() # 拉成一行 means.append(np.mean(pixels)) stdevs.append(np.std(pixels)) # cv2 讀取的圖像格式為BGR,PIL/Skimage讀取到的都是RGB不用轉 means.reverse() # BGR --> RGB stdevs.reverse() print("normMean = {}".format(means)) print("normStd = {}".format(stdevs)) print('transforms.Normalize(normMean = {}, normStd = {})'.format(means, stdevs))
關于“如何計算pytorch標準化Normalize所需要數據集的均值和方差”這篇文章就分享到這里了,希望以上內容可以對大家有一定的幫助,使各位可以學到更多知識,如果覺得文章不錯,請把它分享出去讓更多的人看到。
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。