`num_labels` is needed only when `multi_label` is True.
Package:
tensorflow
158813

Exception Class:
ValueError
Raise code
._built = False
if self.multi_label:
if num_labels:
shape = tensor_shape.TensorShape([None, num_labels])
self._build(shape)
else:
if num_labels:
raise ValueError(
'`num_labels` is needed only when `multi_label` is True.')
self._build(None)
@property
def thresholds(self):
"""The thresholds used for evaluating AUC."""
return list(self._thresholds)
def _
Links to the raise (1)
https://github.com/tensorflow/tensorflow/blob/289d93bc1260ba92e0a3360f1edafe4f2e10a248/tensorflow/python/keras/metrics.py#L2165See also in the other packages (1)
(✅️ Fixed)
keras/num-labels-is-needed-only-when-mul
Ways to fix
If multi_label
is False the num_labels
shouldn't be given.
Reproducing the error:
import tensorflow as tf
m = tf.keras.metrics.AUC(num_thresholds=3,num_labels=2)
m.update_state([[0,1], [1,0], [1,0], [0,1]], [[0,1], [0.5,0.5], [0.3,0.7], [1,0.9]])
print( m.result().numpy())
The error output:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-36-9666877749e9> in <module>()
----> 1 m = tf.keras.metrics.AUC(num_thresholds=3,num_labels=2)
2 m.update_state([0, 0, 1, 1], [0, 0.5, 0.3, 0.9])
3 print( m.result().numpy())
/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/metrics.py in __init__(self, num_thresholds, curve, summation_method, name, dtype, thresholds, multi_label, num_labels, label_weights, from_logits)
2134 if num_labels:
2135 raise ValueError(
-> 2136 '`num_labels` is needed only when `multi_label` is True.')
2137 self._build(None)
2138
ValueError: `num_labels` is needed only when `multi_label` is True.
Fix:
multi_label
is set to False by default. Make sure it is explicitly set to True if num_labels
is given.
import tensorflow as tf
m = tf.keras.metrics.AUC(num_thresholds=3,num_labels=2,multi_label=True)
m.update_state([[0,1], [1,0], [1,0], [0,1]], [[0,1], [0.5,0.5], [0.3,0.7], [1,0.9]])
print( m.result().numpy())
Output:
0.5
Add a possible fix
Please authorize to post fix