Structured results (train loop only. val loop separate PR) (PR 2/5) (#2615)

* r

* r

* r

* patched optimizer closure with sr

* patched optimizer closure with sr

* patched optimizer closure with sr

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added autoreduce for train step

* added auto reduce on train

* added auto reduce on train

* added auto reduce on train

* added auto reduce on train

* added auto reduce on train

* added auto reduce on train

* added hooks

* added hooks

* added hooks

* added hooks

* added hooks

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* cache

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* Update pytorch_lightning/callbacks/early_stopping.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/callbacks/early_stopping.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/callbacks/early_stopping.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/callbacks/model_checkpoint.py

* Update pytorch_lightning/core/step_result.py

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* Apply suggestions from code review

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>

* simple

* finished tests for structured results on train epoch

* simple

* simple

* revert

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* Update tests/base/deterministic_model.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* finished tests for structured results on train epoch

* docstring typos

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* Update pytorch_lightning/core/step_result.py

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* Update pytorch_lightning/overrides/data_parallel.py

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

Co-authored-by: Jirka <jirka@pytorchlightning.ai>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
This commit is contained in:
William Falcon 2020-07-20 19:00:20 -04:00 committed by GitHub
parent 816d8cff06
commit 6d10ac2ac8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 1400 additions and 89 deletions

View File

@ -82,9 +82,9 @@ jobs:
uses: actions/cache@v1
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}-pip-${{ hashFiles('requirements/base.txt') }}-${{ hashFiles('requirements/extra.txt') }}
key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ matrix.requires }}-pip-${{ hashFiles('requirements/base.txt') }}-${{ hashFiles('requirements/extra.txt') }}
restore-keys: |
${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}-pip-
${{ runner.os }}-pip-${{ matrix.python-version }}-${{ matrix.requires }}-pip-
- name: Install dependencies
run: |

View File

@ -55,6 +55,7 @@ else:
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning import metrics
from pytorch_lightning.core.step_result import TrainResult, EvalResult
__all__ = [
'Trainer',
@ -62,7 +63,9 @@ else:
'Callback',
'data_loader',
'seed_everything',
'metrics'
'metrics',
'EvalResult',
'TrainResult'
]
# necessary for regular bolts imports. Skip exception since bolts is not always installed

View File

@ -46,6 +46,30 @@ class Callback(abc.ABC):
"""Called when the validation sanity check ends."""
pass
def on_train_epoch_start(self, trainer, pl_module):
"""Called when the train epoch begins."""
pass
def on_train_epoch_end(self, trainer, pl_module):
"""Called when the train epoch ends."""
pass
def on_validation_epoch_start(self, trainer, pl_module):
"""Called when the val epoch begins."""
pass
def on_validation_epoch_end(self, trainer, pl_module):
"""Called when the val epoch ends."""
pass
def on_test_epoch_start(self, trainer, pl_module):
"""Called when the test epoch begins."""
pass
def on_test_epoch_end(self, trainer, pl_module):
"""Called when the test epoch ends."""
pass
def on_epoch_start(self, trainer, pl_module):
"""Called when the epoch begins."""
pass

View File

