Structured results (train loop only. val loop separate PR) (PR 2/5) (#2615)

* r

* r

* r

* patched optimizer closure with sr

* patched optimizer closure with sr

* patched optimizer closure with sr

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added train step structured result

* added autoreduce for train step

* added auto reduce on train

* added auto reduce on train

* added auto reduce on train

* added auto reduce on train

* added auto reduce on train

* added auto reduce on train

* added hooks

* added hooks

* added hooks

* added hooks

* added hooks

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* cache

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* Update pytorch_lightning/callbacks/early_stopping.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/callbacks/early_stopping.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/callbacks/early_stopping.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* Update pytorch_lightning/callbacks/model_checkpoint.py

* Update pytorch_lightning/core/step_result.py

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* Apply suggestions from code review

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>

* simple

* finished tests for structured results on train epoch

* simple

* simple

* revert

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* Update tests/base/deterministic_model.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

* finished tests for structured results on train epoch

* docstring typos

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* finished tests for structured results on train epoch

* Update pytorch_lightning/core/step_result.py

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

* Update pytorch_lightning/overrides/data_parallel.py

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>

Co-authored-by: Jirka <jirka@pytorchlightning.ai>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
This commit is contained in:
William Falcon 2020-07-20 19:00:20 -04:00 committed by GitHub
parent 816d8cff06
commit 6d10ac2ac8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 1400 additions and 89 deletions

View File

@ -82,9 +82,9 @@ jobs:
uses: actions/cache@v1 uses: actions/cache@v1
with: with:
path: ${{ steps.pip-cache.outputs.dir }} path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}-pip-${{ hashFiles('requirements/base.txt') }}-${{ hashFiles('requirements/extra.txt') }} key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ matrix.requires }}-pip-${{ hashFiles('requirements/base.txt') }}-${{ hashFiles('requirements/extra.txt') }}
restore-keys: | restore-keys: |
${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}-pip- ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ matrix.requires }}-pip-
- name: Install dependencies - name: Install dependencies
run: | run: |

View File

@ -55,6 +55,7 @@ else:
from pytorch_lightning.trainer import Trainer from pytorch_lightning.trainer import Trainer
from pytorch_lightning.utilities.seed import seed_everything from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning import metrics from pytorch_lightning import metrics
from pytorch_lightning.core.step_result import TrainResult, EvalResult
__all__ = [ __all__ = [
'Trainer', 'Trainer',
@ -62,7 +63,9 @@ else:
'Callback', 'Callback',
'data_loader', 'data_loader',
'seed_everything', 'seed_everything',
'metrics' 'metrics',
'EvalResult',
'TrainResult'
] ]
# necessary for regular bolts imports. Skip exception since bolts is not always installed # necessary for regular bolts imports. Skip exception since bolts is not always installed

View File

@ -46,6 +46,30 @@ class Callback(abc.ABC):
"""Called when the validation sanity check ends.""" """Called when the validation sanity check ends."""
pass pass
def on_train_epoch_start(self, trainer, pl_module):
"""Called when the train epoch begins."""
pass
def on_train_epoch_end(self, trainer, pl_module):
"""Called when the train epoch ends."""
pass
def on_validation_epoch_start(self, trainer, pl_module):
"""Called when the val epoch begins."""
pass
def on_validation_epoch_end(self, trainer, pl_module):
"""Called when the val epoch ends."""
pass
def on_test_epoch_start(self, trainer, pl_module):
"""Called when the test epoch begins."""
pass
def on_test_epoch_end(self, trainer, pl_module):
"""Called when the test epoch ends."""
pass
def on_epoch_start(self, trainer, pl_module): def on_epoch_start(self, trainer, pl_module):
"""Called when the epoch begins.""" """Called when the epoch begins."""
pass pass

View File

@ -7,6 +7,7 @@ Monitor a validation metric and stop training when it stops improving.
""" """
from copy import deepcopy from copy import deepcopy
import os
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -140,12 +141,33 @@ class EarlyStopping(Callback):
def on_validation_end(self, trainer, pl_module): def on_validation_end(self, trainer, pl_module):
self._run_early_stopping_check(trainer, pl_module) self._run_early_stopping_check(trainer, pl_module)
def on_train_epoch_end(self, trainer, pl_module):
# early stopping can also work in the train loop when there is no val loop and when using structured results
should_check_early_stop = False
train_es_key = 'early_stop_on'
if trainer.callback_metrics.get(train_es_key, None) is not None:
self.monitor = train_es_key
should_check_early_stop = True
val_es_key = 'val_early_stop_on'
if trainer.callback_metrics.get(val_es_key, None) is not None:
self.monitor = val_es_key
should_check_early_stop = True
if should_check_early_stop:
self._run_early_stopping_check(trainer, pl_module)
def _run_early_stopping_check(self, trainer, pl_module): def _run_early_stopping_check(self, trainer, pl_module):
logs = trainer.callback_metrics logs = trainer.callback_metrics
if not self._validate_condition_metric(logs): if not self._validate_condition_metric(logs):
return # short circuit if metric not present return # short circuit if metric not present
current = logs.get(self.monitor) current = logs.get(self.monitor)
# when in dev debugging
trainer.dev_debugger.track_early_stopping_history(current)
if not isinstance(current, torch.Tensor): if not isinstance(current, torch.Tensor):
current = torch.tensor(current, device=pl_module.device) current = torch.tensor(current, device=pl_module.device)

View File

@ -159,7 +159,11 @@ class ModelCheckpoint(Callback):
if os.path.isfile(filepath): if os.path.isfile(filepath):
os.remove(filepath) os.remove(filepath)
def _save_model(self, filepath): def _save_model(self, filepath, trainer, pl_module):
# in debugging, track when we save checkpoints
trainer.dev_debugger.track_checkpointing_history(filepath)
# make paths # make paths
os.makedirs(os.path.dirname(filepath), exist_ok=True) os.makedirs(os.path.dirname(filepath), exist_ok=True)
@ -270,6 +274,11 @@ class ModelCheckpoint(Callback):
metrics = trainer.callback_metrics metrics = trainer.callback_metrics
epoch = trainer.current_epoch epoch = trainer.current_epoch
# support structured results
if metrics.get('checkpoint_on') is not None:
self.monitor = 'checkpoint_on'
if self.save_top_k == 0: if self.save_top_k == 0:
# no models are saved # no models are saved
return return
@ -281,7 +290,7 @@ class ModelCheckpoint(Callback):
if self.save_last: if self.save_last:
filepath = os.path.join(self.dirpath, self.prefix + 'last.ckpt') filepath = os.path.join(self.dirpath, self.prefix + 'last.ckpt')
self._save_model(filepath) self._save_model(filepath, trainer, pl_module)
filepath = self.format_checkpoint_name(epoch, metrics) filepath = self.format_checkpoint_name(epoch, metrics)
version_cnt = 0 version_cnt = 0
@ -306,7 +315,7 @@ class ModelCheckpoint(Callback):
f'Can save best model only with {self.monitor} available, skipping.', RuntimeWarning f'Can save best model only with {self.monitor} available, skipping.', RuntimeWarning
) )
elif self.check_monitor_top_k(current): elif self.check_monitor_top_k(current):
self._do_check_save(filepath, current, epoch) self._do_check_save(filepath, current, epoch, trainer, pl_module)
elif self.verbose > 0: elif self.verbose > 0:
log.info(f'\nEpoch {epoch:05d}: {self.monitor} was not in top {self.save_top_k}') log.info(f'\nEpoch {epoch:05d}: {self.monitor} was not in top {self.save_top_k}')
@ -315,9 +324,9 @@ class ModelCheckpoint(Callback):
log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}') log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')
assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0' assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0'
self._save_model(filepath) self._save_model(filepath, trainer, pl_module)
def _do_check_save(self, filepath, current, epoch): def _do_check_save(self, filepath, current, epoch, trainer, pl_module):
# remove kth # remove kth
del_list = [] del_list = []
@ -343,7 +352,7 @@ class ModelCheckpoint(Callback):
f'\nEpoch {epoch:05d}: {self.monitor} reached' f'\nEpoch {epoch:05d}: {self.monitor} reached'
f' {current:0.5f} (best {self.best_model_score:0.5f}), saving model to' f' {current:0.5f} (best {self.best_model_score:0.5f}), saving model to'
f' {filepath} as top {self.save_top_k}') f' {filepath} as top {self.save_top_k}')
self._save_model(filepath) self._save_model(filepath, trainer, pl_module)
for cur_path in del_list: for cur_path in del_list:
if cur_path != filepath: if cur_path != filepath:

View File

