中文字幕av专区_日韩电影在线播放_精品国产精品久久一区免费式_av在线免费观看网站

溫馨提示×

溫馨提示×

您好,登錄后才能下訂單哦!

密碼登錄×
登錄注冊×
其他方式登錄
點擊 登錄注冊 即表示同意《億速云用戶服務條款》

Pytorch中怎么利用ResNet50實現圖像分類

發布時間:2021-08-10 15:15:58 來源:億速云 閱讀:382 作者:Leah 欄目:編程語言

這期內容當中小編將會給大家帶來有關Pytorch中怎么利用ResNet50實現圖像分類,文章內容豐富且以專業的角度為大家分析和敘述,閱讀完這篇文章希望大家可以有所收獲。


模型


Torchvision.models包里面包含了常見的各種基礎模型架構,主要包括:

AlexNet
VGG
ResNet
SqueezeNet
DenseNet
Inception v3
GoogLeNet
ShuffleNet v2
MobileNet v2
ResNeXt
Wide ResNet
MNASNet

這里我選擇了ResNet50,基于ImageNet訓練的基礎網絡來實現圖像分類, 網絡模型下載與加載如下:


  • model = torchvision.models.resnet50(pretrained=True).eval().cuda()

  • tf = transforms.Compose([

  •             transforms.Resize(256),

  •             transforms.CenterCrop(224),

  •             transforms.ToTensor(),

  •             transforms.Normalize(

  •             mean=[0.485, 0.456, 0.406],

  •             std=[0.229, 0.224, 0.225]

  •         )])


使用模型實現圖像分類


這里首先需要加載ImageNet的分類標簽,目的是最后顯示分類的文本標簽時候使用。然后對輸入圖像完成預處理,使用ResNet50模型實現分類預測,對預測結果解析之后,顯示標簽文本,完整的代碼演示如下:

 1with open('imagenet_classes.txt') as f:
2    labels = [line.strip() for line in f.readlines()]
3
4src = cv.imread("D:/images/space_shuttle.jpg") # aeroplane.jpg
5image = cv.resize(src, (224, 224))
6image = np.float32(image) / 255.0
7image[:,:,] -= (np.float32(0.485), np.float32(0.456), np.float32(0.406))
8image[:,:,] /= (np.float32(0.229), np.float32(0.224), np.float32(0.225))
9image = image.transpose((2, 0, 1))
10input_x = torch.from_numpy(image).unsqueeze(0)
11print(input_x.size())
12pred = model(input_x.cuda())
13pred_index = torch.argmax(pred, 1).cpu().detach().numpy()
14print(pred_index)
15print("current predict class name : %s"%labels[pred_index[0]])
16cv.putText(src, labels[pred_index[0]], (50, 50), cv.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2)
17cv.imshow("input", src)
18cv.waitKey(0)
19cv.destroyAllWindows()

運行結果如下:

Pytorch中怎么利用ResNet50實現圖像分類


轉ONNX支持


在torchvision中的模型基本上都可以轉換為ONNX格式,而且被OpenCV DNN模塊所支持,所以,很方便的可以對torchvision自帶的模型轉為ONNX,實現OpenCV DNN的調用,首先轉為ONNX模型,直接使用torch.onnx.export即可轉換(還不知道怎么轉,快點看前面的例子)。轉換之后使用OpenCV DNN調用的代碼如下:

 1with open('imagenet_classes.txt') as f:
2    labels = [line.strip() for line in f.readlines()]
3net = cv.dnn.readNetFromONNX("resnet.onnx")
4src = cv.imread("D:/images/messi.jpg")  # aeroplane.jpg
5image = cv.resize(src, (224, 224))
6image = np.float32(image) / 255.0
7image[:, :, ] -= (np.float32(0.485), np.float32(0.456), np.float32(0.406))
8image[:, :, ] /= (np.float32(0.229), np.float32(0.224), np.float32(0.225))
9blob = cv.dnn.blobFromImage(image, 1.0, (224, 224), (0, 0, 0), False)
10net.setInput(blob)
11probs = net.forward()
12index = np.argmax(probs)
13cv.putText(src, labels[index], (50, 50), cv.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2)
14cv.imshow("input", src)
15cv.waitKey(0)
16cv.destroyAllWindows()


上述就是小編為大家分享的Pytorch中怎么利用ResNet50實現圖像分類了,如果剛好有類似的疑惑,不妨參照上述分析進行理解。如果想知道更多相關知識,歡迎關注億速云行業資訊頻道。

向AI問一下細節

免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。

AI

澎湖县| 天镇县| 驻马店市| 石台县| 彝良县| 弋阳县| 德钦县| 新泰市| 望江县| 徐水县| 保靖县| 阿拉善左旗| 讷河市| 南汇区| 皋兰县| 聂荣县| 万全县| 南丰县| 泸州市| 泾源县| 巴里| 互助| 湖州市| 丹凤县| 施秉县| 广河县| 福海县| 韩城市| 额敏县| 梧州市| 厦门市| 浙江省| 阳原县| 定结县| 德州市| 甘德县| 新兴县| 育儿| 武隆县| 锡林郭勒盟| 若羌县|