中文字幕av专区_日韩电影在线播放_精品国产精品久久一区免费式_av在线免费观看网站

溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

tensorflow入門:TFRecordDataset變長數據的batch讀取詳解

發布時間:2020-10-22 19:35:31 來源:腳本之家 閱讀:418 作者:yeqiustu 欄目:開發技術

在上一篇文章tensorflow入門:tfrecord 和tf.data.TFRecordDataset的使用里,講到了使用如何使用tf.data.TFRecordDatase來對tfrecord文件進行batch讀取,即使用dataset的batch方法進行;但如果每條數據的長度不一樣(常見于語音、視頻、NLP等領域),則不能直接用batch方法獲取數據,這時則有兩個解決辦法:

1.在把數據寫入tfrecord時,先把數據pad到統一的長度再寫入tfrecord;這個方法的問題在于:若是有大量數據的長度都遠遠小于最大長度,則會造成存儲空間的大量浪費。

2.使用dataset中的padded_batch方法來進行,參數padded_shapes #指明每條記錄中各成員要pad成的形狀,成員若是scalar,則用[],若是list,則用[mx_length],若是array,則用[d1,...,dn],假如各成員的順序是scalar數據、list數據、array數據,則padded_shapes=([], [mx_length], [d1,...,dn]);該方法的函數說明如下:

padded_batch(
 batch_size,
 padded_shapes,
 padding_values=None #默認使用各類型數據的默認值,一般使用時可忽略該項
)

使用mnist數據來舉例說明,首先在把mnist寫入tfrecord之前,把mnist數據進行更改,以使得每個mnist圖像的大小不等,如下:

import tensorflow as tf
from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets
 
mnist = read_data_sets("MNIST_data/", one_hot=True)
 
 
def get_tfrecords_example(feature, label):
 tfrecords_features = {}
 feat_shape = feature.shape
 tfrecords_features['feature'] = tf.train.Feature(float_list=tf.train.FloatList(value=feature))
 tfrecords_features['shape'] = tf.train.Feature(int64_list=tf.train.Int64List(value=list(feat_shape)))
 tfrecords_features['label'] = tf.train.Feature(float_list=tf.train.FloatList(value=label))
 return tf.train.Example(features=tf.train.Features(feature=tfrecords_features))
 
 
def make_tfrecord(data, outf_nm='mnist-train'):
 feats, labels = data
 outf_nm += '.tfrecord'
 tfrecord_wrt = tf.python_io.TFRecordWriter(outf_nm)
 ndatas = len(labels)
 print(feats[0].dtype, feats[0].shape, ndatas)
 assert len(labels[0]) > 1
 for inx in range(ndatas):
 ed = random.randint(0,3) #隨機丟掉幾個數據點,以使長度不等
 exmp = get_tfrecords_example(feats[inx][:-ed], labels[inx])
 exmp_serial = exmp.SerializeToString()
 tfrecord_wrt.write(exmp_serial)
 tfrecord_wrt.close()
 
import random
nDatas = len(mnist.train.labels)
inx_lst = range(nDatas)
random.shuffle(inx_lst)
random.shuffle(inx_lst)
ntrains = int(0.85*nDatas)
 
# make training set
data = ([mnist.train.images[i] for i in inx_lst[:ntrains]], \
 [mnist.train.labels[i] for i in inx_lst[:ntrains]])
make_tfrecord(data, outf_nm='mnist-train')
 
# make validation set
data = ([mnist.train.images[i] for i in inx_lst[ntrains:]], \
 [mnist.train.labels[i] for i in inx_lst[ntrains:]])
make_tfrecord(data, outf_nm='mnist-val')
 
# make test set
data = (mnist.test.images, mnist.test.labels)
make_tfrecord(data, outf_nm='mnist-test')

用dataset加載批量數據,在解析數據時用到tf.VarLenFeature(tf.datatype),而非tf.FixedLenFeature([], tf.datatype)},且要配合tf.sparse_tensor_to_dense函數使用,如下:

import tensorflow as tf
 
train_f, val_f, test_f = ['mnist-%s.tfrecord'%i for i in ['train', 'val', 'test']]
 
def parse_exmp(serial_exmp):
 feats = tf.parse_single_example(serial_exmp, features={'feature':tf.VarLenFeature(tf.float32),\
 'label':tf.FixedLenFeature([10],tf.float32), 'shape':tf.FixedLenFeature([], tf.int64)})
 image = tf.sparse_tensor_to_dense(feats['feature']) #使用VarLenFeature讀入的是一個sparse_tensor,用該函數進行轉換
 label = tf.reshape(feats['label'],[2,5]) #把label變成[2,5],以說明array數據如何padding
 shape = tf.cast(feats['shape'], tf.int32)
 return image, label, shape
 
def get_dataset(fname):
 dataset = tf.data.TFRecordDataset(fname)
 return dataset.map(parse_exmp) # use padded_batch method if padding needed
 
epochs = 16
batch_size = 50 
padded_shapes = ([784],[3,5],[]) #把image pad至784,把label pad至[3,5],shape是一個scalar,不輸入數字
# training dataset
dataset_train = get_dataset(train_f)
dataset_train = dataset_train.repeat(epochs).shuffle(1000).padded_batch(batch_size, padded_shapes=padded_shapes)

以上這篇tensorflow入門:TFRecordDataset變長數據的batch讀取詳解就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持億速云。

向AI問一下細節

免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。

AI

伊宁市| 丰台区| 焉耆| 望奎县| 东乡县| 松阳县| 宜阳县| 桓仁| 和平区| 搜索| 固始县| 万载县| 镇坪县| 海阳市| 米脂县| 抚顺县| 宝丰县| 凤山县| 柯坪县| 湟中县| 建平县| 秭归县| 阳江市| 武陟县| 灵宝市| 长乐市| 宜城市| 泰兴市| 柏乡县| 密云县| 聂拉木县| 肇庆市| 仙居县| 离岛区| 达日县| 株洲县| 博客| 万年县| 广河县| 保康县| 安新县|