Fix dangerous default argument (#8164)

Co-authored-by: deepsource-autofix[bot] <62050782+deepsource-autofix[bot]@users.noreply.github.com>
This commit is contained in:
deepsource-autofix[bot] 2021-06-28 09:52:37 +00:00 committed by GitHub
parent 9bd3747c71
commit 67f7e1318f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 12 additions and 4 deletions

View File

@ -66,7 +66,7 @@ def _class_test(
metric_class: Metric,
sk_metric: Callable,
dist_sync_on_step: bool,
metric_args: dict = {},
metric_args: dict = None,
check_dist_sync_on_step: bool = True,
check_batch: bool = True,
atol: float = 1e-8,
@ -89,6 +89,8 @@ def _class_test(
check_batch: bool, if true will check if the metric is also correctly
calculated across devices for each batch (and not just at the end)
"""
if metric_args is None:
metric_args = {}
# Instanciate lightning metric
metric = metric_class(compute_on_step=True, dist_sync_on_step=dist_sync_on_step, **metric_args)
@ -130,7 +132,7 @@ def _functional_test(
target: torch.Tensor,
metric_functional: Callable,
sk_metric: Callable,
metric_args: dict = {},
metric_args: dict = None,
atol: float = 1e-8,
):
"""Utility function doing the actual comparison between lightning functional metric
@ -143,6 +145,8 @@ def _functional_test(
sk_metric: callable function that is used for comparison
metric_args: dict with additional arguments used for class initialization
"""
if metric_args is None:
metric_args = {}
metric = partial(metric_functional, **metric_args)
for i in range(NUM_BATCHES):
@ -185,7 +189,7 @@ class MetricTester:
target: torch.Tensor,
metric_functional: Callable,
sk_metric: Callable,
metric_args: dict = {},
metric_args: dict = None,
):
"""Main method that should be used for testing functions. Call this inside
testing method
@ -197,6 +201,8 @@ class MetricTester:
sk_metric: callable function that is used for comparison
metric_args: dict with additional arguments used for class initialization
"""
if metric_args is None:
metric_args = {}
_functional_test(
preds=preds,
target=target,
@ -214,7 +220,7 @@ class MetricTester:
metric_class: Metric,
sk_metric: Callable,
dist_sync_on_step: bool,
metric_args: dict = {},
metric_args: dict = None,
check_dist_sync_on_step: bool = True,
check_batch: bool = True,
):
@ -235,6 +241,8 @@ class MetricTester:
check_batch: bool, if true will check if the metric is also correctly
calculated across devices for each batch (and not just at the end)
"""
if metric_args is None:
metric_args = {}
if ddp:
if sys.platform == "win32":
pytest.skip("DDP not supported on windows")