Make Trainer readable and debuggable (3/n) (#14871)
* clean trainer 3/n * clean trainer 3/n * clean trainer 3/n * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * clean trainer 3/n Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
fbfcc3d871
commit
047d7088f4
|
@ -0,0 +1,204 @@
|
|||
# Copyright Lightning AI.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Houses the methods used to set up the Trainer."""
|
||||
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from lightning_lite.utilities.warnings import PossibleUserWarning
|
||||
from pytorch_lightning.accelerators import (
|
||||
CUDAAccelerator,
|
||||
HPUAccelerator,
|
||||
IPUAccelerator,
|
||||
MPSAccelerator,
|
||||
TPUAccelerator,
|
||||
)
|
||||
from pytorch_lightning.loggers.logger import DummyLogger
|
||||
from pytorch_lightning.profilers import (
|
||||
AdvancedProfiler,
|
||||
PassThroughProfiler,
|
||||
Profiler,
|
||||
PyTorchProfiler,
|
||||
SimpleProfiler,
|
||||
XLAProfiler,
|
||||
)
|
||||
from pytorch_lightning.utilities import _HPU_AVAILABLE, _IPU_AVAILABLE, _TPU_AVAILABLE
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
|
||||
|
||||
|
||||
def init_debugging_flags(
|
||||
trainer: Any,
|
||||
limit_train_batches: Optional[Union[int, float]],
|
||||
limit_val_batches: Optional[Union[int, float]],
|
||||
limit_test_batches: Optional[Union[int, float]],
|
||||
limit_predict_batches: Optional[Union[int, float]],
|
||||
fast_dev_run: Union[int, bool],
|
||||
overfit_batches: Union[int, float],
|
||||
val_check_interval: Optional[Union[int, float]],
|
||||
num_sanity_val_steps: int,
|
||||
) -> None:
|
||||
# init debugging flags
|
||||
if isinstance(fast_dev_run, int) and (fast_dev_run < 0):
|
||||
raise MisconfigurationException(
|
||||
f"fast_dev_run={fast_dev_run!r} is not a valid configuration. It should be >= 0."
|
||||
)
|
||||
trainer.fast_dev_run = fast_dev_run
|
||||
|
||||
# set fast_dev_run=True when it is 1, used while logging
|
||||
if fast_dev_run == 1:
|
||||
trainer.fast_dev_run = True
|
||||
|
||||
trainer.overfit_batches = _determine_batch_limits(overfit_batches, "overfit_batches")
|
||||
overfit_batches_enabled = overfit_batches > 0
|
||||
|
||||
if fast_dev_run:
|
||||
num_batches = int(fast_dev_run)
|
||||
if not overfit_batches_enabled:
|
||||
trainer.limit_train_batches = num_batches
|
||||
trainer.limit_val_batches = num_batches
|
||||
|
||||
trainer.limit_test_batches = num_batches
|
||||
trainer.limit_predict_batches = num_batches
|
||||
trainer.fit_loop.max_steps = num_batches
|
||||
trainer.num_sanity_val_steps = 0
|
||||
trainer.fit_loop.max_epochs = 1
|
||||
trainer.val_check_interval = 1.0
|
||||
trainer.check_val_every_n_epoch = 1
|
||||
trainer.loggers = [DummyLogger()] if trainer.loggers else []
|
||||
rank_zero_info(
|
||||
f"Running in `fast_dev_run` mode: will run the requested loop using {num_batches} batch(es). "
|
||||
"Logging and checkpointing is suppressed."
|
||||
)
|
||||
else:
|
||||
if not overfit_batches_enabled:
|
||||
trainer.limit_train_batches = _determine_batch_limits(limit_train_batches, "limit_train_batches")
|
||||
trainer.limit_val_batches = _determine_batch_limits(limit_val_batches, "limit_val_batches")
|
||||
trainer.limit_test_batches = _determine_batch_limits(limit_test_batches, "limit_test_batches")
|
||||
trainer.limit_predict_batches = _determine_batch_limits(limit_predict_batches, "limit_predict_batches")
|
||||
trainer.num_sanity_val_steps = float("inf") if num_sanity_val_steps == -1 else num_sanity_val_steps
|
||||
trainer.val_check_interval = _determine_batch_limits(val_check_interval, "val_check_interval")
|
||||
|
||||
if overfit_batches_enabled:
|
||||
trainer.limit_train_batches = overfit_batches
|
||||
trainer.limit_val_batches = overfit_batches
|
||||
|
||||
|
||||
def _determine_batch_limits(batches: Optional[Union[int, float]], name: str) -> Union[int, float]:
|
||||
if batches is None:
|
||||
# batches is optional to know if the user passed a value so that we can show the above info messages only to the
|
||||
# users that set a value explicitly
|
||||
return 1.0
|
||||
|
||||
# differentiating based on the type can be error-prone for users. show a message describing the chosen behaviour
|
||||
if isinstance(batches, int) and batches == 1:
|
||||
if name == "limit_train_batches":
|
||||
message = "1 batch per epoch will be used."
|
||||
elif name == "val_check_interval":
|
||||
message = "validation will run after every batch."
|
||||
else:
|
||||
message = "1 batch will be used."
|
||||
rank_zero_info(f"`Trainer({name}=1)` was configured so {message}")
|
||||
elif isinstance(batches, float) and batches == 1.0:
|
||||
if name == "limit_train_batches":
|
||||
message = "100% of the batches per epoch will be used."
|
||||
elif name == "val_check_interval":
|
||||
message = "validation will run at the end of the training epoch."
|
||||
else:
|
||||
message = "100% of the batches will be used."
|
||||
rank_zero_info(f"`Trainer({name}=1.0)` was configured so {message}.")
|
||||
|
||||
if 0 <= batches <= 1:
|
||||
return batches
|
||||
if batches > 1 and batches % 1.0 == 0:
|
||||
return int(batches)
|
||||
raise MisconfigurationException(
|
||||
f"You have passed invalid value {batches} for {name}, it has to be in [0.0, 1.0] or an int."
|
||||
)
|
||||
|
||||
|
||||
def init_profiler(trainer: Any, profiler: Optional[Union[Profiler, str]]) -> None:
|
||||
if isinstance(profiler, str):
|
||||
PROFILERS = {
|
||||
"simple": SimpleProfiler,
|
||||
"advanced": AdvancedProfiler,
|
||||
"pytorch": PyTorchProfiler,
|
||||
"xla": XLAProfiler,
|
||||
}
|
||||
profiler = profiler.lower()
|
||||
if profiler not in PROFILERS:
|
||||
raise MisconfigurationException(
|
||||
"When passing string value for the `profiler` parameter of `Trainer`,"
|
||||
f" it can only be one of {list(PROFILERS.keys())}"
|
||||
)
|
||||
profiler_class = PROFILERS[profiler]
|
||||
profiler = profiler_class()
|
||||
trainer.profiler = profiler or PassThroughProfiler()
|
||||
|
||||
|
||||
def log_device_info(trainer: Any) -> None:
|
||||
|
||||
if CUDAAccelerator.is_available():
|
||||
gpu_available = True
|
||||
gpu_type = " (cuda)"
|
||||
elif MPSAccelerator.is_available():
|
||||
gpu_available = True
|
||||
gpu_type = " (mps)"
|
||||
else:
|
||||
gpu_available = False
|
||||
gpu_type = ""
|
||||
|
||||
gpu_used = isinstance(trainer.accelerator, (CUDAAccelerator, MPSAccelerator))
|
||||
rank_zero_info(f"GPU available: {gpu_available}{gpu_type}, used: {gpu_used}")
|
||||
|
||||
num_tpu_cores = trainer.num_devices if isinstance(trainer.accelerator, TPUAccelerator) else 0
|
||||
rank_zero_info(f"TPU available: {_TPU_AVAILABLE}, using: {num_tpu_cores} TPU cores")
|
||||
|
||||
num_ipus = trainer.num_devices if isinstance(trainer.accelerator, IPUAccelerator) else 0
|
||||
rank_zero_info(f"IPU available: {_IPU_AVAILABLE}, using: {num_ipus} IPUs")
|
||||
|
||||
num_hpus = trainer.num_devices if isinstance(trainer.accelerator, HPUAccelerator) else 0
|
||||
rank_zero_info(f"HPU available: {_HPU_AVAILABLE}, using: {num_hpus} HPUs")
|
||||
|
||||
# TODO: Integrate MPS Accelerator here, once gpu maps to both
|
||||
if CUDAAccelerator.is_available() and not isinstance(trainer.accelerator, CUDAAccelerator):
|
||||
rank_zero_warn(
|
||||
"GPU available but not used. Set `accelerator` and `devices` using"
|
||||
f" `Trainer(accelerator='gpu', devices={CUDAAccelerator.auto_device_count()})`.",
|
||||
category=PossibleUserWarning,
|
||||
)
|
||||
|
||||
if _TPU_AVAILABLE and not isinstance(trainer.accelerator, TPUAccelerator):
|
||||
rank_zero_warn(
|
||||
"TPU available but not used. Set `accelerator` and `devices` using"
|
||||
f" `Trainer(accelerator='tpu', devices={TPUAccelerator.auto_device_count()})`."
|
||||
)
|
||||
|
||||
if _IPU_AVAILABLE and not isinstance(trainer.accelerator, IPUAccelerator):
|
||||
rank_zero_warn(
|
||||
"IPU available but not used. Set `accelerator` and `devices` using"
|
||||
f" `Trainer(accelerator='ipu', devices={IPUAccelerator.auto_device_count()})`."
|
||||
)
|
||||
|
||||
if _HPU_AVAILABLE and not isinstance(trainer.accelerator, HPUAccelerator):
|
||||
rank_zero_warn(
|
||||
"HPU available but not used. Set `accelerator` and `devices` using"
|
||||
f" `Trainer(accelerator='hpu', devices={HPUAccelerator.auto_device_count()})`."
|
||||
)
|
||||
|
||||
if MPSAccelerator.is_available() and not isinstance(trainer.accelerator, MPSAccelerator):
|
||||
rank_zero_warn(
|
||||
"MPS available but not used. Set `accelerator` and `devices` using"
|
||||
f" `Trainer(accelerator='mps', devices={MPSAccelerator.auto_device_count()})`."
|
||||
)
|
|
@ -11,6 +11,14 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# THIS FILE MUST READ EASILY, FOR UNDERSTANDING AND DEBUGGING PURPOSES.
|
||||
# DO NOT OBSCURE THE TRAINING LOOP
|
||||
# THIS IS A HARD REQUIREMENT TO CONTRIBUTING TO LIGHTNING
|
||||
# WE FAVOR READABILITY OVER ENGINEERING-CONSTRUCTS BY DESIGN
|
||||
# DO NOT REMOVE THIS NOTICE
|
||||
# - WILLIAM FALCON
|
||||
|
||||
"""Trainer to automate the training."""
|
||||
import inspect
|
||||
import logging
|
||||
|
@ -39,20 +47,12 @@ from lightning_lite.utilities.cloud_io import get_filesystem
|
|||
from lightning_lite.utilities.data import _auto_add_worker_init_fn
|
||||
from lightning_lite.utilities.types import _PATH
|
||||
from lightning_lite.utilities.warnings import PossibleUserWarning
|
||||
from pytorch_lightning.accelerators import (
|
||||
Accelerator,
|
||||
CUDAAccelerator,
|
||||
HPUAccelerator,
|
||||
IPUAccelerator,
|
||||
MPSAccelerator,
|
||||
TPUAccelerator,
|
||||
)
|
||||
from pytorch_lightning.accelerators import Accelerator, HPUAccelerator, TPUAccelerator
|
||||
from pytorch_lightning.callbacks import Callback, Checkpoint, EarlyStopping, ProgressBarBase
|
||||
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
|
||||
from pytorch_lightning.core.datamodule import LightningDataModule
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
from pytorch_lightning.loggers import Logger
|
||||
from pytorch_lightning.loggers.logger import DummyLogger
|
||||
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
|
||||
from pytorch_lightning.loops import PredictionLoop, TrainingEpochLoop
|
||||
from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop
|
||||
|
@ -64,16 +64,9 @@ from pytorch_lightning.plugins import (
|
|||
PLUGIN_INPUT,
|
||||
PrecisionPlugin,
|
||||
)
|
||||
from pytorch_lightning.profilers import (
|
||||
AdvancedProfiler,
|
||||
PassThroughProfiler,
|
||||
Profiler,
|
||||
PyTorchProfiler,
|
||||
SimpleProfiler,
|
||||
XLAProfiler,
|
||||
)
|
||||
from pytorch_lightning.profilers import Profiler
|
||||
from pytorch_lightning.strategies import ParallelStrategy, Strategy
|
||||
from pytorch_lightning.trainer import teardown
|
||||
from pytorch_lightning.trainer import setup, teardown
|
||||
from pytorch_lightning.trainer.configuration_validator import verify_loop_configurations
|
||||
from pytorch_lightning.trainer.connectors.accelerator_connector import _LITERAL_WARN, AcceleratorConnector
|
||||
from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
|
||||
|
@ -87,14 +80,7 @@ from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
|
|||
from pytorch_lightning.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus
|
||||
from pytorch_lightning.trainer.supporters import CombinedLoader
|
||||
from pytorch_lightning.tuner.tuning import _TunerResult, Tuner
|
||||
from pytorch_lightning.utilities import (
|
||||
_HPU_AVAILABLE,
|
||||
_IPU_AVAILABLE,
|
||||
_TPU_AVAILABLE,
|
||||
AMPType,
|
||||
GradClipAlgorithmType,
|
||||
parsing,
|
||||
)
|
||||
from pytorch_lightning.utilities import AMPType, GradClipAlgorithmType, parsing
|
||||
from pytorch_lightning.utilities.argparse import (
|
||||
_defaults_from_env_vars,
|
||||
add_argparse_args,
|
||||
|
@ -523,7 +509,7 @@ class Trainer(
|
|||
self.tuner.on_trainer_init(auto_lr_find, auto_scale_batch_size)
|
||||
|
||||
# configure profiler
|
||||
self.__init_profiler(profiler)
|
||||
setup.init_profiler(self, profiler)
|
||||
|
||||
# init logger flags
|
||||
self._loggers: List[Logger]
|
||||
|
@ -531,7 +517,8 @@ class Trainer(
|
|||
|
||||
# init debugging flags
|
||||
self.val_check_interval: Union[int, float]
|
||||
self._init_debugging_flags(
|
||||
setup.init_debugging_flags(
|
||||
self,
|
||||
limit_train_batches,
|
||||
limit_val_batches,
|
||||
limit_test_batches,
|
||||
|
@ -545,64 +532,8 @@ class Trainer(
|
|||
# Callback system
|
||||
self._call_callback_hooks("on_init_end")
|
||||
|
||||
def _init_debugging_flags(
|
||||
self,
|
||||
limit_train_batches: Optional[Union[int, float]],
|
||||
limit_val_batches: Optional[Union[int, float]],
|
||||
limit_test_batches: Optional[Union[int, float]],
|
||||
limit_predict_batches: Optional[Union[int, float]],
|
||||
fast_dev_run: Union[int, bool],
|
||||
overfit_batches: Union[int, float],
|
||||
val_check_interval: Optional[Union[int, float]],
|
||||
num_sanity_val_steps: int,
|
||||
):
|
||||
# init debugging flags
|
||||
if isinstance(fast_dev_run, int) and (fast_dev_run < 0):
|
||||
raise MisconfigurationException(
|
||||
f"fast_dev_run={fast_dev_run!r} is not a valid configuration. It should be >= 0."
|
||||
)
|
||||
self.fast_dev_run = fast_dev_run
|
||||
|
||||
# set fast_dev_run=True when it is 1, used while logging
|
||||
if fast_dev_run == 1:
|
||||
self.fast_dev_run = True
|
||||
|
||||
self.overfit_batches = _determine_batch_limits(overfit_batches, "overfit_batches")
|
||||
overfit_batches_enabled = overfit_batches > 0
|
||||
|
||||
if fast_dev_run:
|
||||
num_batches = int(fast_dev_run)
|
||||
if not overfit_batches_enabled:
|
||||
self.limit_train_batches = num_batches
|
||||
self.limit_val_batches = num_batches
|
||||
|
||||
self.limit_test_batches = num_batches
|
||||
self.limit_predict_batches = num_batches
|
||||
self.fit_loop.max_steps = num_batches
|
||||
self.num_sanity_val_steps = 0
|
||||
self.fit_loop.max_epochs = 1
|
||||
self.val_check_interval = 1.0
|
||||
self.check_val_every_n_epoch = 1
|
||||
self.loggers = [DummyLogger()] if self.loggers else []
|
||||
rank_zero_info(
|
||||
f"Running in `fast_dev_run` mode: will run the requested loop using {num_batches} batch(es). "
|
||||
"Logging and checkpointing is suppressed."
|
||||
)
|
||||
else:
|
||||
if not overfit_batches_enabled:
|
||||
self.limit_train_batches = _determine_batch_limits(limit_train_batches, "limit_train_batches")
|
||||
self.limit_val_batches = _determine_batch_limits(limit_val_batches, "limit_val_batches")
|
||||
self.limit_test_batches = _determine_batch_limits(limit_test_batches, "limit_test_batches")
|
||||
self.limit_predict_batches = _determine_batch_limits(limit_predict_batches, "limit_predict_batches")
|
||||
self.num_sanity_val_steps = float("inf") if num_sanity_val_steps == -1 else num_sanity_val_steps
|
||||
self.val_check_interval = _determine_batch_limits(val_check_interval, "val_check_interval")
|
||||
|
||||
if overfit_batches_enabled:
|
||||
self.limit_train_batches = overfit_batches
|
||||
self.limit_val_batches = overfit_batches
|
||||
|
||||
def _setup_on_init(self) -> None:
|
||||
self._log_device_info()
|
||||
setup.log_device_info(self)
|
||||
|
||||
self.should_stop = False
|
||||
self.state = TrainerState()
|
||||
|
@ -1591,85 +1522,11 @@ class Trainer(
|
|||
def _log_api_event(event: str) -> None:
|
||||
torch._C._log_api_usage_once("lightning.trainer." + event)
|
||||
|
||||
def __init_profiler(self, profiler: Optional[Union[Profiler, str]]) -> None:
|
||||
if isinstance(profiler, str):
|
||||
PROFILERS = {
|
||||
"simple": SimpleProfiler,
|
||||
"advanced": AdvancedProfiler,
|
||||
"pytorch": PyTorchProfiler,
|
||||
"xla": XLAProfiler,
|
||||
}
|
||||
profiler = profiler.lower()
|
||||
if profiler not in PROFILERS:
|
||||
raise MisconfigurationException(
|
||||
"When passing string value for the `profiler` parameter of `Trainer`,"
|
||||
f" it can only be one of {list(PROFILERS.keys())}"
|
||||
)
|
||||
profiler_class = PROFILERS[profiler]
|
||||
profiler = profiler_class()
|
||||
self.profiler: Profiler = profiler or PassThroughProfiler()
|
||||
|
||||
def __setup_profiler(self) -> None:
|
||||
local_rank = self.local_rank if self.world_size > 1 else None
|
||||
self.profiler._lightning_module = proxy(self.lightning_module)
|
||||
self.profiler.setup(stage=self.state.fn._setup_fn, local_rank=local_rank, log_dir=self.log_dir)
|
||||
|
||||
def _log_device_info(self) -> None:
|
||||
|
||||
if CUDAAccelerator.is_available():
|
||||
gpu_available = True
|
||||
gpu_type = " (cuda)"
|
||||
elif MPSAccelerator.is_available():
|
||||
gpu_available = True
|
||||
gpu_type = " (mps)"
|
||||
else:
|
||||
gpu_available = False
|
||||
gpu_type = ""
|
||||
|
||||
gpu_used = isinstance(self.accelerator, (CUDAAccelerator, MPSAccelerator))
|
||||
rank_zero_info(f"GPU available: {gpu_available}{gpu_type}, used: {gpu_used}")
|
||||
|
||||
num_tpu_cores = self.num_devices if isinstance(self.accelerator, TPUAccelerator) else 0
|
||||
rank_zero_info(f"TPU available: {_TPU_AVAILABLE}, using: {num_tpu_cores} TPU cores")
|
||||
|
||||
num_ipus = self.num_devices if isinstance(self.accelerator, IPUAccelerator) else 0
|
||||
rank_zero_info(f"IPU available: {_IPU_AVAILABLE}, using: {num_ipus} IPUs")
|
||||
|
||||
num_hpus = self.num_devices if isinstance(self.accelerator, HPUAccelerator) else 0
|
||||
rank_zero_info(f"HPU available: {_HPU_AVAILABLE}, using: {num_hpus} HPUs")
|
||||
|
||||
# TODO: Integrate MPS Accelerator here, once gpu maps to both
|
||||
if CUDAAccelerator.is_available() and not isinstance(self.accelerator, CUDAAccelerator):
|
||||
rank_zero_warn(
|
||||
"GPU available but not used. Set `accelerator` and `devices` using"
|
||||
f" `Trainer(accelerator='gpu', devices={CUDAAccelerator.auto_device_count()})`.",
|
||||
category=PossibleUserWarning,
|
||||
)
|
||||
|
||||
if _TPU_AVAILABLE and not isinstance(self.accelerator, TPUAccelerator):
|
||||
rank_zero_warn(
|
||||
"TPU available but not used. Set `accelerator` and `devices` using"
|
||||
f" `Trainer(accelerator='tpu', devices={TPUAccelerator.auto_device_count()})`."
|
||||
)
|
||||
|
||||
if _IPU_AVAILABLE and not isinstance(self.accelerator, IPUAccelerator):
|
||||
rank_zero_warn(
|
||||
"IPU available but not used. Set `accelerator` and `devices` using"
|
||||
f" `Trainer(accelerator='ipu', devices={IPUAccelerator.auto_device_count()})`."
|
||||
)
|
||||
|
||||
if _HPU_AVAILABLE and not isinstance(self.accelerator, HPUAccelerator):
|
||||
rank_zero_warn(
|
||||
"HPU available but not used. Set `accelerator` and `devices` using"
|
||||
f" `Trainer(accelerator='hpu', devices={HPUAccelerator.auto_device_count()})`."
|
||||
)
|
||||
|
||||
if MPSAccelerator.is_available() and not isinstance(self.accelerator, MPSAccelerator):
|
||||
rank_zero_warn(
|
||||
"MPS available but not used. Set `accelerator` and `devices` using"
|
||||
f" `Trainer(accelerator='mps', devices={MPSAccelerator.auto_device_count()})`."
|
||||
)
|
||||
|
||||
"""
|
||||
Data loading methods
|
||||
"""
|
||||
|
@ -2542,36 +2399,3 @@ def _evaluation_context(accelerator: Accelerator) -> Generator:
|
|||
)
|
||||
with context_manager_class():
|
||||
yield
|
||||
|
||||
|
||||
def _determine_batch_limits(batches: Optional[Union[int, float]], name: str) -> Union[int, float]:
|
||||
if batches is None:
|
||||
# batches is optional to know if the user passed a value so that we can show the above info messages only to the
|
||||
# users that set a value explicitly
|
||||
return 1.0
|
||||
|
||||
# differentiating based on the type can be error-prone for users. show a message describing the chosen behaviour
|
||||
if isinstance(batches, int) and batches == 1:
|
||||
if name == "limit_train_batches":
|
||||
message = "1 batch per epoch will be used."
|
||||
elif name == "val_check_interval":
|
||||
message = "validation will run after every batch."
|
||||
else:
|
||||
message = "1 batch will be used."
|
||||
rank_zero_info(f"`Trainer({name}=1)` was configured so {message}")
|
||||
elif isinstance(batches, float) and batches == 1.0:
|
||||
if name == "limit_train_batches":
|
||||
message = "100% of the batches per epoch will be used."
|
||||
elif name == "val_check_interval":
|
||||
message = "validation will run at the end of the training epoch."
|
||||
else:
|
||||
message = "100% of the batches will be used."
|
||||
rank_zero_info(f"`Trainer({name}=1.0)` was configured so {message}.")
|
||||
|
||||
if 0 <= batches <= 1:
|
||||
return batches
|
||||
if batches > 1 and batches % 1.0 == 0:
|
||||
return int(batches)
|
||||
raise MisconfigurationException(
|
||||
f"You have passed invalid value {batches} for {name}, it has to be in [0.0, 1.0] or an int."
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue