Sort Trainer arguments based on importance (#17022)

This commit is contained in:
Adrian Wälchli 2023-03-11 00:53:38 +01:00 committed by GitHub
parent 6fcccea3fa
commit b9591d91ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 27 additions and 26 deletions

View File

@ -89,19 +89,15 @@ class Trainer:
@_defaults_from_env_vars
def __init__(
self,
logger: Optional[Union[Logger, Iterable[Logger], bool]] = None,
enable_checkpointing: Optional[bool] = None,
callbacks: Optional[Union[List[Callback], Callback]] = None,
default_root_dir: Optional[_PATH] = None,
gradient_clip_val: Optional[Union[int, float]] = None,
gradient_clip_algorithm: Optional[str] = None,
num_nodes: int = 1,
*,
accelerator: Union[str, Accelerator] = "auto",
strategy: Union[str, Strategy] = "auto",
devices: Union[List[int], str, int] = "auto",
enable_progress_bar: Optional[bool] = None,
overfit_batches: Union[int, float] = 0.0,
check_val_every_n_epoch: Optional[int] = 1,
num_nodes: int = 1,
precision: _PRECISION_INPUT = "32-true",
logger: Optional[Union[Logger, Iterable[Logger], bool]] = None,
callbacks: Optional[Union[List[Callback], Callback]] = None,
fast_dev_run: Union[int, bool] = False,
accumulate_grad_batches: int = 1,
max_epochs: Optional[int] = None,
min_epochs: Optional[int] = None,
max_steps: int = -1,
@ -111,23 +107,28 @@ class Trainer:
limit_val_batches: Optional[Union[int, float]] = None,
limit_test_batches: Optional[Union[int, float]] = None,
limit_predict_batches: Optional[Union[int, float]] = None,
overfit_batches: Union[int, float] = 0.0,
val_check_interval: Optional[Union[int, float]] = None,
log_every_n_steps: Optional[int] = None,
accelerator: Union[str, Accelerator] = "auto",
strategy: Union[str, Strategy] = "auto",
sync_batchnorm: bool = False,
precision: _PRECISION_INPUT = "32-true",
enable_model_summary: Optional[bool] = None,
check_val_every_n_epoch: Optional[int] = 1,
num_sanity_val_steps: Optional[int] = None,
profiler: Optional[Union[Profiler, str]] = None,
benchmark: Optional[bool] = None,
log_every_n_steps: Optional[int] = None,
enable_checkpointing: Optional[bool] = None,
enable_progress_bar: Optional[bool] = None,
enable_model_summary: Optional[bool] = None,
accumulate_grad_batches: int = 1,
gradient_clip_val: Optional[Union[int, float]] = None,
gradient_clip_algorithm: Optional[str] = None,
deterministic: Optional[Union[bool, _LITERAL_WARN]] = None,
reload_dataloaders_every_n_epochs: int = 0,
use_distributed_sampler: bool = True,
detect_anomaly: bool = False,
plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None,
benchmark: Optional[bool] = None,
inference_mode: bool = True,
use_distributed_sampler: bool = True,
profiler: Optional[Union[Profiler, str]] = None,
detect_anomaly: bool = False,
barebones: bool = False,
plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None,
sync_batchnorm: bool = False,
reload_dataloaders_every_n_epochs: int = 0,
default_root_dir: Optional[_PATH] = None,
) -> None:
r"""
Customize every aspect of training via flags.

View File

@ -102,7 +102,7 @@ def test_rich_progress_bar_import_error(monkeypatch):
@RunIf(rich=True)
def test_rich_progress_bar_custom_theme(tmpdir):
def test_rich_progress_bar_custom_theme():
"""Test to ensure that custom theme styles are used."""
with mock.patch.multiple(
"lightning.pytorch.callbacks.progress.rich_progress",
@ -114,7 +114,7 @@ def test_rich_progress_bar_custom_theme(tmpdir):
theme = RichProgressBarTheme()
progress_bar = RichProgressBar(theme=theme)
progress_bar.on_train_start(Trainer(tmpdir), BoringModel())
progress_bar.on_train_start(Trainer(), BoringModel())
assert progress_bar.theme == theme
args, kwargs = mocks["CustomBarColumn"].call_args

View File

@ -43,7 +43,7 @@ def test_passing_env_variables_only():
@mock.patch.dict(os.environ, {"PL_TRAINER_LOGGER": "True", "PL_TRAINER_MAX_STEPS": "7"})
def test_passing_env_variables_defaults():
"""Testing overwriting trainer arguments."""
trainer = Trainer(False, max_steps=42)
trainer = Trainer(logger=False, max_steps=42)
assert trainer.logger is None
assert trainer.max_steps == 42