您好,登錄后才能下訂單哦!
這篇文章主要介紹pytorch 6中batch_train批訓練操作的示例分析,文中介紹的非常詳細,具有一定的參考價值,感興趣的小伙伴們一定要看完!
import torch import torch.utils.data as Data torch.manual_seed(1) # reproducible # BATCH_SIZE = 5 BATCH_SIZE = 8 # 每次使用8個數據同時傳入網路 x = torch.linspace(1, 10, 10) # this is x data (torch tensor) y = torch.linspace(10, 1, 10) # this is y data (torch tensor) torch_dataset = Data.TensorDataset(x, y) loader = Data.DataLoader( dataset=torch_dataset, # torch TensorDataset format batch_size=BATCH_SIZE, # mini batch size shuffle=False, # 設置不隨機打亂數據 random shuffle for training num_workers=2, # 使用兩個進程提取數據,subprocesses for loading data ) def show_batch(): for epoch in range(3): # 全部的數據使用3遍,train entire dataset 3 times for step, (batch_x, batch_y) in enumerate(loader): # for each training step # train your data... print('Epoch: ', epoch, '| Step: ', step, '| batch x: ', batch_x.numpy(), '| batch y: ', batch_y.numpy()) if __name__ == '__main__': show_batch()
Epoch: 0 | Step: 0 | batch x: [1. 2. 3. 4. 5. 6. 7. 8.] | batch y: [10. 9. 8. 7. 6. 5. 4. 3.] Epoch: 0 | Step: 1 | batch x: [ 9. 10.] | batch y: [2. 1.] Epoch: 1 | Step: 0 | batch x: [1. 2. 3. 4. 5. 6. 7. 8.] | batch y: [10. 9. 8. 7. 6. 5. 4. 3.] Epoch: 1 | Step: 1 | batch x: [ 9. 10.] | batch y: [2. 1.] Epoch: 2 | Step: 0 | batch x: [1. 2. 3. 4. 5. 6. 7. 8.] | batch y: [10. 9. 8. 7. 6. 5. 4. 3.] Epoch: 2 | Step: 1 | batch x: [ 9. 10.] | batch y: [2. 1.]
補充:pytorch批訓練bug
在進行pytorch神經網絡批訓練的時候,有時會出現報錯
TypeError: batch must contain tensors, numbers, dicts or lists; found <class 'torch.autograd.variable.Variable'>
檢查(重點!!!!!):
train_dataset = Data.TensorDataset(train_x, train_y)
train_x,和train_y格式,要求是tensor類,我第一次出錯就是因為傳入的是variable
可以這樣將數據變為tensor類:
train_x = torch.FloatTensor(train_x)
train_loader = Data.DataLoader( dataset=train_dataset, batch_size=batch_size, shuffle=True )
實例化一個DataLoader對象
for epoch in range(epochs): for step, (batch_x, batch_y) in enumerate(train_loader): batch_x, batch_y = Variable(batch_x), Variable(batch_y)
這樣就可以批訓練了
需要注意的是:train_loader輸出的是tensor,在訓練網絡時,需要變成Variable
以上是“pytorch 6中batch_train批訓練操作的示例分析”這篇文章的所有內容,感謝各位的閱讀!希望分享的內容對大家有幫助,更多相關知識,歡迎關注億速云行業資訊頻道!
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。