Fix dangerous default argument (#8164)
Co-authored-by: deepsource-autofix[bot] <62050782+deepsource-autofix[bot]@users.noreply.github.com>
This commit is contained in:
parent
9bd3747c71
commit
67f7e1318f
|
@ -66,7 +66,7 @@ def _class_test(
|
|||
metric_class: Metric,
|
||||
sk_metric: Callable,
|
||||
dist_sync_on_step: bool,
|
||||
metric_args: dict = {},
|
||||
metric_args: dict = None,
|
||||
check_dist_sync_on_step: bool = True,
|
||||
check_batch: bool = True,
|
||||
atol: float = 1e-8,
|
||||
|
@ -89,6 +89,8 @@ def _class_test(
|
|||
check_batch: bool, if true will check if the metric is also correctly
|
||||
calculated across devices for each batch (and not just at the end)
|
||||
"""
|
||||
if metric_args is None:
|
||||
metric_args = {}
|
||||
# Instanciate lightning metric
|
||||
metric = metric_class(compute_on_step=True, dist_sync_on_step=dist_sync_on_step, **metric_args)
|
||||
|
||||
|
@ -130,7 +132,7 @@ def _functional_test(
|
|||
target: torch.Tensor,
|
||||
metric_functional: Callable,
|
||||
sk_metric: Callable,
|
||||
metric_args: dict = {},
|
||||
metric_args: dict = None,
|
||||
atol: float = 1e-8,
|
||||
):
|
||||
"""Utility function doing the actual comparison between lightning functional metric
|
||||
|
@ -143,6 +145,8 @@ def _functional_test(
|
|||
sk_metric: callable function that is used for comparison
|
||||
metric_args: dict with additional arguments used for class initialization
|
||||
"""
|
||||
if metric_args is None:
|
||||
metric_args = {}
|
||||
metric = partial(metric_functional, **metric_args)
|
||||
|
||||
for i in range(NUM_BATCHES):
|
||||
|
@ -185,7 +189,7 @@ class MetricTester:
|
|||
target: torch.Tensor,
|
||||
metric_functional: Callable,
|
||||
sk_metric: Callable,
|
||||
metric_args: dict = {},
|
||||
metric_args: dict = None,
|
||||
):
|
||||
"""Main method that should be used for testing functions. Call this inside
|
||||
testing method
|
||||
|
@ -197,6 +201,8 @@ class MetricTester:
|
|||
sk_metric: callable function that is used for comparison
|
||||
metric_args: dict with additional arguments used for class initialization
|
||||
"""
|
||||
if metric_args is None:
|
||||
metric_args = {}
|
||||
_functional_test(
|
||||
preds=preds,
|
||||
target=target,
|
||||
|
@ -214,7 +220,7 @@ class MetricTester:
|
|||
metric_class: Metric,
|
||||
sk_metric: Callable,
|
||||
dist_sync_on_step: bool,
|
||||
metric_args: dict = {},
|
||||
metric_args: dict = None,
|
||||
check_dist_sync_on_step: bool = True,
|
||||
check_batch: bool = True,
|
||||
):
|
||||
|
@ -235,6 +241,8 @@ class MetricTester:
|
|||
check_batch: bool, if true will check if the metric is also correctly
|
||||
calculated across devices for each batch (and not just at the end)
|
||||
"""
|
||||
if metric_args is None:
|
||||
metric_args = {}
|
||||
if ddp:
|
||||
if sys.platform == "win32":
|
||||
pytest.skip("DDP not supported on windows")
|
||||
|
|
Loading…
Reference in New Issue