@ -7,6 +7,7 @@ Monitor a validation metric and stop training when it stops improving.
"""
from copy import deepcopy
import os
import numpy as np
import torch
import torch.distributed as dist
@ -140,12 +141,33 @@ class EarlyStopping(Callback):
def on_validation_end(self, trainer, pl_module):
self._run_early_stopping_check(trainer, pl_module)
def on_train_epoch_end(self, trainer, pl_module):
# early stopping can also work in the train loop when there is no val loop and when using structured results
should_check_early_stop = False
train_es_key = 'early_stop_on'
if trainer.callback_metrics.get(train_es_key, None) is not None:
self.monitor = train_es_key
should_check_early_stop = True
val_es_key = 'val_early_stop_on'
if trainer.callback_metrics.get(val_es_key, None) is not None:
self.monitor = val_es_key
should_check_early_stop = True
if should_check_early_stop:
self._run_early_stopping_check(trainer, pl_module)
def _run_early_stopping_check(self, trainer, pl_module):
logs = trainer.callback_metrics
if not self._validate_condition_metric(logs):
return # short circuit if metric not present
current = logs.get(self.monitor)
# when in dev debugging
trainer.dev_debugger.track_early_stopping_history(current)
if not isinstance(current, torch.Tensor):
current = torch.tensor(current, device=pl_module.device)

View File

@ -159,7 +159,11 @@ class ModelCheckpoint(Callback):
if os.path.isfile(filepath):
os.remove(filepath)
def _save_model(self, filepath):
def _save_model(self, filepath, trainer, pl_module):
# in debugging, track when we save checkpoints
trainer.dev_debugger.track_checkpointing_history(filepath)
# make paths
os.makedirs(os.path.dirname(filepath), exist_ok=True)
@ -270,6 +274,11 @@ class ModelCheckpoint(Callback):
metrics = trainer.callback_metrics
epoch = trainer.current_epoch
# support structured results
if metrics.get('checkpoint_on') is not None:
self.monitor = 'checkpoint_on'
if self.save_top_k == 0:
# no models are saved
return
@ -281,7 +290,7 @@ class ModelCheckpoint(Callback):
if self.save_last:
filepath = os.path.join(self.dirpath, self.prefix + 'last.ckpt')
self._save_model(filepath)
self._save_model(filepath, trainer, pl_module)
filepath = self.format_checkpoint_name(epoch, metrics)
version_cnt = 0
@ -306,7 +315,7 @@ class ModelCheckpoint(Callback):
f'Can save best model only with {self.monitor} available, skipping.', RuntimeWarning
)
elif self.check_monitor_top_k(current):
self._do_check_save(filepath, current, epoch)
self._do_check_save(filepath, current, epoch, trainer, pl_module)
elif self.verbose > 0:
log.info(f'\nEpoch {epoch:05d}: {self.monitor} was not in top {self.save_top_k}')
@ -315,9 +324,9 @@ class ModelCheckpoint(Callback):
log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')
assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0'
self._save_model(filepath)
self._save_model(filepath, trainer, pl_module)
def _do_check_save(self, filepath, current, epoch):
def _do_check_save(self, filepath, current, epoch, trainer, pl_module):
# remove kth
del_list = []
@ -343,7 +352,7 @@ class ModelCheckpoint(Callback):
f'\nEpoch {epoch:05d}: {self.monitor} reached'
f' {current:0.5f} (best {self.best_model_score:0.5f}), saving model to'
f' {filepath} as top {self.save_top_k}')
self._save_model(filepath)
self._save_model(filepath, trainer, pl_module)
for cur_path in del_list:
if cur_path != filepath:

View File

@ -115,6 +115,42 @@ class ModelHooks(Module):
"""
# do something when the epoch ends
def on_train_epoch_start(self) -> None:
"""
Called in the training loop at the very beginning of the epoch.
"""
# do something when the epoch starts
def on_train_epoch_end(self) -> None:
"""
Called in the training loop at the very end of the epoch.
"""
# do something when the epoch ends
def on_validation_epoch_start(self) -> None:
"""
Called in the validation loop at the very beginning of the epoch.
"""
# do something when the epoch starts
def on_validation_epoch_end(self) -> None:
"""
Called in the validation loop at the very end of the epoch.
"""
# do something when the epoch ends
def on_test_epoch_start(self) -> None:
"""
Called in the test loop at the very beginning of the epoch.
"""
# do something when the epoch starts
def on_test_epoch_end(self) -> None:
"""
Called in the test loop at the very end of the epoch.
"""
# do something when the epoch ends
def on_pre_performance_check(self) -> None:
"""
Called at the very beginning of the validation loop.

View File

@ -0,0 +1,336 @@
from typing import Optional, Dict, Union, Sequence, Callable, MutableMapping, Any
from torch import Tensor
import torch
from copy import copy
class Result(Dict):
def __init__(
self,
minimize: Optional[Tensor] = None,
early_stop_on: Optional[Tensor] = None,
checkpoint_on: Union[Tensor, bool, None] = None,
hiddens: Optional[Tensor] = None,
):
super().__init__()
if early_stop_on is not None:
self.early_stop_on = early_stop_on
if checkpoint_on is not None and checkpoint_on:
self.checkpoint_on = checkpoint_on
if hiddens is not None:
self.hiddens = hiddens
if minimize is not None:
err = 'Minimize can only be used in training_step, training_step_end, training_epoch_end'
self._assert_grad_tensor_metric('minimize', minimize, err)
self.minimize = minimize
if minimize is not None and checkpoint_on is None:
self.checkpoint_on = minimize.detach()
self['meta'] = {
'_internal': {
'_reduce_on_epoch': False
}
}
def __getattr__(self, key: str) -> Any:
try:
if key == 'callback_metrics':
return self.get_callback_metrics()
elif key == 'batch_log_metrics':
return self.get_batch_log_metrics()
elif key == 'batch_pbar_metrics':
return self.get_batch_pbar_metrics()
elif key == 'epoch_log_metrics':
return self.get_epoch_log_metrics()
elif key == 'epoch_pbar_metrics':
return self.get_epoch_pbar_metrics()
else:
return self[key]
except KeyError:
return None
def __setattr__(self, key: str, val: Union[Tensor, Any]):
# ensure reserve keys are tensors and detached
if key in {'hiddens', 'checkpoint_on', 'early_stop_on'}:
self._assert_tensor_metric(key, val)
if val is not None and isinstance(val, torch.Tensor):
val = val.detach()
# ensure anything else that is a tensor is detached
elif isinstance(val, torch.Tensor) and key != 'minimize':
val = val.detach()
self[key] = val
def _assert_tensor_metric(self, name: str, potential_metric: Union[bool, Tensor, None, Any]):
if potential_metric is not None and not isinstance(potential_metric, bool):
assert isinstance(potential_metric, Tensor), f'{name} must be a torch.Tensor'
def _assert_grad_tensor_metric(self, name: str, x: Union[torch.Tensor, Any], additional_err: str = ''):
if x is not None:
assert isinstance(x, Tensor), f'{name} must be a torch.Tensor'
m = f'{name} must have a computational graph.'
if additional_err:
m += f' {additional_err}'
assert x.grad_fn is not None, m
def log(
self,
name: str,
value: Any,
prog_bar: bool = False,
logger: bool = True,
on_step: bool = False,
on_epoch: bool = True,
reduce_fx: Callable = torch.mean,
enable_graph: bool = False,
):
# no metrics should be logged with graphs
if not enable_graph and isinstance(value, torch.Tensor):
value = value.detach()
if 'meta' not in self:
self.__setitem__('meta', {})
self.__set_meta(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx)
# set the value
self.__setitem__(name, value)
def __set_meta(
self,
name: str,
value: Any,
prog_bar: bool,
logger: bool,
on_step: bool,
on_epoch: bool,
reduce_fx: Callable,
):
# set the meta for the item
meta_value = value
meta = dict(
prog_bar=prog_bar,
logger=logger,
on_step=on_step,
on_epoch=on_epoch,
reduce_fx=reduce_fx,
value=meta_value
)
self['meta'][name] = meta
# track whether any input requires reduction on epoch end
_internal = self['meta']['_internal']
_internal['_reduce_on_epoch'] = max(_internal['_reduce_on_epoch'], on_epoch)
def get_callback_metrics(self) -> dict:
result = {
'early_stop_on': self.early_stop_on,
'checkpoint_on': self.checkpoint_on
}
return result
def get_batch_log_metrics(self) -> dict:
"""
Gets the metrics to log at the end of the batch step
"""
result = {}
meta = self['meta']
for k, options in meta.items():
if k == '_internal':
continue
if options['logger'] and options['on_step']:
result[k] = self[k]
return result
def get_epoch_log_metrics(self) -> dict:
"""
Gets the metrics to log at the end of the batch step
"""
result = {}
meta = self['meta']
for k, options in meta.items():
if k == '_internal':
continue
if options['logger'] and options['on_epoch']:
result[k] = self[k]
return result
def get_epoch_pbar_metrics(self):
"""
Gets the metrics to log at the end of the batch step
"""
result = {}
meta = self['meta']
for k, options in meta.items():
if k == '_internal':
continue
if options['prog_bar'] and options['on_epoch']:
result[k] = self[k]
return result
def get_batch_pbar_metrics(self):
"""
Gets the metrics to log at the end of the batch step
"""
result = {}
meta = self['meta']
for k, options in meta.items():
if k == '_internal':
continue
if options['prog_bar'] and options['on_step']:
result[k] = self[k]
return result
def detach(self):
for k, v in self.items():
if isinstance(v, torch.Tensor):
self.__setitem__(k, v.detach())
def __repr__(self):
self_copy = self.copy()
if 'meta' in self_copy:
del self_copy['meta']
return str(self_copy)
def __str__(self):
copy = self.copy()
del copy['meta']
return str(copy)
def __copy__(self):
newone = type(self)()
for k, v in self.items():
newone[k] = copy(v)
return newone
@classmethod
def gather(cls, outputs):
meta = outputs[0]['meta']
result = cls()
result = recursive_gather(outputs, result)
recursive_stack(result)
result['meta'] = meta
return result
@classmethod
def reduce_on_epoch_end(cls, outputs):
meta = outputs[0]['meta']
result = cls()
result = recursive_gather(outputs, result)
recursive_stack(result)
for k, option in meta.items():
if k == '_internal':
continue
if option['on_epoch']:
fx = option['reduce_fx']
result[k] = fx(result[k])
result['meta'] = meta
return result
@property
def should_reduce_on_epoch_end(self) -> bool:
return self['meta']['_internal']['_reduce_on_epoch']
def recursive_gather(outputs: Sequence[dict], result: Optional[MutableMapping] = None) -> Optional[MutableMapping]:
for out in outputs:
if 'meta' in out:
del out['meta']
for k, v in out.items():
if isinstance(v, dict):
v = recursive_gather([v], result)
if k not in result:
result[k] = []
result[k].append(v)
return result
def recursive_stack(result: MutableMapping):
for k, v in result.items():
if isinstance(v, dict):
recursive_stack(v)
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], torch.Tensor):
v = torch.stack(v)
result[k] = v
class TrainResult(Result):
def __init__(
self,
minimize: Optional[Tensor] = None,
early_stop_on: Tensor = None,
checkpoint_on: Union[Tensor, bool] = None,
hiddens: Optional[Tensor] = None,
):
super().__init__(minimize, early_stop_on, checkpoint_on, hiddens)
def log(
self,
name,
value,
prog_bar: bool = False,
logger: bool = True,
on_step: bool = True,
on_epoch: bool = False,
reduce_fx: Callable = torch.mean,
enable_graph: bool = False,
):
super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph)
class EvalResult(Result):
def __init__(
self,
early_stop_on: Optional[Tensor] = None,
checkpoint_on: Optional[Tensor] = None,
hiddens: Optional[Tensor] = None,
):
super().__init__(None, early_stop_on, checkpoint_on, hiddens)
def log(
self,
name,
value,
prog_bar: bool = False,
logger: bool = True,
on_step: bool = False,
on_epoch: bool = True,
reduce_fx: Callable = torch.mean,
enable_graph: bool = False,
):
super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph)
# if __name__ == '__main__':
# import torch
# result = TrainResult()
# result.hiddens = torch.tensor(1)
# result.log('some', 123)
# print(result)
# result.minimize = torch.tensor(1)

View File

@ -6,6 +6,7 @@ import torch
from torch.cuda._utils import _get_device_index
from torch.nn import DataParallel
from torch.nn.parallel import DistributedDataParallel
from pytorch_lightning.core.step_result import Result
def _find_tensors(obj): # pragma: no-cover
@ -63,7 +64,34 @@ class LightningDataParallel(DataParallel):
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
outputs = self.parallel_apply(replicas, inputs, kwargs)
return self.gather(outputs, self.output_device)
if isinstance(outputs[0], Result):
outputs = self.__gather_structured_result(outputs)
else:
outputs = self.gather(outputs, self.output_device)
return outputs
def __gather_structured_result(self, outputs):
prototype_output = outputs[0]
original_class = prototype_output.__class__
outputs = [dict(x) for x in outputs]
# remove all the meta info
meta = outputs[0]['meta']
for i, output in enumerate(outputs):
del output['meta']
outputs = self.gather(outputs, self.output_device)
# pass minimize to constructor for TrainResult
if 'minimize' in outputs:
result = original_class(outputs['minimize'])
else:
result = original_class()
result.update(outputs)
result['meta'] = meta
return result
def parallel_apply(self, replicas, inputs, kwargs):
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
@ -160,6 +188,8 @@ def parallel_apply(modules, inputs, kwargs_tup=None, devices=None): # pragma: n
if not isinstance(input, (list, tuple)):
input = (input,)
module = module.to(device)
# ---------------
# CHANGE
if module.training:

View File

@ -51,6 +51,36 @@ class TrainerCallbackHookMixin(ABC):
for callback in self.callbacks:
callback.on_sanity_check_end(self, self.get_model())
def on_train_epoch_start(self):
"""Called when the epoch begins."""
for callback in self.callbacks:
callback.on_train_epoch_start(self, self.get_model())
def on_train_epoch_end(self):
"""Called when the epoch ends."""
for callback in self.callbacks:
callback.on_train_epoch_end(self, self.get_model())
def on_validation_epoch_start(self):
"""Called when the epoch begins."""
for callback in self.callbacks:
callback.on_validation_epoch_start(self, self.get_model())
def on_validation_epoch_end(self):
"""Called when the epoch ends."""
for callback in self.callbacks:
callback.on_validation_epoch_end(self, self.get_model())
def on_test_epoch_start(self):
"""Called when the epoch begins."""
for callback in self.callbacks:
callback.on_test_epoch_start(self, self.get_model())
def on_test_epoch_end(self):
"""Called when the epoch ends."""
for callback in self.callbacks:
callback.on_test_epoch_end(self, self.get_model())
def on_epoch_start(self):
"""Called when the epoch begins."""
for callback in self.callbacks:

View File

@ -1,3 +1,4 @@
import os
from abc import ABC
from typing import Union, Iterable
@ -73,6 +74,8 @@ class TrainerLoggingMixin(ABC):
self.logger.agg_and_log_metrics(scalar_metrics, step=step)
self.logger.save()
self.dev_debugger.track_logged_metrics_history(scalar_metrics)
def add_progress_bar_metrics(self, metrics):
for k, v in metrics.items():
if isinstance(v, torch.Tensor):
@ -80,6 +83,8 @@ class TrainerLoggingMixin(ABC):
self.progress_bar_metrics[k] = v
self.dev_debugger.track_pbar_metrics_history(metrics)
def metrics_to_scalars(self, metrics):
new_metrics = {}
for k, v in metrics.items():

View File

@ -76,3 +76,17 @@ class TensorRunningAccum(object):
return getattr(self.memory, how)()
else:
return getattr(self.memory[:self.current_idx], how)()
class Accumulator(object):
def __init__(self):
self.num_values = 0
self.total = 0
def accumulate(self, x):
with torch.no_grad():
self.total += x
self.num_values += 1
def mean(self):
return self.total / self.num_values

View File

@ -33,6 +33,7 @@ from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.trainer.lr_finder import TrainerLRFinderMixin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities import rank_zero_warn, parsing, rank_zero_info, rank_zero_only
from pytorch_lightning.utilities.debugging import InternalDebugger
import warnings
# warnings to ignore in trainer
@ -616,6 +617,9 @@ class Trainer(
self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE')
# tracks internal state for debugging
self.dev_debugger = InternalDebugger(self)
# Callback system
self.on_init_end()

View File

@ -143,7 +143,7 @@ in your model.
trainer = Trainer(terminate_on_nan=True)
"""
import os
import subprocess
from abc import ABC, abstractmethod
from typing import Callable
@ -153,17 +153,19 @@ import numpy as np
import torch
from torch.utils.data import DataLoader
import torch.distributed as torch_distrib
from copy import copy
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.trainer.supporters import TensorRunningAccum, Accumulator
from pytorch_lightning.utilities import rank_zero_warn, NATIVE_AMP_AVALAIBLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.core.step_result import EvalResult, TrainResult, Result
try:
from apex import amp
@ -251,6 +253,8 @@ class TrainerTrainLoopMixin(ABC):
on_epoch_end: Callable
on_validation_end: Callable
on_keyboard_interrupt: Callable
on_train_epoch_start: Callable
on_train_epoch_end: Callable
@abstractmethod
def get_model(self) -> LightningModule:
@ -420,6 +424,15 @@ class TrainerTrainLoopMixin(ABC):
if self.is_function_implemented('on_epoch_start'):
model.on_epoch_start()
# Epoch start events
with self.profiler.profile('on_train_epoch_start'):
# callbacks
self.on_train_epoch_start()
# model hooks
if self.is_function_implemented('on_train_epoch_start'):
model.on_train_epoch_start()
def run_training_epoch(self):
# get model
@ -435,6 +448,10 @@ class TrainerTrainLoopMixin(ABC):
epoch_output = []
should_check_val = False
# structured result accumulators for callbacks
early_stopping_accumulator = Accumulator()
checkpoint_accumulator = Accumulator()
# run epoch
for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable(
enumerate(_with_is_last(train_dataloader)), "get_train_batch"
@ -453,7 +470,15 @@ class TrainerTrainLoopMixin(ABC):
# only track outputs when user implements training_epoch_end
# otherwise we will build up unnecessary memory
if self.is_overridden('training_epoch_end', model=self.get_model()):
step_out = batch_output.training_step_output_for_epoch_end
should_auto_reduce_train_result = isinstance(step_out, Result) and step_out.should_reduce_on_epoch_end
if isinstance(step_out, dict) and 'early_stop_on' in step_out:
early_stopping_accumulator.accumulate(step_out['early_stop_on'])
if isinstance(step_out, dict) and 'checkpoint_on' in step_out:
checkpoint_accumulator.accumulate(step_out['checkpoint_on'])
if self.is_overridden('training_epoch_end', model=self.get_model()) or should_auto_reduce_train_result:
epoch_output.append(batch_output.training_step_output_for_epoch_end)
# update LR schedulers
@ -496,7 +521,7 @@ class TrainerTrainLoopMixin(ABC):
self.sync_horovod()
# process epoch outputs
self.run_training_epoch_end(epoch_output)
self.run_training_epoch_end(epoch_output, checkpoint_accumulator, early_stopping_accumulator)
# checkpoint callback
self.check_checkpoint_callback(should_check_val)
@ -525,23 +550,74 @@ class TrainerTrainLoopMixin(ABC):
if self.is_function_implemented('on_epoch_end'):
model.on_epoch_end()
def run_training_epoch_end(self, epoch_output):
with self.profiler.profile('on_train_epoch_end'):
# callbacks
self.on_train_epoch_end()
# model hooks
if self.is_function_implemented('on_train_epoch_end'):
model.on_train_epoch_end()
def run_training_epoch_end(self, epoch_output, checkpoint_accumulator, early_stopping_accumulator):
model = self.get_model()
is_result_obj = len(epoch_output) > 0 and isinstance(epoch_output[0], Result)
epoch_log_metrics = {}
epoch_callback_metrics = {}
epoch_progress_bar_metrics = {}
# -----------------------
# Calculate epoch callback values if given
# -----------------------
if checkpoint_accumulator.num_values > 0:
epoch_callback_metrics['checkpoint_on'] = checkpoint_accumulator.mean()
if early_stopping_accumulator.num_values > 0:
epoch_callback_metrics['early_stop_on'] = early_stopping_accumulator.mean()
# --------------------------
# EPOCH END STEP IF DEFINED
# --------------------------
if self.is_overridden('training_epoch_end', model=model):
self.global_step += 1
epoch_output = model.training_epoch_end(epoch_output)
_processed_outputs = self.process_output(epoch_output)
log_epoch_metrics = _processed_outputs[2]
callback_epoch_metrics = _processed_outputs[3]
# remove the protected keys so the user doesn't have to deal with them
if is_result_obj:
epoch_output = epoch_output[0].__class__.gather(epoch_output)
# run training_epoch_end
epoch_output = model.training_epoch_end(epoch_output)
if isinstance(epoch_output, Result):
epoch_log_metrics = epoch_output.epoch_log_metrics
epoch_progress_bar_metrics = epoch_output.epoch_pbar_metrics
else:
_processed_outputs = self.process_output(epoch_output)
epoch_progress_bar_metrics = _processed_outputs[1]
epoch_log_metrics = _processed_outputs[2]
epoch_callback_metrics = _processed_outputs[3]
# --------------------------
# Structured Result (auto epoch end)
# --------------------------
elif is_result_obj:
epoch_output = epoch_output[0].__class__.reduce_on_epoch_end(epoch_output)
epoch_output.minimize = epoch_output.minimize.mean()
epoch_log_metrics = epoch_output.epoch_log_metrics
epoch_progress_bar_metrics = epoch_output.epoch_pbar_metrics
# --------------------------
# track results
# --------------------------
# add the metrics to the loggers
self.log_metrics(log_epoch_metrics, {})
if epoch_log_metrics and len(epoch_log_metrics) > 0:
self.log_metrics(epoch_log_metrics, {})
# add metrics to callbacks
self.callback_metrics.update(callback_epoch_metrics)
self.callback_metrics.update(epoch_callback_metrics)
# add metrics to progress_bar
self.add_progress_bar_metrics(_processed_outputs[1])
self.add_progress_bar_metrics(epoch_progress_bar_metrics)
def sync_horovod(self):
if self.use_horovod:
@ -558,7 +634,10 @@ class TrainerTrainLoopMixin(ABC):
should_log_metrics = batch_idx % self.row_log_interval == 0 or self.should_stop
if should_log_metrics or self.fast_dev_run:
# logs user requested information to logger
self.log_metrics(batch_output.batch_log_metrics, batch_output.grad_norm_dic)
metrics = batch_output.batch_log_metrics
grad_norm_dic = batch_output.grad_norm_dic
if len(metrics) > 0 or len(grad_norm_dic) > 0:
self.log_metrics(metrics, grad_norm_dic)
def save_loggers_in_training_loop(self, batch_idx):
# when loggers should save to disk
@ -588,6 +667,8 @@ class TrainerTrainLoopMixin(ABC):
# track metrics to log
batch_log_metrics = []
using_results_obj = False
if batch is None:
return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic)
@ -622,7 +703,7 @@ class TrainerTrainLoopMixin(ABC):
param.requires_grad = True
# -------------------
# calculate loss
# calculate loss (train step + train step end)
# -------------------
opt_closure_result = self.optimizer_closure(
split_batch,
@ -631,14 +712,26 @@ class TrainerTrainLoopMixin(ABC):
optimizer,
self.hiddens
)
using_results_obj = isinstance(opt_closure_result.training_step_output, Result)
# ------------------------------
# POST forward bookkeeping
# ------------------------------
batch_callback_metrics.append(opt_closure_result.training_step_output.callback_metrics)
batch_log_metrics.append(opt_closure_result.training_step_output.log_metrics)
self.add_progress_bar_metrics(opt_closure_result.training_step_output.pbar_on_batch_end)
# add metrics to loggers
if using_results_obj:
metrics_to_log = opt_closure_result.training_step_output.batch_log_metrics
else:
metrics_to_log = opt_closure_result.training_step_output.log_metrics
batch_log_metrics.append(metrics_to_log)
# add metrics to progress bar
if using_results_obj:
metrics_for_pbar = opt_closure_result.training_step_output.batch_pbar_metrics
else:
metrics_for_pbar = opt_closure_result.training_step_output.pbar_on_batch_end
self.add_progress_bar_metrics(metrics_for_pbar)
# track hiddens
self.hiddens = opt_closure_result.hiddens
@ -677,6 +770,7 @@ class TrainerTrainLoopMixin(ABC):
batch_log_metrics = {k: v for d in batch_log_metrics for k, v in d.items()}
# track all metrics for callbacks
if not using_results_obj:
self.callback_metrics.update({k: v for d in batch_callback_metrics for k, v in d.items()})
result = AttributeDict(
@ -764,7 +858,7 @@ class TrainerTrainLoopMixin(ABC):
wrap the forward step in a closure so second order methods work
"""
# ---------------------------
# FORWARD
# FORWARD (TRAINING STEP + TRAIN STEP END)
# ---------------------------
with self.profiler.profile('model_forward'):
if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu:
@ -780,9 +874,17 @@ class TrainerTrainLoopMixin(ABC):
# ----------------------------
# format and reduce outputs accordingly
training_step_output_for_epoch_end = training_step_output
is_result_obj = isinstance(training_step_output, Result)
# don't allow EvalResult in the training_step
if isinstance(training_step_output, EvalResult):
raise MisconfigurationException('training_step cannot return EvalResult, '
'use a dict or TrainResult instead')
# handle regular dicts
if not is_result_obj:
training_step_output = self.process_output(training_step_output, train=True)
# TODO: temporary part of structured results PR
training_step_output = AttributeDict(
batch_loss=training_step_output[0],
pbar_on_batch_end=training_step_output[1],
@ -794,12 +896,16 @@ class TrainerTrainLoopMixin(ABC):
# if the user decides to finally reduce things in epoch_end, save raw output without graphs
if isinstance(training_step_output_for_epoch_end, torch.Tensor):
training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach()
elif is_result_obj:
training_step_output_for_epoch_end = copy(training_step_output)
training_step_output_for_epoch_end.detach()
else:
training_step_output_for_epoch_end = recursive_detach(training_step_output_for_epoch_end)
# accumulate loss
# (if accumulate_grad_batches = 1 no effect)
closure_loss = training_step_output.batch_loss / self.accumulate_grad_batches
closure_loss = training_step_output.minimize if is_result_obj else training_step_output.batch_loss
closure_loss = closure_loss / self.accumulate_grad_batches
# the loss will get scaled for amp. avoid any modifications to it
untouched_loss = closure_loss.detach().clone()
@ -829,6 +935,10 @@ class TrainerTrainLoopMixin(ABC):
# once backward has been applied, release graph
closure_loss = closure_loss.detach()
if is_result_obj:
training_step_output.detach()
else:
training_step_output.batch_loss = training_step_output.batch_loss.detach()
if self.use_horovod:
@ -841,6 +951,9 @@ class TrainerTrainLoopMixin(ABC):
with self.profiler.profile('on_after_backward'):
model_ref.on_after_backward()
# when in dev debugging track the losses
self.dev_debugger.track_train_loss_history(batch_idx, untouched_loss.detach())
result = AttributeDict(
loss=untouched_loss,
training_step_output=training_step_output,
@ -963,6 +1076,7 @@ class TrainerTrainLoopMixin(ABC):
if self.is_overridden('training_step_end'):
model_ref = self.get_model()
with self.profiler.profile('training_step_end'):
# TODO: modify when using result obj
output = model_ref.training_step_end(output)
# allow any mode to define training_end

View File

@ -0,0 +1,54 @@
import os
class InternalDebugger(object):
def __init__(self, trainer):
self.enabled = 'PL_DEV_DEBUG' in os.environ
self.trainer = trainer
self.logged_metrics = []
self.pbar_added_metrics = []
self.saved_losses = []
self.early_stopping_history = []
self.checkpoint_callback_history = []
def track_logged_metrics_history(self, scalar_metrics):
if self.enabled:
scalar_metrics['global_step'] = self.trainer.global_step
self.logged_metrics.append(scalar_metrics)
def track_train_loss_history(self, batch_idx, loss):
if self.enabled:
loss_dict = {'batch_idx': batch_idx, 'epoch': self.trainer.current_epoch, 'loss': loss.detach()}
self.saved_losses.append(loss_dict)
def track_pbar_metrics_history(self, metrics):
if self.enabled:
metrics['debug_epoch'] = self.trainer.current_epoch
self.pbar_added_metrics.append(metrics)
def track_early_stopping_history(self, current):
if self.enabled:
es = self.trainer.early_stop_callback
debug_dict = {
'epoch': self.trainer.current_epoch,
'global_step': self.trainer.global_step,
'rank': self.trainer.global_rank,
'current': current,
'best': es.best_score,
'patience': es.wait_count
}
self.early_stopping_history.append(debug_dict)
def track_checkpointing_history(self, filepath):
if self.enabled:
cb = self.trainer.checkpoint_callback
debug_dict = {
'epoch': self.trainer.current_epoch,
'global_step': self.trainer.global_step,
'monitor': cb.monitor,
'rank': self.trainer.global_rank,
'filepath': filepath
}
self.checkpoint_callback_history.append(debug_dict)

View File

@ -1,5 +1,6 @@
import inspect
from argparse import Namespace
from typing import Dict
def str_to_bool(val):
@ -93,7 +94,7 @@ def collect_init_args(frame, path_args: list, inside: bool = False) -> list:
return path_args
class AttributeDict(dict):
class AttributeDict(Dict):
"""Extended dictionary accesisable with dot notation.
>>> ad = AttributeDict({'key1': 1, 'key2': 'abc'})

View File

@ -2,6 +2,7 @@ import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import TrainResult
from pytorch_lightning.core.lightning import LightningModule
@ -19,6 +20,8 @@ class DeterministicModel(LightningModule):
self.validation_step_end_called = False
self.validation_epoch_end_called = False
self.assert_backward = True
self.l1 = nn.Linear(2, 3, bias=False)
if weights is None:
weights = torch.tensor([
@ -33,13 +36,15 @@ class DeterministicModel(LightningModule):
def step(self, batch, batch_idx):
x = batch
y_hat = self(x)
bs = x.size(0)
y_hat = self.l1(x)
print(x.device, self.device, self.l1.weight.device)
test_hat = y_hat.cpu().detach()
assert torch.all(test_hat[:, 0] == 15.0)
assert torch.all(test_hat[:, 1] == 42.0)
out = y_hat.sum()
assert out == (42.0 * 3) + (15.0 * 3)
assert out == (42.0 * bs) + (15.0 * bs)
return out
@ -97,6 +102,105 @@ class DeterministicModel(LightningModule):
prototype_loss = outputs[0]
return prototype_loss
def training_step_no_default_callbacks_for_train_loop(self, batch, batch_idx):
"""
Early stop and checkpoint only on these values
"""
acc = self.step(batch, batch_idx)
result = TrainResult(minimize=acc)
assert 'early_step_on' not in result
assert 'checkpoint_on' in result
return result
def training_step_no_callbacks_result_obj(self, batch, batch_idx):
"""
Early stop and checkpoint only on these values
"""
acc = self.step(batch, batch_idx)
result = TrainResult(minimize=acc, checkpoint_on=False)
assert 'early_step_on' not in result
assert 'checkpoint_on' not in result
return result
def training_step_result_log_epoch_and_step_for_callbacks(self, batch, batch_idx):
"""
Early stop and checkpoint only on these values
"""
acc = self.step(batch, batch_idx)
self.assert_backward = False
losses = [20, 19, 18, 10, 15, 14, 9, 11, 11, 20]
idx = self.current_epoch
loss = acc + losses[idx]
result = TrainResult(minimize=loss, early_stop_on=loss, checkpoint_on=loss)
return result
def training_step_result_log_step_only(self, batch, batch_idx):
acc = self.step(batch, batch_idx)
result = TrainResult(minimize=acc)
# step only metrics
result.log(f'step_log_and_pbar_acc1_b{batch_idx}', torch.tensor(11).type_as(acc), prog_bar=True)
result.log(f'step_log_acc2_b{batch_idx}', torch.tensor(12).type_as(acc))
result.log(f'step_pbar_acc3_b{batch_idx}', torch.tensor(13).type_as(acc), logger=False, prog_bar=True)
self.training_step_called = True
return result
def training_step_result_log_epoch_only(self, batch, batch_idx):
acc = self.step(batch, batch_idx)
result = TrainResult(minimize=acc)
result.log(f'epoch_log_and_pbar_acc1_e{self.current_epoch}', torch.tensor(14).type_as(acc),
on_epoch=True, prog_bar=True, on_step=False)
result.log(f'epoch_log_acc2_e{self.current_epoch}', torch.tensor(15).type_as(acc),
on_epoch=True, on_step=False)
result.log(f'epoch_pbar_acc3_e{self.current_epoch}', torch.tensor(16).type_as(acc),
on_epoch=True, logger=False, prog_bar=True, on_step=False)
self.training_step_called = True
return result
def training_step_result_log_epoch_and_step(self, batch, batch_idx):
acc = self.step(batch, batch_idx)
result = TrainResult(minimize=acc)
val_1 = (5 + batch_idx) * (self.current_epoch + 1)
val_2 = (6 + batch_idx) * (self.current_epoch + 1)
val_3 = (7 + batch_idx) * (self.current_epoch + 1)
result.log(f'step_epoch_log_and_pbar_acc1', torch.tensor(val_1).type_as(acc),
on_epoch=True, prog_bar=True)
result.log(f'step_epoch_log_acc2', torch.tensor(val_2).type_as(acc),
on_epoch=True)
result.log(f'step_epoch_pbar_acc3', torch.tensor(val_3).type_as(acc),
on_epoch=True, logger=False, prog_bar=True)
self.training_step_called = True
return result
def training_epoch_end_return_for_log_epoch_and_step(self, result):
"""
There should be an array of scalars without graphs that are all 171 (4 of them)
"""
self.training_epoch_end_called = True
if self.use_dp or self.use_ddp2:
pass
else:
# only saw 4 batches
assert isinstance(result, TrainResult)
result.step_epoch_log_and_pbar_acc1 = result.step_epoch_log_and_pbar_acc1.prod()
result.step_epoch_log_acc2 = result.step_epoch_log_acc2.prod()
result.step_epoch_pbar_acc3 = result.step_epoch_pbar_acc3.prod()
result.log('epoch_end_log_acc', torch.tensor(1212).type_as(result.step_epoch_log_acc2),
logger=True, on_epoch=True)
result.log('epoch_end_pbar_acc', torch.tensor(1213).type_as(result.step_epoch_log_acc2),
logger=False, prog_bar=True, on_epoch=True)
result.log('epoch_end_log_pbar_acc', torch.tensor(1214).type_as(result.step_epoch_log_acc2),
logger=True, prog_bar=True, on_epoch=True)
return result
# --------------------------
# dictionary returns
# --------------------------
@ -231,10 +335,12 @@ class DeterministicModel(LightningModule):
return torch.optim.Adam(self.parameters(), lr=0)
def backward(self, trainer, loss, optimizer, optimizer_idx):
if self.assert_backward:
if self.trainer.precision == 16:
assert loss > 171 * 1000
else:
assert loss == 171.0
loss.backward()

View File

@ -63,6 +63,9 @@ class EvalModelTemplate(
self.hidden_dim = hidden_dim
self.b1 = b1
self.b2 = b2
self.training_step_called = False
self.training_step_end_called = False
self.training_epoch_end_called = False
# if you specify an example input, the summary will show input/output for each layer
# TODO: to be fixed in #1773

View File

@ -1,6 +1,7 @@
import math
from abc import ABC
from collections import OrderedDict
from pytorch_lightning import TrainResult
import torch
@ -38,3 +39,35 @@ class TrainingStepVariations(ABC):
else:
output /= 0
return output
def training_step_full_loop_result_obj_dp(self, batch, batch_idx, optimizer_idx=None):
"""
Full loop flow train step (result obj + dp)
"""
x, y = batch
x = x.view(x.size(0), -1)
y_hat = self(x.to(self.device))
loss_val = y_hat.sum()
result = TrainResult(minimize=loss_val)
result.log('train_step_metric', loss_val + 1)
self.training_step_called = True
return result
def training_step_end_full_loop_result_obj_dp(self, result):
"""
Full loop flow train step (result obj + dp)
"""
result.minimize = result.minimize.mean()
result.checkpoint_on = result.checkpoint_on.mean()
result.train_step_metric = result.train_step_metric.mean()
result.log('train_step_end_metric', 1)
self.training_step_end_called = True
return result
def training_epoch_end_full_loop_result_obj_dp(self, result):
"""
Full loop flow train step (result obj + dp)
"""
result.log('train_epoch_end_metric', 1, on_epoch=True)
self.training_epoch_end_called = True
return result

View File

@ -35,7 +35,6 @@ class ValidationEpochEndVariations(ABC):
Args:
outputs: list of individual outputs of each validation step
"""
# if returned a scalar from validation_step, outputs is a list of tensor scalars
# we return just the average in this case (if we want)
def _mean(res, key):

View File

@ -78,11 +78,11 @@ class ModelCheckpointTestInvocations(ModelCheckpoint):
self.count = 0
self.expected_count = expected_count
def _save_model(self, filepath):
def _save_model(self, filepath, trainer, pl_module):
# make sure we don't save twice
assert not os.path.isfile(filepath)
self.count += 1
super()._save_model(filepath)
super()._save_model(filepath, trainer, pl_module)
def on_train_end(self, trainer, pl_module):
super().on_train_end(trainer, pl_module)

View File

@ -1,43 +1,12 @@
import numpy as np
import pytest
import os
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.utilities import rank_zero_only
from tests.base import EvalModelTemplate
from tests.base.develop_utils import reset_seed
class OnlyMetricsListLogger(LightningLoggerBase):
def __init__(self):
super().__init__()
self.metrics = []
@rank_zero_only
def log_metrics(self, metrics, step):
self.metrics.append(metrics)
@property
def experiment(self):
return 'test'
@rank_zero_only
def log_hyperparams(self, params):
pass
@rank_zero_only
def finalize(self, status):
pass
@property
def name(self):
return 'name'
@property
def version(self):
return '1'
class ModelWithManualGradTracker(EvalModelTemplate):
def __init__(self, norm_type, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -75,28 +44,29 @@ class ModelWithManualGradTracker(EvalModelTemplate):
@pytest.mark.parametrize("norm_type", [1., 1.25, 1.5, 2, 3, 5, 10, 'inf'])
def test_grad_tracking(tmpdir, norm_type, rtol=5e-3):
# rtol=5e-3 respects the 3 decmials rounding in `.grad_norms` and above
os.environ['PL_DEV_DEBUG'] = '1'
# rtol=5e-3 respects the 3 decimals rounding in `.grad_norms` and above
reset_seed()
# use a custom grad tracking module and a list logger
model = ModelWithManualGradTracker(norm_type)
logger = OnlyMetricsListLogger()
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=3,
logger=logger,
track_grad_norm=norm_type,
row_log_interval=1, # request grad_norms every batch
)
result = trainer.fit(model)
assert result == 1, "Training failed"
assert len(logger.metrics) == len(model.stored_grad_norms)
logged_metrics = trainer.dev_debugger.logged_metrics
assert len(logged_metrics) == len(model.stored_grad_norms)
# compare the logged metrics against tracked norms on `.backward`
for mod, log in zip(model.stored_grad_norms, logger.metrics):
for mod, log in zip(model.stored_grad_norms, logged_metrics):
common = mod.keys() & log.keys()
log, mod = [log[k] for k in common], [mod[k] for k in common]

View File

@ -589,7 +589,7 @@ def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k):
with pytest.raises(FileNotFoundError):
trainer.test(ckpt_path='random.ckpt')
else:
ckpt_path = str(list((Path(tmpdir) / 'lightning_logs/version_0/checkpoints').iterdir())[0].absolute())
ckpt_path = str(list((Path(tmpdir) / f'lightning_logs/version_{trainer.logger.version}/checkpoints').iterdir())[0].absolute())
trainer.test(ckpt_path=ckpt_path)
assert trainer.tested_ckpt_path == ckpt_path

View File

@ -0,0 +1,518 @@
"""
Tests to ensure that the training loop works with a dict
"""
import os
import torch
from pytorch_lightning import Trainer
from tests.base.deterministic_model import DeterministicModel
from pytorch_lightning.core.step_result import Result, TrainResult, EvalResult
from tests.base import EvalModelTemplate
import pytest
# test with train_step_end
# add logging + row interval tests
def test_training_step_result_log_step_only(tmpdir):
"""
Tests that only training_step can be used with TrainResult
Makes sure that things are routed to pbar, loggers and loss accordingly
Makes sure pbar and logs happen on step only when requested
"""
# enable internal debugging actions
os.environ['PL_DEV_DEBUG'] = '1'
model = DeterministicModel()
model.training_step = model.training_step_result_log_step_only
model.training_step_end = None
model.training_epoch_end = None
model.val_dataloader = None
batches = 3
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=batches,
limit_val_batches=batches,
row_log_interval=1,
max_epochs=1,
weights_summary=None,
)
trainer.fit(model)
# make sure correct steps were called
assert model.training_step_called
assert not model.training_step_end_called
assert not model.training_epoch_end_called
# make sure correct metrics are logged (one per batch step as requested)
assert len(trainer.dev_debugger.logged_metrics) == batches
for batch_idx, logged_metrics in enumerate(trainer.dev_debugger.logged_metrics):
assert logged_metrics[f'step_log_and_pbar_acc1_b{batch_idx}'] == 11.0
assert logged_metrics[f'step_log_acc2_b{batch_idx}'] == 12.0
assert f'step_pbar_acc3_b{batch_idx}' not in logged_metrics
assert len(logged_metrics) == 4
# make sure we are using the correct metrics for callbacks
assert trainer.callback_metrics['checkpoint_on'] == 171
# make sure pbar metrics are correct ang log metrics did not leak
for batch_idx in range(batches):
assert trainer.progress_bar_metrics[f'step_log_and_pbar_acc1_b{batch_idx}'] == 11
assert trainer.progress_bar_metrics[f'step_pbar_acc3_b{batch_idx}'] == 13
assert f'step_log_acc2_b{batch_idx}' not in trainer.progress_bar_metrics
# make sure training outputs what is expected
for batch_idx, batch in enumerate(model.train_dataloader()):
break
out = trainer.run_training_batch(batch, batch_idx)
assert out.signal == 0
assert out.batch_log_metrics[f'step_log_and_pbar_acc1_b{batch_idx}'] == 11.0
assert out.batch_log_metrics[f'step_log_acc2_b{batch_idx}'] == 12.0
train_step_out = out.training_step_output_for_epoch_end
assert isinstance(train_step_out, TrainResult)
assert 'minimize' in train_step_out
assert f'step_log_and_pbar_acc1_b{batch_idx}' in train_step_out
assert f'step_log_acc2_b{batch_idx}' in train_step_out
# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3)
def test_training_step_result_log_epoch_only(tmpdir):
"""
Tests that only training_step can be used with TrainResult
Makes sure that things are routed to pbar, loggers and loss accordingly
Makes sure pbar and logs happen on epoch only when requested
"""
# enable internal debugging actions
os.environ['PL_DEV_DEBUG'] = '1'
model = DeterministicModel()
model.training_step = model.training_step_result_log_epoch_only
model.training_step_end = None
model.training_epoch_end = None
model.val_dataloader = None
epochs = 3
batches = 2
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=batches,
limit_val_batches=batches,
row_log_interval=1,
max_epochs=epochs,
weights_summary=None,
)
trainer.fit(model)
# make sure correct steps were called
assert model.training_step_called
assert not model.training_step_end_called
assert not model.training_epoch_end_called
# make sure correct metrics are logged (one per batch step as requested)
assert len(trainer.dev_debugger.logged_metrics) == epochs
epoch_metrics = trainer.dev_debugger.logged_metrics
assert len(epoch_metrics) == epochs
for batch_idx, logged_metrics in enumerate(epoch_metrics):
assert logged_metrics[f'epoch_log_and_pbar_acc1_e{batch_idx}'] == 14.0
assert logged_metrics[f'epoch_log_acc2_e{batch_idx}'] == 15.0
assert f'epoch_pbar_acc3_e{batch_idx}' not in logged_metrics
assert len(logged_metrics) == 4
# make sure we are using the correct metrics for callbacks
assert trainer.callback_metrics['checkpoint_on'] == 171
# make sure pbar metrics are correct ang log metrics did not leak
for epoch_idx in range(epochs):
assert trainer.progress_bar_metrics[f'epoch_log_and_pbar_acc1_e{epoch_idx}'] == 14
assert trainer.progress_bar_metrics[f'epoch_pbar_acc3_e{epoch_idx}'] == 16
assert f'epoch_log_acc2_e{epoch_idx}' not in trainer.progress_bar_metrics
# make sure training outputs what is expected
for batch_idx, batch in enumerate(model.train_dataloader()):
break
out = trainer.run_training_batch(batch, batch_idx)
assert out.signal == 0
assert len(out.batch_log_metrics) == 0
train_step_out = out.training_step_output_for_epoch_end
assert isinstance(train_step_out, TrainResult)
assert 'minimize' in train_step_out
assert f'epoch_log_and_pbar_acc1_e{trainer.current_epoch}' in train_step_out
assert f'epoch_log_acc2_e{trainer.current_epoch}' in train_step_out
# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3)
def test_training_step_result_log_step_and_epoch(tmpdir):
"""
Tests that only training_step can be used with TrainResult
Makes sure that things are routed to pbar, loggers and loss accordingly
Makes sure pbar and logs happen on epoch only when requested
"""
# enable internal debugging actions
os.environ['PL_DEV_DEBUG'] = '1'
model = DeterministicModel()
model.training_step = model.training_step_result_log_epoch_and_step
model.training_step_end = None
model.training_epoch_end = None
model.val_dataloader = None
epochs = 3
batches = 2
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=batches,
limit_val_batches=batches,
row_log_interval=1,
max_epochs=epochs,
weights_summary=None,
)
trainer.fit(model)
# make sure correct steps were called
assert model.training_step_called
assert not model.training_step_end_called
assert not model.training_epoch_end_called
# make sure correct metrics are logged (one per batch step as requested)
assert len(trainer.dev_debugger.logged_metrics) == (epochs * batches) + epochs
epoch_metrics = trainer.dev_debugger.logged_metrics
epoch_idx = -1
for i_start in range(0, len(epoch_metrics), batches + 1):
epoch_idx += 1
epoch_outputs = epoch_metrics[i_start: i_start + batches + 1]
mean_vals = {
'step_epoch_log_and_pbar_acc1': [],
'step_epoch_log_acc2': []
}
# make sure each batch logged the expected value
for batch_idx in range(len(epoch_outputs) - 1):
logged_metrics = epoch_outputs[batch_idx]
expected_val_1 = (5 + batch_idx) * (epoch_idx + 1)
expected_val_2 = (6 + batch_idx) * (epoch_idx + 1)
mean_vals['step_epoch_log_and_pbar_acc1'].append(torch.tensor(expected_val_1).float())
mean_vals['step_epoch_log_acc2'].append(torch.tensor(expected_val_2).float())
assert logged_metrics['step_epoch_log_and_pbar_acc1'] == expected_val_1
assert logged_metrics['step_epoch_log_acc2'] == expected_val_2
assert 'step_epoch_pbar_acc3' not in logged_metrics
assert len(logged_metrics) == 4
# make sure the metrics for the epoch end are actual means (the default reduce fx) or all the batches
epoch_end_metrics = epoch_outputs[-1]
eval_1 = torch.stack(mean_vals['step_epoch_log_and_pbar_acc1']).mean()
eval_2 = torch.stack(mean_vals['step_epoch_log_acc2']).mean()
assert epoch_end_metrics['step_epoch_log_and_pbar_acc1'] == eval_1
assert epoch_end_metrics['step_epoch_log_acc2'] == eval_2
assert 'step_epoch_pbar_acc3' not in epoch_end_metrics
assert len(logged_metrics) == 4
# make sure we are using the correct metrics for callbacks
assert trainer.callback_metrics['checkpoint_on'] == 171
# -------------------------------
# VERIFY PBAR METRICS
# -------------------------------
# make sure pbar metrics are correct ang log metrics did not leak
all_pbar_metrics = trainer.dev_debugger.pbar_added_metrics
assert len(all_pbar_metrics) == (epochs * batches) + epochs
epoch_idx = -1
for i_start in range(0, len(all_pbar_metrics), batches + 1):
epoch_idx += 1
epoch_outputs = all_pbar_metrics[i_start: i_start + batches + 1]
mean_vals = {
'step_epoch_log_and_pbar_acc1': [],
'step_epoch_pbar_acc3': []
}
# make sure each batch logged the expected value
for batch_idx in range(len(epoch_outputs) - 1):
logged_metrics = epoch_outputs[batch_idx]
expected_val_1 = (5 + batch_idx) * (epoch_idx + 1)
expected_val_2 = (7 + batch_idx) * (epoch_idx + 1)
mean_vals['step_epoch_log_and_pbar_acc1'].append(torch.tensor(expected_val_1).float())
mean_vals['step_epoch_pbar_acc3'].append(torch.tensor(expected_val_2).float())
assert logged_metrics['step_epoch_log_and_pbar_acc1'] == expected_val_1
assert logged_metrics['step_epoch_pbar_acc3'] == expected_val_2
assert 'step_epoch_log_acc2' not in logged_metrics
assert len(logged_metrics) == 3
# make sure the metrics for the epoch end are actual means (the default reduce fx) or all the batches
epoch_end_metrics = epoch_outputs[-1]
eval_1 = torch.stack(mean_vals['step_epoch_log_and_pbar_acc1']).mean()
eval_2 = torch.stack(mean_vals['step_epoch_pbar_acc3']).mean()
assert epoch_end_metrics['step_epoch_log_and_pbar_acc1'] == eval_1
assert epoch_end_metrics['step_epoch_pbar_acc3'] == eval_2
assert 'step_epoch_log_acc2' not in epoch_end_metrics
assert len(logged_metrics) == 3
# -----------------------------------------
# make sure training outputs what is expected
# -----------------------------------------
for batch_idx, batch in enumerate(model.train_dataloader()):
break
out = trainer.run_training_batch(batch, batch_idx)
assert out.signal == 0
assert len(out.batch_log_metrics) == 2
train_step_out = out.training_step_output_for_epoch_end
assert isinstance(train_step_out, TrainResult)
assert 'minimize' in train_step_out
assert 'step_epoch_log_and_pbar_acc1' in train_step_out
assert 'step_epoch_log_acc2' in train_step_out
# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3)
def test_training_step_epoch_end_result(tmpdir):
"""
Makes sure training_step and epoch_end can be used with Results (without batch_end)
"""
os.environ['PL_DEV_DEBUG'] = '1'
model = DeterministicModel()
model.training_step = model.training_step_result_log_epoch_and_step
model.training_epoch_end = model.training_epoch_end_return_for_log_epoch_and_step
model.val_dataloader = None
batches = 3
epochs = 1
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=epochs,
row_log_interval=1,
limit_train_batches=batches,
weights_summary=None,
)
trainer.fit(model)
# make sure correct steps were called
assert model.training_step_called
assert not model.training_step_end_called
assert model.training_epoch_end_called
# make sure correct metrics were logged
logged_metrics = trainer.dev_debugger.logged_metrics
assert len(logged_metrics) == (epochs * batches) + epochs
last_logged = logged_metrics[-1]
assert last_logged['step_epoch_log_and_pbar_acc1'] == 210.0
assert last_logged['step_epoch_log_acc2'] == 336.0
assert last_logged['epoch_end_log_acc'] == 1212.0
assert last_logged['epoch_end_log_pbar_acc'] == 1214.0
assert 'epoch_end_pbar_acc' not in last_logged
# make sure pbar metrics are correct
logged_pbar = trainer.dev_debugger.pbar_added_metrics
assert len(logged_pbar) == (epochs * batches) + epochs
assert trainer.progress_bar_metrics['step_epoch_log_and_pbar_acc1'] == 210.0
assert trainer.progress_bar_metrics['step_epoch_pbar_acc3'] == 504.0
assert trainer.progress_bar_metrics['epoch_end_pbar_acc'] == 1213.0
assert trainer.progress_bar_metrics['epoch_end_log_pbar_acc'] == 1214.0
assert 'epoch_end_log_acc' not in trainer.progress_bar_metrics
assert 'log_acc2' not in trainer.progress_bar_metrics
# make sure callback metrics didn't change
assert trainer.callback_metrics['checkpoint_on'] == 171
# -----------------------------------------
# make sure training outputs what is expected
# -----------------------------------------
for batch_idx, batch in enumerate(model.train_dataloader()):
break
out = trainer.run_training_batch(batch, batch_idx)
assert out.signal == 0
assert len(out.batch_log_metrics) == 2
train_step_out = out.training_step_output_for_epoch_end
assert isinstance(train_step_out, TrainResult)
assert 'minimize' in train_step_out
assert 'step_epoch_log_and_pbar_acc1' in train_step_out
assert 'step_epoch_log_acc2' in train_step_out
# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3)
def test_no_auto_callbacks_with_train_loop_only(tmpdir):
"""
Make sure early stop + checkpoint work with only a train loop
"""
os.environ['PL_DEV_DEBUG'] = '1'
model = DeterministicModel()
model.training_step = model.training_step_no_default_callbacks_for_train_loop
model.training_epoch_end = None
model.val_dataloader = None
batches = 3
epochs = 3
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=epochs,
row_log_interval=1,
limit_train_batches=batches,
weights_summary=None,
)
trainer.fit(model)
all_losses = trainer.dev_debugger.saved_losses
assert len(all_losses) == batches * epochs
assert trainer.checkpoint_callback.monitor == 'checkpoint_on'
assert trainer.early_stop_callback is None
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=True,
max_epochs=epochs,
row_log_interval=1,
limit_train_batches=batches,
weights_summary=None,
)
trainer.fit(model)
assert trainer.early_stop_callback.monitor == 'val_loss'
def test_no_callbacks_with_train_loop_only(tmpdir):
"""
Make sure early stop + checkpoint work with only a train loop
"""
os.environ['PL_DEV_DEBUG'] = '1'
model = DeterministicModel()
model.training_step = model.training_step_no_callbacks_result_obj
model.training_epoch_end = None
model.val_dataloader = None
batches = 3
epochs = 3
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=epochs,
row_log_interval=1,
limit_train_batches=batches,
weights_summary=None,
)
trainer.fit(model)
all_losses = trainer.dev_debugger.saved_losses
assert len(all_losses) == batches * epochs
assert trainer.early_stop_callback is None
assert len(trainer.dev_debugger.checkpoint_callback_history) == 0
assert len(trainer.dev_debugger.early_stopping_history) == 0
def test_use_callbacks_with_train_loop_only(tmpdir):
os.environ['PL_DEV_DEBUG'] = '1'
model = DeterministicModel()
model.training_step = model.training_step_result_log_epoch_and_step_for_callbacks
model.training_epoch_end = None
model.val_dataloader = None
batches = 3
epochs = 300
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=epochs,
early_stop_callback=True,
row_log_interval=1,
limit_train_batches=batches,
weights_summary=None,
)
trainer.fit(model)
num_expected_epochs = 10
# ----------------------------------
# VERIFY EARLY STOPPING BEHAVIOR
# ----------------------------------
# with train loop only it happens on every epoch
early_stop_vals = trainer.dev_debugger.early_stopping_history
assert len(early_stop_vals) == num_expected_epochs
min_val = min([x['best'] for x in early_stop_vals])
assert min_val == 171 + 9
all_losses = trainer.dev_debugger.saved_losses
from collections import Counter
batch_idxs = Counter([x['batch_idx'] for x in all_losses])
for i, val in batch_idxs.items():
assert val == num_expected_epochs
assert i in [0, 1, 2]
# ----------------------------------
# VERIFY CHECKPOINTING BEHAVIOR
# ----------------------------------
ckpt_vals = trainer.dev_debugger.checkpoint_callback_history
assert len(ckpt_vals) == 5, '5 ckpts should have been saved'
for ckpt_val, expected_epoch in zip(ckpt_vals, [0, 1, 2, 3, 6]):
assert ckpt_val['epoch'] == expected_epoch
assert ckpt_val['monitor'] == 'checkpoint_on'
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_full_train_loop_with_results_obj_dp(tmpdir):
os.environ['PL_DEV_DEBUG'] = '1'
batches = 10
epochs = 3
model = EvalModelTemplate()
model.validation_step = None
model.test_step = None
model.training_step = model.training_step_full_loop_result_obj_dp
model.training_step_end = model.training_step_end_full_loop_result_obj_dp
model.training_epoch_end = model.training_epoch_end_full_loop_result_obj_dp
model.val_dataloader = None
model.test_dataloader = None
trainer = Trainer(
default_root_dir=tmpdir,
distributed_backend='dp',
gpus=[0, 1],
max_epochs=epochs,
early_stop_callback=True,
row_log_interval=2,
limit_train_batches=batches,
weights_summary=None,
)
trainer.fit(model)
# make sure we saw all the correct keys
seen_keys = set()
for metric in trainer.dev_debugger.logged_metrics:
seen_keys.update(metric.keys())
assert 'train_step_metric' in seen_keys
assert 'train_step_end_metric' in seen_keys
assert 'train_epoch_end_metric' in seen_keys