@ -115,6 +115,42 @@ class ModelHooks(Module):
""" """
# do something when the epoch ends # do something when the epoch ends
def on_train_epoch_start(self) -> None:
"""
Called in the training loop at the very beginning of the epoch.
"""
# do something when the epoch starts
def on_train_epoch_end(self) -> None:
"""
Called in the training loop at the very end of the epoch.
"""
# do something when the epoch ends
def on_validation_epoch_start(self) -> None:
"""
Called in the validation loop at the very beginning of the epoch.
"""
# do something when the epoch starts
def on_validation_epoch_end(self) -> None:
"""
Called in the validation loop at the very end of the epoch.
"""
# do something when the epoch ends
def on_test_epoch_start(self) -> None:
"""
Called in the test loop at the very beginning of the epoch.
"""
# do something when the epoch starts
def on_test_epoch_end(self) -> None:
"""
Called in the test loop at the very end of the epoch.
"""
# do something when the epoch ends
def on_pre_performance_check(self) -> None: def on_pre_performance_check(self) -> None:
""" """
Called at the very beginning of the validation loop. Called at the very beginning of the validation loop.

View File

@ -0,0 +1,336 @@
from typing import Optional, Dict, Union, Sequence, Callable, MutableMapping, Any
from torch import Tensor
import torch
from copy import copy
class Result(Dict):
def __init__(
self,
minimize: Optional[Tensor] = None,
early_stop_on: Optional[Tensor] = None,
checkpoint_on: Union[Tensor, bool, None] = None,
hiddens: Optional[Tensor] = None,
):
super().__init__()
if early_stop_on is not None:
self.early_stop_on = early_stop_on
if checkpoint_on is not None and checkpoint_on:
self.checkpoint_on = checkpoint_on
if hiddens is not None:
self.hiddens = hiddens
if minimize is not None:
err = 'Minimize can only be used in training_step, training_step_end, training_epoch_end'
self._assert_grad_tensor_metric('minimize', minimize, err)
self.minimize = minimize
if minimize is not None and checkpoint_on is None:
self.checkpoint_on = minimize.detach()
self['meta'] = {
'_internal': {
'_reduce_on_epoch': False
}
}
def __getattr__(self, key: str) -> Any:
try:
if key == 'callback_metrics':
return self.get_callback_metrics()
elif key == 'batch_log_metrics':
return self.get_batch_log_metrics()
elif key == 'batch_pbar_metrics':
return self.get_batch_pbar_metrics()
elif key == 'epoch_log_metrics':
return self.get_epoch_log_metrics()
elif key == 'epoch_pbar_metrics':
return self.get_epoch_pbar_metrics()
else:
return self[key]
except KeyError:
return None
def __setattr__(self, key: str, val: Union[Tensor, Any]):
# ensure reserve keys are tensors and detached
if key in {'hiddens', 'checkpoint_on', 'early_stop_on'}:
self._assert_tensor_metric(key, val)
if val is not None and isinstance(val, torch.Tensor):
val = val.detach()
# ensure anything else that is a tensor is detached
elif isinstance(val, torch.Tensor) and key != 'minimize':
val = val.detach()
self[key] = val
def _assert_tensor_metric(self, name: str, potential_metric: Union[bool, Tensor, None, Any]):
if potential_metric is not None and not isinstance(potential_metric, bool):
assert isinstance(potential_metric, Tensor), f'{name} must be a torch.Tensor'
def _assert_grad_tensor_metric(self, name: str, x: Union[torch.Tensor, Any], additional_err: str = ''):
if x is not None:
assert isinstance(x, Tensor), f'{name} must be a torch.Tensor'
m = f'{name} must have a computational graph.'
if additional_err:
m += f' {additional_err}'
assert x.grad_fn is not None, m
def log(
self,
name: str,
value: Any,
prog_bar: bool = False,
logger: bool = True,
on_step: bool = False,
on_epoch: bool = True,
reduce_fx: Callable = torch.mean,
enable_graph: bool = False,
):
# no metrics should be logged with graphs
if not enable_graph and isinstance(value, torch.Tensor):
value = value.detach()
if 'meta' not in self:
self.__setitem__('meta', {})
self.__set_meta(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx)
# set the value
self.__setitem__(name, value)
def __set_meta(
self,
name: str,
value: Any,
prog_bar: bool,
logger: bool,
on_step: bool,
on_epoch: bool,
reduce_fx: Callable,
):
# set the meta for the item
meta_value = value
meta = dict(
prog_bar=prog_bar,
logger=logger,
on_step=on_step,
on_epoch=on_epoch,
reduce_fx=reduce_fx,
value=meta_value
)
self['meta'][name] = meta
# track whether any input requires reduction on epoch end
_internal = self['meta']['_internal']
_internal['_reduce_on_epoch'] = max(_internal['_reduce_on_epoch'], on_epoch)
def get_callback_metrics(self) -> dict:
result = {
'early_stop_on': self.early_stop_on,
'checkpoint_on': self.checkpoint_on
}
return result
def get_batch_log_metrics(self) -> dict:
"""
Gets the metrics to log at the end of the batch step
"""
result = {}
meta = self['meta']
for k, options in meta.items():
if k == '_internal':
continue
if options['logger'] and options['on_step']:
result[k] = self[k]
return result
def get_epoch_log_metrics(self) -> dict:
"""
Gets the metrics to log at the end of the batch step
"""
result = {}
meta = self['meta']
for k, options in meta.items():
if k == '_internal':
continue
if options['logger'] and options['on_epoch']:
result[k] = self[k]
return result
def get_epoch_pbar_metrics(self):
"""
Gets the metrics to log at the end of the batch step
"""
result = {}
meta = self['meta']
for k, options in meta.items():
if k == '_internal':
continue
if options['prog_bar'] and options['on_epoch']:
result[k] = self[k]
return result
def get_batch_pbar_metrics(self):
"""
Gets the metrics to log at the end of the batch step
"""
result = {}
meta = self['meta']
for k, options in meta.items():
if k == '_internal':
continue
if options['prog_bar'] and options['on_step']:
result[k] = self[k]
return result
def detach(self):
for k, v in self.items():
if isinstance(v, torch.Tensor):
self.__setitem__(k, v.detach())
def __repr__(self):
self_copy = self.copy()
if 'meta' in self_copy:
del self_copy['meta']
return str(self_copy)
def __str__(self):
copy = self.copy()
del copy['meta']
return str(copy)
def __copy__(self):
newone = type(self)()
for k, v in self.items():
newone[k] = copy(v)
return newone
@classmethod
def gather(cls, outputs):
meta = outputs[0]['meta']
result = cls()
result = recursive_gather(outputs, result)
recursive_stack(result)
result['meta'] = meta
return result
@classmethod
def reduce_on_epoch_end(cls, outputs):
meta = outputs[0]['meta']
result = cls()
result = recursive_gather(outputs, result)
recursive_stack(result)
for k, option in meta.items():
if k == '_internal':
continue
if option['on_epoch']:
fx = option['reduce_fx']
result[k] = fx(result[k])
result['meta'] = meta
return result
@property
def should_reduce_on_epoch_end(self) -> bool:
return self['meta']['_internal']['_reduce_on_epoch']
def recursive_gather(outputs: Sequence[dict], result: Optional[MutableMapping] = None) -> Optional[MutableMapping]:
for out in outputs:
if 'meta' in out:
del out['meta']
for k, v in out.items():
if isinstance(v, dict):
v = recursive_gather([v], result)
if k not in result:
result[k] = []
result[k].append(v)
return result
def recursive_stack(result: MutableMapping):
for k, v in result.items():
if isinstance(v, dict):
recursive_stack(v)
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], torch.Tensor):
v = torch.stack(v)
result[k] = v
class TrainResult(Result):
def __init__(
self,
minimize: Optional[Tensor] = None,
early_stop_on: Tensor = None,
checkpoint_on: Union[Tensor, bool] = None,
hiddens: Optional[Tensor] = None,
):
super().__init__(minimize, early_stop_on, checkpoint_on, hiddens)
def log(
self,
name,
value,
prog_bar: bool = False,
logger: bool = True,
on_step: bool = True,
on_epoch: bool = False,
reduce_fx: Callable = torch.mean,
enable_graph: bool = False,
):
super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph)
class EvalResult(Result):
def __init__(
self,
early_stop_on: Optional[Tensor] = None,
checkpoint_on: Optional[Tensor] = None,
hiddens: Optional[Tensor] = None,
):
super().__init__(None, early_stop_on, checkpoint_on, hiddens)
def log(
self,
name,
value,
prog_bar: bool = False,
logger: bool = True,
on_step: bool = False,
on_epoch: bool = True,
reduce_fx: Callable = torch.mean,
enable_graph: bool = False,
):
super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph)
# if __name__ == '__main__':
# import torch
# result = TrainResult()
# result.hiddens = torch.tensor(1)
# result.log('some', 123)
# print(result)
# result.minimize = torch.tensor(1)

View File

