提问者:小点点

你如何看待TensorFlow图像分类模型对每个类的准确性?


以下是tensor flow网站上的图像分类教程:https://www.tensorflow.org/tutorials/images/classification

该模型将花卉分为5类:雏菊、蒲公英、玫瑰、向日葵和郁金香。

我可以看出总体准确度是多少,但有没有办法知道每门课的准确度?

例如,我的模型可以很好地预测雏菊、蒲公英、玫瑰和向日葵(接近100%的准确率),郁金香(接近0%)很差,我想我仍然会看到80%的整体准确率(假设类是平衡的)。我需要知道单个类的准确度,以便将性能与预测所有类的模型区分为大约相等的80%准确度。


共2个答案

匿名用户

您只需在sklearn中使用分类报告就可以做到这一点。

参考文件

匿名用户

当我问这个问题时,我没有足够的python(或scikit)知识来回答。分类报告(如prashant0598所建议的)接近于我所需要的,尽管它实际上没有准确性。以下是如何使用分类报告:

from sklearn.metrics import classification_report
import pandas as pd

y_pred = model.predict(val_ds)
y_pred = np.argmax(y_pred, axis=1)

y_true = np.concatenate([y for x, y in val_ds], axis=0)

cr = classification_report(y_true, y_pred, output_dict=True, target_names=class_names)
pd.DataFrame.from_dict(cr)

分类报告输出(除其他外)准召,这有助于。

为了得到类的准确性,我们必须更手动地这样做。这里有一个方法:

from sklearn.metrics import accuracy_score

def class_accuracy(class_no):
  pred_filter = y_true==class_no
  acc = accuracy_score(y_true[pred_filter], y_pred[pred_filter])
  return acc

{class_name: class_accuracy(i) for i, class_name in enumerate(class_names)}

{'daisy':0.6589147286821705,
蒲公英:0.75,
玫瑰:0.6,
向日葵:0.868421052631579,
郁金香:0.6942675159235668}

所以现在我知道了,向日葵是最容易预测的,而玫瑰则特别狡猾!