formatting 4/n: Trainer (#5720)
* yapf trainer * Apply suggestions from code review Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> * . * fix Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
This commit is contained in:
parent
aa03b73e60
commit
aba212341a
|
@ -28,10 +28,6 @@ pytorch_lightning/plugins/legacy/*
|
|||
pytorch_lightning/profiler/*
|
||||
|
||||
|
||||
# TODO
|
||||
pytorch_lightning/trainer/*
|
||||
|
||||
|
||||
# TODO
|
||||
pytorch_lightning/tuner/*
|
||||
|
||||
|
|
|
@ -95,10 +95,6 @@ class ConfigValidator(object):
|
|||
has_step = is_overridden(step_name, model)
|
||||
|
||||
if has_loader and not has_step:
|
||||
rank_zero_warn(
|
||||
f'you passed in a {loader_name} but have no {step_name}. Skipping {eval_loop_name} loop'
|
||||
)
|
||||
rank_zero_warn(f'you passed in a {loader_name} but have no {step_name}. Skipping {eval_loop_name} loop')
|
||||
if has_step and not has_loader:
|
||||
rank_zero_warn(
|
||||
f'you defined a {step_name} but have no {loader_name}. Skipping {eval_loop_name} loop'
|
||||
)
|
||||
rank_zero_warn(f'you defined a {step_name} but have no {loader_name}. Skipping {eval_loop_name} loop')
|
||||
|
|
|
@ -25,14 +25,14 @@ class CallbackConnector:
|
|||
self.trainer = trainer
|
||||
|
||||
def on_trainer_init(
|
||||
self,
|
||||
callbacks,
|
||||
checkpoint_callback,
|
||||
progress_bar_refresh_rate,
|
||||
process_position,
|
||||
default_root_dir,
|
||||
weights_save_path,
|
||||
resume_from_checkpoint
|
||||
self,
|
||||
callbacks,
|
||||
checkpoint_callback,
|
||||
progress_bar_refresh_rate,
|
||||
process_position,
|
||||
default_root_dir,
|
||||
weights_save_path,
|
||||
resume_from_checkpoint,
|
||||
):
|
||||
self.trainer.resume_from_checkpoint = resume_from_checkpoint
|
||||
|
||||
|
@ -51,9 +51,7 @@ class CallbackConnector:
|
|||
self.configure_checkpoint_callbacks(checkpoint_callback)
|
||||
|
||||
# init progress bar
|
||||
self.trainer._progress_bar_callback = self.configure_progress_bar(
|
||||
progress_bar_refresh_rate, process_position
|
||||
)
|
||||
self.trainer._progress_bar_callback = self.configure_progress_bar(progress_bar_refresh_rate, process_position)
|
||||
|
||||
def configure_checkpoint_callbacks(self, checkpoint_callback: Union[ModelCheckpoint, bool]):
|
||||
if isinstance(checkpoint_callback, ModelCheckpoint):
|
||||
|
@ -61,8 +59,7 @@ class CallbackConnector:
|
|||
rank_zero_warn(
|
||||
"Passing a ModelCheckpoint instance to Trainer(checkpoint_callbacks=...)"
|
||||
" is deprecated since v1.1 and will no longer be supported in v1.3."
|
||||
" Use `callbacks` argument instead.",
|
||||
DeprecationWarning
|
||||
" Use `callbacks` argument instead.", DeprecationWarning
|
||||
)
|
||||
self.trainer.callbacks.append(checkpoint_callback)
|
||||
|
||||
|
|
|
@ -230,7 +230,8 @@ class CheckpointConnector:
|
|||
if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
|
||||
del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
|
||||
rank_zero_warn(
|
||||
'warning, `hyper_parameters` dropped from checkpoint.' f' An attribute is not picklable {err}'
|
||||
'warning, `hyper_parameters` dropped from checkpoint.'
|
||||
f' An attribute is not picklable {err}'
|
||||
)
|
||||
atomic_save(checkpoint, filepath)
|
||||
|
||||
|
@ -297,9 +298,10 @@ class CheckpointConnector:
|
|||
checkpoint['lr_schedulers'] = lr_schedulers
|
||||
|
||||
# dump amp scaling
|
||||
if (self.trainer.amp_backend == AMPType.NATIVE
|
||||
and self.trainer._device_type != DeviceType.TPU
|
||||
and self.trainer.scaler is not None):
|
||||
if (
|
||||
self.trainer.amp_backend == AMPType.NATIVE and self.trainer._device_type != DeviceType.TPU
|
||||
and self.trainer.scaler is not None
|
||||
):
|
||||
checkpoint['native_amp_scaling_state'] = self.trainer.scaler.state_dict()
|
||||
elif self.trainer.amp_backend == AMPType.APEX:
|
||||
checkpoint['amp_scaling_state'] = amp.state_dict()
|
||||
|
@ -409,6 +411,7 @@ class CheckpointConnector:
|
|||
if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
|
||||
del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
|
||||
rank_zero_warn(
|
||||
'Warning, `hyper_parameters` dropped from checkpoint.' f' An attribute is not picklable {err}'
|
||||
'Warning, `hyper_parameters` dropped from checkpoint.'
|
||||
f' An attribute is not picklable {err}'
|
||||
)
|
||||
atomic_save(checkpoint, filepath)
|
||||
|
|
|
@ -36,8 +36,7 @@ class DataConnector(object):
|
|||
|
||||
def get_profiled_train_dataloader(self, train_dataloader):
|
||||
profiled_dl = self.trainer.profiler.profile_iterable(
|
||||
enumerate(self._with_is_last(train_dataloader)),
|
||||
"get_train_batch"
|
||||
enumerate(self._with_is_last(train_dataloader)), "get_train_batch"
|
||||
)
|
||||
return profiled_dl
|
||||
|
||||
|
|
|
@ -25,13 +25,13 @@ class DebuggingConnector:
|
|||
self.trainer = trainer
|
||||
|
||||
def on_init_start(
|
||||
self,
|
||||
limit_train_batches,
|
||||
limit_val_batches,
|
||||
limit_test_batches,
|
||||
val_check_interval,
|
||||
overfit_batches,
|
||||
fast_dev_run
|
||||
self,
|
||||
limit_train_batches,
|
||||
limit_val_batches,
|
||||
limit_test_batches,
|
||||
val_check_interval,
|
||||
overfit_batches,
|
||||
fast_dev_run,
|
||||
):
|
||||
if not isinstance(fast_dev_run, (bool, int)):
|
||||
raise MisconfigurationException(
|
||||
|
|
|
@ -24,6 +24,7 @@ def overwrite_by_env_vars(fn: Callable) -> Callable:
|
|||
input arguments should be moved automatically to the correct device.
|
||||
|
||||
"""
|
||||
|
||||
@wraps(fn)
|
||||
def overwrite_by_env_vars(self, *args, **kwargs):
|
||||
# get the class
|
||||
|
|
|
@ -18,8 +18,9 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|||
class CallbackHookNameValidator:
|
||||
|
||||
@staticmethod
|
||||
def check_logging_in_callbacks(current_hook_fx_name: str = None, on_step: bool = None,
|
||||
on_epoch: bool = None) -> None:
|
||||
def check_logging_in_callbacks(
|
||||
current_hook_fx_name: str = None, on_step: bool = None, on_epoch: bool = None
|
||||
) -> None:
|
||||
if current_hook_fx_name is None:
|
||||
return
|
||||
|
||||
|
|
|
@ -31,6 +31,7 @@ from pytorch_lightning.utilities.model_helpers import is_overridden
|
|||
|
||||
|
||||
class LoggerConnector:
|
||||
|
||||
def __init__(self, trainer):
|
||||
self.trainer = trainer
|
||||
self._callback_metrics = MetricsHolder()
|
||||
|
@ -76,14 +77,14 @@ class LoggerConnector:
|
|||
|
||||
@property
|
||||
def cached_results(self) -> Union[EpochResultStore, None]:
|
||||
return self._cached_results.get(self.trainer._running_stage) # type: ignore
|
||||
return self._cached_results.get(self.trainer._running_stage) # type: ignore
|
||||
|
||||
def get_metrics(self, key: str) -> Dict:
|
||||
metrics_holder = getattr(self, f"_{key}", None)
|
||||
model_ref = self.trainer.get_model()
|
||||
metrics_holder.convert(
|
||||
self.trainer._device_type == DeviceType.TPU,
|
||||
model_ref.device if model_ref is not None else model_ref
|
||||
model_ref.device if model_ref is not None else model_ref,
|
||||
)
|
||||
return metrics_holder.metrics
|
||||
|
||||
|
|
|
@ -21,7 +21,6 @@ from pytorch_lightning.utilities import _TPU_AVAILABLE
|
|||
|
||||
|
||||
class MetricsHolder:
|
||||
|
||||
"""
|
||||
This class acts as a dictonary holder.
|
||||
It holds metrics and implements conversion functions.
|
||||
|
|
|
@ -11,7 +11,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Root module for all distributed operations in Lightning.
|
||||
Currently supports training on CPU, GPU (dp, ddp, ddp2, horovod) and TPU.
|
||||
|
@ -21,6 +20,7 @@ from weakref import proxy
|
|||
|
||||
|
||||
class ModelConnector:
|
||||
|
||||
def __init__(self, trainer):
|
||||
self.trainer = trainer
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|||
|
||||
|
||||
class OptimizerConnector:
|
||||
|
||||
def __init__(self, trainer):
|
||||
self.trainer = trainer
|
||||
|
||||
|
@ -50,9 +51,8 @@ class OptimizerConnector:
|
|||
if lr_scheduler['reduce_on_plateau']:
|
||||
monitor_key = lr_scheduler['monitor']
|
||||
monitor_val = (
|
||||
monitor_metrics.get(monitor_key)
|
||||
if monitor_metrics is not None
|
||||
else self.trainer.logger_connector.callback_metrics.get(monitor_key)
|
||||
monitor_metrics.get(monitor_key) if monitor_metrics is not None else
|
||||
self.trainer.logger_connector.callback_metrics.get(monitor_key)
|
||||
)
|
||||
if monitor_val is None:
|
||||
if lr_scheduler.get('strict', True):
|
||||
|
|
|
@ -49,9 +49,11 @@ class PrecisionConnector:
|
|||
assert amp_type in ('native', 'apex'), f'Unsupported amp type {amp_type}'
|
||||
if amp_type == 'native':
|
||||
if not _NATIVE_AMP_AVAILABLE:
|
||||
rank_zero_warn('You have asked for native AMP but your PyTorch version does not support it.'
|
||||
' Consider upgrading with `pip install torch>=1.6`.'
|
||||
' We will attempt to use NVIDIA Apex for this session.')
|
||||
rank_zero_warn(
|
||||
'You have asked for native AMP but your PyTorch version does not support it.'
|
||||
' Consider upgrading with `pip install torch>=1.6`.'
|
||||
' We will attempt to use NVIDIA Apex for this session.'
|
||||
)
|
||||
amp_type = 'apex'
|
||||
else:
|
||||
self.trainer.amp_backend = AMPType.NATIVE
|
||||
|
@ -60,8 +62,10 @@ class PrecisionConnector:
|
|||
|
||||
if amp_type == 'apex':
|
||||
if not _APEX_AVAILABLE:
|
||||
rank_zero_warn('You have asked for Apex AMP but you have not installed it yet.'
|
||||
' Install apex first using this guide: https://github.com/NVIDIA/apex#linux')
|
||||
rank_zero_warn(
|
||||
'You have asked for Apex AMP but you have not installed it yet.'
|
||||
' Install apex first using this guide: https://github.com/NVIDIA/apex#linux'
|
||||
)
|
||||
else:
|
||||
log.info('Using APEX 16bit precision.')
|
||||
self.trainer.amp_backend = AMPType.APEX
|
||||
|
|
|
@ -24,11 +24,7 @@ from pytorch_lightning.profiler import (
|
|||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
PROFILERS = {
|
||||
"simple": SimpleProfiler,
|
||||
"advanced": AdvancedProfiler,
|
||||
"pytorch": PyTorchProfiler
|
||||
}
|
||||
PROFILERS = {"simple": SimpleProfiler, "advanced": AdvancedProfiler, "pytorch": PyTorchProfiler}
|
||||
|
||||
|
||||
class ProfilerConnector:
|
||||
|
@ -40,14 +36,17 @@ class ProfilerConnector:
|
|||
|
||||
if profiler and not isinstance(profiler, (bool, str, BaseProfiler)):
|
||||
# TODO: Update exception on removal of bool
|
||||
raise MisconfigurationException("Only None, bool, str and subclasses of `BaseProfiler`"
|
||||
" are valid values for `Trainer`'s `profiler` parameter."
|
||||
f" Received {profiler} which is of type {type(profiler)}.")
|
||||
raise MisconfigurationException(
|
||||
"Only None, bool, str and subclasses of `BaseProfiler`"
|
||||
" are valid values for `Trainer`'s `profiler` parameter."
|
||||
f" Received {profiler} which is of type {type(profiler)}."
|
||||
)
|
||||
|
||||
if isinstance(profiler, bool):
|
||||
rank_zero_warn("Passing a bool value as a `profiler` argument to `Trainer` is deprecated"
|
||||
" and will be removed in v1.3. Use str ('simple' or 'advanced') instead.",
|
||||
DeprecationWarning)
|
||||
rank_zero_warn(
|
||||
"Passing a bool value as a `profiler` argument to `Trainer` is deprecated"
|
||||
" and will be removed in v1.3. Use str ('simple' or 'advanced') instead.", DeprecationWarning
|
||||
)
|
||||
if profiler:
|
||||
profiler = SimpleProfiler()
|
||||
elif isinstance(profiler, str):
|
||||
|
@ -55,8 +54,10 @@ class ProfilerConnector:
|
|||
profiler_class = PROFILERS[profiler.lower()]
|
||||
profiler = profiler_class()
|
||||
else:
|
||||
raise ValueError("When passing string value for the `profiler` parameter of"
|
||||
" `Trainer`, it can only be 'simple' or 'advanced'")
|
||||
raise ValueError(
|
||||
"When passing string value for the `profiler` parameter of"
|
||||
" `Trainer`, it can only be 'simple' or 'advanced'"
|
||||
)
|
||||
self.trainer.profiler = profiler or PassThroughProfiler()
|
||||
|
||||
def on_train_start(self, trainer):
|
||||
|
|
|
@ -151,9 +151,5 @@ class SLURMConnector:
|
|||
torch_backend = "nccl" if self.trainer._device_type == DeviceType.GPU else "gloo"
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
log.info(
|
||||
f"initializing ddp (SLURM): GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}"
|
||||
)
|
||||
torch_distrib.init_process_group(
|
||||
torch_backend, rank=global_rank, world_size=world_size
|
||||
)
|
||||
log.info(f"initializing ddp (SLURM): GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}")
|
||||
torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size)
|
||||
|
|
|
@ -21,12 +21,12 @@ class TrainingTricksConnector:
|
|||
self.trainer = trainer
|
||||
|
||||
def on_trainer_init(
|
||||
self,
|
||||
gradient_clip_val,
|
||||
track_grad_norm,
|
||||
accumulate_grad_batches,
|
||||
truncated_bptt_steps,
|
||||
terminate_on_nan
|
||||
self,
|
||||
gradient_clip_val,
|
||||
track_grad_norm,
|
||||
accumulate_grad_batches,
|
||||
truncated_bptt_steps,
|
||||
terminate_on_nan,
|
||||
):
|
||||
|
||||
self.trainer.terminate_on_nan = terminate_on_nan
|
||||
|
|
|
@ -37,12 +37,12 @@ class TrainerDataLoadingMixin(ABC):
|
|||
# this is just a summary on variables used in this abstract class,
|
||||
# the proper values/initialisation should be done in child class
|
||||
global_rank: int
|
||||
shown_warnings: ...
|
||||
shown_warnings:...
|
||||
val_check_interval: float
|
||||
tpu_local_core_rank: int
|
||||
train_dataloader: DataLoader
|
||||
num_training_batches: Union[int, float]
|
||||
val_check_batch: ...
|
||||
val_check_batch:...
|
||||
val_dataloaders: List[DataLoader]
|
||||
num_val_batches: List[Union[int, float]]
|
||||
test_dataloaders: List[DataLoader]
|
||||
|
@ -65,22 +65,27 @@ class TrainerDataLoadingMixin(ABC):
|
|||
using_spawn = self.distributed_backend == "ddp_spawn"
|
||||
if is_dataloader and not on_windows:
|
||||
if dataloader.num_workers > 0 and using_spawn:
|
||||
rank_zero_warn('Dataloader(num_workers>0) and ddp_spawn do not mix well!'
|
||||
' Your performance might suffer dramatically.'
|
||||
' Please consider setting accelerator=ddp to use num_workers > 0'
|
||||
' (this is a bottleneck of Python .spawn() and PyTorch')
|
||||
rank_zero_warn(
|
||||
'Dataloader(num_workers>0) and ddp_spawn do not mix well!'
|
||||
' Your performance might suffer dramatically.'
|
||||
' Please consider setting accelerator=ddp to use num_workers > 0'
|
||||
' (this is a bottleneck of Python .spawn() and PyTorch'
|
||||
)
|
||||
|
||||
elif dataloader.num_workers == 0 and using_spawn:
|
||||
rank_zero_warn('You are using `accelerator=ddp_spawn` with num_workers=0.'
|
||||
' For much faster performance, switch to `accelerator=ddp`'
|
||||
' and set `num_workers>0`')
|
||||
rank_zero_warn(
|
||||
'You are using `accelerator=ddp_spawn` with num_workers=0.'
|
||||
' For much faster performance, switch to `accelerator=ddp` and set `num_workers>0`'
|
||||
)
|
||||
|
||||
elif dataloader.num_workers <= 2 and multiprocessing.cpu_count() > 2 and not using_spawn:
|
||||
num_cpus = multiprocessing.cpu_count()
|
||||
rank_zero_warn(f'The dataloader, {name}, does not have many workers which may be a bottleneck.'
|
||||
' Consider increasing the value of the `num_workers` argument`'
|
||||
f' (try {num_cpus} which is the number of cpus on this machine)'
|
||||
' in the `DataLoader` init to improve performance.')
|
||||
rank_zero_warn(
|
||||
f'The dataloader, {name}, does not have many workers which may be a bottleneck.'
|
||||
' Consider increasing the value of the `num_workers` argument`'
|
||||
f' (try {num_cpus} which is the number of cpus on this machine)'
|
||||
f' in the `DataLoader` init to improve performance.'
|
||||
)
|
||||
|
||||
def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader:
|
||||
|
||||
|
@ -99,7 +104,8 @@ class TrainerDataLoadingMixin(ABC):
|
|||
'You seem to have configured a sampler in your DataLoader. This will be replaced '
|
||||
' by `DistributedSampler` since `replace_sampler_ddp` is True and you are using'
|
||||
' distributed training. Either remove the sampler from your DataLoader or set'
|
||||
' `replace_sampler_ddp`=False if you want to use your custom sampler.')
|
||||
' `replace_sampler_ddp`=False if you want to use your custom sampler.'
|
||||
)
|
||||
|
||||
# replace with distributed sampler
|
||||
sampler = self._get_distributed_sampler(dataloader, shuffle)
|
||||
|
@ -110,9 +116,7 @@ class TrainerDataLoadingMixin(ABC):
|
|||
def replace_sampler(self, dataloader, sampler):
|
||||
skip_keys = ['sampler', 'batch_sampler', 'dataset_kind']
|
||||
|
||||
dl_args = {
|
||||
k: v for k, v in dataloader.__dict__.items() if not k.startswith('_') and k not in skip_keys
|
||||
}
|
||||
dl_args = {k: v for k, v in dataloader.__dict__.items() if not k.startswith('_') and k not in skip_keys}
|
||||
|
||||
dl_args['sampler'] = sampler
|
||||
dl_args['shuffle'] = False
|
||||
|
@ -138,17 +142,21 @@ class TrainerDataLoadingMixin(ABC):
|
|||
|
||||
if (self.overfit_batches > 0):
|
||||
if hasattr(self.train_dataloader, 'sampler') and isinstance(self.train_dataloader.sampler, RandomSampler):
|
||||
rank_zero_warn('You requested to overfit but enabled training dataloader shuffling.'
|
||||
' We are turning it off for you.')
|
||||
rank_zero_warn(
|
||||
'You requested to overfit but enabled training dataloader shuffling.'
|
||||
' We are turning it off for you.'
|
||||
)
|
||||
self.train_dataloader = self.replace_sampler(
|
||||
self.train_dataloader, SequentialSampler(self.train_dataloader.dataset))
|
||||
self.train_dataloader, SequentialSampler(self.train_dataloader.dataset)
|
||||
)
|
||||
|
||||
# debugging
|
||||
self.dev_debugger.track_load_dataloader_call('train_dataloader', dataloaders=[self.train_dataloader])
|
||||
|
||||
# automatically add samplers
|
||||
self.train_dataloader = apply_to_collection(
|
||||
self.train_dataloader, DataLoader, self.auto_add_sampler, shuffle=True)
|
||||
self.train_dataloader, DataLoader, self.auto_add_sampler, shuffle=True
|
||||
)
|
||||
|
||||
# check the workers recursively
|
||||
apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, 'train dataloader')
|
||||
|
@ -166,7 +174,8 @@ class TrainerDataLoadingMixin(ABC):
|
|||
raise MisconfigurationException(
|
||||
'When using an IterableDataset for `limit_train_batches`,'
|
||||
' `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies'
|
||||
' `num_training_batches` to use.')
|
||||
' `num_training_batches` to use.'
|
||||
)
|
||||
|
||||
# determine when to check validation
|
||||
# if int passed in, val checks that often
|
||||
|
@ -177,7 +186,8 @@ class TrainerDataLoadingMixin(ABC):
|
|||
raise ValueError(
|
||||
f'`val_check_interval` ({self.val_check_interval}) must be less than or equal '
|
||||
f'to the number of the training batches ({self.num_training_batches}). '
|
||||
'If you want to disable validation set `limit_val_batches` to 0.0 instead.')
|
||||
'If you want to disable validation set `limit_val_batches` to 0.0 instead.'
|
||||
)
|
||||
else:
|
||||
if not has_len(self.train_dataloader):
|
||||
if self.val_check_interval == 1.0:
|
||||
|
@ -186,15 +196,16 @@ class TrainerDataLoadingMixin(ABC):
|
|||
raise MisconfigurationException(
|
||||
'When using an IterableDataset for `train_dataloader`,'
|
||||
' `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies'
|
||||
' checking validation every k training batches.')
|
||||
' checking validation every k training batches.'
|
||||
)
|
||||
else:
|
||||
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
|
||||
self.val_check_batch = max(1, self.val_check_batch)
|
||||
|
||||
def _reset_eval_dataloader(
|
||||
self,
|
||||
model: LightningModule,
|
||||
mode: str
|
||||
self,
|
||||
model: LightningModule,
|
||||
mode: str,
|
||||
) -> Tuple[List[Union[int, float]], List[DataLoader]]:
|
||||
"""Generic method to reset a dataloader for evaluation.
|
||||
|
||||
|
@ -229,13 +240,17 @@ class TrainerDataLoadingMixin(ABC):
|
|||
|
||||
# when overfitting, the dataloader should not have sampler
|
||||
if self.overfit_batches > 0:
|
||||
rank_zero_warn('You requested to overfit but enabled test/val dataloader shuffling.'
|
||||
' We are turning it off for you.')
|
||||
rank_zero_warn(
|
||||
'You requested to overfit but enabled test/val dataloader shuffling.'
|
||||
' We are turning it off for you.'
|
||||
)
|
||||
dataloaders[loader_i] = self.replace_sampler(loader, SequentialSampler(loader.dataset))
|
||||
|
||||
else:
|
||||
rank_zero_warn(f'Your {mode}_dataloader has `shuffle=True`, it is best practice to turn'
|
||||
' this off for validation and test dataloaders.')
|
||||
rank_zero_warn(
|
||||
f'Your {mode}_dataloader has `shuffle=True`, it is best practice to turn'
|
||||
' this off for validation and test dataloaders.'
|
||||
)
|
||||
|
||||
if any([dl is None for dl in dataloaders]):
|
||||
rank_zero_warn("One of given dataloaders is None and it will be skipped.")
|
||||
|
@ -264,7 +279,8 @@ class TrainerDataLoadingMixin(ABC):
|
|||
raise MisconfigurationException(
|
||||
'When using an IterableDataset for `limit_{mode}_batches`,'
|
||||
f' `Trainer(limit_{mode}_batches)` must be `0.0`, `1.0` or an int. An int k specifies'
|
||||
f' `num_{mode}_batches` to use.')
|
||||
f' `num_{mode}_batches` to use.'
|
||||
)
|
||||
|
||||
if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(limit_eval_batches, float):
|
||||
min_pct = 1.0 / len(dataloader)
|
||||
|
|
|
@ -100,16 +100,12 @@ class DeprecatedDistDeviceAttributes:
|
|||
|
||||
@property
|
||||
def use_horovod(self) -> bool:
|
||||
rank_zero_warn(
|
||||
"Internal: `use_horovod` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning
|
||||
)
|
||||
rank_zero_warn("Internal: `use_horovod` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning)
|
||||
return self._distrib_type == DistributedType.HOROVOD
|
||||
|
||||
@use_horovod.setter
|
||||
def use_horovod(self, val: bool) -> None:
|
||||
rank_zero_warn(
|
||||
"Internal: `use_horovod` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning
|
||||
)
|
||||
rank_zero_warn("Internal: `use_horovod` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning)
|
||||
if val:
|
||||
self._distrib_type = DistributedType.HOROVOD
|
||||
|
||||
|
@ -119,14 +115,16 @@ class DeprecatedDistDeviceAttributes:
|
|||
"Internal: `use_single_gpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning
|
||||
)
|
||||
# todo, limiting to exclude DDP2 is not clear but it comes from connectors...
|
||||
return (self._device_type and self._device_type == DeviceType.GPU
|
||||
and self.num_gpus == 1
|
||||
and self._distrib_type not in (DistributedType.DDP2, ))
|
||||
return (
|
||||
self._device_type and self._device_type == DeviceType.GPU and self.num_gpus == 1
|
||||
and self._distrib_type != DistributedType.DDP2
|
||||
)
|
||||
|
||||
@use_single_gpu.setter
|
||||
def use_single_gpu(self, val: bool) -> None:
|
||||
rank_zero_warn(
|
||||
"Internal: `use_single_gpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning,
|
||||
"Internal: `use_single_gpu` is deprecated in v1.2 and will be removed in v1.4.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
if val:
|
||||
self._device_type = DeviceType.GPU
|
||||
|
|
|
@ -21,6 +21,7 @@ from pytorch_lightning.utilities.warnings import WarningCache
|
|||
|
||||
|
||||
class EvaluationLoop(object):
|
||||
|
||||
def __init__(self, trainer):
|
||||
self.trainer = trainer
|
||||
self.testing = False
|
||||
|
@ -303,7 +304,8 @@ class EvaluationLoop(object):
|
|||
def on_evaluation_batch_start(self, batch, batch_idx, dataloader_idx):
|
||||
# set dataloader_idx to model and track batch_size
|
||||
self.trainer.logger_connector.on_evaluation_batch_start(
|
||||
self.testing, batch, dataloader_idx, self.num_dataloaders)
|
||||
self.testing, batch, dataloader_idx, self.num_dataloaders
|
||||
)
|
||||
|
||||
if self.testing:
|
||||
self.trainer.call_hook('on_test_batch_start', batch, batch_idx, dataloader_idx)
|
||||
|
|
|
@ -17,9 +17,11 @@ import warnings
|
|||
|
||||
def ignore_scalar_return_in_dp():
|
||||
# Users get confused by this warning so we silence it
|
||||
warnings.filterwarnings('ignore', message='Was asked to gather along dimension 0, but all'
|
||||
' input tensors were scalars; will instead unsqueeze'
|
||||
' and return a vector.')
|
||||
warnings.filterwarnings(
|
||||
'ignore',
|
||||
message='Was asked to gather along dimension 0, but all input tensors were scalars;'
|
||||
' will instead unsqueeze and return a vector.'
|
||||
)
|
||||
|
||||
|
||||
ignore_scalar_return_in_dp()
|
||||
|
|
|
@ -31,14 +31,14 @@ class TrainerLoggingMixin(ABC):
|
|||
current_epoch: int
|
||||
_device_type: DeviceType
|
||||
_distrib_type: DistributedType
|
||||
log_gpu_memory: ...
|
||||
log_gpu_memory:...
|
||||
logger: Union[LightningLoggerBase, bool]
|
||||
global_step: int
|
||||
global_rank: int
|
||||
default_root_dir: str
|
||||
slurm_job_id: int
|
||||
num_gpus: int
|
||||
logged_metrics: ...
|
||||
logged_metrics:...
|
||||
|
||||
def metrics_to_scalars(self, metrics):
|
||||
new_metrics = {}
|
||||
|
@ -66,13 +66,12 @@ class TrainerLoggingMixin(ABC):
|
|||
for k, v in output.items():
|
||||
if k in ['log', 'progress_bar']:
|
||||
m = inspect.cleandoc(
|
||||
f"""The {{{k}:dict keyword}} was deprecated in 0.9.1 and will be removed in 1.0.0
|
||||
Please use self.log(...) inside the lightningModule instead.
|
||||
|
||||
# log on a step or aggregate epoch metric to the logger and/or progress bar
|
||||
# (inside LightningModule)
|
||||
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
|
||||
""")
|
||||
f"The {{{k}:dict keyword}} was deprecated in 0.9.1 and will be removed in 1.0.0\n"
|
||||
" Please use self.log(...) inside the lightningModule instead.\n"
|
||||
" # log on a step or aggregate epoch metric to the logger and/or progress bar"
|
||||
" (inside LightningModule)\n"
|
||||
" self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)"
|
||||
)
|
||||
rank_zero_warn(m)
|
||||
|
||||
# --------------------------
|
||||
|
|
|
@ -19,6 +19,7 @@ from pytorch_lightning.core.lightning import LightningModule
|
|||
|
||||
|
||||
class TrainerModelHooksMixin(ABC):
|
||||
|
||||
def is_function_implemented(self, f_name, model=None):
|
||||
if model is None:
|
||||
model = self.get_model()
|
||||
|
|
|
@ -26,6 +26,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|||
|
||||
|
||||
class TrainerOptimizersMixin(ABC):
|
||||
|
||||
def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]:
|
||||
optim_conf = model.configure_optimizers()
|
||||
if optim_conf is None:
|
||||
|
@ -82,6 +83,7 @@ class TrainerOptimizersMixin(ABC):
|
|||
return optimizers, lr_schedulers, optimizer_frequencies
|
||||
|
||||
def convert_to_lightning_optimizers(self):
|
||||
|
||||
def _convert_to_lightning_optimizer(trainer, optimizer):
|
||||
if not isinstance(optimizer, LightningOptimizer):
|
||||
optimizer = LightningOptimizer(optimizer)
|
||||
|
@ -132,9 +134,11 @@ class TrainerOptimizersMixin(ABC):
|
|||
' For example:'
|
||||
' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}'
|
||||
)
|
||||
lr_schedulers.append(
|
||||
{**default_config, 'scheduler': scheduler, 'reduce_on_plateau': True, 'monitor': monitor}
|
||||
)
|
||||
lr_schedulers.append({
|
||||
**default_config, 'scheduler': scheduler,
|
||||
'reduce_on_plateau': True,
|
||||
'monitor': monitor
|
||||
})
|
||||
elif isinstance(scheduler, optim.lr_scheduler._LRScheduler):
|
||||
lr_schedulers.append({**default_config, 'scheduler': scheduler})
|
||||
else:
|
||||
|
|
|
@ -59,6 +59,7 @@ def trainer_state(*, entering: Optional[TrainerState] = None, exiting: Optional[
|
|||
"""
|
||||
|
||||
def wrapper(fn) -> Callable:
|
||||
|
||||
@wraps(fn)
|
||||
def wrapped_fn(self, *args, **kwargs):
|
||||
if not isinstance(self, pytorch_lightning.Trainer):
|
||||
|
|
|
@ -101,10 +101,11 @@ class TensorRunningAccum(object):
|
|||
if self.rotated:
|
||||
return getattr(self.memory, how)()
|
||||
else:
|
||||
return getattr(self.memory[: self.current_idx], how)()
|
||||
return getattr(self.memory[:self.current_idx], how)()
|
||||
|
||||
|
||||
class Accumulator(object):
|
||||
|
||||
def __init__(self):
|
||||
self.num_values = 0
|
||||
self.total = 0
|
||||
|
@ -119,6 +120,7 @@ class Accumulator(object):
|
|||
|
||||
|
||||
class PredictionCollection(object):
|
||||
|
||||
def __init__(self, global_rank: int, world_size: int):
|
||||
self.global_rank = global_rank
|
||||
self.world_size = world_size
|
||||
|
@ -131,9 +133,7 @@ class PredictionCollection(object):
|
|||
elif name not in self.predictions[filename]:
|
||||
self.predictions[filename][name] = values
|
||||
elif isinstance(values, Tensor):
|
||||
self.predictions[filename][name] = torch.cat(
|
||||
(self.predictions[filename][name], values)
|
||||
)
|
||||
self.predictions[filename][name] = torch.cat((self.predictions[filename][name], values))
|
||||
elif isinstance(values, list):
|
||||
self.predictions[filename][name].extend(values)
|
||||
|
||||
|
@ -161,10 +161,7 @@ class PredictionCollection(object):
|
|||
fs.mkdirs(dirpath, exist_ok=True)
|
||||
|
||||
# Convert any tensor values to list
|
||||
predictions = {
|
||||
k: v if not isinstance(v, Tensor) else v.tolist()
|
||||
for k, v in predictions.items()
|
||||
}
|
||||
predictions = {k: v if not isinstance(v, Tensor) else v.tolist() for k, v in predictions.items()}
|
||||
|
||||
# Check if all features for this file add up to same length
|
||||
feature_lens = {k: len(v) for k, v in predictions.items()}
|
||||
|
@ -186,6 +183,7 @@ class CycleIterator(object):
|
|||
"""
|
||||
Iterator for restarting a dataloader if it runs out of samples
|
||||
"""
|
||||
|
||||
def __init__(self, loader: Any, length: Optional[int] = None):
|
||||
"""
|
||||
|
||||
|
@ -296,8 +294,9 @@ class CombinedDataset(object):
|
|||
raise MisconfigurationException(f"Invalid Mode: {mode}")
|
||||
|
||||
# extract the lengths
|
||||
all_lengths = apply_to_collection(datasets, (Dataset, Iterable, type(None)), get_len,
|
||||
wrong_dtype=(Sequence, Mapping))
|
||||
all_lengths = apply_to_collection(
|
||||
datasets, (Dataset, Iterable, type(None)), get_len, wrong_dtype=(Sequence, Mapping)
|
||||
)
|
||||
|
||||
compute_func = CombinedDataset.COMPUTE_FUNCS[mode]
|
||||
|
||||
|
@ -351,8 +350,9 @@ class CombinedLoader(object):
|
|||
"""
|
||||
self.loaders = loaders
|
||||
|
||||
datasets = apply_to_collection(self.loaders, Iterable, getattr, 'dataset', None,
|
||||
wrong_dtype=(Sequence, Mapping))
|
||||
datasets = apply_to_collection(
|
||||
self.loaders, Iterable, getattr, 'dataset', None, wrong_dtype=(Sequence, Mapping)
|
||||
)
|
||||
# could be multiple datasets, but use self.dataset to follow the name convention in DataLoader
|
||||
self.dataset = CombinedDataset(datasets, mode)
|
||||
|
||||
|
@ -367,8 +367,7 @@ class CombinedLoader(object):
|
|||
@property
|
||||
def sampler(self) -> Union[Iterable, Sequence, Mapping]:
|
||||
"""Return a collections of samplers extracting from loaders."""
|
||||
return apply_to_collection(self.loaders, Iterable, getattr, 'sampler', None,
|
||||
wrong_dtype=(Sequence, Mapping))
|
||||
return apply_to_collection(self.loaders, Iterable, getattr, 'sampler', None, wrong_dtype=(Sequence, Mapping))
|
||||
|
||||
def _wrap_loaders_max_size_cycle(self) -> Any:
|
||||
"""
|
||||
|
@ -378,8 +377,7 @@ class CombinedLoader(object):
|
|||
the wrapped loaders
|
||||
|
||||
"""
|
||||
all_lengths = apply_to_collection(self.loaders, Iterable, get_len,
|
||||
wrong_dtype=(Sequence, Mapping))
|
||||
all_lengths = apply_to_collection(self.loaders, Iterable, get_len, wrong_dtype=(Sequence, Mapping))
|
||||
|
||||
if isinstance(all_lengths, (int, float)):
|
||||
length = all_lengths
|
||||
|
@ -391,13 +389,10 @@ class CombinedLoader(object):
|
|||
length = max(all_lengths)
|
||||
|
||||
if isinstance(self.loaders, Mapping):
|
||||
self.loaders = type(self.loaders)({k: CycleIterator(v, length=length)
|
||||
for k, v in self.loaders.items()})
|
||||
self.loaders = type(self.loaders)({k: CycleIterator(v, length=length) for k, v in self.loaders.items()})
|
||||
|
||||
elif isinstance(self.loaders, Sequence):
|
||||
self.loaders = type(self.loaders)([
|
||||
CycleIterator(v, length=length) for v in self.loaders
|
||||
])
|
||||
self.loaders = type(self.loaders)([CycleIterator(v, length=length) for v in self.loaders])
|
||||
|
||||
# dataloaders are iterable but not sequence
|
||||
elif isinstance(self.loaders, Iterable):
|
||||
|
@ -424,8 +419,7 @@ class CombinedLoader(object):
|
|||
length: the minimum length of loaders
|
||||
|
||||
"""
|
||||
all_lengths = apply_to_collection(loaders, Iterable, get_len,
|
||||
wrong_dtype=(Sequence, Mapping))
|
||||
all_lengths = apply_to_collection(loaders, Iterable, get_len, wrong_dtype=(Sequence, Mapping))
|
||||
|
||||
if isinstance(all_lengths, (int, float)):
|
||||
return all_lengths
|
||||
|
@ -441,6 +435,7 @@ class CombinedLoaderIterator(object):
|
|||
"""
|
||||
Custom Iterator returning data from multple loaders, and allows sampling in parallel
|
||||
"""
|
||||
|
||||
def __init__(self, loaders: Any):
|
||||
"""
|
||||
|
||||
|
|
|
@ -11,7 +11,6 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Trainer to automate the training."""
|
||||
|
||||
import os
|
||||
|
@ -66,7 +65,7 @@ from pytorch_lightning.utilities.model_helpers import is_overridden
|
|||
|
||||
# warnings to ignore in trainer
|
||||
warnings.filterwarnings(
|
||||
'ignore', message='torch.distributed.reduce_op is deprecated, ' 'please use torch.distributed.ReduceOp instead'
|
||||
'ignore', message='torch.distributed.reduce_op is deprecated, please use torch.distributed.ReduceOp instead'
|
||||
)
|
||||
|
||||
|
||||
|
@ -80,6 +79,7 @@ class Trainer(
|
|||
TrainerDataLoadingMixin,
|
||||
DeprecatedDistDeviceAttributes,
|
||||
):
|
||||
|
||||
@overwrite_by_env_vars
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -896,9 +896,7 @@ class Trainer(
|
|||
# --------------------
|
||||
# If you supply a datamodule you can't supply dataloaders
|
||||
if dataloaders and datamodule:
|
||||
raise MisconfigurationException(
|
||||
'You cannot pass dataloaders to trainer.predict if you supply a datamodule'
|
||||
)
|
||||
raise MisconfigurationException('You cannot pass dataloaders to trainer.predict if you supply a datamodule')
|
||||
|
||||
if model is None:
|
||||
raise MisconfigurationException('You need to pass a model to `trainer.predict`. ')
|
||||
|
|
|
@ -35,6 +35,7 @@ from pytorch_lightning.utilities.warnings import WarningCache
|
|||
|
||||
|
||||
class TrainLoop:
|
||||
|
||||
def __init__(self, trainer, multiple_trainloader_mode):
|
||||
self.trainer = trainer
|
||||
self.early_stopping_accumulator = None
|
||||
|
@ -154,9 +155,10 @@ class TrainLoop:
|
|||
self.trainer.model_connector.copy_trainer_model_properties(ref_model)
|
||||
|
||||
# init amp. Must be done here instead of __init__ to allow ddp to work
|
||||
if (self.trainer.amp_backend == AMPType.NATIVE
|
||||
and self.trainer.precision == 16
|
||||
and self.trainer._device_type != DeviceType.TPU):
|
||||
if (
|
||||
self.trainer.amp_backend == AMPType.NATIVE and self.trainer.precision == 16
|
||||
and self.trainer._device_type != DeviceType.TPU
|
||||
):
|
||||
self.trainer.scaler = self.trainer.precision_connector.backend.scaler
|
||||
|
||||
# log hyper-parameters
|
||||
|
@ -498,7 +500,8 @@ class TrainLoop:
|
|||
if using_native_amp and is_lbfgs:
|
||||
raise MisconfigurationException(
|
||||
'native PyTorch amp and lbfgs are not compatible.'
|
||||
' To request, please file a Github issue in PyTorch and tag @mcarilli')
|
||||
' To request, please file a Github issue in PyTorch and tag @mcarilli'
|
||||
)
|
||||
|
||||
# wraps into LightingOptimizer only for running step
|
||||
optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, opt_idx)
|
||||
|
@ -641,10 +644,7 @@ class TrainLoop:
|
|||
|
||||
# log epoch metrics
|
||||
self.trainer.logger_connector.log_train_epoch_end_metrics(
|
||||
epoch_output,
|
||||
self.checkpoint_accumulator,
|
||||
self.early_stopping_accumulator,
|
||||
self.num_optimizers
|
||||
epoch_output, self.checkpoint_accumulator, self.early_stopping_accumulator, self.num_optimizers
|
||||
)
|
||||
|
||||
# when no val loop is present or fast-dev-run still need to call checkpoints
|
||||
|
@ -699,11 +699,8 @@ class TrainLoop:
|
|||
# automatic_optimization=False: don't block synchronization here
|
||||
with self.block_ddp_sync_behaviour():
|
||||
self.training_step_and_backward(
|
||||
split_batch,
|
||||
batch_idx,
|
||||
opt_idx,
|
||||
optimizer,
|
||||
self.trainer.hiddens)
|
||||
split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens
|
||||
)
|
||||
|
||||
batch_outputs = self._process_closure_result(
|
||||
batch_outputs=batch_outputs,
|
||||
|
@ -720,11 +717,7 @@ class TrainLoop:
|
|||
|
||||
def train_step_and_backward_closure():
|
||||
result = self.training_step_and_backward(
|
||||
split_batch,
|
||||
batch_idx,
|
||||
opt_idx,
|
||||
optimizer,
|
||||
self.trainer.hiddens
|
||||
split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens
|
||||
)
|
||||
return None if result is None else result.loss
|
||||
|
||||
|
@ -733,10 +726,7 @@ class TrainLoop:
|
|||
|
||||
else:
|
||||
self._curr_step_result = self.training_step(
|
||||
split_batch,
|
||||
batch_idx,
|
||||
opt_idx,
|
||||
self.trainer.hiddens
|
||||
split_batch, batch_idx, opt_idx, self.trainer.hiddens
|
||||
)
|
||||
|
||||
if self._curr_step_result is None:
|
||||
|
@ -783,9 +773,7 @@ class TrainLoop:
|
|||
else:
|
||||
yield None
|
||||
|
||||
def _process_closure_result(
|
||||
self, batch_outputs: list, opt_idx: int
|
||||
) -> list:
|
||||
def _process_closure_result(self, batch_outputs: list, opt_idx: int) -> list:
|
||||
opt_closure_result = self._curr_step_result
|
||||
|
||||
if opt_closure_result is not None:
|
||||
|
|
|
@ -29,7 +29,7 @@ class TrainerTrainingTricksMixin(ABC):
|
|||
# this is just a summary on variables used in this abstract class,
|
||||
# the proper values/initialisation should be done in child class
|
||||
default_root_dir: str
|
||||
progress_bar_callback: ...
|
||||
progress_bar_callback:...
|
||||
on_gpu: bool
|
||||
|
||||
@abstractmethod
|
||||
|
@ -47,9 +47,7 @@ class TrainerTrainingTricksMixin(ABC):
|
|||
|
||||
# check if loss is nan
|
||||
if not torch.isfinite(loss).all():
|
||||
raise ValueError(
|
||||
'The loss returned in `training_step` is nan or inf.'
|
||||
)
|
||||
raise ValueError('The loss returned in `training_step` is nan or inf.')
|
||||
# check if a network weight is nan
|
||||
for name, param in model.named_parameters():
|
||||
if not torch.isfinite(param).all():
|
||||
|
|
Loading…
Reference in New Issue