ref: organize args 3/n (#3447)

* ref: organize args 2/n

* ref: organize args 2/n

* ref: organize args 2/n

* ref: organize args 2/n
This commit is contained in:
William Falcon 2020-09-10 08:55:30 -04:00 committed by GitHub
parent deb82d9c08
commit 541c4ab01d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 157 additions and 98 deletions

View File

@ -82,6 +82,8 @@ class AcceleratorConnector:
# NVIDIA setup
self.trainer.set_nvidia_flags(self.trainer.is_slurm_managing_tasks, self.trainer.data_parallel_device_ids)
self.trainer.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE')
def select_accelerator(self):
# SLURM ddp
use_slurm_ddp = self.trainer.use_ddp and self.trainer.is_slurm_managing_tasks

View File

@ -345,10 +345,3 @@ class TrainerDataLoadingMixin(ABC):
hvd.join()
return dataloader
def determine_data_use_amount(self, overfit_batches: float) -> None:
"""Use less data for debugging purposes"""
if overfit_batches > 0:
self.limit_train_batches = overfit_batches
self.limit_val_batches = overfit_batches
self.limit_test_batches = overfit_batches

View File

@ -0,0 +1,109 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from typing import Union
from pytorch_lightning.utilities import rank_zero_warn, rank_zero_info
class DebuggingConnector:
def __init__(self, trainer):
self.trainer = trainer
def on_init_start(
self,
overfit_pct,
val_percent_check,
test_percent_check,
train_percent_check,
limit_train_batches,
limit_val_batches,
limit_test_batches,
val_check_interval,
overfit_batches,
fast_dev_run
):
self.trainer.fast_dev_run = fast_dev_run
if self.trainer.fast_dev_run:
limit_train_batches = 1
limit_val_batches = 1
limit_test_batches = 1
self.trainer.num_sanity_val_steps = 0
self.trainer.max_epochs = 1
rank_zero_info(
'Running in fast_dev_run mode: will run a full train,' ' val and test loop using a single batch'
)
# how much of the data to use
# TODO: remove in 0.10.0
if overfit_pct is not None:
rank_zero_warn(
"Argument `overfit_pct` is now set by `overfit_batches` since v0.8.0"
" and this argument will be removed in v0.10.0",
DeprecationWarning,
)
overfit_batches = overfit_pct
# TODO: remove in 0.10.0
if val_percent_check is not None:
rank_zero_warn(
"Argument `val_percent_check` is now set by `limit_val_batches` since v0.8.0"
" and this argument will be removed in v0.10.0",
DeprecationWarning,
)
limit_val_batches = val_percent_check
# TODO: remove in 0.10.0
if test_percent_check is not None:
rank_zero_warn(
"Argument `test_percent_check` is now set by `limit_test_batches` since v0.8.0"
" and this argument will be removed in v0.10.0",
DeprecationWarning,
)
limit_test_batches = test_percent_check
# TODO: remove in 0.10.0
if train_percent_check is not None:
rank_zero_warn(
"Argument `train_percent_check` is now set by `limit_train_batches` since v0.8.0"
" and this argument will be removed in v0.10.0",
DeprecationWarning,
)
limit_train_batches = train_percent_check
self.trainer.limit_train_batches = _determine_batch_limits(limit_train_batches, 'limit_train_batches')
self.trainer.limit_val_batches = _determine_batch_limits(limit_val_batches, 'limit_val_batches')
self.trainer.limit_test_batches = _determine_batch_limits(limit_test_batches, 'limit_test_batches')
self.trainer.val_check_interval = _determine_batch_limits(val_check_interval, 'val_check_interval')
self.trainer.overfit_batches = _determine_batch_limits(overfit_batches, 'overfit_batches')
self.determine_data_use_amount(self.trainer.overfit_batches)
def determine_data_use_amount(self, overfit_batches: float) -> None:
"""Use less data for debugging purposes"""
if overfit_batches > 0:
self.trainer.limit_train_batches = overfit_batches
self.trainer.limit_val_batches = overfit_batches
self.trainer.limit_test_batches = overfit_batches
def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]:
if 0 <= batches <= 1:
return batches
elif batches > 1 and batches % 1.0 == 0:
return int(batches)
else:
raise MisconfigurationException(
f'You have passed invalid value {batches} for {name}, it has to be in [0.0, 1.0] or an int.'
)

View File

@ -11,16 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVALAIBLE, rank_zero_warn, AMPType
class Initializer:
class PrecisionConnector:
def __init__(self, trainer):
self.trainer = trainer
def on_trainer_init(self, precision, amp_level, amp_backend):
# AMP init
# These are the only lines needed after v0.8.0
# we wrap the user's forward with autocast and give it back at the end of fit
self.trainer.autocast_original_forward = None
self.trainer.precision = precision
self.trainer.scaler = None
self.trainer.amp_level = amp_level
self.init_amp(amp_backend)
def init_amp(self, amp_type: str):
assert self.trainer.precision in (16, 32), 'only 32 or 16 bit precision supported'
self.trainer.amp_backend = None

View File

