ref: organize args 3/n (#3447)
* ref: organize args 2/n * ref: organize args 2/n * ref: organize args 2/n * ref: organize args 2/n
This commit is contained in:
parent
deb82d9c08
commit
541c4ab01d
|
@ -82,6 +82,8 @@ class AcceleratorConnector:
|
|||
# NVIDIA setup
|
||||
self.trainer.set_nvidia_flags(self.trainer.is_slurm_managing_tasks, self.trainer.data_parallel_device_ids)
|
||||
|
||||
self.trainer.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE')
|
||||
|
||||
def select_accelerator(self):
|
||||
# SLURM ddp
|
||||
use_slurm_ddp = self.trainer.use_ddp and self.trainer.is_slurm_managing_tasks
|
||||
|
|
|
@ -345,10 +345,3 @@ class TrainerDataLoadingMixin(ABC):
|
|||
hvd.join()
|
||||
|
||||
return dataloader
|
||||
|
||||
def determine_data_use_amount(self, overfit_batches: float) -> None:
|
||||
"""Use less data for debugging purposes"""
|
||||
if overfit_batches > 0:
|
||||
self.limit_train_batches = overfit_batches
|
||||
self.limit_val_batches = overfit_batches
|
||||
self.limit_test_batches = overfit_batches
|
||||
|
|
|
@ -0,0 +1,109 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from typing import Union
|
||||
from pytorch_lightning.utilities import rank_zero_warn, rank_zero_info
|
||||
|
||||
|
||||
class DebuggingConnector:
|
||||
|
||||
def __init__(self, trainer):
|
||||
self.trainer = trainer
|
||||
|
||||
def on_init_start(
|
||||
self,
|
||||
overfit_pct,
|
||||
val_percent_check,
|
||||
test_percent_check,
|
||||
train_percent_check,
|
||||
limit_train_batches,
|
||||
limit_val_batches,
|
||||
limit_test_batches,
|
||||
val_check_interval,
|
||||
overfit_batches,
|
||||
fast_dev_run
|
||||
):
|
||||
self.trainer.fast_dev_run = fast_dev_run
|
||||
if self.trainer.fast_dev_run:
|
||||
limit_train_batches = 1
|
||||
limit_val_batches = 1
|
||||
limit_test_batches = 1
|
||||
self.trainer.num_sanity_val_steps = 0
|
||||
self.trainer.max_epochs = 1
|
||||
rank_zero_info(
|
||||
'Running in fast_dev_run mode: will run a full train,' ' val and test loop using a single batch'
|
||||
)
|
||||
|
||||
# how much of the data to use
|
||||
# TODO: remove in 0.10.0
|
||||
if overfit_pct is not None:
|
||||
rank_zero_warn(
|
||||
"Argument `overfit_pct` is now set by `overfit_batches` since v0.8.0"
|
||||
" and this argument will be removed in v0.10.0",
|
||||
DeprecationWarning,
|
||||
)
|
||||
overfit_batches = overfit_pct
|
||||
|
||||
# TODO: remove in 0.10.0
|
||||
if val_percent_check is not None:
|
||||
rank_zero_warn(
|
||||
"Argument `val_percent_check` is now set by `limit_val_batches` since v0.8.0"
|
||||
" and this argument will be removed in v0.10.0",
|
||||
DeprecationWarning,
|
||||
)
|
||||
limit_val_batches = val_percent_check
|
||||
|
||||
# TODO: remove in 0.10.0
|
||||
if test_percent_check is not None:
|
||||
rank_zero_warn(
|
||||
"Argument `test_percent_check` is now set by `limit_test_batches` since v0.8.0"
|
||||
" and this argument will be removed in v0.10.0",
|
||||
DeprecationWarning,
|
||||
)
|
||||
limit_test_batches = test_percent_check
|
||||
|
||||
# TODO: remove in 0.10.0
|
||||
if train_percent_check is not None:
|
||||
rank_zero_warn(
|
||||
"Argument `train_percent_check` is now set by `limit_train_batches` since v0.8.0"
|
||||
" and this argument will be removed in v0.10.0",
|
||||
DeprecationWarning,
|
||||
)
|
||||
limit_train_batches = train_percent_check
|
||||
|
||||
self.trainer.limit_train_batches = _determine_batch_limits(limit_train_batches, 'limit_train_batches')
|
||||
self.trainer.limit_val_batches = _determine_batch_limits(limit_val_batches, 'limit_val_batches')
|
||||
self.trainer.limit_test_batches = _determine_batch_limits(limit_test_batches, 'limit_test_batches')
|
||||
self.trainer.val_check_interval = _determine_batch_limits(val_check_interval, 'val_check_interval')
|
||||
self.trainer.overfit_batches = _determine_batch_limits(overfit_batches, 'overfit_batches')
|
||||
self.determine_data_use_amount(self.trainer.overfit_batches)
|
||||
|
||||
def determine_data_use_amount(self, overfit_batches: float) -> None:
|
||||
"""Use less data for debugging purposes"""
|
||||
if overfit_batches > 0:
|
||||
self.trainer.limit_train_batches = overfit_batches
|
||||
self.trainer.limit_val_batches = overfit_batches
|
||||
self.trainer.limit_test_batches = overfit_batches
|
||||
|
||||
|
||||
def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]:
|
||||
if 0 <= batches <= 1:
|
||||
return batches
|
||||
elif batches > 1 and batches % 1.0 == 0:
|
||||
return int(batches)
|
||||
else:
|
||||
raise MisconfigurationException(
|
||||
f'You have passed invalid value {batches} for {name}, it has to be in [0.0, 1.0] or an int.'
|
||||
)
|
|
@ -11,16 +11,26 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVALAIBLE, rank_zero_warn, AMPType
|
||||
|
||||
|
||||
class Initializer:
|
||||
class PrecisionConnector:
|
||||
|
||||
def __init__(self, trainer):
|
||||
self.trainer = trainer
|
||||
|
||||
def on_trainer_init(self, precision, amp_level, amp_backend):
|
||||
# AMP init
|
||||
# These are the only lines needed after v0.8.0
|
||||
# we wrap the user's forward with autocast and give it back at the end of fit
|
||||
self.trainer.autocast_original_forward = None
|
||||
self.trainer.precision = precision
|
||||
self.trainer.scaler = None
|
||||
|
||||
self.trainer.amp_level = amp_level
|
||||
self.init_amp(amp_backend)
|
||||
|
||||
def init_amp(self, amp_type: str):
|
||||
assert self.trainer.precision in (16, 32), 'only 32 or 16 bit precision supported'
|
||||
self.trainer.amp_backend = None
|
|
@ -43,15 +43,16 @@ from pytorch_lightning.utilities.debugging import InternalDebugger
|
|||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop
|
||||
from pytorch_lightning.trainer.training_loop import TrainLoop
|
||||
from pytorch_lightning.trainer.data_connector import DataConnector
|
||||
from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector
|
||||
from pytorch_lightning.trainer.logger_connector import LoggerConnector
|
||||
from pytorch_lightning.trainer.lr_scheduler_connector import LRSchedulerConnector
|
||||
from pytorch_lightning.trainer.callback_connector import CallbackConnector
|
||||
from pytorch_lightning.trainer.model_connector import ModelConnector
|
||||
from pytorch_lightning.trainer.debugging_connector import DebuggingConnector
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.tuner.tuning import Tuner
|
||||
from pytorch_lightning.trainer.initializer import Initializer
|
||||
from pytorch_lightning.trainer.precision_connector import PrecisionConnector
|
||||
from pytorch_lightning.trainer.data_connector import DataConnector
|
||||
from pytorch_lightning.utilities.model_utils import is_overridden
|
||||
from pytorch_lightning.trainer import docstrings
|
||||
from pytorch_lightning.trainer.properties import TrainerProperties
|
||||
|
@ -173,8 +174,10 @@ class Trainer(
|
|||
self.accelerator_connector = AcceleratorConnector(self)
|
||||
self.logger_connector = LoggerConnector(self)
|
||||
self.model_connector = ModelConnector(self)
|
||||
self.initializer = Initializer(self)
|
||||
self.precision_connector = PrecisionConnector(self)
|
||||
self.callback_connector = CallbackConnector(self)
|
||||
self.debugging_connector = DebuggingConnector(self)
|
||||
|
||||
self.tuner = Tuner(self)
|
||||
self.accelerator_backend = None
|
||||
|
||||
|
@ -253,15 +256,8 @@ class Trainer(
|
|||
# -------------------
|
||||
self.weights_summary = weights_summary
|
||||
|
||||
self.max_epochs = max_epochs
|
||||
self.min_epochs = min_epochs
|
||||
self.max_steps = max_steps
|
||||
self.min_steps = min_steps
|
||||
|
||||
if num_sanity_val_steps == -1:
|
||||
self.num_sanity_val_steps = float('inf')
|
||||
else:
|
||||
self.num_sanity_val_steps = num_sanity_val_steps
|
||||
# init train loop related flags
|
||||
self.train_loop.on_init_start(max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps)
|
||||
|
||||
self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch
|
||||
|
||||
|
@ -275,17 +271,6 @@ class Trainer(
|
|||
self.terminate_on_nan = terminate_on_nan
|
||||
self.shown_warnings = set()
|
||||
|
||||
self.fast_dev_run = fast_dev_run
|
||||
if self.fast_dev_run:
|
||||
limit_train_batches = 1
|
||||
limit_val_batches = 1
|
||||
limit_test_batches = 1
|
||||
self.num_sanity_val_steps = 0
|
||||
self.max_epochs = 1
|
||||
rank_zero_info(
|
||||
'Running in fast_dev_run mode: will run a full train,' ' val and test loop using a single batch'
|
||||
)
|
||||
|
||||
# configure profiler
|
||||
if profiler is True:
|
||||
profiler = SimpleProfiler()
|
||||
|
@ -300,61 +285,22 @@ class Trainer(
|
|||
self.log_save_interval = log_save_interval
|
||||
self.row_log_interval = row_log_interval
|
||||
|
||||
# how much of the data to use
|
||||
# TODO: remove in 0.10.0
|
||||
if overfit_pct is not None:
|
||||
rank_zero_warn(
|
||||
"Argument `overfit_pct` is now set by `overfit_batches` since v0.8.0"
|
||||
" and this argument will be removed in v0.10.0",
|
||||
DeprecationWarning,
|
||||
)
|
||||
overfit_batches = overfit_pct
|
||||
# init debugging flags
|
||||
self.debugging_connector.on_init_start(
|
||||
overfit_pct,
|
||||
val_percent_check,
|
||||
test_percent_check,
|
||||
train_percent_check,
|
||||
limit_train_batches,
|
||||
limit_val_batches,
|
||||
limit_test_batches,
|
||||
val_check_interval,
|
||||
overfit_batches,
|
||||
fast_dev_run
|
||||
)
|
||||
|
||||
# TODO: remove in 0.10.0
|
||||
if val_percent_check is not None:
|
||||
rank_zero_warn(
|
||||
"Argument `val_percent_check` is now set by `limit_val_batches` since v0.8.0"
|
||||
" and this argument will be removed in v0.10.0",
|
||||
DeprecationWarning,
|
||||
)
|
||||
limit_val_batches = val_percent_check
|
||||
|
||||
# TODO: remove in 0.10.0
|
||||
if test_percent_check is not None:
|
||||
rank_zero_warn(
|
||||
"Argument `test_percent_check` is now set by `limit_test_batches` since v0.8.0"
|
||||
" and this argument will be removed in v0.10.0",
|
||||
DeprecationWarning,
|
||||
)
|
||||
limit_test_batches = test_percent_check
|
||||
|
||||
# TODO: remove in 0.10.0
|
||||
if train_percent_check is not None:
|
||||
rank_zero_warn(
|
||||
"Argument `train_percent_check` is now set by `limit_train_batches` since v0.8.0"
|
||||
" and this argument will be removed in v0.10.0",
|
||||
DeprecationWarning,
|
||||
)
|
||||
limit_train_batches = train_percent_check
|
||||
|
||||
self.limit_train_batches = _determine_batch_limits(limit_train_batches, 'limit_train_batches')
|
||||
self.limit_val_batches = _determine_batch_limits(limit_val_batches, 'limit_val_batches')
|
||||
self.limit_test_batches = _determine_batch_limits(limit_test_batches, 'limit_test_batches')
|
||||
self.val_check_interval = _determine_batch_limits(val_check_interval, 'val_check_interval')
|
||||
self.overfit_batches = _determine_batch_limits(overfit_batches, 'overfit_batches')
|
||||
self.determine_data_use_amount(self.overfit_batches)
|
||||
|
||||
# AMP init
|
||||
# These are the only lines needed after v0.8.0
|
||||
# we wrap the user's forward with autocast and give it back at the end of fit
|
||||
self.autocast_original_forward = None
|
||||
self.precision = precision
|
||||
self.scaler = None
|
||||
|
||||
self.amp_level = amp_level
|
||||
self.initializer.init_amp(amp_backend)
|
||||
|
||||
self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE')
|
||||
# set precision
|
||||
self.precision_connector.on_trainer_init(precision, amp_level, amp_backend)
|
||||
|
||||
# Callback system
|
||||
self.on_init_end()
|
||||
|
@ -862,18 +808,6 @@ class Trainer(
|
|||
|
||||
return output
|
||||
|
||||
|
||||
def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]:
|
||||
if 0 <= batches <= 1:
|
||||
return batches
|
||||
elif batches > 1 and batches % 1.0 == 0:
|
||||
return int(batches)
|
||||
else:
|
||||
raise MisconfigurationException(
|
||||
f'You have passed invalid value {batches} for {name}, it has to be in [0.0, 1.0] or an int.'
|
||||
)
|
||||
|
||||
|
||||
# add docstrings
|
||||
Trainer.__init__.__doc__ = docstrings.trainer.init
|
||||
Trainer.fit.__doc__ = docstrings.trainer.fit
|
||||
|
|
|
@ -38,6 +38,17 @@ class TrainLoop:
|
|||
self._teardown_already_run = False
|
||||
self.running_loss = TensorRunningAccum(window_length=20)
|
||||
|
||||
def on_init_start(self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps):
|
||||
self.trainer.max_epochs = max_epochs
|
||||
self.trainer.min_epochs = min_epochs
|
||||
self.trainer.max_steps = max_steps
|
||||
self.trainer.min_steps = min_steps
|
||||
|
||||
if num_sanity_val_steps == -1:
|
||||
self.trainer.num_sanity_val_steps = float('inf')
|
||||
else:
|
||||
self.trainer.num_sanity_val_steps = num_sanity_val_steps
|
||||
|
||||
@property
|
||||
def num_optimizers(self):
|
||||
num_optimizers = len(self.get_optimizers_iterable())
|
||||
|
|
Loading…
Reference in New Issue