您好,登錄后才能下訂單哦!
這篇文章主要講解了“pytorch中計算準確率,召回率和F1值的方法”,文中的講解內容簡單清晰,易于學習與理解,下面請大家跟著小編的思路慢慢深入,一起來研究和學習“pytorch中計算準確率,召回率和F1值的方法”吧!
predict = output.argmax(dim = 1)
confusion_matrix =torch.zeros(2,2)
for t, p in zip(predict.view(-1), target.view(-1)):
confusion_matrix[t.long(), p.long()] += 1
a_p =(confusion_matrix.diag() / confusion_matrix.sum(1))[0]
b_p = (confusion_matrix.diag() / confusion_matrix.sum(1))[1]
a_r =(confusion_matrix.diag() / confusion_matrix.sum(0))[0]
b_r = (confusion_matrix.diag() / confusion_matrix.sum(0))[1]
補充:pytorch 查全率 recall 查準率 precision F1調和平均 準確率 accuracy
def eval():
net.eval()
test_loss = 0
correct = 0
total = 0
classnum = 9
target_num = torch.zeros((1,classnum))
predict_num = torch.zeros((1,classnum))
acc_num = torch.zeros((1,classnum))
for batch_idx, (inputs, targets) in enumerate(testloader):
if use_cuda:
inputs, targets = inputs.cuda(), targets.cuda()
inputs, targets = Variable(inputs, volatile=True), Variable(targets)
outputs = net(inputs)
loss = criterion(outputs, targets)
# loss is variable , if add it(+=loss) directly, there will be a bigger ang bigger graph.
test_loss += loss.data[0]
_, predicted = torch.max(outputs.data, 1)
total += targets.size(0)
correct += predicted.eq(targets.data).cpu().sum()
pre_mask = torch.zeros(outputs.size()).scatter_(1, predicted.cpu().view(-1, 1), 1.)
predict_num += pre_mask.sum(0)
tar_mask = torch.zeros(outputs.size()).scatter_(1, targets.data.cpu().view(-1, 1), 1.)
target_num += tar_mask.sum(0)
acc_mask = pre_mask*tar_mask
acc_num += acc_mask.sum(0)
recall = acc_num/target_num
precision = acc_num/predict_num
F1 = 2*recall*precision/(recall+precision)
accuracy = acc_num.sum(1)/target_num.sum(1)
#精度調整
recall = (recall.numpy()[0]*100).round(3)
precision = (precision.numpy()[0]*100).round(3)
F1 = (F1.numpy()[0]*100).round(3)
accuracy = (accuracy.numpy()[0]*100).round(3)
# 打印格式方便復制
print('recall'," ".join('%s' % id for id in recall))
print('precision'," ".join('%s' % id for id in precision))
print('F1'," ".join('%s' % id for id in F1))
print('accuracy',accuracy)
補充:Python scikit-learn,分類模型的評估,精確率和召回率,classification_report
分類模型的評估標準一般最常見使用的是準確率(estimator.score()),即預測結果正確的百分比。
準確率是相對所有分類結果;精確率、召回率、F1-score是相對于某一個分類的預測評估標準。
精確率(Precision):預測結果為正例樣本中真實為正例的比例(查的準)( )
召回率(Recall):真實為正例的樣本中預測結果為正例的比例(查的全)( )
分類的其他評估標準:F1-score,反映了模型的穩健型
demo.py(分類評估,精確率、召回率、F1-score,classification_report):
from sklearn.datasets import fetch_20newsgroups
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.metrics import classification_report
# 加載數據集 從scikit-learn官網下載新聞數據集(共20個類別)
news = fetch_20newsgroups(subset='all') # all表示下載訓練集和測試集
# 進行數據分割 (劃分訓練集和測試集)
x_train, x_test, y_train, y_test = train_test_split(news.data, news.target, test_size=0.25)
# 對數據集進行特征抽取 (進行特征提取,將新聞文檔轉化成特征詞重要性的數字矩陣)
tf = TfidfVectorizer() # tf-idf表示特征詞的重要性
# 以訓練集數據統計特征詞的重要性 (從訓練集數據中提取特征詞)
x_train = tf.fit_transform(x_train)
print(tf.get_feature_names()) # ["condensed", "condescend", ...]
x_test = tf.transform(x_test) # 不需要重新fit()數據,直接按照訓練集提取的特征詞進行重要性統計。
# 進行樸素貝葉斯算法的預測
mlt = MultinomialNB(alpha=1.0) # alpha表示拉普拉斯平滑系數,默認1
print(x_train.toarray()) # toarray() 將稀疏矩陣以稠密矩陣的形式顯示。
'''
[[ 0. 0. 0. ..., 0.04234873 0. 0. ]
[ 0. 0. 0. ..., 0. 0. 0. ]
...,
[ 0. 0.03934786 0. ..., 0. 0. 0. ]
'''
mlt.fit(x_train, y_train) # 填充訓練集數據
# 預測類別
y_predict = mlt.predict(x_test)
print("預測的文章類別為:", y_predict) # [4 18 8 ..., 15 15 4]
# 準確率
print("準確率為:", mlt.score(x_test, y_test)) # 0.853565365025
print("每個類別的精確率和召回率:", classification_report(y_test, y_predict, target_names=news.target_names))
'''
precision recall f1-score support
alt.atheism 0.86 0.66 0.75 207
comp.graphics 0.85 0.75 0.80 238
sport.baseball 0.96 0.94 0.95 253
...,
'''
召回率的意義(應用場景):產品的不合格率(不想漏掉任何一個不合格的產品,查全);癌癥預測(不想漏掉任何一個癌癥患者)
感謝各位的閱讀,以上就是“pytorch中計算準確率,召回率和F1值的方法”的內容了,經過本文的學習后,相信大家對pytorch中計算準確率,召回率和F1值的方法這一問題有了更深刻的體會,具體使用情況還需要大家實踐驗證。這里是億速云,小編將為大家推送更多相關知識點的文章,歡迎關注!
免責聲明:本站發布的內容(圖片、視頻和文字)以原創、轉載和分享為主,文章觀點不代表本網站立場,如果涉及侵權請聯系站長郵箱:is@yisu.com進行舉報,并提供相關證據,一經查實,將立刻刪除涉嫌侵權內容。