Remove unnecessary generator (#8154)

Co-authored-by: deepsource-autofix[bot] <62050782+deepsource-autofix[bot]@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
deepsource-autofix[bot] 2021-06-30 11:40:13 +00:00 committed by GitHub
parent 74eb6cc7e9
commit c0782ffd1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 6 additions and 5 deletions

View File

@ -153,8 +153,8 @@ class CallbackConnector:
model_callbacks = model.configure_callbacks()
if not model_callbacks:
return
model_callback_types = set(type(c) for c in model_callbacks)
trainer_callback_types = set(type(c) for c in trainer.callbacks)
model_callback_types = {type(c) for c in model_callbacks}
trainer_callback_types = {type(c) for c in trainer.callbacks}
override_types = model_callback_types.intersection(trainer_callback_types)
if override_types:
rank_zero_info(

View File

@ -46,7 +46,7 @@ def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
# we only want to pass in valid Trainer args, the rest may be user specific
valid_kwargs = inspect.signature(cls.__init__).parameters
trainer_kwargs = dict((name, params[name]) for name in valid_kwargs if name in params)
trainer_kwargs = {name: params[name] for name in valid_kwargs if name in params}
trainer_kwargs.update(**kwargs)
return cls(**trainer_kwargs)

View File

@ -823,7 +823,7 @@ def test_model_checkpoint_topk_all(tmpdir):
assert checkpoint_callback.best_model_path == tmpdir / "epoch=2.ckpt"
assert checkpoint_callback.best_model_score == epochs - 1
assert len(os.listdir(tmpdir)) == len(checkpoint_callback.best_k_models) == epochs
assert set(checkpoint_callback.best_k_models.keys()) == set(str(tmpdir / f"epoch={i}.ckpt") for i in range(epochs))
assert set(checkpoint_callback.best_k_models.keys()) == {str(tmpdir / f"epoch={i}.ckpt") for i in range(epochs)}
assert checkpoint_callback.kth_best_model_path == tmpdir / 'epoch=0.ckpt'
@ -1275,7 +1275,8 @@ def test_ckpt_version_after_rerun_new_trainer(tmpdir):
assert {Path(f).name for f in mc.best_k_models} == expected
# check created ckpts
assert set(f.basename for f in tmpdir.listdir()) == {
actual = {f.basename for f in tmpdir.listdir()}
assert actual == {
"epoch=0.ckpt",
"epoch=1.ckpt",
"epoch=0-v1.ckpt",