From c0782ffd1febbbe9267cba090e84950260ab61a0 Mon Sep 17 00:00:00 2001 From: "deepsource-autofix[bot]" <62050782+deepsource-autofix[bot]@users.noreply.github.com> Date: Wed, 30 Jun 2021 11:40:13 +0000 Subject: [PATCH] Remove unnecessary generator (#8154) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: deepsource-autofix[bot] <62050782+deepsource-autofix[bot]@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Adrian Wälchli --- pytorch_lightning/trainer/connectors/callback_connector.py | 4 ++-- pytorch_lightning/utilities/argparse.py | 2 +- tests/checkpointing/test_model_checkpoint.py | 5 +++-- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 75cd74b307..2b14a229ce 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -153,8 +153,8 @@ class CallbackConnector: model_callbacks = model.configure_callbacks() if not model_callbacks: return - model_callback_types = set(type(c) for c in model_callbacks) - trainer_callback_types = set(type(c) for c in trainer.callbacks) + model_callback_types = {type(c) for c in model_callbacks} + trainer_callback_types = {type(c) for c in trainer.callbacks} override_types = model_callback_types.intersection(trainer_callback_types) if override_types: rank_zero_info( diff --git a/pytorch_lightning/utilities/argparse.py b/pytorch_lightning/utilities/argparse.py index b6bc30b8ea..aebbcb41ac 100644 --- a/pytorch_lightning/utilities/argparse.py +++ b/pytorch_lightning/utilities/argparse.py @@ -46,7 +46,7 @@ def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs): # we only want to pass in valid Trainer args, the rest may be user specific valid_kwargs = inspect.signature(cls.__init__).parameters - trainer_kwargs = dict((name, params[name]) for name in valid_kwargs if name in params) + trainer_kwargs = {name: params[name] for name in valid_kwargs if name in params} trainer_kwargs.update(**kwargs) return cls(**trainer_kwargs) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 8e2e52b55e..23e63d3883 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -823,7 +823,7 @@ def test_model_checkpoint_topk_all(tmpdir): assert checkpoint_callback.best_model_path == tmpdir / "epoch=2.ckpt" assert checkpoint_callback.best_model_score == epochs - 1 assert len(os.listdir(tmpdir)) == len(checkpoint_callback.best_k_models) == epochs - assert set(checkpoint_callback.best_k_models.keys()) == set(str(tmpdir / f"epoch={i}.ckpt") for i in range(epochs)) + assert set(checkpoint_callback.best_k_models.keys()) == {str(tmpdir / f"epoch={i}.ckpt") for i in range(epochs)} assert checkpoint_callback.kth_best_model_path == tmpdir / 'epoch=0.ckpt' @@ -1275,7 +1275,8 @@ def test_ckpt_version_after_rerun_new_trainer(tmpdir): assert {Path(f).name for f in mc.best_k_models} == expected # check created ckpts - assert set(f.basename for f in tmpdir.listdir()) == { + actual = {f.basename for f in tmpdir.listdir()} + assert actual == { "epoch=0.ckpt", "epoch=1.ckpt", "epoch=0-v1.ckpt",