votes up 8

`num_labels` is needed only when `multi_label` is True.

Package:
Exception Class:
ValueError

Raise code

._built = False
    if self.multi_label:
      if num_labels:
        shape = tensor_shape.TensorShape([None, num_labels])
        self._build(shape)
    else:
      if num_labels:
        raise ValueError(
            '`num_labels` is needed only when `multi_label` is True.')
      self._build(None)

  @property
  def thresholds(self):
    """The thresholds used for evaluating AUC."""
    return list(self._thresholds)

  def _
😲 Agile task management is now easier than calling a taxi. #Tracklify

Ways to fix

votes up 2 votes down

If multi_label is False the num_labels shouldn't be given.

Reproducing the error:

import tensorflow as tf
m = tf.keras.metrics.AUC(num_thresholds=3,num_labels=2)
m.update_state([[0,1], [1,0], [1,0], [0,1]], [[0,1], [0.5,0.5], [0.3,0.7], [1,0.9]])
print( m.result().numpy())

The error output:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-36-9666877749e9> in <module>()
----> 1 m = tf.keras.metrics.AUC(num_thresholds=3,num_labels=2)
      2 m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
      3 print( m.result().numpy())

/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/metrics.py in __init__(self, num_thresholds, curve, summation_method, name, dtype, thresholds, multi_label, num_labels, label_weights, from_logits)
   2134       if num_labels:
   2135         raise ValueError(
-> 2136             '`num_labels` is needed only when `multi_label` is True.')
   2137       self._build(None)
   2138 

ValueError: `num_labels` is needed only when `multi_label` is True.

Fix:

multi_label is set to False by default. Make sure it is explicitly set to True if num_labels is given.

import tensorflow as tf
m = tf.keras.metrics.AUC(num_thresholds=3,num_labels=2,multi_label=True)
m.update_state([[0,1], [1,0], [1,0], [0,1]], [[0,1], [0.5,0.5], [0.3,0.7], [1,0.9]])
print( m.result().numpy())


Output:

0.5

Jul 09, 2021 kellemnegasi answer
kellemnegasi 22.6k

Add a possible fix

Please authorize to post fix