@ -6,6 +6,7 @@ import torch
from torch.cuda._utils import _get_device_index from torch.cuda._utils import _get_device_index
from torch.nn import DataParallel from torch.nn import DataParallel
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
from pytorch_lightning.core.step_result import Result
def _find_tensors(obj): # pragma: no-cover def _find_tensors(obj): # pragma: no-cover
@ -63,7 +64,34 @@ class LightningDataParallel(DataParallel):
replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
outputs = self.parallel_apply(replicas, inputs, kwargs) outputs = self.parallel_apply(replicas, inputs, kwargs)
return self.gather(outputs, self.output_device)
if isinstance(outputs[0], Result):
outputs = self.__gather_structured_result(outputs)
else:
outputs = self.gather(outputs, self.output_device)
return outputs
def __gather_structured_result(self, outputs):
prototype_output = outputs[0]
original_class = prototype_output.__class__
outputs = [dict(x) for x in outputs]
# remove all the meta info
meta = outputs[0]['meta']
for i, output in enumerate(outputs):
del output['meta']
outputs = self.gather(outputs, self.output_device)
# pass minimize to constructor for TrainResult
if 'minimize' in outputs:
result = original_class(outputs['minimize'])
else:
result = original_class()
result.update(outputs)
result['meta'] = meta
return result
def parallel_apply(self, replicas, inputs, kwargs): def parallel_apply(self, replicas, inputs, kwargs):
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
@ -160,6 +188,8 @@ def parallel_apply(modules, inputs, kwargs_tup=None, devices=None): # pragma: n
if not isinstance(input, (list, tuple)): if not isinstance(input, (list, tuple)):
input = (input,) input = (input,)
module = module.to(device)
# --------------- # ---------------
# CHANGE # CHANGE
if module.training: if module.training:

View File

@ -51,6 +51,36 @@ class TrainerCallbackHookMixin(ABC):
for callback in self.callbacks: for callback in self.callbacks:
callback.on_sanity_check_end(self, self.get_model()) callback.on_sanity_check_end(self, self.get_model())
def on_train_epoch_start(self):
"""Called when the epoch begins."""
for callback in self.callbacks:
callback.on_train_epoch_start(self, self.get_model())
def on_train_epoch_end(self):
"""Called when the epoch ends."""
for callback in self.callbacks:
callback.on_train_epoch_end(self, self.get_model())
def on_validation_epoch_start(self):
"""Called when the epoch begins."""
for callback in self.callbacks:
callback.on_validation_epoch_start(self, self.get_model())
def on_validation_epoch_end(self):
"""Called when the epoch ends."""
for callback in self.callbacks:
callback.on_validation_epoch_end(self, self.get_model())
def on_test_epoch_start(self):
"""Called when the epoch begins."""
for callback in self.callbacks:
callback.on_test_epoch_start(self, self.get_model())
def on_test_epoch_end(self):
"""Called when the epoch ends."""
for callback in self.callbacks:
callback.on_test_epoch_end(self, self.get_model())
def on_epoch_start(self): def on_epoch_start(self):
"""Called when the epoch begins.""" """Called when the epoch begins."""
for callback in self.callbacks: for callback in self.callbacks:

View File

@ -1,3 +1,4 @@
import os
from abc import ABC from abc import ABC
from typing import Union, Iterable from typing import Union, Iterable
@ -73,6 +74,8 @@ class TrainerLoggingMixin(ABC):
self.logger.agg_and_log_metrics(scalar_metrics, step=step) self.logger.agg_and_log_metrics(scalar_metrics, step=step)
self.logger.save() self.logger.save()
self.dev_debugger.track_logged_metrics_history(scalar_metrics)
def add_progress_bar_metrics(self, metrics): def add_progress_bar_metrics(self, metrics):
for k, v in metrics.items(): for k, v in metrics.items():
if isinstance(v, torch.Tensor): if isinstance(v, torch.Tensor):
@ -80,6 +83,8 @@ class TrainerLoggingMixin(ABC):
self.progress_bar_metrics[k] = v self.progress_bar_metrics[k] = v
self.dev_debugger.track_pbar_metrics_history(metrics)
def metrics_to_scalars(self, metrics): def metrics_to_scalars(self, metrics):
new_metrics = {} new_metrics = {}
for k, v in metrics.items(): for k, v in metrics.items():

View File

@ -76,3 +76,17 @@ class TensorRunningAccum(object):
return getattr(self.memory, how)() return getattr(self.memory, how)()
else: else:
return getattr(self.memory[:self.current_idx], how)() return getattr(self.memory[:self.current_idx], how)()
class Accumulator(object):
def __init__(self):
self.num_values = 0
self.total = 0
def accumulate(self, x):
with torch.no_grad():
self.total += x
self.num_values += 1
def mean(self):
return self.total / self.num_values

View File

@ -33,6 +33,7 @@ from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.trainer.lr_finder import TrainerLRFinderMixin from pytorch_lightning.trainer.lr_finder import TrainerLRFinderMixin
from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities import rank_zero_warn, parsing, rank_zero_info, rank_zero_only from pytorch_lightning.utilities import rank_zero_warn, parsing, rank_zero_info, rank_zero_only
from pytorch_lightning.utilities.debugging import InternalDebugger
import warnings import warnings
# warnings to ignore in trainer # warnings to ignore in trainer
@ -616,6 +617,9 @@ class Trainer(
self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE') self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE')
# tracks internal state for debugging
self.dev_debugger = InternalDebugger(self)
# Callback system # Callback system
self.on_init_end() self.on_init_end()

View File

