Sort Trainer arguments based on importance (#17022)
This commit is contained in:
parent
6fcccea3fa
commit
b9591d91ee
|
@ -89,19 +89,15 @@ class Trainer:
|
|||
@_defaults_from_env_vars
|
||||
def __init__(
|
||||
self,
|
||||
logger: Optional[Union[Logger, Iterable[Logger], bool]] = None,
|
||||
enable_checkpointing: Optional[bool] = None,
|
||||
callbacks: Optional[Union[List[Callback], Callback]] = None,
|
||||
default_root_dir: Optional[_PATH] = None,
|
||||
gradient_clip_val: Optional[Union[int, float]] = None,
|
||||
gradient_clip_algorithm: Optional[str] = None,
|
||||
num_nodes: int = 1,
|
||||
*,
|
||||
accelerator: Union[str, Accelerator] = "auto",
|
||||
strategy: Union[str, Strategy] = "auto",
|
||||
devices: Union[List[int], str, int] = "auto",
|
||||
enable_progress_bar: Optional[bool] = None,
|
||||
overfit_batches: Union[int, float] = 0.0,
|
||||
check_val_every_n_epoch: Optional[int] = 1,
|
||||
num_nodes: int = 1,
|
||||
precision: _PRECISION_INPUT = "32-true",
|
||||
logger: Optional[Union[Logger, Iterable[Logger], bool]] = None,
|
||||
callbacks: Optional[Union[List[Callback], Callback]] = None,
|
||||
fast_dev_run: Union[int, bool] = False,
|
||||
accumulate_grad_batches: int = 1,
|
||||
max_epochs: Optional[int] = None,
|
||||
min_epochs: Optional[int] = None,
|
||||
max_steps: int = -1,
|
||||
|
@ -111,23 +107,28 @@ class Trainer:
|
|||
limit_val_batches: Optional[Union[int, float]] = None,
|
||||
limit_test_batches: Optional[Union[int, float]] = None,
|
||||
limit_predict_batches: Optional[Union[int, float]] = None,
|
||||
overfit_batches: Union[int, float] = 0.0,
|
||||
val_check_interval: Optional[Union[int, float]] = None,
|
||||
log_every_n_steps: Optional[int] = None,
|
||||
accelerator: Union[str, Accelerator] = "auto",
|
||||
strategy: Union[str, Strategy] = "auto",
|
||||
sync_batchnorm: bool = False,
|
||||
precision: _PRECISION_INPUT = "32-true",
|
||||
enable_model_summary: Optional[bool] = None,
|
||||
check_val_every_n_epoch: Optional[int] = 1,
|
||||
num_sanity_val_steps: Optional[int] = None,
|
||||
profiler: Optional[Union[Profiler, str]] = None,
|
||||
benchmark: Optional[bool] = None,
|
||||
log_every_n_steps: Optional[int] = None,
|
||||
enable_checkpointing: Optional[bool] = None,
|
||||
enable_progress_bar: Optional[bool] = None,
|
||||
enable_model_summary: Optional[bool] = None,
|
||||
accumulate_grad_batches: int = 1,
|
||||
gradient_clip_val: Optional[Union[int, float]] = None,
|
||||
gradient_clip_algorithm: Optional[str] = None,
|
||||
deterministic: Optional[Union[bool, _LITERAL_WARN]] = None,
|
||||
reload_dataloaders_every_n_epochs: int = 0,
|
||||
use_distributed_sampler: bool = True,
|
||||
detect_anomaly: bool = False,
|
||||
plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None,
|
||||
benchmark: Optional[bool] = None,
|
||||
inference_mode: bool = True,
|
||||
use_distributed_sampler: bool = True,
|
||||
profiler: Optional[Union[Profiler, str]] = None,
|
||||
detect_anomaly: bool = False,
|
||||
barebones: bool = False,
|
||||
plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None,
|
||||
sync_batchnorm: bool = False,
|
||||
reload_dataloaders_every_n_epochs: int = 0,
|
||||
default_root_dir: Optional[_PATH] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Customize every aspect of training via flags.
|
||||
|
|
|
@ -102,7 +102,7 @@ def test_rich_progress_bar_import_error(monkeypatch):
|
|||
|
||||
|
||||
@RunIf(rich=True)
|
||||
def test_rich_progress_bar_custom_theme(tmpdir):
|
||||
def test_rich_progress_bar_custom_theme():
|
||||
"""Test to ensure that custom theme styles are used."""
|
||||
with mock.patch.multiple(
|
||||
"lightning.pytorch.callbacks.progress.rich_progress",
|
||||
|
@ -114,7 +114,7 @@ def test_rich_progress_bar_custom_theme(tmpdir):
|
|||
theme = RichProgressBarTheme()
|
||||
|
||||
progress_bar = RichProgressBar(theme=theme)
|
||||
progress_bar.on_train_start(Trainer(tmpdir), BoringModel())
|
||||
progress_bar.on_train_start(Trainer(), BoringModel())
|
||||
|
||||
assert progress_bar.theme == theme
|
||||
args, kwargs = mocks["CustomBarColumn"].call_args
|
||||
|
|
|
@ -43,7 +43,7 @@ def test_passing_env_variables_only():
|
|||
@mock.patch.dict(os.environ, {"PL_TRAINER_LOGGER": "True", "PL_TRAINER_MAX_STEPS": "7"})
|
||||
def test_passing_env_variables_defaults():
|
||||
"""Testing overwriting trainer arguments."""
|
||||
trainer = Trainer(False, max_steps=42)
|
||||
trainer = Trainer(logger=False, max_steps=42)
|
||||
assert trainer.logger is None
|
||||
assert trainer.max_steps == 42
|
||||
|
||||
|
|
Loading…
Reference in New Issue