中文字幕av专区_日韩电影在线播放_精品国产精品久久一区免费式_av在线免费观看网站

溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

Pytorch如何實現手寫數字mnist識別功能

發布時間:2021-05-24 13:47:44 來源:億速云 閱讀:173 作者:小新 欄目:開發技術

這篇文章給大家分享的是有關Pytorch如何實現手寫數字mnist識別功能的內容。小編覺得挺實用的,因此分享給大家做個參考,一起跟隨小編過來看看吧。

本文實例講述了Pytorch實現的手寫數字mnist識別功能。分享給大家供大家參考,具體如下:

import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import argparse
# 定義是否使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 定義網絡結構
class LeNet(nn.Module):
  def __init__(self):
    super(LeNet, self).__init__()
    self.conv1 = nn.Sequential(   #input_size=(1*28*28)
      nn.Conv2d(1, 6, 5, 1, 2), #padding=2保證輸入輸出尺寸相同
      nn.ReLU(),   #input_size=(6*28*28)
      nn.MaxPool2d(kernel_size=2, stride=2),#output_size=(6*14*14)
    )
    self.conv2 = nn.Sequential(
      nn.Conv2d(6, 16, 5),
      nn.ReLU(),   #input_size=(16*10*10)
      nn.MaxPool2d(2, 2) #output_size=(16*5*5)
    )
    self.fc1 = nn.Sequential(
      nn.Linear(16 * 5 * 5, 120),
      nn.ReLU()
    )
    self.fc2 = nn.Sequential(
      nn.Linear(120, 84),
      nn.ReLU()
    )
    self.fc3 = nn.Linear(84, 10)
  # 定義前向傳播過程,輸入為x
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    # nn.Linear()的輸入輸出都是維度為一的值,所以要把多維度的tensor展平成一維
    x = x.view(x.size()[0], -1)
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    return x
#使得我們能夠手動輸入命令行參數,就是讓風格變得和Linux命令行差不多
parser = argparse.ArgumentParser()
parser.add_argument('--outf', default='./model/', help='folder to output images and model checkpoints') #模型保存路徑
parser.add_argument('--net', default='./model/net.pth', help="path to netG (to continue training)") #模型加載路徑
opt = parser.parse_args()
# 超參數設置
EPOCH = 8  #遍歷數據集次數
BATCH_SIZE = 64   #批處理尺寸(batch_size)
LR = 0.001    #學習率
# 定義數據預處理方式
transform = transforms.ToTensor()
# 定義訓練數據集
trainset = tv.datasets.MNIST(
  root='./data/',
  train=True,
  download=True,
  transform=transform)
# 定義訓練批處理數據
trainloader = torch.utils.data.DataLoader(
  trainset,
  batch_size=BATCH_SIZE,
  shuffle=True,
  )
# 定義測試數據集
testset = tv.datasets.MNIST(
  root='./data/',
  train=False,
  download=True,
  transform=transform)
# 定義測試批處理數據
testloader = torch.utils.data.DataLoader(
  testset,
  batch_size=BATCH_SIZE,
  shuffle=False,
  )
# 定義損失函數loss function 和優化方式(采用SGD)
net = LeNet().to(device)
criterion = nn.CrossEntropyLoss() # 交叉熵損失函數,通常用于多分類問題上
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)
# 訓練
if __name__ == "__main__":
  for epoch in range(EPOCH):
    sum_loss = 0.0
    # 數據讀取
    for i, data in enumerate(trainloader):
      inputs, labels = data
      inputs, labels = inputs.to(device), labels.to(device)
      # 梯度清零
      optimizer.zero_grad()
      # forward + backward
      outputs = net(inputs)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()
      # 每訓練100個batch打印一次平均loss
      sum_loss += loss.item()
      if i % 100 == 99:
        print('[%d, %d] loss: %.03f'
           % (epoch + 1, i + 1, sum_loss / 100))
        sum_loss = 0.0
    # 每跑完一次epoch測試一下準確率
    with torch.no_grad():
      correct = 0
      total = 0
      for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = net(images)
        # 取得分最高的那個類
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()
      print('第%d個epoch的識別準確率為:%d%%' % (epoch + 1, (100 * correct / total)))
  #torch.save(net.state_dict(), '%s/net_%03d.pth' % (opt.outf, epoch + 1))

pytorch的優點

1.PyTorch是相當簡潔且高效快速的框架;2.設計追求最少的封裝;3.設計符合人類思維,它讓用戶盡可能地專注于實現自己的想法;4.與google的Tensorflow類似,FAIR的支持足以確保PyTorch獲得持續的開發更新;5.PyTorch作者親自維護的論壇 供用戶交流和求教問題6.入門簡單

感謝各位的閱讀!關于“Pytorch如何實現手寫數字mnist識別功能”這篇文章就分享到這里了,希望以上內容可以對大家有一定的幫助,讓大家可以學到更多知識,如果覺得文章不錯,可以把它分享出去讓更多的人看到吧!

向AI問一下細節

免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。

AI

新源县| 花莲县| 天长市| 绩溪县| 桑日县| 灌云县| 玉龙| 忻州市| 图们市| 上虞市| 三台县| 图木舒克市| 秀山| 杭州市| 淮南市| 淄博市| 汝城县| 南充市| 南乐县| 邓州市| 吉林省| 彭阳县| 西吉县| 珠海市| 开化县| 和田县| 无锡市| 略阳县| 齐河县| 沁水县| 溆浦县| 昌黎县| 沅江市| 贵阳市| 正镶白旗| 娱乐| 称多县| 辽宁省| 汤阴县| 永福县| 千阳县|