diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index f4f2ed79f6..0bf14ea4d3 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -89,19 +89,15 @@ class Trainer: @_defaults_from_env_vars def __init__( self, - logger: Optional[Union[Logger, Iterable[Logger], bool]] = None, - enable_checkpointing: Optional[bool] = None, - callbacks: Optional[Union[List[Callback], Callback]] = None, - default_root_dir: Optional[_PATH] = None, - gradient_clip_val: Optional[Union[int, float]] = None, - gradient_clip_algorithm: Optional[str] = None, - num_nodes: int = 1, + *, + accelerator: Union[str, Accelerator] = "auto", + strategy: Union[str, Strategy] = "auto", devices: Union[List[int], str, int] = "auto", - enable_progress_bar: Optional[bool] = None, - overfit_batches: Union[int, float] = 0.0, - check_val_every_n_epoch: Optional[int] = 1, + num_nodes: int = 1, + precision: _PRECISION_INPUT = "32-true", + logger: Optional[Union[Logger, Iterable[Logger], bool]] = None, + callbacks: Optional[Union[List[Callback], Callback]] = None, fast_dev_run: Union[int, bool] = False, - accumulate_grad_batches: int = 1, max_epochs: Optional[int] = None, min_epochs: Optional[int] = None, max_steps: int = -1, @@ -111,23 +107,28 @@ class Trainer: limit_val_batches: Optional[Union[int, float]] = None, limit_test_batches: Optional[Union[int, float]] = None, limit_predict_batches: Optional[Union[int, float]] = None, + overfit_batches: Union[int, float] = 0.0, val_check_interval: Optional[Union[int, float]] = None, - log_every_n_steps: Optional[int] = None, - accelerator: Union[str, Accelerator] = "auto", - strategy: Union[str, Strategy] = "auto", - sync_batchnorm: bool = False, - precision: _PRECISION_INPUT = "32-true", - enable_model_summary: Optional[bool] = None, + check_val_every_n_epoch: Optional[int] = 1, num_sanity_val_steps: Optional[int] = None, - profiler: Optional[Union[Profiler, str]] = None, - benchmark: Optional[bool] = None, + log_every_n_steps: Optional[int] = None, + enable_checkpointing: Optional[bool] = None, + enable_progress_bar: Optional[bool] = None, + enable_model_summary: Optional[bool] = None, + accumulate_grad_batches: int = 1, + gradient_clip_val: Optional[Union[int, float]] = None, + gradient_clip_algorithm: Optional[str] = None, deterministic: Optional[Union[bool, _LITERAL_WARN]] = None, - reload_dataloaders_every_n_epochs: int = 0, - use_distributed_sampler: bool = True, - detect_anomaly: bool = False, - plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None, + benchmark: Optional[bool] = None, inference_mode: bool = True, + use_distributed_sampler: bool = True, + profiler: Optional[Union[Profiler, str]] = None, + detect_anomaly: bool = False, barebones: bool = False, + plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None, + sync_batchnorm: bool = False, + reload_dataloaders_every_n_epochs: int = 0, + default_root_dir: Optional[_PATH] = None, ) -> None: r""" Customize every aspect of training via flags. diff --git a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py index 6800d622e8..ea87ba508a 100644 --- a/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_rich_progress_bar.py @@ -102,7 +102,7 @@ def test_rich_progress_bar_import_error(monkeypatch): @RunIf(rich=True) -def test_rich_progress_bar_custom_theme(tmpdir): +def test_rich_progress_bar_custom_theme(): """Test to ensure that custom theme styles are used.""" with mock.patch.multiple( "lightning.pytorch.callbacks.progress.rich_progress", @@ -114,7 +114,7 @@ def test_rich_progress_bar_custom_theme(tmpdir): theme = RichProgressBarTheme() progress_bar = RichProgressBar(theme=theme) - progress_bar.on_train_start(Trainer(tmpdir), BoringModel()) + progress_bar.on_train_start(Trainer(), BoringModel()) assert progress_bar.theme == theme args, kwargs = mocks["CustomBarColumn"].call_args diff --git a/tests/tests_pytorch/trainer/flags/test_env_vars.py b/tests/tests_pytorch/trainer/flags/test_env_vars.py index 7145fa6ae9..62c94d4cc2 100644 --- a/tests/tests_pytorch/trainer/flags/test_env_vars.py +++ b/tests/tests_pytorch/trainer/flags/test_env_vars.py @@ -43,7 +43,7 @@ def test_passing_env_variables_only(): @mock.patch.dict(os.environ, {"PL_TRAINER_LOGGER": "True", "PL_TRAINER_MAX_STEPS": "7"}) def test_passing_env_variables_defaults(): """Testing overwriting trainer arguments.""" - trainer = Trainer(False, max_steps=42) + trainer = Trainer(logger=False, max_steps=42) assert trainer.logger is None assert trainer.max_steps == 42