Fix comments for metrics_to_scalars (#8782)
metrics_to_scalars can return non-float values, such as int or complex, depending on the dtype of the tensor.
This commit is contained in:
parent
e541803636
commit
f3442db3f0
|
@ -27,13 +27,13 @@ def metrics_to_scalars(metrics: Any) -> Any:
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
MisconfigurationException:
|
MisconfigurationException:
|
||||||
If ``value`` contains multiple elements, hence preventing conversion to ``float``
|
If tensors inside ``metrics`` contains multiple elements, hence preventing conversion to a scalar.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def to_item(value: torch.Tensor) -> numbers.Number:
|
def to_item(value: torch.Tensor) -> numbers.Number:
|
||||||
if value.numel() != 1:
|
if value.numel() != 1:
|
||||||
raise MisconfigurationException(
|
raise MisconfigurationException(
|
||||||
f"The metric `{value}` does not contain a single element" f" thus it cannot be converted to float."
|
f"The metric `{value}` does not contain a single element, thus it cannot be converted to a scalar."
|
||||||
)
|
)
|
||||||
return value.item()
|
return value.item()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue