Remove unnecessary generator (#8154)
Co-authored-by: deepsource-autofix[bot] <62050782+deepsource-autofix[bot]@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
74eb6cc7e9
commit
c0782ffd1f
|
@ -153,8 +153,8 @@ class CallbackConnector:
|
|||
model_callbacks = model.configure_callbacks()
|
||||
if not model_callbacks:
|
||||
return
|
||||
model_callback_types = set(type(c) for c in model_callbacks)
|
||||
trainer_callback_types = set(type(c) for c in trainer.callbacks)
|
||||
model_callback_types = {type(c) for c in model_callbacks}
|
||||
trainer_callback_types = {type(c) for c in trainer.callbacks}
|
||||
override_types = model_callback_types.intersection(trainer_callback_types)
|
||||
if override_types:
|
||||
rank_zero_info(
|
||||
|
|
|
@ -46,7 +46,7 @@ def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
|
|||
|
||||
# we only want to pass in valid Trainer args, the rest may be user specific
|
||||
valid_kwargs = inspect.signature(cls.__init__).parameters
|
||||
trainer_kwargs = dict((name, params[name]) for name in valid_kwargs if name in params)
|
||||
trainer_kwargs = {name: params[name] for name in valid_kwargs if name in params}
|
||||
trainer_kwargs.update(**kwargs)
|
||||
|
||||
return cls(**trainer_kwargs)
|
||||
|
|
|
@ -823,7 +823,7 @@ def test_model_checkpoint_topk_all(tmpdir):
|
|||
assert checkpoint_callback.best_model_path == tmpdir / "epoch=2.ckpt"
|
||||
assert checkpoint_callback.best_model_score == epochs - 1
|
||||
assert len(os.listdir(tmpdir)) == len(checkpoint_callback.best_k_models) == epochs
|
||||
assert set(checkpoint_callback.best_k_models.keys()) == set(str(tmpdir / f"epoch={i}.ckpt") for i in range(epochs))
|
||||
assert set(checkpoint_callback.best_k_models.keys()) == {str(tmpdir / f"epoch={i}.ckpt") for i in range(epochs)}
|
||||
assert checkpoint_callback.kth_best_model_path == tmpdir / 'epoch=0.ckpt'
|
||||
|
||||
|
||||
|
@ -1275,7 +1275,8 @@ def test_ckpt_version_after_rerun_new_trainer(tmpdir):
|
|||
assert {Path(f).name for f in mc.best_k_models} == expected
|
||||
|
||||
# check created ckpts
|
||||
assert set(f.basename for f in tmpdir.listdir()) == {
|
||||
actual = {f.basename for f in tmpdir.listdir()}
|
||||
assert actual == {
|
||||
"epoch=0.ckpt",
|
||||
"epoch=1.ckpt",
|
||||
"epoch=0-v1.ckpt",
|
||||
|
|
Loading…
Reference in New Issue