diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index dea7396125..7c3095fc28 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -82,9 +82,9 @@ jobs: uses: actions/cache@v1 with: path: ${{ steps.pip-cache.outputs.dir }} - key: ${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}-pip-${{ hashFiles('requirements/base.txt') }}-${{ hashFiles('requirements/extra.txt') }} + key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ matrix.requires }}-pip-${{ hashFiles('requirements/base.txt') }}-${{ hashFiles('requirements/extra.txt') }} restore-keys: | - ${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}-pip- + ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ matrix.requires }}-pip- - name: Install dependencies run: | diff --git a/pytorch_lightning/__init__.py b/pytorch_lightning/__init__.py index 1413d8d62c..09783e6d18 100644 --- a/pytorch_lightning/__init__.py +++ b/pytorch_lightning/__init__.py @@ -55,6 +55,7 @@ else: from pytorch_lightning.trainer import Trainer from pytorch_lightning.utilities.seed import seed_everything from pytorch_lightning import metrics + from pytorch_lightning.core.step_result import TrainResult, EvalResult __all__ = [ 'Trainer', @@ -62,7 +63,9 @@ else: 'Callback', 'data_loader', 'seed_everything', - 'metrics' + 'metrics', + 'EvalResult', + 'TrainResult' ] # necessary for regular bolts imports. Skip exception since bolts is not always installed diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index dac8ddc11c..7c1d055477 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -46,6 +46,30 @@ class Callback(abc.ABC): """Called when the validation sanity check ends.""" pass + def on_train_epoch_start(self, trainer, pl_module): + """Called when the train epoch begins.""" + pass + + def on_train_epoch_end(self, trainer, pl_module): + """Called when the train epoch ends.""" + pass + + def on_validation_epoch_start(self, trainer, pl_module): + """Called when the val epoch begins.""" + pass + + def on_validation_epoch_end(self, trainer, pl_module): + """Called when the val epoch ends.""" + pass + + def on_test_epoch_start(self, trainer, pl_module): + """Called when the test epoch begins.""" + pass + + def on_test_epoch_end(self, trainer, pl_module): + """Called when the test epoch ends.""" + pass + def on_epoch_start(self, trainer, pl_module): """Called when the epoch begins.""" pass diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 544854fa4e..4e22cba977 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -7,6 +7,7 @@ Monitor a validation metric and stop training when it stops improving. """ from copy import deepcopy +import os import numpy as np import torch import torch.distributed as dist @@ -140,12 +141,33 @@ class EarlyStopping(Callback): def on_validation_end(self, trainer, pl_module): self._run_early_stopping_check(trainer, pl_module) + def on_train_epoch_end(self, trainer, pl_module): + # early stopping can also work in the train loop when there is no val loop and when using structured results + should_check_early_stop = False + train_es_key = 'early_stop_on' + if trainer.callback_metrics.get(train_es_key, None) is not None: + self.monitor = train_es_key + should_check_early_stop = True + + val_es_key = 'val_early_stop_on' + if trainer.callback_metrics.get(val_es_key, None) is not None: + self.monitor = val_es_key + should_check_early_stop = True + + if should_check_early_stop: + self._run_early_stopping_check(trainer, pl_module) + def _run_early_stopping_check(self, trainer, pl_module): logs = trainer.callback_metrics + if not self._validate_condition_metric(logs): return # short circuit if metric not present current = logs.get(self.monitor) + + # when in dev debugging + trainer.dev_debugger.track_early_stopping_history(current) + if not isinstance(current, torch.Tensor): current = torch.tensor(current, device=pl_module.device) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index f70d8d8d0a..370a30b75d 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -159,7 +159,11 @@ class ModelCheckpoint(Callback): if os.path.isfile(filepath): os.remove(filepath) - def _save_model(self, filepath): + def _save_model(self, filepath, trainer, pl_module): + + # in debugging, track when we save checkpoints + trainer.dev_debugger.track_checkpointing_history(filepath) + # make paths os.makedirs(os.path.dirname(filepath), exist_ok=True) @@ -270,6 +274,11 @@ class ModelCheckpoint(Callback): metrics = trainer.callback_metrics epoch = trainer.current_epoch + + # support structured results + if metrics.get('checkpoint_on') is not None: + self.monitor = 'checkpoint_on' + if self.save_top_k == 0: # no models are saved return @@ -281,7 +290,7 @@ class ModelCheckpoint(Callback): if self.save_last: filepath = os.path.join(self.dirpath, self.prefix + 'last.ckpt') - self._save_model(filepath) + self._save_model(filepath, trainer, pl_module) filepath = self.format_checkpoint_name(epoch, metrics) version_cnt = 0 @@ -306,7 +315,7 @@ class ModelCheckpoint(Callback): f'Can save best model only with {self.monitor} available, skipping.', RuntimeWarning ) elif self.check_monitor_top_k(current): - self._do_check_save(filepath, current, epoch) + self._do_check_save(filepath, current, epoch, trainer, pl_module) elif self.verbose > 0: log.info(f'\nEpoch {epoch:05d}: {self.monitor} was not in top {self.save_top_k}') @@ -315,9 +324,9 @@ class ModelCheckpoint(Callback): log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}') assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0' - self._save_model(filepath) + self._save_model(filepath, trainer, pl_module) - def _do_check_save(self, filepath, current, epoch): + def _do_check_save(self, filepath, current, epoch, trainer, pl_module): # remove kth del_list = [] @@ -343,7 +352,7 @@ class ModelCheckpoint(Callback): f'\nEpoch {epoch:05d}: {self.monitor} reached' f' {current:0.5f} (best {self.best_model_score:0.5f}), saving model to' f' {filepath} as top {self.save_top_k}') - self._save_model(filepath) + self._save_model(filepath, trainer, pl_module) for cur_path in del_list: if cur_path != filepath: diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index d8c2181251..aa4e274298 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -115,6 +115,42 @@ class ModelHooks(Module): """ # do something when the epoch ends + def on_train_epoch_start(self) -> None: + """ + Called in the training loop at the very beginning of the epoch. + """ + # do something when the epoch starts + + def on_train_epoch_end(self) -> None: + """ + Called in the training loop at the very end of the epoch. + """ + # do something when the epoch ends + + def on_validation_epoch_start(self) -> None: + """ + Called in the validation loop at the very beginning of the epoch. + """ + # do something when the epoch starts + + def on_validation_epoch_end(self) -> None: + """ + Called in the validation loop at the very end of the epoch. + """ + # do something when the epoch ends + + def on_test_epoch_start(self) -> None: + """ + Called in the test loop at the very beginning of the epoch. + """ + # do something when the epoch starts + + def on_test_epoch_end(self) -> None: + """ + Called in the test loop at the very end of the epoch. + """ + # do something when the epoch ends + def on_pre_performance_check(self) -> None: """ Called at the very beginning of the validation loop. diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py new file mode 100644 index 0000000000..1dc88db15c --- /dev/null +++ b/pytorch_lightning/core/step_result.py @@ -0,0 +1,336 @@ +from typing import Optional, Dict, Union, Sequence, Callable, MutableMapping, Any +from torch import Tensor +import torch +from copy import copy + + +class Result(Dict): + + def __init__( + self, + minimize: Optional[Tensor] = None, + early_stop_on: Optional[Tensor] = None, + checkpoint_on: Union[Tensor, bool, None] = None, + hiddens: Optional[Tensor] = None, + ): + + super().__init__() + + if early_stop_on is not None: + self.early_stop_on = early_stop_on + if checkpoint_on is not None and checkpoint_on: + self.checkpoint_on = checkpoint_on + if hiddens is not None: + self.hiddens = hiddens + if minimize is not None: + err = 'Minimize can only be used in training_step, training_step_end, training_epoch_end' + self._assert_grad_tensor_metric('minimize', minimize, err) + self.minimize = minimize + + if minimize is not None and checkpoint_on is None: + self.checkpoint_on = minimize.detach() + + self['meta'] = { + '_internal': { + '_reduce_on_epoch': False + } + } + + def __getattr__(self, key: str) -> Any: + try: + if key == 'callback_metrics': + return self.get_callback_metrics() + elif key == 'batch_log_metrics': + return self.get_batch_log_metrics() + elif key == 'batch_pbar_metrics': + return self.get_batch_pbar_metrics() + elif key == 'epoch_log_metrics': + return self.get_epoch_log_metrics() + elif key == 'epoch_pbar_metrics': + return self.get_epoch_pbar_metrics() + else: + return self[key] + except KeyError: + return None + + def __setattr__(self, key: str, val: Union[Tensor, Any]): + # ensure reserve keys are tensors and detached + if key in {'hiddens', 'checkpoint_on', 'early_stop_on'}: + self._assert_tensor_metric(key, val) + if val is not None and isinstance(val, torch.Tensor): + val = val.detach() + + # ensure anything else that is a tensor is detached + elif isinstance(val, torch.Tensor) and key != 'minimize': + val = val.detach() + + self[key] = val + + def _assert_tensor_metric(self, name: str, potential_metric: Union[bool, Tensor, None, Any]): + if potential_metric is not None and not isinstance(potential_metric, bool): + assert isinstance(potential_metric, Tensor), f'{name} must be a torch.Tensor' + + def _assert_grad_tensor_metric(self, name: str, x: Union[torch.Tensor, Any], additional_err: str = ''): + if x is not None: + assert isinstance(x, Tensor), f'{name} must be a torch.Tensor' + m = f'{name} must have a computational graph.' + + if additional_err: + m += f' {additional_err}' + assert x.grad_fn is not None, m + + def log( + self, + name: str, + value: Any, + prog_bar: bool = False, + logger: bool = True, + on_step: bool = False, + on_epoch: bool = True, + reduce_fx: Callable = torch.mean, + enable_graph: bool = False, + ): + # no metrics should be logged with graphs + if not enable_graph and isinstance(value, torch.Tensor): + value = value.detach() + + if 'meta' not in self: + self.__setitem__('meta', {}) + + self.__set_meta(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx) + + # set the value + self.__setitem__(name, value) + + def __set_meta( + self, + name: str, + value: Any, + prog_bar: bool, + logger: bool, + on_step: bool, + on_epoch: bool, + reduce_fx: Callable, + ): + # set the meta for the item + meta_value = value + meta = dict( + prog_bar=prog_bar, + logger=logger, + on_step=on_step, + on_epoch=on_epoch, + reduce_fx=reduce_fx, + value=meta_value + ) + self['meta'][name] = meta + + # track whether any input requires reduction on epoch end + _internal = self['meta']['_internal'] + _internal['_reduce_on_epoch'] = max(_internal['_reduce_on_epoch'], on_epoch) + + def get_callback_metrics(self) -> dict: + result = { + 'early_stop_on': self.early_stop_on, + 'checkpoint_on': self.checkpoint_on + } + + return result + + def get_batch_log_metrics(self) -> dict: + """ + Gets the metrics to log at the end of the batch step + """ + result = {} + + meta = self['meta'] + for k, options in meta.items(): + if k == '_internal': + continue + if options['logger'] and options['on_step']: + result[k] = self[k] + return result + + def get_epoch_log_metrics(self) -> dict: + """ + Gets the metrics to log at the end of the batch step + """ + result = {} + + meta = self['meta'] + for k, options in meta.items(): + if k == '_internal': + continue + if options['logger'] and options['on_epoch']: + result[k] = self[k] + return result + + def get_epoch_pbar_metrics(self): + """ + Gets the metrics to log at the end of the batch step + """ + result = {} + + meta = self['meta'] + for k, options in meta.items(): + if k == '_internal': + continue + if options['prog_bar'] and options['on_epoch']: + result[k] = self[k] + return result + + def get_batch_pbar_metrics(self): + """ + Gets the metrics to log at the end of the batch step + """ + result = {} + + meta = self['meta'] + for k, options in meta.items(): + if k == '_internal': + continue + if options['prog_bar'] and options['on_step']: + result[k] = self[k] + return result + + def detach(self): + for k, v in self.items(): + if isinstance(v, torch.Tensor): + self.__setitem__(k, v.detach()) + + def __repr__(self): + self_copy = self.copy() + + if 'meta' in self_copy: + del self_copy['meta'] + + return str(self_copy) + + def __str__(self): + copy = self.copy() + del copy['meta'] + + return str(copy) + + def __copy__(self): + newone = type(self)() + for k, v in self.items(): + newone[k] = copy(v) + return newone + + @classmethod + def gather(cls, outputs): + meta = outputs[0]['meta'] + result = cls() + result = recursive_gather(outputs, result) + recursive_stack(result) + result['meta'] = meta + return result + + @classmethod + def reduce_on_epoch_end(cls, outputs): + meta = outputs[0]['meta'] + result = cls() + result = recursive_gather(outputs, result) + recursive_stack(result) + + for k, option in meta.items(): + if k == '_internal': + continue + + if option['on_epoch']: + fx = option['reduce_fx'] + result[k] = fx(result[k]) + + result['meta'] = meta + return result + + @property + def should_reduce_on_epoch_end(self) -> bool: + return self['meta']['_internal']['_reduce_on_epoch'] + + +def recursive_gather(outputs: Sequence[dict], result: Optional[MutableMapping] = None) -> Optional[MutableMapping]: + for out in outputs: + if 'meta' in out: + del out['meta'] + + for k, v in out.items(): + if isinstance(v, dict): + v = recursive_gather([v], result) + + if k not in result: + result[k] = [] + + result[k].append(v) + + return result + + +def recursive_stack(result: MutableMapping): + for k, v in result.items(): + if isinstance(v, dict): + recursive_stack(v) + + if isinstance(v, list) and len(v) > 0 and isinstance(v[0], torch.Tensor): + v = torch.stack(v) + result[k] = v + + +class TrainResult(Result): + + def __init__( + self, + minimize: Optional[Tensor] = None, + early_stop_on: Tensor = None, + checkpoint_on: Union[Tensor, bool] = None, + hiddens: Optional[Tensor] = None, + ): + + super().__init__(minimize, early_stop_on, checkpoint_on, hiddens) + + def log( + self, + name, + value, + prog_bar: bool = False, + logger: bool = True, + on_step: bool = True, + on_epoch: bool = False, + reduce_fx: Callable = torch.mean, + enable_graph: bool = False, + ): + super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph) + + +class EvalResult(Result): + + def __init__( + self, + early_stop_on: Optional[Tensor] = None, + checkpoint_on: Optional[Tensor] = None, + hiddens: Optional[Tensor] = None, + ): + + super().__init__(None, early_stop_on, checkpoint_on, hiddens) + + def log( + self, + name, + value, + prog_bar: bool = False, + logger: bool = True, + on_step: bool = False, + on_epoch: bool = True, + reduce_fx: Callable = torch.mean, + enable_graph: bool = False, + ): + super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph) + + +# if __name__ == '__main__': +# import torch +# result = TrainResult() +# result.hiddens = torch.tensor(1) +# result.log('some', 123) +# print(result) +# result.minimize = torch.tensor(1) diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index f2a23b188e..c9c793cc89 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -6,6 +6,7 @@ import torch from torch.cuda._utils import _get_device_index from torch.nn import DataParallel from torch.nn.parallel import DistributedDataParallel +from pytorch_lightning.core.step_result import Result def _find_tensors(obj): # pragma: no-cover @@ -63,7 +64,34 @@ class LightningDataParallel(DataParallel): replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) outputs = self.parallel_apply(replicas, inputs, kwargs) - return self.gather(outputs, self.output_device) + + if isinstance(outputs[0], Result): + outputs = self.__gather_structured_result(outputs) + else: + outputs = self.gather(outputs, self.output_device) + return outputs + + def __gather_structured_result(self, outputs): + prototype_output = outputs[0] + original_class = prototype_output.__class__ + outputs = [dict(x) for x in outputs] + + # remove all the meta info + meta = outputs[0]['meta'] + for i, output in enumerate(outputs): + del output['meta'] + + outputs = self.gather(outputs, self.output_device) + + # pass minimize to constructor for TrainResult + if 'minimize' in outputs: + result = original_class(outputs['minimize']) + else: + result = original_class() + + result.update(outputs) + result['meta'] = meta + return result def parallel_apply(self, replicas, inputs, kwargs): return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) @@ -160,6 +188,8 @@ def parallel_apply(modules, inputs, kwargs_tup=None, devices=None): # pragma: n if not isinstance(input, (list, tuple)): input = (input,) + module = module.to(device) + # --------------- # CHANGE if module.training: diff --git a/pytorch_lightning/trainer/callback_hook.py b/pytorch_lightning/trainer/callback_hook.py index 50ea8bb7ce..89b5e712c9 100644 --- a/pytorch_lightning/trainer/callback_hook.py +++ b/pytorch_lightning/trainer/callback_hook.py @@ -51,6 +51,36 @@ class TrainerCallbackHookMixin(ABC): for callback in self.callbacks: callback.on_sanity_check_end(self, self.get_model()) + def on_train_epoch_start(self): + """Called when the epoch begins.""" + for callback in self.callbacks: + callback.on_train_epoch_start(self, self.get_model()) + + def on_train_epoch_end(self): + """Called when the epoch ends.""" + for callback in self.callbacks: + callback.on_train_epoch_end(self, self.get_model()) + + def on_validation_epoch_start(self): + """Called when the epoch begins.""" + for callback in self.callbacks: + callback.on_validation_epoch_start(self, self.get_model()) + + def on_validation_epoch_end(self): + """Called when the epoch ends.""" + for callback in self.callbacks: + callback.on_validation_epoch_end(self, self.get_model()) + + def on_test_epoch_start(self): + """Called when the epoch begins.""" + for callback in self.callbacks: + callback.on_test_epoch_start(self, self.get_model()) + + def on_test_epoch_end(self): + """Called when the epoch ends.""" + for callback in self.callbacks: + callback.on_test_epoch_end(self, self.get_model()) + def on_epoch_start(self): """Called when the epoch begins.""" for callback in self.callbacks: diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index 35f5d5d35b..3baed4ef9d 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -1,3 +1,4 @@ +import os from abc import ABC from typing import Union, Iterable @@ -73,6 +74,8 @@ class TrainerLoggingMixin(ABC): self.logger.agg_and_log_metrics(scalar_metrics, step=step) self.logger.save() + self.dev_debugger.track_logged_metrics_history(scalar_metrics) + def add_progress_bar_metrics(self, metrics): for k, v in metrics.items(): if isinstance(v, torch.Tensor): @@ -80,6 +83,8 @@ class TrainerLoggingMixin(ABC): self.progress_bar_metrics[k] = v + self.dev_debugger.track_pbar_metrics_history(metrics) + def metrics_to_scalars(self, metrics): new_metrics = {} for k, v in metrics.items(): diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index fcd21becfe..8853d7aaa0 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -76,3 +76,17 @@ class TensorRunningAccum(object): return getattr(self.memory, how)() else: return getattr(self.memory[:self.current_idx], how)() + + +class Accumulator(object): + def __init__(self): + self.num_values = 0 + self.total = 0 + + def accumulate(self, x): + with torch.no_grad(): + self.total += x + self.num_values += 1 + + def mean(self): + return self.total / self.num_values diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 0ad5157662..2af10878b7 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -33,6 +33,7 @@ from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin from pytorch_lightning.trainer.lr_finder import TrainerLRFinderMixin from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities import rank_zero_warn, parsing, rank_zero_info, rank_zero_only +from pytorch_lightning.utilities.debugging import InternalDebugger import warnings # warnings to ignore in trainer @@ -616,6 +617,9 @@ class Trainer( self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE') + # tracks internal state for debugging + self.dev_debugger = InternalDebugger(self) + # Callback system self.on_init_end() diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index fa493f2e1b..0caf9f22b5 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -143,7 +143,7 @@ in your model. trainer = Trainer(terminate_on_nan=True) """ - +import os import subprocess from abc import ABC, abstractmethod from typing import Callable @@ -153,17 +153,19 @@ import numpy as np import torch from torch.utils.data import DataLoader import torch.distributed as torch_distrib +from copy import copy from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.trainer.supporters import TensorRunningAccum +from pytorch_lightning.trainer.supporters import TensorRunningAccum, Accumulator from pytorch_lightning.utilities import rank_zero_warn, NATIVE_AMP_AVALAIBLE from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.parsing import AttributeDict from pytorch_lightning.utilities.memory import recursive_detach +from pytorch_lightning.core.step_result import EvalResult, TrainResult, Result try: from apex import amp @@ -251,6 +253,8 @@ class TrainerTrainLoopMixin(ABC): on_epoch_end: Callable on_validation_end: Callable on_keyboard_interrupt: Callable + on_train_epoch_start: Callable + on_train_epoch_end: Callable @abstractmethod def get_model(self) -> LightningModule: @@ -420,6 +424,15 @@ class TrainerTrainLoopMixin(ABC): if self.is_function_implemented('on_epoch_start'): model.on_epoch_start() + # Epoch start events + with self.profiler.profile('on_train_epoch_start'): + # callbacks + self.on_train_epoch_start() + + # model hooks + if self.is_function_implemented('on_train_epoch_start'): + model.on_train_epoch_start() + def run_training_epoch(self): # get model @@ -435,6 +448,10 @@ class TrainerTrainLoopMixin(ABC): epoch_output = [] should_check_val = False + # structured result accumulators for callbacks + early_stopping_accumulator = Accumulator() + checkpoint_accumulator = Accumulator() + # run epoch for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable( enumerate(_with_is_last(train_dataloader)), "get_train_batch" @@ -453,7 +470,15 @@ class TrainerTrainLoopMixin(ABC): # only track outputs when user implements training_epoch_end # otherwise we will build up unnecessary memory - if self.is_overridden('training_epoch_end', model=self.get_model()): + step_out = batch_output.training_step_output_for_epoch_end + should_auto_reduce_train_result = isinstance(step_out, Result) and step_out.should_reduce_on_epoch_end + if isinstance(step_out, dict) and 'early_stop_on' in step_out: + early_stopping_accumulator.accumulate(step_out['early_stop_on']) + + if isinstance(step_out, dict) and 'checkpoint_on' in step_out: + checkpoint_accumulator.accumulate(step_out['checkpoint_on']) + + if self.is_overridden('training_epoch_end', model=self.get_model()) or should_auto_reduce_train_result: epoch_output.append(batch_output.training_step_output_for_epoch_end) # update LR schedulers @@ -496,7 +521,7 @@ class TrainerTrainLoopMixin(ABC): self.sync_horovod() # process epoch outputs - self.run_training_epoch_end(epoch_output) + self.run_training_epoch_end(epoch_output, checkpoint_accumulator, early_stopping_accumulator) # checkpoint callback self.check_checkpoint_callback(should_check_val) @@ -525,23 +550,74 @@ class TrainerTrainLoopMixin(ABC): if self.is_function_implemented('on_epoch_end'): model.on_epoch_end() - def run_training_epoch_end(self, epoch_output): + with self.profiler.profile('on_train_epoch_end'): + # callbacks + self.on_train_epoch_end() + + # model hooks + if self.is_function_implemented('on_train_epoch_end'): + model.on_train_epoch_end() + + def run_training_epoch_end(self, epoch_output, checkpoint_accumulator, early_stopping_accumulator): model = self.get_model() + is_result_obj = len(epoch_output) > 0 and isinstance(epoch_output[0], Result) + + epoch_log_metrics = {} + epoch_callback_metrics = {} + epoch_progress_bar_metrics = {} + + # ----------------------- + # Calculate epoch callback values if given + # ----------------------- + if checkpoint_accumulator.num_values > 0: + epoch_callback_metrics['checkpoint_on'] = checkpoint_accumulator.mean() + + if early_stopping_accumulator.num_values > 0: + epoch_callback_metrics['early_stop_on'] = early_stopping_accumulator.mean() + + # -------------------------- + # EPOCH END STEP IF DEFINED + # -------------------------- if self.is_overridden('training_epoch_end', model=model): self.global_step += 1 + + # remove the protected keys so the user doesn't have to deal with them + if is_result_obj: + epoch_output = epoch_output[0].__class__.gather(epoch_output) + + # run training_epoch_end epoch_output = model.training_epoch_end(epoch_output) - _processed_outputs = self.process_output(epoch_output) - log_epoch_metrics = _processed_outputs[2] - callback_epoch_metrics = _processed_outputs[3] - # add the metrics to the loggers - self.log_metrics(log_epoch_metrics, {}) + if isinstance(epoch_output, Result): + epoch_log_metrics = epoch_output.epoch_log_metrics + epoch_progress_bar_metrics = epoch_output.epoch_pbar_metrics + else: + _processed_outputs = self.process_output(epoch_output) + epoch_progress_bar_metrics = _processed_outputs[1] + epoch_log_metrics = _processed_outputs[2] + epoch_callback_metrics = _processed_outputs[3] - # add metrics to callbacks - self.callback_metrics.update(callback_epoch_metrics) + # -------------------------- + # Structured Result (auto epoch end) + # -------------------------- + elif is_result_obj: + epoch_output = epoch_output[0].__class__.reduce_on_epoch_end(epoch_output) + epoch_output.minimize = epoch_output.minimize.mean() + epoch_log_metrics = epoch_output.epoch_log_metrics + epoch_progress_bar_metrics = epoch_output.epoch_pbar_metrics - # add metrics to progress_bar - self.add_progress_bar_metrics(_processed_outputs[1]) + # -------------------------- + # track results + # -------------------------- + # add the metrics to the loggers + if epoch_log_metrics and len(epoch_log_metrics) > 0: + self.log_metrics(epoch_log_metrics, {}) + + # add metrics to callbacks + self.callback_metrics.update(epoch_callback_metrics) + + # add metrics to progress_bar + self.add_progress_bar_metrics(epoch_progress_bar_metrics) def sync_horovod(self): if self.use_horovod: @@ -558,7 +634,10 @@ class TrainerTrainLoopMixin(ABC): should_log_metrics = batch_idx % self.row_log_interval == 0 or self.should_stop if should_log_metrics or self.fast_dev_run: # logs user requested information to logger - self.log_metrics(batch_output.batch_log_metrics, batch_output.grad_norm_dic) + metrics = batch_output.batch_log_metrics + grad_norm_dic = batch_output.grad_norm_dic + if len(metrics) > 0 or len(grad_norm_dic) > 0: + self.log_metrics(metrics, grad_norm_dic) def save_loggers_in_training_loop(self, batch_idx): # when loggers should save to disk @@ -588,6 +667,8 @@ class TrainerTrainLoopMixin(ABC): # track metrics to log batch_log_metrics = [] + using_results_obj = False + if batch is None: return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic) @@ -622,7 +703,7 @@ class TrainerTrainLoopMixin(ABC): param.requires_grad = True # ------------------- - # calculate loss + # calculate loss (train step + train step end) # ------------------- opt_closure_result = self.optimizer_closure( split_batch, @@ -631,14 +712,26 @@ class TrainerTrainLoopMixin(ABC): optimizer, self.hiddens ) + using_results_obj = isinstance(opt_closure_result.training_step_output, Result) # ------------------------------ # POST forward bookkeeping # ------------------------------ batch_callback_metrics.append(opt_closure_result.training_step_output.callback_metrics) - batch_log_metrics.append(opt_closure_result.training_step_output.log_metrics) - self.add_progress_bar_metrics(opt_closure_result.training_step_output.pbar_on_batch_end) + # add metrics to loggers + if using_results_obj: + metrics_to_log = opt_closure_result.training_step_output.batch_log_metrics + else: + metrics_to_log = opt_closure_result.training_step_output.log_metrics + batch_log_metrics.append(metrics_to_log) + + # add metrics to progress bar + if using_results_obj: + metrics_for_pbar = opt_closure_result.training_step_output.batch_pbar_metrics + else: + metrics_for_pbar = opt_closure_result.training_step_output.pbar_on_batch_end + self.add_progress_bar_metrics(metrics_for_pbar) # track hiddens self.hiddens = opt_closure_result.hiddens @@ -677,7 +770,8 @@ class TrainerTrainLoopMixin(ABC): batch_log_metrics = {k: v for d in batch_log_metrics for k, v in d.items()} # track all metrics for callbacks - self.callback_metrics.update({k: v for d in batch_callback_metrics for k, v in d.items()}) + if not using_results_obj: + self.callback_metrics.update({k: v for d in batch_callback_metrics for k, v in d.items()}) result = AttributeDict( signal=0, @@ -764,7 +858,7 @@ class TrainerTrainLoopMixin(ABC): wrap the forward step in a closure so second order methods work """ # --------------------------- - # FORWARD + # FORWARD (TRAINING STEP + TRAIN STEP END) # --------------------------- with self.profiler.profile('model_forward'): if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu: @@ -780,26 +874,38 @@ class TrainerTrainLoopMixin(ABC): # ---------------------------- # format and reduce outputs accordingly training_step_output_for_epoch_end = training_step_output - training_step_output = self.process_output(training_step_output, train=True) + is_result_obj = isinstance(training_step_output, Result) - # TODO: temporary part of structured results PR - training_step_output = AttributeDict( - batch_loss=training_step_output[0], - pbar_on_batch_end=training_step_output[1], - log_metrics=training_step_output[2], - callback_metrics=training_step_output[3], - hiddens=training_step_output[4], - ) + # don't allow EvalResult in the training_step + if isinstance(training_step_output, EvalResult): + raise MisconfigurationException('training_step cannot return EvalResult, ' + 'use a dict or TrainResult instead') + + # handle regular dicts + if not is_result_obj: + training_step_output = self.process_output(training_step_output, train=True) + + training_step_output = AttributeDict( + batch_loss=training_step_output[0], + pbar_on_batch_end=training_step_output[1], + log_metrics=training_step_output[2], + callback_metrics=training_step_output[3], + hiddens=training_step_output[4], + ) # if the user decides to finally reduce things in epoch_end, save raw output without graphs if isinstance(training_step_output_for_epoch_end, torch.Tensor): training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach() + elif is_result_obj: + training_step_output_for_epoch_end = copy(training_step_output) + training_step_output_for_epoch_end.detach() else: training_step_output_for_epoch_end = recursive_detach(training_step_output_for_epoch_end) # accumulate loss # (if accumulate_grad_batches = 1 no effect) - closure_loss = training_step_output.batch_loss / self.accumulate_grad_batches + closure_loss = training_step_output.minimize if is_result_obj else training_step_output.batch_loss + closure_loss = closure_loss / self.accumulate_grad_batches # the loss will get scaled for amp. avoid any modifications to it untouched_loss = closure_loss.detach().clone() @@ -829,7 +935,11 @@ class TrainerTrainLoopMixin(ABC): # once backward has been applied, release graph closure_loss = closure_loss.detach() - training_step_output.batch_loss = training_step_output.batch_loss.detach() + + if is_result_obj: + training_step_output.detach() + else: + training_step_output.batch_loss = training_step_output.batch_loss.detach() if self.use_horovod: # Synchronize Horovod to ensure gradient manipulations (e.g., loss scaling) are valid @@ -841,6 +951,9 @@ class TrainerTrainLoopMixin(ABC): with self.profiler.profile('on_after_backward'): model_ref.on_after_backward() + # when in dev debugging track the losses + self.dev_debugger.track_train_loss_history(batch_idx, untouched_loss.detach()) + result = AttributeDict( loss=untouched_loss, training_step_output=training_step_output, @@ -963,6 +1076,7 @@ class TrainerTrainLoopMixin(ABC): if self.is_overridden('training_step_end'): model_ref = self.get_model() with self.profiler.profile('training_step_end'): + # TODO: modify when using result obj output = model_ref.training_step_end(output) # allow any mode to define training_end diff --git a/pytorch_lightning/utilities/debugging.py b/pytorch_lightning/utilities/debugging.py new file mode 100644 index 0000000000..490356938f --- /dev/null +++ b/pytorch_lightning/utilities/debugging.py @@ -0,0 +1,54 @@ +import os + + +class InternalDebugger(object): + + def __init__(self, trainer): + + self.enabled = 'PL_DEV_DEBUG' in os.environ + self.trainer = trainer + self.logged_metrics = [] + self.pbar_added_metrics = [] + self.saved_losses = [] + self.early_stopping_history = [] + self.checkpoint_callback_history = [] + + def track_logged_metrics_history(self, scalar_metrics): + if self.enabled: + scalar_metrics['global_step'] = self.trainer.global_step + self.logged_metrics.append(scalar_metrics) + + def track_train_loss_history(self, batch_idx, loss): + if self.enabled: + loss_dict = {'batch_idx': batch_idx, 'epoch': self.trainer.current_epoch, 'loss': loss.detach()} + self.saved_losses.append(loss_dict) + + def track_pbar_metrics_history(self, metrics): + if self.enabled: + metrics['debug_epoch'] = self.trainer.current_epoch + self.pbar_added_metrics.append(metrics) + + def track_early_stopping_history(self, current): + if self.enabled: + es = self.trainer.early_stop_callback + debug_dict = { + 'epoch': self.trainer.current_epoch, + 'global_step': self.trainer.global_step, + 'rank': self.trainer.global_rank, + 'current': current, + 'best': es.best_score, + 'patience': es.wait_count + } + self.early_stopping_history.append(debug_dict) + + def track_checkpointing_history(self, filepath): + if self.enabled: + cb = self.trainer.checkpoint_callback + debug_dict = { + 'epoch': self.trainer.current_epoch, + 'global_step': self.trainer.global_step, + 'monitor': cb.monitor, + 'rank': self.trainer.global_rank, + 'filepath': filepath + } + self.checkpoint_callback_history.append(debug_dict) diff --git a/pytorch_lightning/utilities/parsing.py b/pytorch_lightning/utilities/parsing.py index 7acaea4fd2..920e14bbef 100644 --- a/pytorch_lightning/utilities/parsing.py +++ b/pytorch_lightning/utilities/parsing.py @@ -1,5 +1,6 @@ import inspect from argparse import Namespace +from typing import Dict def str_to_bool(val): @@ -93,7 +94,7 @@ def collect_init_args(frame, path_args: list, inside: bool = False) -> list: return path_args -class AttributeDict(dict): +class AttributeDict(Dict): """Extended dictionary accesisable with dot notation. >>> ad = AttributeDict({'key1': 1, 'key2': 'abc'}) diff --git a/tests/base/deterministic_model.py b/tests/base/deterministic_model.py index a4988673c6..2b892dc78e 100644 --- a/tests/base/deterministic_model.py +++ b/tests/base/deterministic_model.py @@ -2,6 +2,7 @@ import numpy as np import torch from torch import nn from torch.utils.data import Dataset, DataLoader +from pytorch_lightning import TrainResult from pytorch_lightning.core.lightning import LightningModule @@ -19,6 +20,8 @@ class DeterministicModel(LightningModule): self.validation_step_end_called = False self.validation_epoch_end_called = False + self.assert_backward = True + self.l1 = nn.Linear(2, 3, bias=False) if weights is None: weights = torch.tensor([ @@ -33,13 +36,15 @@ class DeterministicModel(LightningModule): def step(self, batch, batch_idx): x = batch - y_hat = self(x) + bs = x.size(0) + y_hat = self.l1(x) + print(x.device, self.device, self.l1.weight.device) test_hat = y_hat.cpu().detach() assert torch.all(test_hat[:, 0] == 15.0) assert torch.all(test_hat[:, 1] == 42.0) out = y_hat.sum() - assert out == (42.0 * 3) + (15.0 * 3) + assert out == (42.0 * bs) + (15.0 * bs) return out @@ -97,6 +102,105 @@ class DeterministicModel(LightningModule): prototype_loss = outputs[0] return prototype_loss + def training_step_no_default_callbacks_for_train_loop(self, batch, batch_idx): + """ + Early stop and checkpoint only on these values + """ + acc = self.step(batch, batch_idx) + result = TrainResult(minimize=acc) + assert 'early_step_on' not in result + assert 'checkpoint_on' in result + return result + + def training_step_no_callbacks_result_obj(self, batch, batch_idx): + """ + Early stop and checkpoint only on these values + """ + acc = self.step(batch, batch_idx) + result = TrainResult(minimize=acc, checkpoint_on=False) + assert 'early_step_on' not in result + assert 'checkpoint_on' not in result + return result + + def training_step_result_log_epoch_and_step_for_callbacks(self, batch, batch_idx): + """ + Early stop and checkpoint only on these values + """ + acc = self.step(batch, batch_idx) + + self.assert_backward = False + losses = [20, 19, 18, 10, 15, 14, 9, 11, 11, 20] + idx = self.current_epoch + loss = acc + losses[idx] + result = TrainResult(minimize=loss, early_stop_on=loss, checkpoint_on=loss) + return result + + def training_step_result_log_step_only(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + result = TrainResult(minimize=acc) + + # step only metrics + result.log(f'step_log_and_pbar_acc1_b{batch_idx}', torch.tensor(11).type_as(acc), prog_bar=True) + result.log(f'step_log_acc2_b{batch_idx}', torch.tensor(12).type_as(acc)) + result.log(f'step_pbar_acc3_b{batch_idx}', torch.tensor(13).type_as(acc), logger=False, prog_bar=True) + + self.training_step_called = True + return result + + def training_step_result_log_epoch_only(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + result = TrainResult(minimize=acc) + + result.log(f'epoch_log_and_pbar_acc1_e{self.current_epoch}', torch.tensor(14).type_as(acc), + on_epoch=True, prog_bar=True, on_step=False) + result.log(f'epoch_log_acc2_e{self.current_epoch}', torch.tensor(15).type_as(acc), + on_epoch=True, on_step=False) + result.log(f'epoch_pbar_acc3_e{self.current_epoch}', torch.tensor(16).type_as(acc), + on_epoch=True, logger=False, prog_bar=True, on_step=False) + + self.training_step_called = True + return result + + def training_step_result_log_epoch_and_step(self, batch, batch_idx): + acc = self.step(batch, batch_idx) + result = TrainResult(minimize=acc) + + val_1 = (5 + batch_idx) * (self.current_epoch + 1) + val_2 = (6 + batch_idx) * (self.current_epoch + 1) + val_3 = (7 + batch_idx) * (self.current_epoch + 1) + result.log(f'step_epoch_log_and_pbar_acc1', torch.tensor(val_1).type_as(acc), + on_epoch=True, prog_bar=True) + result.log(f'step_epoch_log_acc2', torch.tensor(val_2).type_as(acc), + on_epoch=True) + result.log(f'step_epoch_pbar_acc3', torch.tensor(val_3).type_as(acc), + on_epoch=True, logger=False, prog_bar=True) + + self.training_step_called = True + return result + + def training_epoch_end_return_for_log_epoch_and_step(self, result): + """ + There should be an array of scalars without graphs that are all 171 (4 of them) + """ + self.training_epoch_end_called = True + + if self.use_dp or self.use_ddp2: + pass + else: + # only saw 4 batches + assert isinstance(result, TrainResult) + + result.step_epoch_log_and_pbar_acc1 = result.step_epoch_log_and_pbar_acc1.prod() + result.step_epoch_log_acc2 = result.step_epoch_log_acc2.prod() + result.step_epoch_pbar_acc3 = result.step_epoch_pbar_acc3.prod() + result.log('epoch_end_log_acc', torch.tensor(1212).type_as(result.step_epoch_log_acc2), + logger=True, on_epoch=True) + result.log('epoch_end_pbar_acc', torch.tensor(1213).type_as(result.step_epoch_log_acc2), + logger=False, prog_bar=True, on_epoch=True) + result.log('epoch_end_log_pbar_acc', torch.tensor(1214).type_as(result.step_epoch_log_acc2), + logger=True, prog_bar=True, on_epoch=True) + return result + # -------------------------- # dictionary returns # -------------------------- @@ -231,10 +335,12 @@ class DeterministicModel(LightningModule): return torch.optim.Adam(self.parameters(), lr=0) def backward(self, trainer, loss, optimizer, optimizer_idx): - if self.trainer.precision == 16: - assert loss > 171 * 1000 - else: - assert loss == 171.0 + if self.assert_backward: + if self.trainer.precision == 16: + assert loss > 171 * 1000 + else: + assert loss == 171.0 + loss.backward() diff --git a/tests/base/model_template.py b/tests/base/model_template.py index 48851cdb08..a89769e6f4 100644 --- a/tests/base/model_template.py +++ b/tests/base/model_template.py @@ -63,6 +63,9 @@ class EvalModelTemplate( self.hidden_dim = hidden_dim self.b1 = b1 self.b2 = b2 + self.training_step_called = False + self.training_step_end_called = False + self.training_epoch_end_called = False # if you specify an example input, the summary will show input/output for each layer # TODO: to be fixed in #1773 diff --git a/tests/base/model_train_steps.py b/tests/base/model_train_steps.py index fcd020d852..6022b86478 100644 --- a/tests/base/model_train_steps.py +++ b/tests/base/model_train_steps.py @@ -1,6 +1,7 @@ import math from abc import ABC from collections import OrderedDict +from pytorch_lightning import TrainResult import torch @@ -38,3 +39,35 @@ class TrainingStepVariations(ABC): else: output /= 0 return output + + def training_step_full_loop_result_obj_dp(self, batch, batch_idx, optimizer_idx=None): + """ + Full loop flow train step (result obj + dp) + """ + x, y = batch + x = x.view(x.size(0), -1) + y_hat = self(x.to(self.device)) + loss_val = y_hat.sum() + result = TrainResult(minimize=loss_val) + result.log('train_step_metric', loss_val + 1) + self.training_step_called = True + return result + + def training_step_end_full_loop_result_obj_dp(self, result): + """ + Full loop flow train step (result obj + dp) + """ + result.minimize = result.minimize.mean() + result.checkpoint_on = result.checkpoint_on.mean() + result.train_step_metric = result.train_step_metric.mean() + result.log('train_step_end_metric', 1) + self.training_step_end_called = True + return result + + def training_epoch_end_full_loop_result_obj_dp(self, result): + """ + Full loop flow train step (result obj + dp) + """ + result.log('train_epoch_end_metric', 1, on_epoch=True) + self.training_epoch_end_called = True + return result diff --git a/tests/base/model_valid_epoch_ends.py b/tests/base/model_valid_epoch_ends.py index 5170527397..a7295aa9ca 100644 --- a/tests/base/model_valid_epoch_ends.py +++ b/tests/base/model_valid_epoch_ends.py @@ -35,7 +35,6 @@ class ValidationEpochEndVariations(ABC): Args: outputs: list of individual outputs of each validation step """ - # if returned a scalar from validation_step, outputs is a list of tensor scalars # we return just the average in this case (if we want) def _mean(res, key): diff --git a/tests/callbacks/test_model_checkpoint.py b/tests/callbacks/test_model_checkpoint.py index 1091a4cf3a..7257dc3874 100644 --- a/tests/callbacks/test_model_checkpoint.py +++ b/tests/callbacks/test_model_checkpoint.py @@ -78,11 +78,11 @@ class ModelCheckpointTestInvocations(ModelCheckpoint): self.count = 0 self.expected_count = expected_count - def _save_model(self, filepath): + def _save_model(self, filepath, trainer, pl_module): # make sure we don't save twice assert not os.path.isfile(filepath) self.count += 1 - super()._save_model(filepath) + super()._save_model(filepath, trainer, pl_module) def on_train_end(self, trainer, pl_module): super().on_train_end(trainer, pl_module) diff --git a/tests/models/test_grad_norm.py b/tests/models/test_grad_norm.py index ff627c5088..d7978965a3 100644 --- a/tests/models/test_grad_norm.py +++ b/tests/models/test_grad_norm.py @@ -1,43 +1,12 @@ import numpy as np import pytest +import os from pytorch_lightning import Trainer -from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.utilities import rank_zero_only from tests.base import EvalModelTemplate from tests.base.develop_utils import reset_seed -class OnlyMetricsListLogger(LightningLoggerBase): - def __init__(self): - super().__init__() - self.metrics = [] - - @rank_zero_only - def log_metrics(self, metrics, step): - self.metrics.append(metrics) - - @property - def experiment(self): - return 'test' - - @rank_zero_only - def log_hyperparams(self, params): - pass - - @rank_zero_only - def finalize(self, status): - pass - - @property - def name(self): - return 'name' - - @property - def version(self): - return '1' - - class ModelWithManualGradTracker(EvalModelTemplate): def __init__(self, norm_type, *args, **kwargs): super().__init__(*args, **kwargs) @@ -75,28 +44,29 @@ class ModelWithManualGradTracker(EvalModelTemplate): @pytest.mark.parametrize("norm_type", [1., 1.25, 1.5, 2, 3, 5, 10, 'inf']) def test_grad_tracking(tmpdir, norm_type, rtol=5e-3): - # rtol=5e-3 respects the 3 decmials rounding in `.grad_norms` and above + os.environ['PL_DEV_DEBUG'] = '1' + + # rtol=5e-3 respects the 3 decimals rounding in `.grad_norms` and above reset_seed() # use a custom grad tracking module and a list logger model = ModelWithManualGradTracker(norm_type) - logger = OnlyMetricsListLogger() trainer = Trainer( default_root_dir=tmpdir, max_epochs=3, - logger=logger, track_grad_norm=norm_type, row_log_interval=1, # request grad_norms every batch ) result = trainer.fit(model) assert result == 1, "Training failed" - assert len(logger.metrics) == len(model.stored_grad_norms) + logged_metrics = trainer.dev_debugger.logged_metrics + assert len(logged_metrics) == len(model.stored_grad_norms) # compare the logged metrics against tracked norms on `.backward` - for mod, log in zip(model.stored_grad_norms, logger.metrics): + for mod, log in zip(model.stored_grad_norms, logged_metrics): common = mod.keys() & log.keys() log, mod = [log[k] for k in common], [mod[k] for k in common] diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 0c4212b66f..e45040f8ce 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -589,7 +589,7 @@ def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k): with pytest.raises(FileNotFoundError): trainer.test(ckpt_path='random.ckpt') else: - ckpt_path = str(list((Path(tmpdir) / 'lightning_logs/version_0/checkpoints').iterdir())[0].absolute()) + ckpt_path = str(list((Path(tmpdir) / f'lightning_logs/version_{trainer.logger.version}/checkpoints').iterdir())[0].absolute()) trainer.test(ckpt_path=ckpt_path) assert trainer.tested_ckpt_path == ckpt_path diff --git a/tests/trainer/test_trainer_steps_result_return.py b/tests/trainer/test_trainer_steps_result_return.py new file mode 100644 index 0000000000..16353bb8b2 --- /dev/null +++ b/tests/trainer/test_trainer_steps_result_return.py @@ -0,0 +1,518 @@ +""" +Tests to ensure that the training loop works with a dict +""" +import os +import torch +from pytorch_lightning import Trainer +from tests.base.deterministic_model import DeterministicModel +from pytorch_lightning.core.step_result import Result, TrainResult, EvalResult +from tests.base import EvalModelTemplate +import pytest + + +# test with train_step_end +# add logging + row interval tests + +def test_training_step_result_log_step_only(tmpdir): + """ + Tests that only training_step can be used with TrainResult + Makes sure that things are routed to pbar, loggers and loss accordingly + + Makes sure pbar and logs happen on step only when requested + """ + # enable internal debugging actions + os.environ['PL_DEV_DEBUG'] = '1' + + model = DeterministicModel() + model.training_step = model.training_step_result_log_step_only + model.training_step_end = None + model.training_epoch_end = None + model.val_dataloader = None + + batches = 3 + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=batches, + limit_val_batches=batches, + row_log_interval=1, + max_epochs=1, + weights_summary=None, + ) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert not model.training_step_end_called + assert not model.training_epoch_end_called + + # make sure correct metrics are logged (one per batch step as requested) + assert len(trainer.dev_debugger.logged_metrics) == batches + for batch_idx, logged_metrics in enumerate(trainer.dev_debugger.logged_metrics): + assert logged_metrics[f'step_log_and_pbar_acc1_b{batch_idx}'] == 11.0 + assert logged_metrics[f'step_log_acc2_b{batch_idx}'] == 12.0 + assert f'step_pbar_acc3_b{batch_idx}' not in logged_metrics + assert len(logged_metrics) == 4 + + # make sure we are using the correct metrics for callbacks + assert trainer.callback_metrics['checkpoint_on'] == 171 + + # make sure pbar metrics are correct ang log metrics did not leak + for batch_idx in range(batches): + assert trainer.progress_bar_metrics[f'step_log_and_pbar_acc1_b{batch_idx}'] == 11 + assert trainer.progress_bar_metrics[f'step_pbar_acc3_b{batch_idx}'] == 13 + assert f'step_log_acc2_b{batch_idx}' not in trainer.progress_bar_metrics + + # make sure training outputs what is expected + for batch_idx, batch in enumerate(model.train_dataloader()): + break + + out = trainer.run_training_batch(batch, batch_idx) + assert out.signal == 0 + assert out.batch_log_metrics[f'step_log_and_pbar_acc1_b{batch_idx}'] == 11.0 + assert out.batch_log_metrics[f'step_log_acc2_b{batch_idx}'] == 12.0 + + train_step_out = out.training_step_output_for_epoch_end + assert isinstance(train_step_out, TrainResult) + + assert 'minimize' in train_step_out + assert f'step_log_and_pbar_acc1_b{batch_idx}' in train_step_out + assert f'step_log_acc2_b{batch_idx}' in train_step_out + + # make sure the optimizer closure returns the correct things + opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens) + assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3) + + +def test_training_step_result_log_epoch_only(tmpdir): + """ + Tests that only training_step can be used with TrainResult + Makes sure that things are routed to pbar, loggers and loss accordingly + + Makes sure pbar and logs happen on epoch only when requested + """ + # enable internal debugging actions + os.environ['PL_DEV_DEBUG'] = '1' + + model = DeterministicModel() + model.training_step = model.training_step_result_log_epoch_only + model.training_step_end = None + model.training_epoch_end = None + model.val_dataloader = None + + epochs = 3 + batches = 2 + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=batches, + limit_val_batches=batches, + row_log_interval=1, + max_epochs=epochs, + weights_summary=None, + ) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert not model.training_step_end_called + assert not model.training_epoch_end_called + + # make sure correct metrics are logged (one per batch step as requested) + assert len(trainer.dev_debugger.logged_metrics) == epochs + epoch_metrics = trainer.dev_debugger.logged_metrics + assert len(epoch_metrics) == epochs + for batch_idx, logged_metrics in enumerate(epoch_metrics): + assert logged_metrics[f'epoch_log_and_pbar_acc1_e{batch_idx}'] == 14.0 + assert logged_metrics[f'epoch_log_acc2_e{batch_idx}'] == 15.0 + assert f'epoch_pbar_acc3_e{batch_idx}' not in logged_metrics + assert len(logged_metrics) == 4 + + # make sure we are using the correct metrics for callbacks + assert trainer.callback_metrics['checkpoint_on'] == 171 + + # make sure pbar metrics are correct ang log metrics did not leak + for epoch_idx in range(epochs): + assert trainer.progress_bar_metrics[f'epoch_log_and_pbar_acc1_e{epoch_idx}'] == 14 + assert trainer.progress_bar_metrics[f'epoch_pbar_acc3_e{epoch_idx}'] == 16 + assert f'epoch_log_acc2_e{epoch_idx}' not in trainer.progress_bar_metrics + + # make sure training outputs what is expected + for batch_idx, batch in enumerate(model.train_dataloader()): + break + + out = trainer.run_training_batch(batch, batch_idx) + assert out.signal == 0 + assert len(out.batch_log_metrics) == 0 + + train_step_out = out.training_step_output_for_epoch_end + assert isinstance(train_step_out, TrainResult) + + assert 'minimize' in train_step_out + assert f'epoch_log_and_pbar_acc1_e{trainer.current_epoch}' in train_step_out + assert f'epoch_log_acc2_e{trainer.current_epoch}' in train_step_out + + # make sure the optimizer closure returns the correct things + opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens) + assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3) + + +def test_training_step_result_log_step_and_epoch(tmpdir): + """ + Tests that only training_step can be used with TrainResult + Makes sure that things are routed to pbar, loggers and loss accordingly + + Makes sure pbar and logs happen on epoch only when requested + """ + # enable internal debugging actions + os.environ['PL_DEV_DEBUG'] = '1' + + model = DeterministicModel() + model.training_step = model.training_step_result_log_epoch_and_step + model.training_step_end = None + model.training_epoch_end = None + model.val_dataloader = None + + epochs = 3 + batches = 2 + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=batches, + limit_val_batches=batches, + row_log_interval=1, + max_epochs=epochs, + weights_summary=None, + ) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert not model.training_step_end_called + assert not model.training_epoch_end_called + + # make sure correct metrics are logged (one per batch step as requested) + assert len(trainer.dev_debugger.logged_metrics) == (epochs * batches) + epochs + epoch_metrics = trainer.dev_debugger.logged_metrics + epoch_idx = -1 + for i_start in range(0, len(epoch_metrics), batches + 1): + epoch_idx += 1 + epoch_outputs = epoch_metrics[i_start: i_start + batches + 1] + mean_vals = { + 'step_epoch_log_and_pbar_acc1': [], + 'step_epoch_log_acc2': [] + } + + # make sure each batch logged the expected value + for batch_idx in range(len(epoch_outputs) - 1): + logged_metrics = epoch_outputs[batch_idx] + + expected_val_1 = (5 + batch_idx) * (epoch_idx + 1) + expected_val_2 = (6 + batch_idx) * (epoch_idx + 1) + mean_vals['step_epoch_log_and_pbar_acc1'].append(torch.tensor(expected_val_1).float()) + mean_vals['step_epoch_log_acc2'].append(torch.tensor(expected_val_2).float()) + assert logged_metrics['step_epoch_log_and_pbar_acc1'] == expected_val_1 + assert logged_metrics['step_epoch_log_acc2'] == expected_val_2 + assert 'step_epoch_pbar_acc3' not in logged_metrics + assert len(logged_metrics) == 4 + + # make sure the metrics for the epoch end are actual means (the default reduce fx) or all the batches + epoch_end_metrics = epoch_outputs[-1] + eval_1 = torch.stack(mean_vals['step_epoch_log_and_pbar_acc1']).mean() + eval_2 = torch.stack(mean_vals['step_epoch_log_acc2']).mean() + assert epoch_end_metrics['step_epoch_log_and_pbar_acc1'] == eval_1 + assert epoch_end_metrics['step_epoch_log_acc2'] == eval_2 + assert 'step_epoch_pbar_acc3' not in epoch_end_metrics + assert len(logged_metrics) == 4 + + # make sure we are using the correct metrics for callbacks + assert trainer.callback_metrics['checkpoint_on'] == 171 + + # ------------------------------- + # VERIFY PBAR METRICS + # ------------------------------- + # make sure pbar metrics are correct ang log metrics did not leak + all_pbar_metrics = trainer.dev_debugger.pbar_added_metrics + assert len(all_pbar_metrics) == (epochs * batches) + epochs + + epoch_idx = -1 + for i_start in range(0, len(all_pbar_metrics), batches + 1): + epoch_idx += 1 + epoch_outputs = all_pbar_metrics[i_start: i_start + batches + 1] + mean_vals = { + 'step_epoch_log_and_pbar_acc1': [], + 'step_epoch_pbar_acc3': [] + } + + # make sure each batch logged the expected value + for batch_idx in range(len(epoch_outputs) - 1): + logged_metrics = epoch_outputs[batch_idx] + + expected_val_1 = (5 + batch_idx) * (epoch_idx + 1) + expected_val_2 = (7 + batch_idx) * (epoch_idx + 1) + mean_vals['step_epoch_log_and_pbar_acc1'].append(torch.tensor(expected_val_1).float()) + mean_vals['step_epoch_pbar_acc3'].append(torch.tensor(expected_val_2).float()) + assert logged_metrics['step_epoch_log_and_pbar_acc1'] == expected_val_1 + assert logged_metrics['step_epoch_pbar_acc3'] == expected_val_2 + assert 'step_epoch_log_acc2' not in logged_metrics + assert len(logged_metrics) == 3 + + # make sure the metrics for the epoch end are actual means (the default reduce fx) or all the batches + epoch_end_metrics = epoch_outputs[-1] + eval_1 = torch.stack(mean_vals['step_epoch_log_and_pbar_acc1']).mean() + eval_2 = torch.stack(mean_vals['step_epoch_pbar_acc3']).mean() + assert epoch_end_metrics['step_epoch_log_and_pbar_acc1'] == eval_1 + assert epoch_end_metrics['step_epoch_pbar_acc3'] == eval_2 + assert 'step_epoch_log_acc2' not in epoch_end_metrics + assert len(logged_metrics) == 3 + + # ----------------------------------------- + # make sure training outputs what is expected + # ----------------------------------------- + for batch_idx, batch in enumerate(model.train_dataloader()): + break + + out = trainer.run_training_batch(batch, batch_idx) + assert out.signal == 0 + assert len(out.batch_log_metrics) == 2 + + train_step_out = out.training_step_output_for_epoch_end + assert isinstance(train_step_out, TrainResult) + + assert 'minimize' in train_step_out + assert 'step_epoch_log_and_pbar_acc1' in train_step_out + assert 'step_epoch_log_acc2' in train_step_out + + # make sure the optimizer closure returns the correct things + opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens) + assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3) + + +def test_training_step_epoch_end_result(tmpdir): + """ + Makes sure training_step and epoch_end can be used with Results (without batch_end) + """ + os.environ['PL_DEV_DEBUG'] = '1' + + model = DeterministicModel() + model.training_step = model.training_step_result_log_epoch_and_step + model.training_epoch_end = model.training_epoch_end_return_for_log_epoch_and_step + model.val_dataloader = None + + batches = 3 + epochs = 1 + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=epochs, + row_log_interval=1, + limit_train_batches=batches, + weights_summary=None, + ) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert not model.training_step_end_called + assert model.training_epoch_end_called + + # make sure correct metrics were logged + logged_metrics = trainer.dev_debugger.logged_metrics + assert len(logged_metrics) == (epochs * batches) + epochs + last_logged = logged_metrics[-1] + + assert last_logged['step_epoch_log_and_pbar_acc1'] == 210.0 + assert last_logged['step_epoch_log_acc2'] == 336.0 + assert last_logged['epoch_end_log_acc'] == 1212.0 + assert last_logged['epoch_end_log_pbar_acc'] == 1214.0 + assert 'epoch_end_pbar_acc' not in last_logged + + # make sure pbar metrics are correct + logged_pbar = trainer.dev_debugger.pbar_added_metrics + assert len(logged_pbar) == (epochs * batches) + epochs + + assert trainer.progress_bar_metrics['step_epoch_log_and_pbar_acc1'] == 210.0 + assert trainer.progress_bar_metrics['step_epoch_pbar_acc3'] == 504.0 + assert trainer.progress_bar_metrics['epoch_end_pbar_acc'] == 1213.0 + assert trainer.progress_bar_metrics['epoch_end_log_pbar_acc'] == 1214.0 + assert 'epoch_end_log_acc' not in trainer.progress_bar_metrics + assert 'log_acc2' not in trainer.progress_bar_metrics + + # make sure callback metrics didn't change + assert trainer.callback_metrics['checkpoint_on'] == 171 + + # ----------------------------------------- + # make sure training outputs what is expected + # ----------------------------------------- + for batch_idx, batch in enumerate(model.train_dataloader()): + break + + out = trainer.run_training_batch(batch, batch_idx) + assert out.signal == 0 + assert len(out.batch_log_metrics) == 2 + + train_step_out = out.training_step_output_for_epoch_end + assert isinstance(train_step_out, TrainResult) + + assert 'minimize' in train_step_out + assert 'step_epoch_log_and_pbar_acc1' in train_step_out + assert 'step_epoch_log_acc2' in train_step_out + + # make sure the optimizer closure returns the correct things + opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens) + assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3) + + +def test_no_auto_callbacks_with_train_loop_only(tmpdir): + """ + Make sure early stop + checkpoint work with only a train loop + """ + os.environ['PL_DEV_DEBUG'] = '1' + + model = DeterministicModel() + model.training_step = model.training_step_no_default_callbacks_for_train_loop + model.training_epoch_end = None + model.val_dataloader = None + + batches = 3 + epochs = 3 + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=epochs, + row_log_interval=1, + limit_train_batches=batches, + weights_summary=None, + ) + trainer.fit(model) + + all_losses = trainer.dev_debugger.saved_losses + assert len(all_losses) == batches * epochs + + assert trainer.checkpoint_callback.monitor == 'checkpoint_on' + assert trainer.early_stop_callback is None + + trainer = Trainer( + default_root_dir=tmpdir, + early_stop_callback=True, + max_epochs=epochs, + row_log_interval=1, + limit_train_batches=batches, + weights_summary=None, + ) + trainer.fit(model) + + assert trainer.early_stop_callback.monitor == 'val_loss' + + +def test_no_callbacks_with_train_loop_only(tmpdir): + """ + Make sure early stop + checkpoint work with only a train loop + """ + os.environ['PL_DEV_DEBUG'] = '1' + + model = DeterministicModel() + model.training_step = model.training_step_no_callbacks_result_obj + model.training_epoch_end = None + model.val_dataloader = None + + batches = 3 + epochs = 3 + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=epochs, + row_log_interval=1, + limit_train_batches=batches, + weights_summary=None, + ) + trainer.fit(model) + + all_losses = trainer.dev_debugger.saved_losses + assert len(all_losses) == batches * epochs + + assert trainer.early_stop_callback is None + + assert len(trainer.dev_debugger.checkpoint_callback_history) == 0 + assert len(trainer.dev_debugger.early_stopping_history) == 0 + + +def test_use_callbacks_with_train_loop_only(tmpdir): + os.environ['PL_DEV_DEBUG'] = '1' + + model = DeterministicModel() + model.training_step = model.training_step_result_log_epoch_and_step_for_callbacks + model.training_epoch_end = None + model.val_dataloader = None + + batches = 3 + epochs = 300 + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=epochs, + early_stop_callback=True, + row_log_interval=1, + limit_train_batches=batches, + weights_summary=None, + ) + trainer.fit(model) + + num_expected_epochs = 10 + + # ---------------------------------- + # VERIFY EARLY STOPPING BEHAVIOR + # ---------------------------------- + # with train loop only it happens on every epoch + early_stop_vals = trainer.dev_debugger.early_stopping_history + assert len(early_stop_vals) == num_expected_epochs + min_val = min([x['best'] for x in early_stop_vals]) + assert min_val == 171 + 9 + all_losses = trainer.dev_debugger.saved_losses + + from collections import Counter + batch_idxs = Counter([x['batch_idx'] for x in all_losses]) + for i, val in batch_idxs.items(): + assert val == num_expected_epochs + assert i in [0, 1, 2] + + # ---------------------------------- + # VERIFY CHECKPOINTING BEHAVIOR + # ---------------------------------- + ckpt_vals = trainer.dev_debugger.checkpoint_callback_history + assert len(ckpt_vals) == 5, '5 ckpts should have been saved' + for ckpt_val, expected_epoch in zip(ckpt_vals, [0, 1, 2, 3, 6]): + assert ckpt_val['epoch'] == expected_epoch + assert ckpt_val['monitor'] == 'checkpoint_on' + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +def test_full_train_loop_with_results_obj_dp(tmpdir): + os.environ['PL_DEV_DEBUG'] = '1' + + batches = 10 + epochs = 3 + + model = EvalModelTemplate() + model.validation_step = None + model.test_step = None + model.training_step = model.training_step_full_loop_result_obj_dp + model.training_step_end = model.training_step_end_full_loop_result_obj_dp + model.training_epoch_end = model.training_epoch_end_full_loop_result_obj_dp + model.val_dataloader = None + model.test_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + distributed_backend='dp', + gpus=[0, 1], + max_epochs=epochs, + early_stop_callback=True, + row_log_interval=2, + limit_train_batches=batches, + weights_summary=None, + ) + + trainer.fit(model) + + # make sure we saw all the correct keys + seen_keys = set() + for metric in trainer.dev_debugger.logged_metrics: + seen_keys.update(metric.keys()) + + assert 'train_step_metric' in seen_keys + assert 'train_step_end_metric' in seen_keys + assert 'train_epoch_end_metric' in seen_keys