提问者:小点点

如何计算多标签分类的汉明分数


我试图计算多标签文本分类的汉明损失和汉明分数

def hamming_score(y_true, y_pred, normalize=True, sample_weight=None):
    acc_list = []
    for i in range(y_true.shape[0]):
        set_true = set( np.where(y_true[i])[0] )
        set_pred = set( np.where(y_pred[i])[0] )
        tmp_a = None
        if len(set_true) == 0 and len(set_pred) == 0:
        tmp_a = 1
        else:
            tmp_a = len(set_true.intersection(set_pred))/float(len(set_true.union(set_pred)))
        acc_list.append(tmp_a)
    return np.mean(acc_list)

def print_score(y_pred, clf):
    print("Clf: ", clf.__class__.__name__)
    print("Hamming loss: {}".format(hamming_loss(y_pred, y_test)))
    print("Hamming score: {}".format(hamming_score(y_pred, y_test)))
    print("---")   

nb_clf = MultinomialNB()
sgd = SGDClassifier(loss='hinge', penalty='l2', alpha=1e-3, random_state=42, max_iter=6, tol=None)
lr = LogisticRegression()
mn = MultinomialNB()

for classifier in [nb_clf, sgd, lr, mn]:
    clf = OneVsRestClassifier(classifier)
    clf.fit(x_train, y_train)
    y_pred = clf.predict(x_test)
    print_score(y_pred, classifier)

汉明损失的结果发生了,但汉明分数出现了错误,有人能帮我解决这个问题吗?非常感谢。

Clf:多项式NB汉明损失:0.01911111111-----------------------------------------------8 Clf中的ValueError回溯(最近一次调用)。安装(x_系列,y_系列)9 y_pred=clf。预测(x_检验)---

<ipython-input-313-60ed43baa4c1> in print_score(y_pred, clf)
     21     print("Clf: ", clf.__class__.__name__)
     22     print("Hamming loss: {}".format(hamming_loss(y_pred, y_test)))
---> 23     print("Hamming score: {}".format(hamming_score(predictions, y_test)))
     24     print("---")

<ipython-input-313-60ed43baa4c1> in hamming_score(y_true, y_pred, normalize, sample_weight)
      8     acc_list = []
      9     for i in range(y_true.shape[0]):
---> 10         set_true = set( np.where(y_true[i])[0] )
     11         set_pred = set( np.where(y_pred[i])[0] )
     12         tmp_a = None

~\Anaconda3\lib\site-packages\scipy\sparse\base.py in __bool__(self)
    285             return self.nnz != 0
    286         else:
--> 287             raise ValueError("The truth value of an array with more than one "
    288                              "element is ambiguous. Use a.any() or a.all().")
    289     __nonzero__ = __bool__

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all().

共1个答案

匿名用户

您好,您可以尝试下面的代码,看看它是否工作

  print("Hamming score: {}".format(hamming_score(y_pred, y_test,normalize=False)))