@ -143,7 +143,7 @@ in your model.
trainer = Trainer(terminate_on_nan=True) trainer = Trainer(terminate_on_nan=True)
""" """
import os
import subprocess import subprocess
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable from typing import Callable
@ -153,17 +153,19 @@ import numpy as np
import torch import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import torch.distributed as torch_distrib import torch.distributed as torch_distrib
from copy import copy
from pytorch_lightning import _logger as log from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.trainer.supporters import TensorRunningAccum, Accumulator
from pytorch_lightning.utilities import rank_zero_warn, NATIVE_AMP_AVALAIBLE from pytorch_lightning.utilities import rank_zero_warn, NATIVE_AMP_AVALAIBLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.parsing import AttributeDict from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.utilities.memory import recursive_detach from pytorch_lightning.utilities.memory import recursive_detach
from pytorch_lightning.core.step_result import EvalResult, TrainResult, Result
try: try:
from apex import amp from apex import amp
@ -251,6 +253,8 @@ class TrainerTrainLoopMixin(ABC):
on_epoch_end: Callable on_epoch_end: Callable
on_validation_end: Callable on_validation_end: Callable
on_keyboard_interrupt: Callable on_keyboard_interrupt: Callable
on_train_epoch_start: Callable
on_train_epoch_end: Callable
@abstractmethod @abstractmethod
def get_model(self) -> LightningModule: def get_model(self) -> LightningModule:
@ -420,6 +424,15 @@ class TrainerTrainLoopMixin(ABC):
if self.is_function_implemented('on_epoch_start'): if self.is_function_implemented('on_epoch_start'):
model.on_epoch_start() model.on_epoch_start()
# Epoch start events
with self.profiler.profile('on_train_epoch_start'):
# callbacks
self.on_train_epoch_start()
# model hooks
if self.is_function_implemented('on_train_epoch_start'):
model.on_train_epoch_start()
def run_training_epoch(self): def run_training_epoch(self):
# get model # get model
@ -435,6 +448,10 @@ class TrainerTrainLoopMixin(ABC):
epoch_output = [] epoch_output = []
should_check_val = False should_check_val = False
# structured result accumulators for callbacks
early_stopping_accumulator = Accumulator()
checkpoint_accumulator = Accumulator()
# run epoch # run epoch
for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable( for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable(
enumerate(_with_is_last(train_dataloader)), "get_train_batch" enumerate(_with_is_last(train_dataloader)), "get_train_batch"
@ -453,7 +470,15 @@ class TrainerTrainLoopMixin(ABC):
# only track outputs when user implements training_epoch_end # only track outputs when user implements training_epoch_end
# otherwise we will build up unnecessary memory # otherwise we will build up unnecessary memory
if self.is_overridden('training_epoch_end', model=self.get_model()): step_out = batch_output.training_step_output_for_epoch_end
should_auto_reduce_train_result = isinstance(step_out, Result) and step_out.should_reduce_on_epoch_end
if isinstance(step_out, dict) and 'early_stop_on' in step_out:
early_stopping_accumulator.accumulate(step_out['early_stop_on'])
if isinstance(step_out, dict) and 'checkpoint_on' in step_out:
checkpoint_accumulator.accumulate(step_out['checkpoint_on'])
if self.is_overridden('training_epoch_end', model=self.get_model()) or should_auto_reduce_train_result:
epoch_output.append(batch_output.training_step_output_for_epoch_end) epoch_output.append(batch_output.training_step_output_for_epoch_end)
# update LR schedulers # update LR schedulers
@ -496,7 +521,7 @@ class TrainerTrainLoopMixin(ABC):
self.sync_horovod() self.sync_horovod()
# process epoch outputs # process epoch outputs
self.run_training_epoch_end(epoch_output) self.run_training_epoch_end(epoch_output, checkpoint_accumulator, early_stopping_accumulator)
# checkpoint callback # checkpoint callback
self.check_checkpoint_callback(should_check_val) self.check_checkpoint_callback(should_check_val)
@ -525,23 +550,74 @@ class TrainerTrainLoopMixin(ABC):
if self.is_function_implemented('on_epoch_end'): if self.is_function_implemented('on_epoch_end'):
model.on_epoch_end() model.on_epoch_end()
def run_training_epoch_end(self, epoch_output): with self.profiler.profile('on_train_epoch_end'):
# callbacks
self.on_train_epoch_end()
# model hooks
if self.is_function_implemented('on_train_epoch_end'):
model.on_train_epoch_end()
def run_training_epoch_end(self, epoch_output, checkpoint_accumulator, early_stopping_accumulator):
model = self.get_model() model = self.get_model()
is_result_obj = len(epoch_output) > 0 and isinstance(epoch_output[0], Result)
epoch_log_metrics = {}
epoch_callback_metrics = {}
epoch_progress_bar_metrics = {}
# -----------------------
# Calculate epoch callback values if given
# -----------------------
if checkpoint_accumulator.num_values > 0:
epoch_callback_metrics['checkpoint_on'] = checkpoint_accumulator.mean()
if early_stopping_accumulator.num_values > 0:
epoch_callback_metrics['early_stop_on'] = early_stopping_accumulator.mean()
# --------------------------
# EPOCH END STEP IF DEFINED
# --------------------------
if self.is_overridden('training_epoch_end', model=model): if self.is_overridden('training_epoch_end', model=model):
self.global_step += 1 self.global_step += 1
# remove the protected keys so the user doesn't have to deal with them
if is_result_obj:
epoch_output = epoch_output[0].__class__.gather(epoch_output)
# run training_epoch_end
epoch_output = model.training_epoch_end(epoch_output) epoch_output = model.training_epoch_end(epoch_output)
_processed_outputs = self.process_output(epoch_output)
log_epoch_metrics = _processed_outputs[2]
callback_epoch_metrics = _processed_outputs[3]
# add the metrics to the loggers if isinstance(epoch_output, Result):
self.log_metrics(log_epoch_metrics, {}) epoch_log_metrics = epoch_output.epoch_log_metrics
epoch_progress_bar_metrics = epoch_output.epoch_pbar_metrics
else:
_processed_outputs = self.process_output(epoch_output)
epoch_progress_bar_metrics = _processed_outputs[1]
epoch_log_metrics = _processed_outputs[2]
epoch_callback_metrics = _processed_outputs[3]
# add metrics to callbacks # --------------------------
self.callback_metrics.update(callback_epoch_metrics) # Structured Result (auto epoch end)
# --------------------------
elif is_result_obj:
epoch_output = epoch_output[0].__class__.reduce_on_epoch_end(epoch_output)
epoch_output.minimize = epoch_output.minimize.mean()
epoch_log_metrics = epoch_output.epoch_log_metrics
epoch_progress_bar_metrics = epoch_output.epoch_pbar_metrics
# add metrics to progress_bar # --------------------------
self.add_progress_bar_metrics(_processed_outputs[1]) # track results
# --------------------------
# add the metrics to the loggers
if epoch_log_metrics and len(epoch_log_metrics) > 0:
self.log_metrics(epoch_log_metrics, {})
# add metrics to callbacks
self.callback_metrics.update(epoch_callback_metrics)
# add metrics to progress_bar
self.add_progress_bar_metrics(epoch_progress_bar_metrics)
def sync_horovod(self): def sync_horovod(self):
if self.use_horovod: if self.use_horovod:
@ -558,7 +634,10 @@ class TrainerTrainLoopMixin(ABC):
should_log_metrics = batch_idx % self.row_log_interval == 0 or self.should_stop should_log_metrics = batch_idx % self.row_log_interval == 0 or self.should_stop
if should_log_metrics or self.fast_dev_run: if should_log_metrics or self.fast_dev_run:
# logs user requested information to logger # logs user requested information to logger
self.log_metrics(batch_output.batch_log_metrics, batch_output.grad_norm_dic) metrics = batch_output.batch_log_metrics
grad_norm_dic = batch_output.grad_norm_dic
if len(metrics) > 0 or len(grad_norm_dic) > 0:
self.log_metrics(metrics, grad_norm_dic)
def save_loggers_in_training_loop(self, batch_idx): def save_loggers_in_training_loop(self, batch_idx):
# when loggers should save to disk # when loggers should save to disk
@ -588,6 +667,8 @@ class TrainerTrainLoopMixin(ABC):
# track metrics to log # track metrics to log
batch_log_metrics = [] batch_log_metrics = []
using_results_obj = False
if batch is None: if batch is None:
return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic) return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic)
@ -622,7 +703,7 @@ class TrainerTrainLoopMixin(ABC):
param.requires_grad = True param.requires_grad = True
# ------------------- # -------------------
# calculate loss # calculate loss (train step + train step end)
# ------------------- # -------------------
opt_closure_result = self.optimizer_closure( opt_closure_result = self.optimizer_closure(
split_batch, split_batch,
@ -631,14 +712,26 @@ class TrainerTrainLoopMixin(ABC):
optimizer, optimizer,
self.hiddens self.hiddens
) )
using_results_obj = isinstance(opt_closure_result.training_step_output, Result)
# ------------------------------ # ------------------------------
# POST forward bookkeeping # POST forward bookkeeping
# ------------------------------ # ------------------------------
batch_callback_metrics.append(opt_closure_result.training_step_output.callback_metrics) batch_callback_metrics.append(opt_closure_result.training_step_output.callback_metrics)
batch_log_metrics.append(opt_closure_result.training_step_output.log_metrics)
self.add_progress_bar_metrics(opt_closure_result.training_step_output.pbar_on_batch_end) # add metrics to loggers
if using_results_obj:
metrics_to_log = opt_closure_result.training_step_output.batch_log_metrics
else:
metrics_to_log = opt_closure_result.training_step_output.log_metrics
batch_log_metrics.append(metrics_to_log)
# add metrics to progress bar
if using_results_obj:
metrics_for_pbar = opt_closure_result.training_step_output.batch_pbar_metrics
else:
metrics_for_pbar = opt_closure_result.training_step_output.pbar_on_batch_end
self.add_progress_bar_metrics(metrics_for_pbar)
# track hiddens # track hiddens
self.hiddens = opt_closure_result.hiddens self.hiddens = opt_closure_result.hiddens
@ -677,7 +770,8 @@ class TrainerTrainLoopMixin(ABC):
batch_log_metrics = {k: v for d in batch_log_metrics for k, v in d.items()} batch_log_metrics = {k: v for d in batch_log_metrics for k, v in d.items()}
# track all metrics for callbacks # track all metrics for callbacks
self.callback_metrics.update({k: v for d in batch_callback_metrics for k, v in d.items()}) if not using_results_obj:
self.callback_metrics.update({k: v for d in batch_callback_metrics for k, v in d.items()})
result = AttributeDict( result = AttributeDict(
signal=0, signal=0,
@ -764,7 +858,7 @@ class TrainerTrainLoopMixin(ABC):
wrap the forward step in a closure so second order methods work wrap the forward step in a closure so second order methods work
""" """
# --------------------------- # ---------------------------
# FORWARD # FORWARD (TRAINING STEP + TRAIN STEP END)
# --------------------------- # ---------------------------
with self.profiler.profile('model_forward'): with self.profiler.profile('model_forward'):
if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu: if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu:
@ -780,26 +874,38 @@ class TrainerTrainLoopMixin(ABC):
# ---------------------------- # ----------------------------
# format and reduce outputs accordingly # format and reduce outputs accordingly
training_step_output_for_epoch_end = training_step_output training_step_output_for_epoch_end = training_step_output
training_step_output = self.process_output(training_step_output, train=True) is_result_obj = isinstance(training_step_output, Result)
# TODO: temporary part of structured results PR # don't allow EvalResult in the training_step
training_step_output = AttributeDict( if isinstance(training_step_output, EvalResult):
batch_loss=training_step_output[0], raise MisconfigurationException('training_step cannot return EvalResult, '
pbar_on_batch_end=training_step_output[1], 'use a dict or TrainResult instead')
log_metrics=training_step_output[2],
callback_metrics=training_step_output[3], # handle regular dicts
hiddens=training_step_output[4], if not is_result_obj:
) training_step_output = self.process_output(training_step_output, train=True)
training_step_output = AttributeDict(
batch_loss=training_step_output[0],
pbar_on_batch_end=training_step_output[1],
log_metrics=training_step_output[2],
callback_metrics=training_step_output[3],
hiddens=training_step_output[4],
)
# if the user decides to finally reduce things in epoch_end, save raw output without graphs # if the user decides to finally reduce things in epoch_end, save raw output without graphs
if isinstance(training_step_output_for_epoch_end, torch.Tensor): if isinstance(training_step_output_for_epoch_end, torch.Tensor):
training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach() training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach()
elif is_result_obj:
training_step_output_for_epoch_end = copy(training_step_output)
training_step_output_for_epoch_end.detach()
else: else:
training_step_output_for_epoch_end = recursive_detach(training_step_output_for_epoch_end) training_step_output_for_epoch_end = recursive_detach(training_step_output_for_epoch_end)
# accumulate loss # accumulate loss
# (if accumulate_grad_batches = 1 no effect) # (if accumulate_grad_batches = 1 no effect)
closure_loss = training_step_output.batch_loss / self.accumulate_grad_batches closure_loss = training_step_output.minimize if is_result_obj else training_step_output.batch_loss
closure_loss = closure_loss / self.accumulate_grad_batches
# the loss will get scaled for amp. avoid any modifications to it # the loss will get scaled for amp. avoid any modifications to it
untouched_loss = closure_loss.detach().clone() untouched_loss = closure_loss.detach().clone()
@ -829,7 +935,11 @@ class TrainerTrainLoopMixin(ABC):
# once backward has been applied, release graph # once backward has been applied, release graph
closure_loss = closure_loss.detach() closure_loss = closure_loss.detach()
training_step_output.batch_loss = training_step_output.batch_loss.detach()
if is_result_obj:
training_step_output.detach()
else:
training_step_output.batch_loss = training_step_output.batch_loss.detach()
if self.use_horovod: if self.use_horovod:
# Synchronize Horovod to ensure gradient manipulations (e.g., loss scaling) are valid # Synchronize Horovod to ensure gradient manipulations (e.g., loss scaling) are valid
@ -841,6 +951,9 @@ class TrainerTrainLoopMixin(ABC):
with self.profiler.profile('on_after_backward'): with self.profiler.profile('on_after_backward'):
model_ref.on_after_backward() model_ref.on_after_backward()
# when in dev debugging track the losses
self.dev_debugger.track_train_loss_history(batch_idx, untouched_loss.detach())
result = AttributeDict( result = AttributeDict(
loss=untouched_loss, loss=untouched_loss,
training_step_output=training_step_output, training_step_output=training_step_output,
@ -963,6 +1076,7 @@ class TrainerTrainLoopMixin(ABC):
if self.is_overridden('training_step_end'): if self.is_overridden('training_step_end'):
model_ref = self.get_model() model_ref = self.get_model()
with self.profiler.profile('training_step_end'): with self.profiler.profile('training_step_end'):
# TODO: modify when using result obj
output = model_ref.training_step_end(output) output = model_ref.training_step_end(output)
# allow any mode to define training_end # allow any mode to define training_end

View File

@ -0,0 +1,54 @@
import os
class InternalDebugger(object):
def __init__(self, trainer):
self.enabled = 'PL_DEV_DEBUG' in os.environ
self.trainer = trainer
self.logged_metrics = []
self.pbar_added_metrics = []
self.saved_losses = []
self.early_stopping_history = []
self.checkpoint_callback_history = []
def track_logged_metrics_history(self, scalar_metrics):
if self.enabled:
scalar_metrics['global_step'] = self.trainer.global_step
self.logged_metrics.append(scalar_metrics)
def track_train_loss_history(self, batch_idx, loss):
if self.enabled:
loss_dict = {'batch_idx': batch_idx, 'epoch': self.trainer.current_epoch, 'loss': loss.detach()}
self.saved_losses.append(loss_dict)
def track_pbar_metrics_history(self, metrics):
if self.enabled:
metrics['debug_epoch'] = self.trainer.current_epoch
self.pbar_added_metrics.append(metrics)
def track_early_stopping_history(self, current):
if self.enabled:
es = self.trainer.early_stop_callback
debug_dict = {
'epoch': self.trainer.current_epoch,
'global_step': self.trainer.global_step,
'rank': self.trainer.global_rank,
'current': current,
'best': es.best_score,
'patience': es.wait_count
}
self.early_stopping_history.append(debug_dict)
def track_checkpointing_history(self, filepath):
if self.enabled:
cb = self.trainer.checkpoint_callback
debug_dict = {
'epoch': self.trainer.current_epoch,
'global_step': self.trainer.global_step,
'monitor': cb.monitor,
'rank': self.trainer.global_rank,
'filepath': filepath
}
self.checkpoint_callback_history.append(debug_dict)

View File

@ -1,5 +1,6 @@
import inspect import inspect
from argparse import Namespace from argparse import Namespace
from typing import Dict
def str_to_bool(val): def str_to_bool(val):
@ -93,7 +94,7 @@ def collect_init_args(frame, path_args: list, inside: bool = False) -> list:
return path_args return path_args
class AttributeDict(dict): class AttributeDict(Dict):
"""Extended dictionary accesisable with dot notation. """Extended dictionary accesisable with dot notation.
>>> ad = AttributeDict({'key1': 1, 'key2': 'abc'}) >>> ad = AttributeDict({'key1': 1, 'key2': 'abc'})

View File

@ -2,6 +2,7 @@ import numpy as np
import torch import torch
from torch import nn from torch import nn
from torch.utils.data import Dataset, DataLoader from torch.utils.data import Dataset, DataLoader
from pytorch_lightning import TrainResult
from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.core.lightning import LightningModule
@ -19,6 +20,8 @@ class DeterministicModel(LightningModule):
self.validation_step_end_called = False self.validation_step_end_called = False
self.validation_epoch_end_called = False self.validation_epoch_end_called = False
self.assert_backward = True
self.l1 = nn.Linear(2, 3, bias=False) self.l1 = nn.Linear(2, 3, bias=False)
if weights is None: if weights is None:
weights = torch.tensor([ weights = torch.tensor([
@ -33,13 +36,15 @@ class DeterministicModel(LightningModule):
def step(self, batch, batch_idx): def step(self, batch, batch_idx):
x = batch x = batch
y_hat = self(x) bs = x.size(0)
y_hat = self.l1(x)
print(x.device, self.device, self.l1.weight.device)
test_hat = y_hat.cpu().detach() test_hat = y_hat.cpu().detach()
assert torch.all(test_hat[:, 0] == 15.0) assert torch.all(test_hat[:, 0] == 15.0)
assert torch.all(test_hat[:, 1] == 42.0) assert torch.all(test_hat[:, 1] == 42.0)
out = y_hat.sum() out = y_hat.sum()
assert out == (42.0 * 3) + (15.0 * 3) assert out == (42.0 * bs) + (15.0 * bs)
return out return out
@ -97,6 +102,105 @@ class DeterministicModel(LightningModule):
prototype_loss = outputs[0] prototype_loss = outputs[0]
return prototype_loss return prototype_loss
def training_step_no_default_callbacks_for_train_loop(self, batch, batch_idx):
"""
Early stop and checkpoint only on these values
"""
acc = self.step(batch, batch_idx)
result = TrainResult(minimize=acc)
assert 'early_step_on' not in result
assert 'checkpoint_on' in result
return result
def training_step_no_callbacks_result_obj(self, batch, batch_idx):
"""
Early stop and checkpoint only on these values
"""
acc = self.step(batch, batch_idx)
result = TrainResult(minimize=acc, checkpoint_on=False)
assert 'early_step_on' not in result
assert 'checkpoint_on' not in result
return result
def training_step_result_log_epoch_and_step_for_callbacks(self, batch, batch_idx):
"""
Early stop and checkpoint only on these values
"""
acc = self.step(batch, batch_idx)
self.assert_backward = False
losses = [20, 19, 18, 10, 15, 14, 9, 11, 11, 20]
idx = self.current_epoch
loss = acc + losses[idx]
result = TrainResult(minimize=loss, early_stop_on=loss, checkpoint_on=loss)
return result
def training_step_result_log_step_only(self, batch, batch_idx):
acc = self.step(batch, batch_idx)
result = TrainResult(minimize=acc)
# step only metrics
result.log(f'step_log_and_pbar_acc1_b{batch_idx}', torch.tensor(11).type_as(acc), prog_bar=True)
result.log(f'step_log_acc2_b{batch_idx}', torch.tensor(12).type_as(acc))
result.log(f'step_pbar_acc3_b{batch_idx}', torch.tensor(13).type_as(acc), logger=False, prog_bar=True)
self.training_step_called = True
return result
def training_step_result_log_epoch_only(self, batch, batch_idx):
acc = self.step(batch, batch_idx)
result = TrainResult(minimize=acc)
result.log(f'epoch_log_and_pbar_acc1_e{self.current_epoch}', torch.tensor(14).type_as(acc),
on_epoch=True, prog_bar=True, on_step=False)
result.log(f'epoch_log_acc2_e{self.current_epoch}', torch.tensor(15).type_as(acc),
on_epoch=True, on_step=False)
result.log(f'epoch_pbar_acc3_e{self.current_epoch}', torch.tensor(16).type_as(acc),
on_epoch=True, logger=False, prog_bar=True, on_step=False)
self.training_step_called = True
return result
def training_step_result_log_epoch_and_step(self, batch, batch_idx):
acc = self.step(batch, batch_idx)
result = TrainResult(minimize=acc)
val_1 = (5 + batch_idx) * (self.current_epoch + 1)
val_2 = (6 + batch_idx) * (self.current_epoch + 1)
val_3 = (7 + batch_idx) * (self.current_epoch + 1)
result.log(f'step_epoch_log_and_pbar_acc1', torch.tensor(val_1).type_as(acc),
on_epoch=True, prog_bar=True)
result.log(f'step_epoch_log_acc2', torch.tensor(val_2).type_as(acc),
on_epoch=True)
result.log(f'step_epoch_pbar_acc3', torch.tensor(val_3).type_as(acc),
on_epoch=True, logger=False, prog_bar=True)
self.training_step_called = True
return result
def training_epoch_end_return_for_log_epoch_and_step(self, result):
"""
There should be an array of scalars without graphs that are all 171 (4 of them)
"""
self.training_epoch_end_called = True
if self.use_dp or self.use_ddp2:
pass
else:
# only saw 4 batches
assert isinstance(result, TrainResult)
result.step_epoch_log_and_pbar_acc1 = result.step_epoch_log_and_pbar_acc1.prod()
result.step_epoch_log_acc2 = result.step_epoch_log_acc2.prod()
result.step_epoch_pbar_acc3 = result.step_epoch_pbar_acc3.prod()
result.log('epoch_end_log_acc', torch.tensor(1212).type_as(result.step_epoch_log_acc2),
logger=True, on_epoch=True)
result.log('epoch_end_pbar_acc', torch.tensor(1213).type_as(result.step_epoch_log_acc2),
logger=False, prog_bar=True, on_epoch=True)
result.log('epoch_end_log_pbar_acc', torch.tensor(1214).type_as(result.step_epoch_log_acc2),
logger=True, prog_bar=True, on_epoch=True)
return result
# -------------------------- # --------------------------
# dictionary returns # dictionary returns
# -------------------------- # --------------------------
@ -231,10 +335,12 @@ class DeterministicModel(LightningModule):
return torch.optim.Adam(self.parameters(), lr=0) return torch.optim.Adam(self.parameters(), lr=0)
def backward(self, trainer, loss, optimizer, optimizer_idx): def backward(self, trainer, loss, optimizer, optimizer_idx):
if self.trainer.precision == 16: if self.assert_backward:
assert loss > 171 * 1000 if self.trainer.precision == 16:
else: assert loss > 171 * 1000
assert loss == 171.0 else:
assert loss == 171.0
loss.backward() loss.backward()

View File

@ -63,6 +63,9 @@ class EvalModelTemplate(
self.hidden_dim = hidden_dim self.hidden_dim = hidden_dim
self.b1 = b1 self.b1 = b1
self.b2 = b2 self.b2 = b2
self.training_step_called = False
self.training_step_end_called = False
self.training_epoch_end_called = False
# if you specify an example input, the summary will show input/output for each layer # if you specify an example input, the summary will show input/output for each layer
# TODO: to be fixed in #1773 # TODO: to be fixed in #1773

View File

@ -1,6 +1,7 @@
import math import math
from abc import ABC from abc import ABC
from collections import OrderedDict from collections import OrderedDict
from pytorch_lightning import TrainResult
import torch import torch
@ -38,3 +39,35 @@ class TrainingStepVariations(ABC):
else: else:
output /= 0 output /= 0
return output return output
def training_step_full_loop_result_obj_dp(self, batch, batch_idx, optimizer_idx=None):
"""
Full loop flow train step (result obj + dp)
"""
x, y = batch
x = x.view(x.size(0), -1)
y_hat = self(x.to(self.device))
loss_val = y_hat.sum()
result = TrainResult(minimize=loss_val)
result.log('train_step_metric', loss_val + 1)
self.training_step_called = True
return result
def training_step_end_full_loop_result_obj_dp(self, result):
"""
Full loop flow train step (result obj + dp)
"""
result.minimize = result.minimize.mean()
result.checkpoint_on = result.checkpoint_on.mean()
result.train_step_metric = result.train_step_metric.mean()
result.log('train_step_end_metric', 1)
self.training_step_end_called = True
return result
def training_epoch_end_full_loop_result_obj_dp(self, result):
"""
Full loop flow train step (result obj + dp)
"""
result.log('train_epoch_end_metric', 1, on_epoch=True)
self.training_epoch_end_called = True
return result

View File

@ -35,7 +35,6 @@ class ValidationEpochEndVariations(ABC):
Args: Args:
outputs: list of individual outputs of each validation step outputs: list of individual outputs of each validation step
""" """
# if returned a scalar from validation_step, outputs is a list of tensor scalars # if returned a scalar from validation_step, outputs is a list of tensor scalars
# we return just the average in this case (if we want) # we return just the average in this case (if we want)
def _mean(res, key): def _mean(res, key):

View File

@ -78,11 +78,11 @@ class ModelCheckpointTestInvocations(ModelCheckpoint):
self.count = 0 self.count = 0
self.expected_count = expected_count self.expected_count = expected_count
def _save_model(self, filepath): def _save_model(self, filepath, trainer, pl_module):
# make sure we don't save twice # make sure we don't save twice
assert not os.path.isfile(filepath) assert not os.path.isfile(filepath)
self.count += 1 self.count += 1
super()._save_model(filepath) super()._save_model(filepath, trainer, pl_module)
def on_train_end(self, trainer, pl_module): def on_train_end(self, trainer, pl_module):
super().on_train_end(trainer, pl_module) super().on_train_end(trainer, pl_module)

View File

@ -1,43 +1,12 @@
import numpy as np import numpy as np
import pytest import pytest
import os
from pytorch_lightning import Trainer from pytorch_lightning import Trainer
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.utilities import rank_zero_only
from tests.base import EvalModelTemplate from tests.base import EvalModelTemplate
from tests.base.develop_utils import reset_seed from tests.base.develop_utils import reset_seed
class OnlyMetricsListLogger(LightningLoggerBase):
def __init__(self):
super().__init__()
self.metrics = []
@rank_zero_only
def log_metrics(self, metrics, step):
self.metrics.append(metrics)
@property
def experiment(self):
return 'test'
@rank_zero_only
def log_hyperparams(self, params):
pass
@rank_zero_only
def finalize(self, status):
pass
@property
def name(self):
return 'name'
@property
def version(self):
return '1'
class ModelWithManualGradTracker(EvalModelTemplate): class ModelWithManualGradTracker(EvalModelTemplate):
def __init__(self, norm_type, *args, **kwargs): def __init__(self, norm_type, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -75,28 +44,29 @@ class ModelWithManualGradTracker(EvalModelTemplate):
@pytest.mark.parametrize("norm_type", [1., 1.25, 1.5, 2, 3, 5, 10, 'inf']) @pytest.mark.parametrize("norm_type", [1., 1.25, 1.5, 2, 3, 5, 10, 'inf'])
def test_grad_tracking(tmpdir, norm_type, rtol=5e-3): def test_grad_tracking(tmpdir, norm_type, rtol=5e-3):
# rtol=5e-3 respects the 3 decmials rounding in `.grad_norms` and above os.environ['PL_DEV_DEBUG'] = '1'
# rtol=5e-3 respects the 3 decimals rounding in `.grad_norms` and above
reset_seed() reset_seed()
# use a custom grad tracking module and a list logger # use a custom grad tracking module and a list logger
model = ModelWithManualGradTracker(norm_type) model = ModelWithManualGradTracker(norm_type)
logger = OnlyMetricsListLogger()
trainer = Trainer( trainer = Trainer(
default_root_dir=tmpdir, default_root_dir=tmpdir,
max_epochs=3, max_epochs=3,
logger=logger,
track_grad_norm=norm_type, track_grad_norm=norm_type,
row_log_interval=1, # request grad_norms every batch row_log_interval=1, # request grad_norms every batch
) )
result = trainer.fit(model) result = trainer.fit(model)
assert result == 1, "Training failed" assert result == 1, "Training failed"
assert len(logger.metrics) == len(model.stored_grad_norms) logged_metrics = trainer.dev_debugger.logged_metrics
assert len(logged_metrics) == len(model.stored_grad_norms)
# compare the logged metrics against tracked norms on `.backward` # compare the logged metrics against tracked norms on `.backward`
for mod, log in zip(model.stored_grad_norms, logger.metrics): for mod, log in zip(model.stored_grad_norms, logged_metrics):
common = mod.keys() & log.keys() common = mod.keys() & log.keys()
log, mod = [log[k] for k in common], [mod[k] for k in common] log, mod = [log[k] for k in common], [mod[k] for k in common]

View File

@ -589,7 +589,7 @@ def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k):
with pytest.raises(FileNotFoundError): with pytest.raises(FileNotFoundError):
trainer.test(ckpt_path='random.ckpt') trainer.test(ckpt_path='random.ckpt')
else: else:
ckpt_path = str(list((Path(tmpdir) / 'lightning_logs/version_0/checkpoints').iterdir())[0].absolute()) ckpt_path = str(list((Path(tmpdir) / f'lightning_logs/version_{trainer.logger.version}/checkpoints').iterdir())[0].absolute())
trainer.test(ckpt_path=ckpt_path) trainer.test(ckpt_path=ckpt_path)
assert trainer.tested_ckpt_path == ckpt_path assert trainer.tested_ckpt_path == ckpt_path

View File

@ -0,0 +1,518 @@
"""
Tests to ensure that the training loop works with a dict
"""
import os
import torch
from pytorch_lightning import Trainer
from tests.base.deterministic_model import DeterministicModel
from pytorch_lightning.core.step_result import Result, TrainResult, EvalResult
from tests.base import EvalModelTemplate
import pytest
# test with train_step_end
# add logging + row interval tests
def test_training_step_result_log_step_only(tmpdir):
"""
Tests that only training_step can be used with TrainResult
Makes sure that things are routed to pbar, loggers and loss accordingly
Makes sure pbar and logs happen on step only when requested
"""
# enable internal debugging actions
os.environ['PL_DEV_DEBUG'] = '1'
model = DeterministicModel()
model.training_step = model.training_step_result_log_step_only
model.training_step_end = None
model.training_epoch_end = None
model.val_dataloader = None
batches = 3
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=batches,
limit_val_batches=batches,
row_log_interval=1,
max_epochs=1,
weights_summary=None,
)
trainer.fit(model)
# make sure correct steps were called
assert model.training_step_called
assert not model.training_step_end_called
assert not model.training_epoch_end_called
# make sure correct metrics are logged (one per batch step as requested)
assert len(trainer.dev_debugger.logged_metrics) == batches
for batch_idx, logged_metrics in enumerate(trainer.dev_debugger.logged_metrics):
assert logged_metrics[f'step_log_and_pbar_acc1_b{batch_idx}'] == 11.0
assert logged_metrics[f'step_log_acc2_b{batch_idx}'] == 12.0
assert f'step_pbar_acc3_b{batch_idx}' not in logged_metrics
assert len(logged_metrics) == 4
# make sure we are using the correct metrics for callbacks
assert trainer.callback_metrics['checkpoint_on'] == 171
# make sure pbar metrics are correct ang log metrics did not leak
for batch_idx in range(batches):
assert trainer.progress_bar_metrics[f'step_log_and_pbar_acc1_b{batch_idx}'] == 11
assert trainer.progress_bar_metrics[f'step_pbar_acc3_b{batch_idx}'] == 13
assert f'step_log_acc2_b{batch_idx}' not in trainer.progress_bar_metrics
# make sure training outputs what is expected
for batch_idx, batch in enumerate(model.train_dataloader()):
break
out = trainer.run_training_batch(batch, batch_idx)
assert out.signal == 0
assert out.batch_log_metrics[f'step_log_and_pbar_acc1_b{batch_idx}'] == 11.0
assert out.batch_log_metrics[f'step_log_acc2_b{batch_idx}'] == 12.0
train_step_out = out.training_step_output_for_epoch_end
assert isinstance(train_step_out, TrainResult)
assert 'minimize' in train_step_out
assert f'step_log_and_pbar_acc1_b{batch_idx}' in train_step_out
assert f'step_log_acc2_b{batch_idx}' in train_step_out
# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3)
def test_training_step_result_log_epoch_only(tmpdir):
"""
Tests that only training_step can be used with TrainResult
Makes sure that things are routed to pbar, loggers and loss accordingly
Makes sure pbar and logs happen on epoch only when requested
"""
# enable internal debugging actions
os.environ['PL_DEV_DEBUG'] = '1'
model = DeterministicModel()
model.training_step = model.training_step_result_log_epoch_only
model.training_step_end = None
model.training_epoch_end = None
model.val_dataloader = None
epochs = 3
batches = 2
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=batches,
limit_val_batches=batches,
row_log_interval=1,
max_epochs=epochs,
weights_summary=None,
)
trainer.fit(model)
# make sure correct steps were called
assert model.training_step_called
assert not model.training_step_end_called
assert not model.training_epoch_end_called
# make sure correct metrics are logged (one per batch step as requested)
assert len(trainer.dev_debugger.logged_metrics) == epochs
epoch_metrics = trainer.dev_debugger.logged_metrics
assert len(epoch_metrics) == epochs
for batch_idx, logged_metrics in enumerate(epoch_metrics):
assert logged_metrics[f'epoch_log_and_pbar_acc1_e{batch_idx}'] == 14.0
assert logged_metrics[f'epoch_log_acc2_e{batch_idx}'] == 15.0
assert f'epoch_pbar_acc3_e{batch_idx}' not in logged_metrics
assert len(logged_metrics) == 4
# make sure we are using the correct metrics for callbacks
assert trainer.callback_metrics['checkpoint_on'] == 171
# make sure pbar metrics are correct ang log metrics did not leak
for epoch_idx in range(epochs):
assert trainer.progress_bar_metrics[f'epoch_log_and_pbar_acc1_e{epoch_idx}'] == 14
assert trainer.progress_bar_metrics[f'epoch_pbar_acc3_e{epoch_idx}'] == 16
assert f'epoch_log_acc2_e{epoch_idx}' not in trainer.progress_bar_metrics
# make sure training outputs what is expected
for batch_idx, batch in enumerate(model.train_dataloader()):
break
out = trainer.run_training_batch(batch, batch_idx)
assert out.signal == 0
assert len(out.batch_log_metrics) == 0
train_step_out = out.training_step_output_for_epoch_end
assert isinstance(train_step_out, TrainResult)
assert 'minimize' in train_step_out
assert f'epoch_log_and_pbar_acc1_e{trainer.current_epoch}' in train_step_out
assert f'epoch_log_acc2_e{trainer.current_epoch}' in train_step_out
# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3)
def test_training_step_result_log_step_and_epoch(tmpdir):
"""
Tests that only training_step can be used with TrainResult
Makes sure that things are routed to pbar, loggers and loss accordingly
Makes sure pbar and logs happen on epoch only when requested
"""
# enable internal debugging actions
os.environ['PL_DEV_DEBUG'] = '1'
model = DeterministicModel()
model.training_step = model.training_step_result_log_epoch_and_step
model.training_step_end = None
model.training_epoch_end = None
model.val_dataloader = None
epochs = 3
batches = 2
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=batches,
limit_val_batches=batches,
row_log_interval=1,
max_epochs=epochs,
weights_summary=None,
)
trainer.fit(model)
# make sure correct steps were called
assert model.training_step_called
assert not model.training_step_end_called
assert not model.training_epoch_end_called
# make sure correct metrics are logged (one per batch step as requested)
assert len(trainer.dev_debugger.logged_metrics) == (epochs * batches) + epochs
epoch_metrics = trainer.dev_debugger.logged_metrics
epoch_idx = -1
for i_start in range(0, len(epoch_metrics), batches + 1):
epoch_idx += 1
epoch_outputs = epoch_metrics[i_start: i_start + batches + 1]
mean_vals = {
'step_epoch_log_and_pbar_acc1': [],
'step_epoch_log_acc2': []
}
# make sure each batch logged the expected value
for batch_idx in range(len(epoch_outputs) - 1):
logged_metrics = epoch_outputs[batch_idx]
expected_val_1 = (5 + batch_idx) * (epoch_idx + 1)
expected_val_2 = (6 + batch_idx) * (epoch_idx + 1)
mean_vals['step_epoch_log_and_pbar_acc1'].append(torch.tensor(expected_val_1).float())
mean_vals['step_epoch_log_acc2'].append(torch.tensor(expected_val_2).float())
assert logged_metrics['step_epoch_log_and_pbar_acc1'] == expected_val_1
assert logged_metrics['step_epoch_log_acc2'] == expected_val_2
assert 'step_epoch_pbar_acc3' not in logged_metrics
assert len(logged_metrics) == 4
# make sure the metrics for the epoch end are actual means (the default reduce fx) or all the batches
epoch_end_metrics = epoch_outputs[-1]
eval_1 = torch.stack(mean_vals['step_epoch_log_and_pbar_acc1']).mean()
eval_2 = torch.stack(mean_vals['step_epoch_log_acc2']).mean()
assert epoch_end_metrics['step_epoch_log_and_pbar_acc1'] == eval_1
assert epoch_end_metrics['step_epoch_log_acc2'] == eval_2
assert 'step_epoch_pbar_acc3' not in epoch_end_metrics
assert len(logged_metrics) == 4
# make sure we are using the correct metrics for callbacks
assert trainer.callback_metrics['checkpoint_on'] == 171
# -------------------------------
# VERIFY PBAR METRICS
# -------------------------------
# make sure pbar metrics are correct ang log metrics did not leak
all_pbar_metrics = trainer.dev_debugger.pbar_added_metrics
assert len(all_pbar_metrics) == (epochs * batches) + epochs
epoch_idx = -1
for i_start in range(0, len(all_pbar_metrics), batches + 1):
epoch_idx += 1
epoch_outputs = all_pbar_metrics[i_start: i_start + batches + 1]
mean_vals = {
'step_epoch_log_and_pbar_acc1': [],
'step_epoch_pbar_acc3': []
}
# make sure each batch logged the expected value
for batch_idx in range(len(epoch_outputs) - 1):
logged_metrics = epoch_outputs[batch_idx]
expected_val_1 = (5 + batch_idx) * (epoch_idx + 1)
expected_val_2 = (7 + batch_idx) * (epoch_idx + 1)
mean_vals['step_epoch_log_and_pbar_acc1'].append(torch.tensor(expected_val_1).float())
mean_vals['step_epoch_pbar_acc3'].append(torch.tensor(expected_val_2).float())
assert logged_metrics['step_epoch_log_and_pbar_acc1'] == expected_val_1
assert logged_metrics['step_epoch_pbar_acc3'] == expected_val_2
assert 'step_epoch_log_acc2' not in logged_metrics
assert len(logged_metrics) == 3
# make sure the metrics for the epoch end are actual means (the default reduce fx) or all the batches
epoch_end_metrics = epoch_outputs[-1]
eval_1 = torch.stack(mean_vals['step_epoch_log_and_pbar_acc1']).mean()
eval_2 = torch.stack(mean_vals['step_epoch_pbar_acc3']).mean()
assert epoch_end_metrics['step_epoch_log_and_pbar_acc1'] == eval_1
assert epoch_end_metrics['step_epoch_pbar_acc3'] == eval_2
assert 'step_epoch_log_acc2' not in epoch_end_metrics
assert len(logged_metrics) == 3
# -----------------------------------------
# make sure training outputs what is expected
# -----------------------------------------
for batch_idx, batch in enumerate(model.train_dataloader()):
break
out = trainer.run_training_batch(batch, batch_idx)
assert out.signal == 0
assert len(out.batch_log_metrics) == 2
train_step_out = out.training_step_output_for_epoch_end
assert isinstance(train_step_out, TrainResult)
assert 'minimize' in train_step_out
assert 'step_epoch_log_and_pbar_acc1' in train_step_out
assert 'step_epoch_log_acc2' in train_step_out
# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3)
def test_training_step_epoch_end_result(tmpdir):
"""
Makes sure training_step and epoch_end can be used with Results (without batch_end)
"""
os.environ['PL_DEV_DEBUG'] = '1'
model = DeterministicModel()
model.training_step = model.training_step_result_log_epoch_and_step
model.training_epoch_end = model.training_epoch_end_return_for_log_epoch_and_step
model.val_dataloader = None
batches = 3
epochs = 1
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=epochs,
row_log_interval=1,
limit_train_batches=batches,
weights_summary=None,
)
trainer.fit(model)
# make sure correct steps were called
assert model.training_step_called
assert not model.training_step_end_called
assert model.training_epoch_end_called
# make sure correct metrics were logged
logged_metrics = trainer.dev_debugger.logged_metrics
assert len(logged_metrics) == (epochs * batches) + epochs
last_logged = logged_metrics[-1]
assert last_logged['step_epoch_log_and_pbar_acc1'] == 210.0
assert last_logged['step_epoch_log_acc2'] == 336.0
assert last_logged['epoch_end_log_acc'] == 1212.0
assert last_logged['epoch_end_log_pbar_acc'] == 1214.0
assert 'epoch_end_pbar_acc' not in last_logged
# make sure pbar metrics are correct
logged_pbar = trainer.dev_debugger.pbar_added_metrics
assert len(logged_pbar) == (epochs * batches) + epochs
assert trainer.progress_bar_metrics['step_epoch_log_and_pbar_acc1'] == 210.0
assert trainer.progress_bar_metrics['step_epoch_pbar_acc3'] == 504.0
assert trainer.progress_bar_metrics['epoch_end_pbar_acc'] == 1213.0
assert trainer.progress_bar_metrics['epoch_end_log_pbar_acc'] == 1214.0
assert 'epoch_end_log_acc' not in trainer.progress_bar_metrics
assert 'log_acc2' not in trainer.progress_bar_metrics
# make sure callback metrics didn't change
assert trainer.callback_metrics['checkpoint_on'] == 171
# -----------------------------------------
# make sure training outputs what is expected
# -----------------------------------------
for batch_idx, batch in enumerate(model.train_dataloader()):
break
out = trainer.run_training_batch(batch, batch_idx)
assert out.signal == 0
assert len(out.batch_log_metrics) == 2
train_step_out = out.training_step_output_for_epoch_end
assert isinstance(train_step_out, TrainResult)
assert 'minimize' in train_step_out
assert 'step_epoch_log_and_pbar_acc1' in train_step_out
assert 'step_epoch_log_acc2' in train_step_out
# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3)
def test_no_auto_callbacks_with_train_loop_only(tmpdir):
"""
Make sure early stop + checkpoint work with only a train loop
"""
os.environ['PL_DEV_DEBUG'] = '1'
model = DeterministicModel()
model.training_step = model.training_step_no_default_callbacks_for_train_loop
model.training_epoch_end = None
model.val_dataloader = None
batches = 3
epochs = 3
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=epochs,
row_log_interval=1,
limit_train_batches=batches,
weights_summary=None,
)
trainer.fit(model)
all_losses = trainer.dev_debugger.saved_losses
assert len(all_losses) == batches * epochs
assert trainer.checkpoint_callback.monitor == 'checkpoint_on'
assert trainer.early_stop_callback is None
trainer = Trainer(
default_root_dir=tmpdir,
early_stop_callback=True,
max_epochs=epochs,
row_log_interval=1,
limit_train_batches=batches,
weights_summary=None,
)
trainer.fit(model)
assert trainer.early_stop_callback.monitor == 'val_loss'
def test_no_callbacks_with_train_loop_only(tmpdir):
"""
Make sure early stop + checkpoint work with only a train loop
"""
os.environ['PL_DEV_DEBUG'] = '1'
model = DeterministicModel()
model.training_step = model.training_step_no_callbacks_result_obj
model.training_epoch_end = None
model.val_dataloader = None
batches = 3
epochs = 3
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=epochs,
row_log_interval=1,
limit_train_batches=batches,
weights_summary=None,
)
trainer.fit(model)
all_losses = trainer.dev_debugger.saved_losses
assert len(all_losses) == batches * epochs
assert trainer.early_stop_callback is None
assert len(trainer.dev_debugger.checkpoint_callback_history) == 0
assert len(trainer.dev_debugger.early_stopping_history) == 0
def test_use_callbacks_with_train_loop_only(tmpdir):
os.environ['PL_DEV_DEBUG'] = '1'
model = DeterministicModel()
model.training_step = model.training_step_result_log_epoch_and_step_for_callbacks
model.training_epoch_end = None
model.val_dataloader = None
batches = 3
epochs = 300
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=epochs,
early_stop_callback=True,
row_log_interval=1,
limit_train_batches=batches,
weights_summary=None,
)
trainer.fit(model)
num_expected_epochs = 10
# ----------------------------------
# VERIFY EARLY STOPPING BEHAVIOR
# ----------------------------------
# with train loop only it happens on every epoch
early_stop_vals = trainer.dev_debugger.early_stopping_history
assert len(early_stop_vals) == num_expected_epochs
min_val = min([x['best'] for x in early_stop_vals])
assert min_val == 171 + 9
all_losses = trainer.dev_debugger.saved_losses
from collections import Counter
batch_idxs = Counter([x['batch_idx'] for x in all_losses])
for i, val in batch_idxs.items():
assert val == num_expected_epochs
assert i in [0, 1, 2]
# ----------------------------------
# VERIFY CHECKPOINTING BEHAVIOR
# ----------------------------------
ckpt_vals = trainer.dev_debugger.checkpoint_callback_history
assert len(ckpt_vals) == 5, '5 ckpts should have been saved'
for ckpt_val, expected_epoch in zip(ckpt_vals, [0, 1, 2, 3, 6]):
assert ckpt_val['epoch'] == expected_epoch
assert ckpt_val['monitor'] == 'checkpoint_on'
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_full_train_loop_with_results_obj_dp(tmpdir):
os.environ['PL_DEV_DEBUG'] = '1'
batches = 10
epochs = 3
model = EvalModelTemplate()
model.validation_step = None
model.test_step = None
model.training_step = model.training_step_full_loop_result_obj_dp
model.training_step_end = model.training_step_end_full_loop_result_obj_dp
model.training_epoch_end = model.training_epoch_end_full_loop_result_obj_dp
model.val_dataloader = None
model.test_dataloader = None
trainer = Trainer(
default_root_dir=tmpdir,
distributed_backend='dp',
gpus=[0, 1],
max_epochs=epochs,
early_stop_callback=True,
row_log_interval=2,
limit_train_batches=batches,
weights_summary=None,
)
trainer.fit(model)
# make sure we saw all the correct keys
seen_keys = set()
for metric in trainer.dev_debugger.logged_metrics:
seen_keys.update(metric.keys())
assert 'train_step_metric' in seen_keys
assert 'train_step_end_metric' in seen_keys
assert 'train_epoch_end_metric' in seen_keys