diff --git a/.yapfignore b/.yapfignore index 9b42f233da..221d7db8d0 100644 --- a/.yapfignore +++ b/.yapfignore @@ -28,10 +28,6 @@ pytorch_lightning/plugins/legacy/* pytorch_lightning/profiler/* -# TODO -pytorch_lightning/trainer/* - - # TODO pytorch_lightning/tuner/* diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index 584bdd7772..a7e13de8ed 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -95,10 +95,6 @@ class ConfigValidator(object): has_step = is_overridden(step_name, model) if has_loader and not has_step: - rank_zero_warn( - f'you passed in a {loader_name} but have no {step_name}. Skipping {eval_loop_name} loop' - ) + rank_zero_warn(f'you passed in a {loader_name} but have no {step_name}. Skipping {eval_loop_name} loop') if has_step and not has_loader: - rank_zero_warn( - f'you defined a {step_name} but have no {loader_name}. Skipping {eval_loop_name} loop' - ) + rank_zero_warn(f'you defined a {step_name} but have no {loader_name}. Skipping {eval_loop_name} loop') diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 61e59d2fb9..5140e5fc78 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -25,14 +25,14 @@ class CallbackConnector: self.trainer = trainer def on_trainer_init( - self, - callbacks, - checkpoint_callback, - progress_bar_refresh_rate, - process_position, - default_root_dir, - weights_save_path, - resume_from_checkpoint + self, + callbacks, + checkpoint_callback, + progress_bar_refresh_rate, + process_position, + default_root_dir, + weights_save_path, + resume_from_checkpoint, ): self.trainer.resume_from_checkpoint = resume_from_checkpoint @@ -51,9 +51,7 @@ class CallbackConnector: self.configure_checkpoint_callbacks(checkpoint_callback) # init progress bar - self.trainer._progress_bar_callback = self.configure_progress_bar( - progress_bar_refresh_rate, process_position - ) + self.trainer._progress_bar_callback = self.configure_progress_bar(progress_bar_refresh_rate, process_position) def configure_checkpoint_callbacks(self, checkpoint_callback: Union[ModelCheckpoint, bool]): if isinstance(checkpoint_callback, ModelCheckpoint): @@ -61,8 +59,7 @@ class CallbackConnector: rank_zero_warn( "Passing a ModelCheckpoint instance to Trainer(checkpoint_callbacks=...)" " is deprecated since v1.1 and will no longer be supported in v1.3." - " Use `callbacks` argument instead.", - DeprecationWarning + " Use `callbacks` argument instead.", DeprecationWarning ) self.trainer.callbacks.append(checkpoint_callback) diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 001b0b9ed3..6fe98bac13 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -230,7 +230,8 @@ class CheckpointConnector: if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] rank_zero_warn( - 'warning, `hyper_parameters` dropped from checkpoint.' f' An attribute is not picklable {err}' + 'warning, `hyper_parameters` dropped from checkpoint.' + f' An attribute is not picklable {err}' ) atomic_save(checkpoint, filepath) @@ -297,9 +298,10 @@ class CheckpointConnector: checkpoint['lr_schedulers'] = lr_schedulers # dump amp scaling - if (self.trainer.amp_backend == AMPType.NATIVE - and self.trainer._device_type != DeviceType.TPU - and self.trainer.scaler is not None): + if ( + self.trainer.amp_backend == AMPType.NATIVE and self.trainer._device_type != DeviceType.TPU + and self.trainer.scaler is not None + ): checkpoint['native_amp_scaling_state'] = self.trainer.scaler.state_dict() elif self.trainer.amp_backend == AMPType.APEX: checkpoint['amp_scaling_state'] = amp.state_dict() @@ -409,6 +411,7 @@ class CheckpointConnector: if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] rank_zero_warn( - 'Warning, `hyper_parameters` dropped from checkpoint.' f' An attribute is not picklable {err}' + 'Warning, `hyper_parameters` dropped from checkpoint.' + f' An attribute is not picklable {err}' ) atomic_save(checkpoint, filepath) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 802f8f0941..9161f3e875 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -36,8 +36,7 @@ class DataConnector(object): def get_profiled_train_dataloader(self, train_dataloader): profiled_dl = self.trainer.profiler.profile_iterable( - enumerate(self._with_is_last(train_dataloader)), - "get_train_batch" + enumerate(self._with_is_last(train_dataloader)), "get_train_batch" ) return profiled_dl diff --git a/pytorch_lightning/trainer/connectors/debugging_connector.py b/pytorch_lightning/trainer/connectors/debugging_connector.py index 3a5447dd94..caab61f113 100644 --- a/pytorch_lightning/trainer/connectors/debugging_connector.py +++ b/pytorch_lightning/trainer/connectors/debugging_connector.py @@ -25,13 +25,13 @@ class DebuggingConnector: self.trainer = trainer def on_init_start( - self, - limit_train_batches, - limit_val_batches, - limit_test_batches, - val_check_interval, - overfit_batches, - fast_dev_run + self, + limit_train_batches, + limit_val_batches, + limit_test_batches, + val_check_interval, + overfit_batches, + fast_dev_run, ): if not isinstance(fast_dev_run, (bool, int)): raise MisconfigurationException( diff --git a/pytorch_lightning/trainer/connectors/env_vars_connector.py b/pytorch_lightning/trainer/connectors/env_vars_connector.py index e4d5670b5f..2e788c256a 100644 --- a/pytorch_lightning/trainer/connectors/env_vars_connector.py +++ b/pytorch_lightning/trainer/connectors/env_vars_connector.py @@ -24,6 +24,7 @@ def overwrite_by_env_vars(fn: Callable) -> Callable: input arguments should be moved automatically to the correct device. """ + @wraps(fn) def overwrite_by_env_vars(self, *args, **kwargs): # get the class diff --git a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py index cf90877372..534dad5199 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py @@ -18,8 +18,9 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException class CallbackHookNameValidator: @staticmethod - def check_logging_in_callbacks(current_hook_fx_name: str = None, on_step: bool = None, - on_epoch: bool = None) -> None: + def check_logging_in_callbacks( + current_hook_fx_name: str = None, on_step: bool = None, on_epoch: bool = None + ) -> None: if current_hook_fx_name is None: return diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index f6700187c3..1d54a96254 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -31,6 +31,7 @@ from pytorch_lightning.utilities.model_helpers import is_overridden class LoggerConnector: + def __init__(self, trainer): self.trainer = trainer self._callback_metrics = MetricsHolder() @@ -76,14 +77,14 @@ class LoggerConnector: @property def cached_results(self) -> Union[EpochResultStore, None]: - return self._cached_results.get(self.trainer._running_stage) # type: ignore + return self._cached_results.get(self.trainer._running_stage) # type: ignore def get_metrics(self, key: str) -> Dict: metrics_holder = getattr(self, f"_{key}", None) model_ref = self.trainer.get_model() metrics_holder.convert( self.trainer._device_type == DeviceType.TPU, - model_ref.device if model_ref is not None else model_ref + model_ref.device if model_ref is not None else model_ref, ) return metrics_holder.metrics diff --git a/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py b/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py index d2e2c9b787..394e4285d3 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py @@ -21,7 +21,6 @@ from pytorch_lightning.utilities import _TPU_AVAILABLE class MetricsHolder: - """ This class acts as a dictonary holder. It holds metrics and implements conversion functions. diff --git a/pytorch_lightning/trainer/connectors/model_connector.py b/pytorch_lightning/trainer/connectors/model_connector.py index 673e8765ed..6a303b9822 100644 --- a/pytorch_lightning/trainer/connectors/model_connector.py +++ b/pytorch_lightning/trainer/connectors/model_connector.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """ Root module for all distributed operations in Lightning. Currently supports training on CPU, GPU (dp, ddp, ddp2, horovod) and TPU. @@ -21,6 +20,7 @@ from weakref import proxy class ModelConnector: + def __init__(self, trainer): self.trainer = trainer diff --git a/pytorch_lightning/trainer/connectors/optimizer_connector.py b/pytorch_lightning/trainer/connectors/optimizer_connector.py index 8b23203e42..5fb7b698b1 100644 --- a/pytorch_lightning/trainer/connectors/optimizer_connector.py +++ b/pytorch_lightning/trainer/connectors/optimizer_connector.py @@ -16,6 +16,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException class OptimizerConnector: + def __init__(self, trainer): self.trainer = trainer @@ -50,9 +51,8 @@ class OptimizerConnector: if lr_scheduler['reduce_on_plateau']: monitor_key = lr_scheduler['monitor'] monitor_val = ( - monitor_metrics.get(monitor_key) - if monitor_metrics is not None - else self.trainer.logger_connector.callback_metrics.get(monitor_key) + monitor_metrics.get(monitor_key) if monitor_metrics is not None else + self.trainer.logger_connector.callback_metrics.get(monitor_key) ) if monitor_val is None: if lr_scheduler.get('strict', True): diff --git a/pytorch_lightning/trainer/connectors/precision_connector.py b/pytorch_lightning/trainer/connectors/precision_connector.py index 551e855cdd..fdb469effa 100644 --- a/pytorch_lightning/trainer/connectors/precision_connector.py +++ b/pytorch_lightning/trainer/connectors/precision_connector.py @@ -49,9 +49,11 @@ class PrecisionConnector: assert amp_type in ('native', 'apex'), f'Unsupported amp type {amp_type}' if amp_type == 'native': if not _NATIVE_AMP_AVAILABLE: - rank_zero_warn('You have asked for native AMP but your PyTorch version does not support it.' - ' Consider upgrading with `pip install torch>=1.6`.' - ' We will attempt to use NVIDIA Apex for this session.') + rank_zero_warn( + 'You have asked for native AMP but your PyTorch version does not support it.' + ' Consider upgrading with `pip install torch>=1.6`.' + ' We will attempt to use NVIDIA Apex for this session.' + ) amp_type = 'apex' else: self.trainer.amp_backend = AMPType.NATIVE @@ -60,8 +62,10 @@ class PrecisionConnector: if amp_type == 'apex': if not _APEX_AVAILABLE: - rank_zero_warn('You have asked for Apex AMP but you have not installed it yet.' - ' Install apex first using this guide: https://github.com/NVIDIA/apex#linux') + rank_zero_warn( + 'You have asked for Apex AMP but you have not installed it yet.' + ' Install apex first using this guide: https://github.com/NVIDIA/apex#linux' + ) else: log.info('Using APEX 16bit precision.') self.trainer.amp_backend = AMPType.APEX diff --git a/pytorch_lightning/trainer/connectors/profiler_connector.py b/pytorch_lightning/trainer/connectors/profiler_connector.py index 2e66d2370e..225c6b9c98 100644 --- a/pytorch_lightning/trainer/connectors/profiler_connector.py +++ b/pytorch_lightning/trainer/connectors/profiler_connector.py @@ -24,11 +24,7 @@ from pytorch_lightning.profiler import ( from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException -PROFILERS = { - "simple": SimpleProfiler, - "advanced": AdvancedProfiler, - "pytorch": PyTorchProfiler -} +PROFILERS = {"simple": SimpleProfiler, "advanced": AdvancedProfiler, "pytorch": PyTorchProfiler} class ProfilerConnector: @@ -40,14 +36,17 @@ class ProfilerConnector: if profiler and not isinstance(profiler, (bool, str, BaseProfiler)): # TODO: Update exception on removal of bool - raise MisconfigurationException("Only None, bool, str and subclasses of `BaseProfiler`" - " are valid values for `Trainer`'s `profiler` parameter." - f" Received {profiler} which is of type {type(profiler)}.") + raise MisconfigurationException( + "Only None, bool, str and subclasses of `BaseProfiler`" + " are valid values for `Trainer`'s `profiler` parameter." + f" Received {profiler} which is of type {type(profiler)}." + ) if isinstance(profiler, bool): - rank_zero_warn("Passing a bool value as a `profiler` argument to `Trainer` is deprecated" - " and will be removed in v1.3. Use str ('simple' or 'advanced') instead.", - DeprecationWarning) + rank_zero_warn( + "Passing a bool value as a `profiler` argument to `Trainer` is deprecated" + " and will be removed in v1.3. Use str ('simple' or 'advanced') instead.", DeprecationWarning + ) if profiler: profiler = SimpleProfiler() elif isinstance(profiler, str): @@ -55,8 +54,10 @@ class ProfilerConnector: profiler_class = PROFILERS[profiler.lower()] profiler = profiler_class() else: - raise ValueError("When passing string value for the `profiler` parameter of" - " `Trainer`, it can only be 'simple' or 'advanced'") + raise ValueError( + "When passing string value for the `profiler` parameter of" + " `Trainer`, it can only be 'simple' or 'advanced'" + ) self.trainer.profiler = profiler or PassThroughProfiler() def on_train_start(self, trainer): diff --git a/pytorch_lightning/trainer/connectors/slurm_connector.py b/pytorch_lightning/trainer/connectors/slurm_connector.py index ad860c0b15..d53f4dfed4 100644 --- a/pytorch_lightning/trainer/connectors/slurm_connector.py +++ b/pytorch_lightning/trainer/connectors/slurm_connector.py @@ -151,9 +151,5 @@ class SLURMConnector: torch_backend = "nccl" if self.trainer._device_type == DeviceType.GPU else "gloo" if not torch.distributed.is_initialized(): - log.info( - f"initializing ddp (SLURM): GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}" - ) - torch_distrib.init_process_group( - torch_backend, rank=global_rank, world_size=world_size - ) + log.info(f"initializing ddp (SLURM): GLOBAL_RANK: {global_rank}, MEMBER: {global_rank + 1}/{world_size}") + torch_distrib.init_process_group(torch_backend, rank=global_rank, world_size=world_size) diff --git a/pytorch_lightning/trainer/connectors/training_trick_connector.py b/pytorch_lightning/trainer/connectors/training_trick_connector.py index b5d1d45461..dd7aad8cd6 100644 --- a/pytorch_lightning/trainer/connectors/training_trick_connector.py +++ b/pytorch_lightning/trainer/connectors/training_trick_connector.py @@ -21,12 +21,12 @@ class TrainingTricksConnector: self.trainer = trainer def on_trainer_init( - self, - gradient_clip_val, - track_grad_norm, - accumulate_grad_batches, - truncated_bptt_steps, - terminate_on_nan + self, + gradient_clip_val, + track_grad_norm, + accumulate_grad_batches, + truncated_bptt_steps, + terminate_on_nan, ): self.trainer.terminate_on_nan = terminate_on_nan diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 5031357b41..2476d2f4d4 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -37,12 +37,12 @@ class TrainerDataLoadingMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class global_rank: int - shown_warnings: ... + shown_warnings:... val_check_interval: float tpu_local_core_rank: int train_dataloader: DataLoader num_training_batches: Union[int, float] - val_check_batch: ... + val_check_batch:... val_dataloaders: List[DataLoader] num_val_batches: List[Union[int, float]] test_dataloaders: List[DataLoader] @@ -65,22 +65,27 @@ class TrainerDataLoadingMixin(ABC): using_spawn = self.distributed_backend == "ddp_spawn" if is_dataloader and not on_windows: if dataloader.num_workers > 0 and using_spawn: - rank_zero_warn('Dataloader(num_workers>0) and ddp_spawn do not mix well!' - ' Your performance might suffer dramatically.' - ' Please consider setting accelerator=ddp to use num_workers > 0' - ' (this is a bottleneck of Python .spawn() and PyTorch') + rank_zero_warn( + 'Dataloader(num_workers>0) and ddp_spawn do not mix well!' + ' Your performance might suffer dramatically.' + ' Please consider setting accelerator=ddp to use num_workers > 0' + ' (this is a bottleneck of Python .spawn() and PyTorch' + ) elif dataloader.num_workers == 0 and using_spawn: - rank_zero_warn('You are using `accelerator=ddp_spawn` with num_workers=0.' - ' For much faster performance, switch to `accelerator=ddp`' - ' and set `num_workers>0`') + rank_zero_warn( + 'You are using `accelerator=ddp_spawn` with num_workers=0.' + ' For much faster performance, switch to `accelerator=ddp` and set `num_workers>0`' + ) elif dataloader.num_workers <= 2 and multiprocessing.cpu_count() > 2 and not using_spawn: num_cpus = multiprocessing.cpu_count() - rank_zero_warn(f'The dataloader, {name}, does not have many workers which may be a bottleneck.' - ' Consider increasing the value of the `num_workers` argument`' - f' (try {num_cpus} which is the number of cpus on this machine)' - ' in the `DataLoader` init to improve performance.') + rank_zero_warn( + f'The dataloader, {name}, does not have many workers which may be a bottleneck.' + ' Consider increasing the value of the `num_workers` argument`' + f' (try {num_cpus} which is the number of cpus on this machine)' + f' in the `DataLoader` init to improve performance.' + ) def auto_add_sampler(self, dataloader: DataLoader, shuffle: bool) -> DataLoader: @@ -99,7 +104,8 @@ class TrainerDataLoadingMixin(ABC): 'You seem to have configured a sampler in your DataLoader. This will be replaced ' ' by `DistributedSampler` since `replace_sampler_ddp` is True and you are using' ' distributed training. Either remove the sampler from your DataLoader or set' - ' `replace_sampler_ddp`=False if you want to use your custom sampler.') + ' `replace_sampler_ddp`=False if you want to use your custom sampler.' + ) # replace with distributed sampler sampler = self._get_distributed_sampler(dataloader, shuffle) @@ -110,9 +116,7 @@ class TrainerDataLoadingMixin(ABC): def replace_sampler(self, dataloader, sampler): skip_keys = ['sampler', 'batch_sampler', 'dataset_kind'] - dl_args = { - k: v for k, v in dataloader.__dict__.items() if not k.startswith('_') and k not in skip_keys - } + dl_args = {k: v for k, v in dataloader.__dict__.items() if not k.startswith('_') and k not in skip_keys} dl_args['sampler'] = sampler dl_args['shuffle'] = False @@ -138,17 +142,21 @@ class TrainerDataLoadingMixin(ABC): if (self.overfit_batches > 0): if hasattr(self.train_dataloader, 'sampler') and isinstance(self.train_dataloader.sampler, RandomSampler): - rank_zero_warn('You requested to overfit but enabled training dataloader shuffling.' - ' We are turning it off for you.') + rank_zero_warn( + 'You requested to overfit but enabled training dataloader shuffling.' + ' We are turning it off for you.' + ) self.train_dataloader = self.replace_sampler( - self.train_dataloader, SequentialSampler(self.train_dataloader.dataset)) + self.train_dataloader, SequentialSampler(self.train_dataloader.dataset) + ) # debugging self.dev_debugger.track_load_dataloader_call('train_dataloader', dataloaders=[self.train_dataloader]) # automatically add samplers self.train_dataloader = apply_to_collection( - self.train_dataloader, DataLoader, self.auto_add_sampler, shuffle=True) + self.train_dataloader, DataLoader, self.auto_add_sampler, shuffle=True + ) # check the workers recursively apply_to_collection(self.train_dataloader, DataLoader, self._worker_check, 'train dataloader') @@ -166,7 +174,8 @@ class TrainerDataLoadingMixin(ABC): raise MisconfigurationException( 'When using an IterableDataset for `limit_train_batches`,' ' `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies' - ' `num_training_batches` to use.') + ' `num_training_batches` to use.' + ) # determine when to check validation # if int passed in, val checks that often @@ -177,7 +186,8 @@ class TrainerDataLoadingMixin(ABC): raise ValueError( f'`val_check_interval` ({self.val_check_interval}) must be less than or equal ' f'to the number of the training batches ({self.num_training_batches}). ' - 'If you want to disable validation set `limit_val_batches` to 0.0 instead.') + 'If you want to disable validation set `limit_val_batches` to 0.0 instead.' + ) else: if not has_len(self.train_dataloader): if self.val_check_interval == 1.0: @@ -186,15 +196,16 @@ class TrainerDataLoadingMixin(ABC): raise MisconfigurationException( 'When using an IterableDataset for `train_dataloader`,' ' `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies' - ' checking validation every k training batches.') + ' checking validation every k training batches.' + ) else: self.val_check_batch = int(self.num_training_batches * self.val_check_interval) self.val_check_batch = max(1, self.val_check_batch) def _reset_eval_dataloader( - self, - model: LightningModule, - mode: str + self, + model: LightningModule, + mode: str, ) -> Tuple[List[Union[int, float]], List[DataLoader]]: """Generic method to reset a dataloader for evaluation. @@ -229,13 +240,17 @@ class TrainerDataLoadingMixin(ABC): # when overfitting, the dataloader should not have sampler if self.overfit_batches > 0: - rank_zero_warn('You requested to overfit but enabled test/val dataloader shuffling.' - ' We are turning it off for you.') + rank_zero_warn( + 'You requested to overfit but enabled test/val dataloader shuffling.' + ' We are turning it off for you.' + ) dataloaders[loader_i] = self.replace_sampler(loader, SequentialSampler(loader.dataset)) else: - rank_zero_warn(f'Your {mode}_dataloader has `shuffle=True`, it is best practice to turn' - ' this off for validation and test dataloaders.') + rank_zero_warn( + f'Your {mode}_dataloader has `shuffle=True`, it is best practice to turn' + ' this off for validation and test dataloaders.' + ) if any([dl is None for dl in dataloaders]): rank_zero_warn("One of given dataloaders is None and it will be skipped.") @@ -264,7 +279,8 @@ class TrainerDataLoadingMixin(ABC): raise MisconfigurationException( 'When using an IterableDataset for `limit_{mode}_batches`,' f' `Trainer(limit_{mode}_batches)` must be `0.0`, `1.0` or an int. An int k specifies' - f' `num_{mode}_batches` to use.') + f' `num_{mode}_batches` to use.' + ) if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(limit_eval_batches, float): min_pct = 1.0 / len(dataloader) diff --git a/pytorch_lightning/trainer/deprecated_api.py b/pytorch_lightning/trainer/deprecated_api.py index e9407379cb..e0c79c20cf 100644 --- a/pytorch_lightning/trainer/deprecated_api.py +++ b/pytorch_lightning/trainer/deprecated_api.py @@ -100,16 +100,12 @@ class DeprecatedDistDeviceAttributes: @property def use_horovod(self) -> bool: - rank_zero_warn( - "Internal: `use_horovod` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning - ) + rank_zero_warn("Internal: `use_horovod` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) return self._distrib_type == DistributedType.HOROVOD @use_horovod.setter def use_horovod(self, val: bool) -> None: - rank_zero_warn( - "Internal: `use_horovod` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning - ) + rank_zero_warn("Internal: `use_horovod` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning) if val: self._distrib_type = DistributedType.HOROVOD @@ -119,14 +115,16 @@ class DeprecatedDistDeviceAttributes: "Internal: `use_single_gpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning ) # todo, limiting to exclude DDP2 is not clear but it comes from connectors... - return (self._device_type and self._device_type == DeviceType.GPU - and self.num_gpus == 1 - and self._distrib_type not in (DistributedType.DDP2, )) + return ( + self._device_type and self._device_type == DeviceType.GPU and self.num_gpus == 1 + and self._distrib_type != DistributedType.DDP2 + ) @use_single_gpu.setter def use_single_gpu(self, val: bool) -> None: rank_zero_warn( - "Internal: `use_single_gpu` is deprecated in v1.2 and will be removed in v1.4.", DeprecationWarning, + "Internal: `use_single_gpu` is deprecated in v1.2 and will be removed in v1.4.", + DeprecationWarning, ) if val: self._device_type = DeviceType.GPU diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 2aa6f86dc0..972446bd38 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -21,6 +21,7 @@ from pytorch_lightning.utilities.warnings import WarningCache class EvaluationLoop(object): + def __init__(self, trainer): self.trainer = trainer self.testing = False @@ -303,7 +304,8 @@ class EvaluationLoop(object): def on_evaluation_batch_start(self, batch, batch_idx, dataloader_idx): # set dataloader_idx to model and track batch_size self.trainer.logger_connector.on_evaluation_batch_start( - self.testing, batch, dataloader_idx, self.num_dataloaders) + self.testing, batch, dataloader_idx, self.num_dataloaders + ) if self.testing: self.trainer.call_hook('on_test_batch_start', batch, batch_idx, dataloader_idx) diff --git a/pytorch_lightning/trainer/ignored_warnings.py b/pytorch_lightning/trainer/ignored_warnings.py index 6aa185c527..894416d607 100644 --- a/pytorch_lightning/trainer/ignored_warnings.py +++ b/pytorch_lightning/trainer/ignored_warnings.py @@ -17,9 +17,11 @@ import warnings def ignore_scalar_return_in_dp(): # Users get confused by this warning so we silence it - warnings.filterwarnings('ignore', message='Was asked to gather along dimension 0, but all' - ' input tensors were scalars; will instead unsqueeze' - ' and return a vector.') + warnings.filterwarnings( + 'ignore', + message='Was asked to gather along dimension 0, but all input tensors were scalars;' + ' will instead unsqueeze and return a vector.' + ) ignore_scalar_return_in_dp() diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index f4ed647477..16060f8638 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -31,14 +31,14 @@ class TrainerLoggingMixin(ABC): current_epoch: int _device_type: DeviceType _distrib_type: DistributedType - log_gpu_memory: ... + log_gpu_memory:... logger: Union[LightningLoggerBase, bool] global_step: int global_rank: int default_root_dir: str slurm_job_id: int num_gpus: int - logged_metrics: ... + logged_metrics:... def metrics_to_scalars(self, metrics): new_metrics = {} @@ -66,13 +66,12 @@ class TrainerLoggingMixin(ABC): for k, v in output.items(): if k in ['log', 'progress_bar']: m = inspect.cleandoc( - f"""The {{{k}:dict keyword}} was deprecated in 0.9.1 and will be removed in 1.0.0 - Please use self.log(...) inside the lightningModule instead. - - # log on a step or aggregate epoch metric to the logger and/or progress bar - # (inside LightningModule) - self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True) - """) + f"The {{{k}:dict keyword}} was deprecated in 0.9.1 and will be removed in 1.0.0\n" + " Please use self.log(...) inside the lightningModule instead.\n" + " # log on a step or aggregate epoch metric to the logger and/or progress bar" + " (inside LightningModule)\n" + " self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)" + ) rank_zero_warn(m) # -------------------------- diff --git a/pytorch_lightning/trainer/model_hooks.py b/pytorch_lightning/trainer/model_hooks.py index 12225dd955..420911bb2b 100644 --- a/pytorch_lightning/trainer/model_hooks.py +++ b/pytorch_lightning/trainer/model_hooks.py @@ -19,6 +19,7 @@ from pytorch_lightning.core.lightning import LightningModule class TrainerModelHooksMixin(ABC): + def is_function_implemented(self, f_name, model=None): if model is None: model = self.get_model() diff --git a/pytorch_lightning/trainer/optimizers.py b/pytorch_lightning/trainer/optimizers.py index 2aaed17e98..a00c8b5fbf 100644 --- a/pytorch_lightning/trainer/optimizers.py +++ b/pytorch_lightning/trainer/optimizers.py @@ -26,6 +26,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException class TrainerOptimizersMixin(ABC): + def init_optimizers(self, model: LightningModule) -> Tuple[List, List, List]: optim_conf = model.configure_optimizers() if optim_conf is None: @@ -82,6 +83,7 @@ class TrainerOptimizersMixin(ABC): return optimizers, lr_schedulers, optimizer_frequencies def convert_to_lightning_optimizers(self): + def _convert_to_lightning_optimizer(trainer, optimizer): if not isinstance(optimizer, LightningOptimizer): optimizer = LightningOptimizer(optimizer) @@ -132,9 +134,11 @@ class TrainerOptimizersMixin(ABC): ' For example:' ' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}' ) - lr_schedulers.append( - {**default_config, 'scheduler': scheduler, 'reduce_on_plateau': True, 'monitor': monitor} - ) + lr_schedulers.append({ + **default_config, 'scheduler': scheduler, + 'reduce_on_plateau': True, + 'monitor': monitor + }) elif isinstance(scheduler, optim.lr_scheduler._LRScheduler): lr_schedulers.append({**default_config, 'scheduler': scheduler}) else: diff --git a/pytorch_lightning/trainer/states.py b/pytorch_lightning/trainer/states.py index a3ef08df1e..1758cb41ee 100644 --- a/pytorch_lightning/trainer/states.py +++ b/pytorch_lightning/trainer/states.py @@ -59,6 +59,7 @@ def trainer_state(*, entering: Optional[TrainerState] = None, exiting: Optional[ """ def wrapper(fn) -> Callable: + @wraps(fn) def wrapped_fn(self, *args, **kwargs): if not isinstance(self, pytorch_lightning.Trainer): diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index dc16062e3a..aff458d1b6 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -101,10 +101,11 @@ class TensorRunningAccum(object): if self.rotated: return getattr(self.memory, how)() else: - return getattr(self.memory[: self.current_idx], how)() + return getattr(self.memory[:self.current_idx], how)() class Accumulator(object): + def __init__(self): self.num_values = 0 self.total = 0 @@ -119,6 +120,7 @@ class Accumulator(object): class PredictionCollection(object): + def __init__(self, global_rank: int, world_size: int): self.global_rank = global_rank self.world_size = world_size @@ -131,9 +133,7 @@ class PredictionCollection(object): elif name not in self.predictions[filename]: self.predictions[filename][name] = values elif isinstance(values, Tensor): - self.predictions[filename][name] = torch.cat( - (self.predictions[filename][name], values) - ) + self.predictions[filename][name] = torch.cat((self.predictions[filename][name], values)) elif isinstance(values, list): self.predictions[filename][name].extend(values) @@ -161,10 +161,7 @@ class PredictionCollection(object): fs.mkdirs(dirpath, exist_ok=True) # Convert any tensor values to list - predictions = { - k: v if not isinstance(v, Tensor) else v.tolist() - for k, v in predictions.items() - } + predictions = {k: v if not isinstance(v, Tensor) else v.tolist() for k, v in predictions.items()} # Check if all features for this file add up to same length feature_lens = {k: len(v) for k, v in predictions.items()} @@ -186,6 +183,7 @@ class CycleIterator(object): """ Iterator for restarting a dataloader if it runs out of samples """ + def __init__(self, loader: Any, length: Optional[int] = None): """ @@ -296,8 +294,9 @@ class CombinedDataset(object): raise MisconfigurationException(f"Invalid Mode: {mode}") # extract the lengths - all_lengths = apply_to_collection(datasets, (Dataset, Iterable, type(None)), get_len, - wrong_dtype=(Sequence, Mapping)) + all_lengths = apply_to_collection( + datasets, (Dataset, Iterable, type(None)), get_len, wrong_dtype=(Sequence, Mapping) + ) compute_func = CombinedDataset.COMPUTE_FUNCS[mode] @@ -351,8 +350,9 @@ class CombinedLoader(object): """ self.loaders = loaders - datasets = apply_to_collection(self.loaders, Iterable, getattr, 'dataset', None, - wrong_dtype=(Sequence, Mapping)) + datasets = apply_to_collection( + self.loaders, Iterable, getattr, 'dataset', None, wrong_dtype=(Sequence, Mapping) + ) # could be multiple datasets, but use self.dataset to follow the name convention in DataLoader self.dataset = CombinedDataset(datasets, mode) @@ -367,8 +367,7 @@ class CombinedLoader(object): @property def sampler(self) -> Union[Iterable, Sequence, Mapping]: """Return a collections of samplers extracting from loaders.""" - return apply_to_collection(self.loaders, Iterable, getattr, 'sampler', None, - wrong_dtype=(Sequence, Mapping)) + return apply_to_collection(self.loaders, Iterable, getattr, 'sampler', None, wrong_dtype=(Sequence, Mapping)) def _wrap_loaders_max_size_cycle(self) -> Any: """ @@ -378,8 +377,7 @@ class CombinedLoader(object): the wrapped loaders """ - all_lengths = apply_to_collection(self.loaders, Iterable, get_len, - wrong_dtype=(Sequence, Mapping)) + all_lengths = apply_to_collection(self.loaders, Iterable, get_len, wrong_dtype=(Sequence, Mapping)) if isinstance(all_lengths, (int, float)): length = all_lengths @@ -391,13 +389,10 @@ class CombinedLoader(object): length = max(all_lengths) if isinstance(self.loaders, Mapping): - self.loaders = type(self.loaders)({k: CycleIterator(v, length=length) - for k, v in self.loaders.items()}) + self.loaders = type(self.loaders)({k: CycleIterator(v, length=length) for k, v in self.loaders.items()}) elif isinstance(self.loaders, Sequence): - self.loaders = type(self.loaders)([ - CycleIterator(v, length=length) for v in self.loaders - ]) + self.loaders = type(self.loaders)([CycleIterator(v, length=length) for v in self.loaders]) # dataloaders are iterable but not sequence elif isinstance(self.loaders, Iterable): @@ -424,8 +419,7 @@ class CombinedLoader(object): length: the minimum length of loaders """ - all_lengths = apply_to_collection(loaders, Iterable, get_len, - wrong_dtype=(Sequence, Mapping)) + all_lengths = apply_to_collection(loaders, Iterable, get_len, wrong_dtype=(Sequence, Mapping)) if isinstance(all_lengths, (int, float)): return all_lengths @@ -441,6 +435,7 @@ class CombinedLoaderIterator(object): """ Custom Iterator returning data from multple loaders, and allows sampling in parallel """ + def __init__(self, loaders: Any): """ diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index ba34c49581..182c89ee85 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Trainer to automate the training.""" import os @@ -66,7 +65,7 @@ from pytorch_lightning.utilities.model_helpers import is_overridden # warnings to ignore in trainer warnings.filterwarnings( - 'ignore', message='torch.distributed.reduce_op is deprecated, ' 'please use torch.distributed.ReduceOp instead' + 'ignore', message='torch.distributed.reduce_op is deprecated, please use torch.distributed.ReduceOp instead' ) @@ -80,6 +79,7 @@ class Trainer( TrainerDataLoadingMixin, DeprecatedDistDeviceAttributes, ): + @overwrite_by_env_vars def __init__( self, @@ -896,9 +896,7 @@ class Trainer( # -------------------- # If you supply a datamodule you can't supply dataloaders if dataloaders and datamodule: - raise MisconfigurationException( - 'You cannot pass dataloaders to trainer.predict if you supply a datamodule' - ) + raise MisconfigurationException('You cannot pass dataloaders to trainer.predict if you supply a datamodule') if model is None: raise MisconfigurationException('You need to pass a model to `trainer.predict`. ') diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 85c5758ec2..f370e1d5b8 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -35,6 +35,7 @@ from pytorch_lightning.utilities.warnings import WarningCache class TrainLoop: + def __init__(self, trainer, multiple_trainloader_mode): self.trainer = trainer self.early_stopping_accumulator = None @@ -154,9 +155,10 @@ class TrainLoop: self.trainer.model_connector.copy_trainer_model_properties(ref_model) # init amp. Must be done here instead of __init__ to allow ddp to work - if (self.trainer.amp_backend == AMPType.NATIVE - and self.trainer.precision == 16 - and self.trainer._device_type != DeviceType.TPU): + if ( + self.trainer.amp_backend == AMPType.NATIVE and self.trainer.precision == 16 + and self.trainer._device_type != DeviceType.TPU + ): self.trainer.scaler = self.trainer.precision_connector.backend.scaler # log hyper-parameters @@ -498,7 +500,8 @@ class TrainLoop: if using_native_amp and is_lbfgs: raise MisconfigurationException( 'native PyTorch amp and lbfgs are not compatible.' - ' To request, please file a Github issue in PyTorch and tag @mcarilli') + ' To request, please file a Github issue in PyTorch and tag @mcarilli' + ) # wraps into LightingOptimizer only for running step optimizer = LightningOptimizer._to_lightning_optimizer(optimizer, self.trainer, opt_idx) @@ -641,10 +644,7 @@ class TrainLoop: # log epoch metrics self.trainer.logger_connector.log_train_epoch_end_metrics( - epoch_output, - self.checkpoint_accumulator, - self.early_stopping_accumulator, - self.num_optimizers + epoch_output, self.checkpoint_accumulator, self.early_stopping_accumulator, self.num_optimizers ) # when no val loop is present or fast-dev-run still need to call checkpoints @@ -699,11 +699,8 @@ class TrainLoop: # automatic_optimization=False: don't block synchronization here with self.block_ddp_sync_behaviour(): self.training_step_and_backward( - split_batch, - batch_idx, - opt_idx, - optimizer, - self.trainer.hiddens) + split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens + ) batch_outputs = self._process_closure_result( batch_outputs=batch_outputs, @@ -720,11 +717,7 @@ class TrainLoop: def train_step_and_backward_closure(): result = self.training_step_and_backward( - split_batch, - batch_idx, - opt_idx, - optimizer, - self.trainer.hiddens + split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens ) return None if result is None else result.loss @@ -733,10 +726,7 @@ class TrainLoop: else: self._curr_step_result = self.training_step( - split_batch, - batch_idx, - opt_idx, - self.trainer.hiddens + split_batch, batch_idx, opt_idx, self.trainer.hiddens ) if self._curr_step_result is None: @@ -783,9 +773,7 @@ class TrainLoop: else: yield None - def _process_closure_result( - self, batch_outputs: list, opt_idx: int - ) -> list: + def _process_closure_result(self, batch_outputs: list, opt_idx: int) -> list: opt_closure_result = self._curr_step_result if opt_closure_result is not None: diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index be9793e9e5..7665f96426 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -29,7 +29,7 @@ class TrainerTrainingTricksMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class default_root_dir: str - progress_bar_callback: ... + progress_bar_callback:... on_gpu: bool @abstractmethod @@ -47,9 +47,7 @@ class TrainerTrainingTricksMixin(ABC): # check if loss is nan if not torch.isfinite(loss).all(): - raise ValueError( - 'The loss returned in `training_step` is nan or inf.' - ) + raise ValueError('The loss returned in `training_step` is nan or inf.') # check if a network weight is nan for name, param in model.named_parameters(): if not torch.isfinite(param).all():