@ -43,15 +43,16 @@ from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop
from pytorch_lightning.trainer.training_loop import TrainLoop
from pytorch_lightning.trainer.data_connector import DataConnector
from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.logger_connector import LoggerConnector
from pytorch_lightning.trainer.lr_scheduler_connector import LRSchedulerConnector
from pytorch_lightning.trainer.callback_connector import CallbackConnector
from pytorch_lightning.trainer.model_connector import ModelConnector
from pytorch_lightning.trainer.debugging_connector import DebuggingConnector
from pytorch_lightning import _logger as log
from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.trainer.initializer import Initializer
from pytorch_lightning.trainer.precision_connector import PrecisionConnector
from pytorch_lightning.trainer.data_connector import DataConnector
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.trainer import docstrings
from pytorch_lightning.trainer.properties import TrainerProperties
@ -173,8 +174,10 @@ class Trainer(
self.accelerator_connector = AcceleratorConnector(self)
self.logger_connector = LoggerConnector(self)
self.model_connector = ModelConnector(self)
self.initializer = Initializer(self)
self.precision_connector = PrecisionConnector(self)
self.callback_connector = CallbackConnector(self)
self.debugging_connector = DebuggingConnector(self)
self.tuner = Tuner(self)
self.accelerator_backend = None
@ -253,15 +256,8 @@ class Trainer(
# -------------------
self.weights_summary = weights_summary
self.max_epochs = max_epochs
self.min_epochs = min_epochs
self.max_steps = max_steps
self.min_steps = min_steps
if num_sanity_val_steps == -1:
self.num_sanity_val_steps = float('inf')
else:
self.num_sanity_val_steps = num_sanity_val_steps
# init train loop related flags
self.train_loop.on_init_start(max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps)
self.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch
@ -275,17 +271,6 @@ class Trainer(
self.terminate_on_nan = terminate_on_nan
self.shown_warnings = set()
self.fast_dev_run = fast_dev_run
if self.fast_dev_run:
limit_train_batches = 1
limit_val_batches = 1
limit_test_batches = 1
self.num_sanity_val_steps = 0
self.max_epochs = 1
rank_zero_info(
'Running in fast_dev_run mode: will run a full train,' ' val and test loop using a single batch'
)
# configure profiler
if profiler is True:
profiler = SimpleProfiler()
@ -300,61 +285,22 @@ class Trainer(
self.log_save_interval = log_save_interval
self.row_log_interval = row_log_interval
# how much of the data to use
# TODO: remove in 0.10.0
if overfit_pct is not None:
rank_zero_warn(
"Argument `overfit_pct` is now set by `overfit_batches` since v0.8.0"
" and this argument will be removed in v0.10.0",
DeprecationWarning,
)
overfit_batches = overfit_pct
# init debugging flags
self.debugging_connector.on_init_start(
overfit_pct,
val_percent_check,
test_percent_check,
train_percent_check,
limit_train_batches,
limit_val_batches,
limit_test_batches,
val_check_interval,
overfit_batches,
fast_dev_run
)
# TODO: remove in 0.10.0
if val_percent_check is not None:
rank_zero_warn(
"Argument `val_percent_check` is now set by `limit_val_batches` since v0.8.0"
" and this argument will be removed in v0.10.0",
DeprecationWarning,
)
limit_val_batches = val_percent_check
# TODO: remove in 0.10.0
if test_percent_check is not None:
rank_zero_warn(
"Argument `test_percent_check` is now set by `limit_test_batches` since v0.8.0"
" and this argument will be removed in v0.10.0",
DeprecationWarning,
)
limit_test_batches = test_percent_check
# TODO: remove in 0.10.0
if train_percent_check is not None:
rank_zero_warn(
"Argument `train_percent_check` is now set by `limit_train_batches` since v0.8.0"
" and this argument will be removed in v0.10.0",
DeprecationWarning,
)
limit_train_batches = train_percent_check
self.limit_train_batches = _determine_batch_limits(limit_train_batches, 'limit_train_batches')
self.limit_val_batches = _determine_batch_limits(limit_val_batches, 'limit_val_batches')
self.limit_test_batches = _determine_batch_limits(limit_test_batches, 'limit_test_batches')
self.val_check_interval = _determine_batch_limits(val_check_interval, 'val_check_interval')
self.overfit_batches = _determine_batch_limits(overfit_batches, 'overfit_batches')
self.determine_data_use_amount(self.overfit_batches)
# AMP init
# These are the only lines needed after v0.8.0
# we wrap the user's forward with autocast and give it back at the end of fit
self.autocast_original_forward = None
self.precision = precision
self.scaler = None
self.amp_level = amp_level
self.initializer.init_amp(amp_backend)
self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE')
# set precision
self.precision_connector.on_trainer_init(precision, amp_level, amp_backend)
# Callback system
self.on_init_end()
@ -862,18 +808,6 @@ class Trainer(
return output
def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]:
if 0 <= batches <= 1:
return batches
elif batches > 1 and batches % 1.0 == 0:
return int(batches)
else:
raise MisconfigurationException(
f'You have passed invalid value {batches} for {name}, it has to be in [0.0, 1.0] or an int.'
)
# add docstrings
Trainer.__init__.__doc__ = docstrings.trainer.init
Trainer.fit.__doc__ = docstrings.trainer.fit

View File

@ -38,6 +38,17 @@ class TrainLoop:
self._teardown_already_run = False
self.running_loss = TensorRunningAccum(window_length=20)
def on_init_start(self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps):
self.trainer.max_epochs = max_epochs
self.trainer.min_epochs = min_epochs
self.trainer.max_steps = max_steps
self.trainer.min_steps = min_steps
if num_sanity_val_steps == -1:
self.trainer.num_sanity_val_steps = float('inf')
else:
self.trainer.num_sanity_val_steps = num_sanity_val_steps
@property
def num_optimizers(self):
num_optimizers = len(self.get_optimizers_iterable())