Sort Trainer arguments based on importance (#17022)
This commit is contained in:
parent
6fcccea3fa
commit
b9591d91ee
|
@ -89,19 +89,15 @@ class Trainer:
|
||||||
@_defaults_from_env_vars
|
@_defaults_from_env_vars
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
logger: Optional[Union[Logger, Iterable[Logger], bool]] = None,
|
*,
|
||||||
enable_checkpointing: Optional[bool] = None,
|
accelerator: Union[str, Accelerator] = "auto",
|
||||||
callbacks: Optional[Union[List[Callback], Callback]] = None,
|
strategy: Union[str, Strategy] = "auto",
|
||||||
default_root_dir: Optional[_PATH] = None,
|
|
||||||
gradient_clip_val: Optional[Union[int, float]] = None,
|
|
||||||
gradient_clip_algorithm: Optional[str] = None,
|
|
||||||
num_nodes: int = 1,
|
|
||||||
devices: Union[List[int], str, int] = "auto",
|
devices: Union[List[int], str, int] = "auto",
|
||||||
enable_progress_bar: Optional[bool] = None,
|
num_nodes: int = 1,
|
||||||
overfit_batches: Union[int, float] = 0.0,
|
precision: _PRECISION_INPUT = "32-true",
|
||||||
check_val_every_n_epoch: Optional[int] = 1,
|
logger: Optional[Union[Logger, Iterable[Logger], bool]] = None,
|
||||||
|
callbacks: Optional[Union[List[Callback], Callback]] = None,
|
||||||
fast_dev_run: Union[int, bool] = False,
|
fast_dev_run: Union[int, bool] = False,
|
||||||
accumulate_grad_batches: int = 1,
|
|
||||||
max_epochs: Optional[int] = None,
|
max_epochs: Optional[int] = None,
|
||||||
min_epochs: Optional[int] = None,
|
min_epochs: Optional[int] = None,
|
||||||
max_steps: int = -1,
|
max_steps: int = -1,
|
||||||
|
@ -111,23 +107,28 @@ class Trainer:
|
||||||
limit_val_batches: Optional[Union[int, float]] = None,
|
limit_val_batches: Optional[Union[int, float]] = None,
|
||||||
limit_test_batches: Optional[Union[int, float]] = None,
|
limit_test_batches: Optional[Union[int, float]] = None,
|
||||||
limit_predict_batches: Optional[Union[int, float]] = None,
|
limit_predict_batches: Optional[Union[int, float]] = None,
|
||||||
|
overfit_batches: Union[int, float] = 0.0,
|
||||||
val_check_interval: Optional[Union[int, float]] = None,
|
val_check_interval: Optional[Union[int, float]] = None,
|
||||||
log_every_n_steps: Optional[int] = None,
|
check_val_every_n_epoch: Optional[int] = 1,
|
||||||
accelerator: Union[str, Accelerator] = "auto",
|
|
||||||
strategy: Union[str, Strategy] = "auto",
|
|
||||||
sync_batchnorm: bool = False,
|
|
||||||
precision: _PRECISION_INPUT = "32-true",
|
|
||||||
enable_model_summary: Optional[bool] = None,
|
|
||||||
num_sanity_val_steps: Optional[int] = None,
|
num_sanity_val_steps: Optional[int] = None,
|
||||||
profiler: Optional[Union[Profiler, str]] = None,
|
log_every_n_steps: Optional[int] = None,
|
||||||
benchmark: Optional[bool] = None,
|
enable_checkpointing: Optional[bool] = None,
|
||||||
|
enable_progress_bar: Optional[bool] = None,
|
||||||
|
enable_model_summary: Optional[bool] = None,
|
||||||
|
accumulate_grad_batches: int = 1,
|
||||||
|
gradient_clip_val: Optional[Union[int, float]] = None,
|
||||||
|
gradient_clip_algorithm: Optional[str] = None,
|
||||||
deterministic: Optional[Union[bool, _LITERAL_WARN]] = None,
|
deterministic: Optional[Union[bool, _LITERAL_WARN]] = None,
|
||||||
reload_dataloaders_every_n_epochs: int = 0,
|
benchmark: Optional[bool] = None,
|
||||||
use_distributed_sampler: bool = True,
|
|
||||||
detect_anomaly: bool = False,
|
|
||||||
plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None,
|
|
||||||
inference_mode: bool = True,
|
inference_mode: bool = True,
|
||||||
|
use_distributed_sampler: bool = True,
|
||||||
|
profiler: Optional[Union[Profiler, str]] = None,
|
||||||
|
detect_anomaly: bool = False,
|
||||||
barebones: bool = False,
|
barebones: bool = False,
|
||||||
|
plugins: Optional[Union[PLUGIN_INPUT, List[PLUGIN_INPUT]]] = None,
|
||||||
|
sync_batchnorm: bool = False,
|
||||||
|
reload_dataloaders_every_n_epochs: int = 0,
|
||||||
|
default_root_dir: Optional[_PATH] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
r"""
|
r"""
|
||||||
Customize every aspect of training via flags.
|
Customize every aspect of training via flags.
|
||||||
|
|
|
@ -102,7 +102,7 @@ def test_rich_progress_bar_import_error(monkeypatch):
|
||||||
|
|
||||||
|
|
||||||
@RunIf(rich=True)
|
@RunIf(rich=True)
|
||||||
def test_rich_progress_bar_custom_theme(tmpdir):
|
def test_rich_progress_bar_custom_theme():
|
||||||
"""Test to ensure that custom theme styles are used."""
|
"""Test to ensure that custom theme styles are used."""
|
||||||
with mock.patch.multiple(
|
with mock.patch.multiple(
|
||||||
"lightning.pytorch.callbacks.progress.rich_progress",
|
"lightning.pytorch.callbacks.progress.rich_progress",
|
||||||
|
@ -114,7 +114,7 @@ def test_rich_progress_bar_custom_theme(tmpdir):
|
||||||
theme = RichProgressBarTheme()
|
theme = RichProgressBarTheme()
|
||||||
|
|
||||||
progress_bar = RichProgressBar(theme=theme)
|
progress_bar = RichProgressBar(theme=theme)
|
||||||
progress_bar.on_train_start(Trainer(tmpdir), BoringModel())
|
progress_bar.on_train_start(Trainer(), BoringModel())
|
||||||
|
|
||||||
assert progress_bar.theme == theme
|
assert progress_bar.theme == theme
|
||||||
args, kwargs = mocks["CustomBarColumn"].call_args
|
args, kwargs = mocks["CustomBarColumn"].call_args
|
||||||
|
|
|
@ -43,7 +43,7 @@ def test_passing_env_variables_only():
|
||||||
@mock.patch.dict(os.environ, {"PL_TRAINER_LOGGER": "True", "PL_TRAINER_MAX_STEPS": "7"})
|
@mock.patch.dict(os.environ, {"PL_TRAINER_LOGGER": "True", "PL_TRAINER_MAX_STEPS": "7"})
|
||||||
def test_passing_env_variables_defaults():
|
def test_passing_env_variables_defaults():
|
||||||
"""Testing overwriting trainer arguments."""
|
"""Testing overwriting trainer arguments."""
|
||||||
trainer = Trainer(False, max_steps=42)
|
trainer = Trainer(logger=False, max_steps=42)
|
||||||
assert trainer.logger is None
|
assert trainer.logger is None
|
||||||
assert trainer.max_steps == 42
|
assert trainer.max_steps == 42
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue