From 67f7e1318f1f6ed6325c41814bfa304ce73bd881 Mon Sep 17 00:00:00 2001 From: "deepsource-autofix[bot]" <62050782+deepsource-autofix[bot]@users.noreply.github.com> Date: Mon, 28 Jun 2021 09:52:37 +0000 Subject: [PATCH] Fix dangerous default argument (#8164) Co-authored-by: deepsource-autofix[bot] <62050782+deepsource-autofix[bot]@users.noreply.github.com> --- tests/metrics/utils.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index f1f17d0624..29c530953f 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -66,7 +66,7 @@ def _class_test( metric_class: Metric, sk_metric: Callable, dist_sync_on_step: bool, - metric_args: dict = {}, + metric_args: dict = None, check_dist_sync_on_step: bool = True, check_batch: bool = True, atol: float = 1e-8, @@ -89,6 +89,8 @@ def _class_test( check_batch: bool, if true will check if the metric is also correctly calculated across devices for each batch (and not just at the end) """ + if metric_args is None: + metric_args = {} # Instanciate lightning metric metric = metric_class(compute_on_step=True, dist_sync_on_step=dist_sync_on_step, **metric_args) @@ -130,7 +132,7 @@ def _functional_test( target: torch.Tensor, metric_functional: Callable, sk_metric: Callable, - metric_args: dict = {}, + metric_args: dict = None, atol: float = 1e-8, ): """Utility function doing the actual comparison between lightning functional metric @@ -143,6 +145,8 @@ def _functional_test( sk_metric: callable function that is used for comparison metric_args: dict with additional arguments used for class initialization """ + if metric_args is None: + metric_args = {} metric = partial(metric_functional, **metric_args) for i in range(NUM_BATCHES): @@ -185,7 +189,7 @@ class MetricTester: target: torch.Tensor, metric_functional: Callable, sk_metric: Callable, - metric_args: dict = {}, + metric_args: dict = None, ): """Main method that should be used for testing functions. Call this inside testing method @@ -197,6 +201,8 @@ class MetricTester: sk_metric: callable function that is used for comparison metric_args: dict with additional arguments used for class initialization """ + if metric_args is None: + metric_args = {} _functional_test( preds=preds, target=target, @@ -214,7 +220,7 @@ class MetricTester: metric_class: Metric, sk_metric: Callable, dist_sync_on_step: bool, - metric_args: dict = {}, + metric_args: dict = None, check_dist_sync_on_step: bool = True, check_batch: bool = True, ): @@ -235,6 +241,8 @@ class MetricTester: check_batch: bool, if true will check if the metric is also correctly calculated across devices for each batch (and not just at the end) """ + if metric_args is None: + metric_args = {} if ddp: if sys.platform == "win32": pytest.skip("DDP not supported on windows")