Structured results (train loop only. val loop separate PR) (PR 2/5) (#2615)
* r * r * r * patched optimizer closure with sr * patched optimizer closure with sr * patched optimizer closure with sr * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added autoreduce for train step * added auto reduce on train * added auto reduce on train * added auto reduce on train * added auto reduce on train * added auto reduce on train * added auto reduce on train * added hooks * added hooks * added hooks * added hooks * added hooks * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * cache * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * Update pytorch_lightning/callbacks/early_stopping.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/callbacks/early_stopping.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/callbacks/early_stopping.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/callbacks/model_checkpoint.py * Update pytorch_lightning/core/step_result.py * finished tests for structured results on train epoch * finished tests for structured results on train epoch * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> * simple * finished tests for structured results on train epoch * simple * simple * revert * finished tests for structured results on train epoch * finished tests for structured results on train epoch * Update tests/base/deterministic_model.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * finished tests for structured results on train epoch * docstring typos * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * Update pytorch_lightning/core/step_result.py Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Update pytorch_lightning/overrides/data_parallel.py Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
This commit is contained in:
parent
816d8cff06
commit
6d10ac2ac8
|
@ -82,9 +82,9 @@ jobs:
|
||||||
uses: actions/cache@v1
|
uses: actions/cache@v1
|
||||||
with:
|
with:
|
||||||
path: ${{ steps.pip-cache.outputs.dir }}
|
path: ${{ steps.pip-cache.outputs.dir }}
|
||||||
key: ${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}-pip-${{ hashFiles('requirements/base.txt') }}-${{ hashFiles('requirements/extra.txt') }}
|
key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ matrix.requires }}-pip-${{ hashFiles('requirements/base.txt') }}-${{ hashFiles('requirements/extra.txt') }}
|
||||||
restore-keys: |
|
restore-keys: |
|
||||||
${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}-pip-
|
${{ runner.os }}-pip-${{ matrix.python-version }}-${{ matrix.requires }}-pip-
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
|
|
|
@ -55,6 +55,7 @@ else:
|
||||||
from pytorch_lightning.trainer import Trainer
|
from pytorch_lightning.trainer import Trainer
|
||||||
from pytorch_lightning.utilities.seed import seed_everything
|
from pytorch_lightning.utilities.seed import seed_everything
|
||||||
from pytorch_lightning import metrics
|
from pytorch_lightning import metrics
|
||||||
|
from pytorch_lightning.core.step_result import TrainResult, EvalResult
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'Trainer',
|
'Trainer',
|
||||||
|
@ -62,7 +63,9 @@ else:
|
||||||
'Callback',
|
'Callback',
|
||||||
'data_loader',
|
'data_loader',
|
||||||
'seed_everything',
|
'seed_everything',
|
||||||
'metrics'
|
'metrics',
|
||||||
|
'EvalResult',
|
||||||
|
'TrainResult'
|
||||||
]
|
]
|
||||||
|
|
||||||
# necessary for regular bolts imports. Skip exception since bolts is not always installed
|
# necessary for regular bolts imports. Skip exception since bolts is not always installed
|
||||||
|
|
|
@ -46,6 +46,30 @@ class Callback(abc.ABC):
|
||||||
"""Called when the validation sanity check ends."""
|
"""Called when the validation sanity check ends."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def on_train_epoch_start(self, trainer, pl_module):
|
||||||
|
"""Called when the train epoch begins."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_train_epoch_end(self, trainer, pl_module):
|
||||||
|
"""Called when the train epoch ends."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_validation_epoch_start(self, trainer, pl_module):
|
||||||
|
"""Called when the val epoch begins."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_validation_epoch_end(self, trainer, pl_module):
|
||||||
|
"""Called when the val epoch ends."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_test_epoch_start(self, trainer, pl_module):
|
||||||
|
"""Called when the test epoch begins."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_test_epoch_end(self, trainer, pl_module):
|
||||||
|
"""Called when the test epoch ends."""
|
||||||
|
pass
|
||||||
|
|
||||||
def on_epoch_start(self, trainer, pl_module):
|
def on_epoch_start(self, trainer, pl_module):
|
||||||
"""Called when the epoch begins."""
|
"""Called when the epoch begins."""
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -7,6 +7,7 @@ Monitor a validation metric and stop training when it stops improving.
|
||||||
"""
|
"""
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
|
||||||
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -140,12 +141,33 @@ class EarlyStopping(Callback):
|
||||||
def on_validation_end(self, trainer, pl_module):
|
def on_validation_end(self, trainer, pl_module):
|
||||||
self._run_early_stopping_check(trainer, pl_module)
|
self._run_early_stopping_check(trainer, pl_module)
|
||||||
|
|
||||||
|
def on_train_epoch_end(self, trainer, pl_module):
|
||||||
|
# early stopping can also work in the train loop when there is no val loop and when using structured results
|
||||||
|
should_check_early_stop = False
|
||||||
|
train_es_key = 'early_stop_on'
|
||||||
|
if trainer.callback_metrics.get(train_es_key, None) is not None:
|
||||||
|
self.monitor = train_es_key
|
||||||
|
should_check_early_stop = True
|
||||||
|
|
||||||
|
val_es_key = 'val_early_stop_on'
|
||||||
|
if trainer.callback_metrics.get(val_es_key, None) is not None:
|
||||||
|
self.monitor = val_es_key
|
||||||
|
should_check_early_stop = True
|
||||||
|
|
||||||
|
if should_check_early_stop:
|
||||||
|
self._run_early_stopping_check(trainer, pl_module)
|
||||||
|
|
||||||
def _run_early_stopping_check(self, trainer, pl_module):
|
def _run_early_stopping_check(self, trainer, pl_module):
|
||||||
logs = trainer.callback_metrics
|
logs = trainer.callback_metrics
|
||||||
|
|
||||||
if not self._validate_condition_metric(logs):
|
if not self._validate_condition_metric(logs):
|
||||||
return # short circuit if metric not present
|
return # short circuit if metric not present
|
||||||
|
|
||||||
current = logs.get(self.monitor)
|
current = logs.get(self.monitor)
|
||||||
|
|
||||||
|
# when in dev debugging
|
||||||
|
trainer.dev_debugger.track_early_stopping_history(current)
|
||||||
|
|
||||||
if not isinstance(current, torch.Tensor):
|
if not isinstance(current, torch.Tensor):
|
||||||
current = torch.tensor(current, device=pl_module.device)
|
current = torch.tensor(current, device=pl_module.device)
|
||||||
|
|
||||||
|
|
|
@ -159,7 +159,11 @@ class ModelCheckpoint(Callback):
|
||||||
if os.path.isfile(filepath):
|
if os.path.isfile(filepath):
|
||||||
os.remove(filepath)
|
os.remove(filepath)
|
||||||
|
|
||||||
def _save_model(self, filepath):
|
def _save_model(self, filepath, trainer, pl_module):
|
||||||
|
|
||||||
|
# in debugging, track when we save checkpoints
|
||||||
|
trainer.dev_debugger.track_checkpointing_history(filepath)
|
||||||
|
|
||||||
# make paths
|
# make paths
|
||||||
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
||||||
|
|
||||||
|
@ -270,6 +274,11 @@ class ModelCheckpoint(Callback):
|
||||||
|
|
||||||
metrics = trainer.callback_metrics
|
metrics = trainer.callback_metrics
|
||||||
epoch = trainer.current_epoch
|
epoch = trainer.current_epoch
|
||||||
|
|
||||||
|
# support structured results
|
||||||
|
if metrics.get('checkpoint_on') is not None:
|
||||||
|
self.monitor = 'checkpoint_on'
|
||||||
|
|
||||||
if self.save_top_k == 0:
|
if self.save_top_k == 0:
|
||||||
# no models are saved
|
# no models are saved
|
||||||
return
|
return
|
||||||
|
@ -281,7 +290,7 @@ class ModelCheckpoint(Callback):
|
||||||
|
|
||||||
if self.save_last:
|
if self.save_last:
|
||||||
filepath = os.path.join(self.dirpath, self.prefix + 'last.ckpt')
|
filepath = os.path.join(self.dirpath, self.prefix + 'last.ckpt')
|
||||||
self._save_model(filepath)
|
self._save_model(filepath, trainer, pl_module)
|
||||||
|
|
||||||
filepath = self.format_checkpoint_name(epoch, metrics)
|
filepath = self.format_checkpoint_name(epoch, metrics)
|
||||||
version_cnt = 0
|
version_cnt = 0
|
||||||
|
@ -306,7 +315,7 @@ class ModelCheckpoint(Callback):
|
||||||
f'Can save best model only with {self.monitor} available, skipping.', RuntimeWarning
|
f'Can save best model only with {self.monitor} available, skipping.', RuntimeWarning
|
||||||
)
|
)
|
||||||
elif self.check_monitor_top_k(current):
|
elif self.check_monitor_top_k(current):
|
||||||
self._do_check_save(filepath, current, epoch)
|
self._do_check_save(filepath, current, epoch, trainer, pl_module)
|
||||||
elif self.verbose > 0:
|
elif self.verbose > 0:
|
||||||
log.info(f'\nEpoch {epoch:05d}: {self.monitor} was not in top {self.save_top_k}')
|
log.info(f'\nEpoch {epoch:05d}: {self.monitor} was not in top {self.save_top_k}')
|
||||||
|
|
||||||
|
@ -315,9 +324,9 @@ class ModelCheckpoint(Callback):
|
||||||
log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')
|
log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')
|
||||||
|
|
||||||
assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0'
|
assert trainer.global_rank == 0, 'tried to make a checkpoint from non global_rank=0'
|
||||||
self._save_model(filepath)
|
self._save_model(filepath, trainer, pl_module)
|
||||||
|
|
||||||
def _do_check_save(self, filepath, current, epoch):
|
def _do_check_save(self, filepath, current, epoch, trainer, pl_module):
|
||||||
# remove kth
|
# remove kth
|
||||||
|
|
||||||
del_list = []
|
del_list = []
|
||||||
|
@ -343,7 +352,7 @@ class ModelCheckpoint(Callback):
|
||||||
f'\nEpoch {epoch:05d}: {self.monitor} reached'
|
f'\nEpoch {epoch:05d}: {self.monitor} reached'
|
||||||
f' {current:0.5f} (best {self.best_model_score:0.5f}), saving model to'
|
f' {current:0.5f} (best {self.best_model_score:0.5f}), saving model to'
|
||||||
f' {filepath} as top {self.save_top_k}')
|
f' {filepath} as top {self.save_top_k}')
|
||||||
self._save_model(filepath)
|
self._save_model(filepath, trainer, pl_module)
|
||||||
|
|
||||||
for cur_path in del_list:
|
for cur_path in del_list:
|
||||||
if cur_path != filepath:
|
if cur_path != filepath:
|
||||||
|
|
|
@ -115,6 +115,42 @@ class ModelHooks(Module):
|
||||||
"""
|
"""
|
||||||
# do something when the epoch ends
|
# do something when the epoch ends
|
||||||
|
|
||||||
|
def on_train_epoch_start(self) -> None:
|
||||||
|
"""
|
||||||
|
Called in the training loop at the very beginning of the epoch.
|
||||||
|
"""
|
||||||
|
# do something when the epoch starts
|
||||||
|
|
||||||
|
def on_train_epoch_end(self) -> None:
|
||||||
|
"""
|
||||||
|
Called in the training loop at the very end of the epoch.
|
||||||
|
"""
|
||||||
|
# do something when the epoch ends
|
||||||
|
|
||||||
|
def on_validation_epoch_start(self) -> None:
|
||||||
|
"""
|
||||||
|
Called in the validation loop at the very beginning of the epoch.
|
||||||
|
"""
|
||||||
|
# do something when the epoch starts
|
||||||
|
|
||||||
|
def on_validation_epoch_end(self) -> None:
|
||||||
|
"""
|
||||||
|
Called in the validation loop at the very end of the epoch.
|
||||||
|
"""
|
||||||
|
# do something when the epoch ends
|
||||||
|
|
||||||
|
def on_test_epoch_start(self) -> None:
|
||||||
|
"""
|
||||||
|
Called in the test loop at the very beginning of the epoch.
|
||||||
|
"""
|
||||||
|
# do something when the epoch starts
|
||||||
|
|
||||||
|
def on_test_epoch_end(self) -> None:
|
||||||
|
"""
|
||||||
|
Called in the test loop at the very end of the epoch.
|
||||||
|
"""
|
||||||
|
# do something when the epoch ends
|
||||||
|
|
||||||
def on_pre_performance_check(self) -> None:
|
def on_pre_performance_check(self) -> None:
|
||||||
"""
|
"""
|
||||||
Called at the very beginning of the validation loop.
|
Called at the very beginning of the validation loop.
|
||||||
|
|
|
@ -0,0 +1,336 @@
|
||||||
|
from typing import Optional, Dict, Union, Sequence, Callable, MutableMapping, Any
|
||||||
|
from torch import Tensor
|
||||||
|
import torch
|
||||||
|
from copy import copy
|
||||||
|
|
||||||
|
|
||||||
|
class Result(Dict):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
minimize: Optional[Tensor] = None,
|
||||||
|
early_stop_on: Optional[Tensor] = None,
|
||||||
|
checkpoint_on: Union[Tensor, bool, None] = None,
|
||||||
|
hiddens: Optional[Tensor] = None,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
if early_stop_on is not None:
|
||||||
|
self.early_stop_on = early_stop_on
|
||||||
|
if checkpoint_on is not None and checkpoint_on:
|
||||||
|
self.checkpoint_on = checkpoint_on
|
||||||
|
if hiddens is not None:
|
||||||
|
self.hiddens = hiddens
|
||||||
|
if minimize is not None:
|
||||||
|
err = 'Minimize can only be used in training_step, training_step_end, training_epoch_end'
|
||||||
|
self._assert_grad_tensor_metric('minimize', minimize, err)
|
||||||
|
self.minimize = minimize
|
||||||
|
|
||||||
|
if minimize is not None and checkpoint_on is None:
|
||||||
|
self.checkpoint_on = minimize.detach()
|
||||||
|
|
||||||
|
self['meta'] = {
|
||||||
|
'_internal': {
|
||||||
|
'_reduce_on_epoch': False
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def __getattr__(self, key: str) -> Any:
|
||||||
|
try:
|
||||||
|
if key == 'callback_metrics':
|
||||||
|
return self.get_callback_metrics()
|
||||||
|
elif key == 'batch_log_metrics':
|
||||||
|
return self.get_batch_log_metrics()
|
||||||
|
elif key == 'batch_pbar_metrics':
|
||||||
|
return self.get_batch_pbar_metrics()
|
||||||
|
elif key == 'epoch_log_metrics':
|
||||||
|
return self.get_epoch_log_metrics()
|
||||||
|
elif key == 'epoch_pbar_metrics':
|
||||||
|
return self.get_epoch_pbar_metrics()
|
||||||
|
else:
|
||||||
|
return self[key]
|
||||||
|
except KeyError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def __setattr__(self, key: str, val: Union[Tensor, Any]):
|
||||||
|
# ensure reserve keys are tensors and detached
|
||||||
|
if key in {'hiddens', 'checkpoint_on', 'early_stop_on'}:
|
||||||
|
self._assert_tensor_metric(key, val)
|
||||||
|
if val is not None and isinstance(val, torch.Tensor):
|
||||||
|
val = val.detach()
|
||||||
|
|
||||||
|
# ensure anything else that is a tensor is detached
|
||||||
|
elif isinstance(val, torch.Tensor) and key != 'minimize':
|
||||||
|
val = val.detach()
|
||||||
|
|
||||||
|
self[key] = val
|
||||||
|
|
||||||
|
def _assert_tensor_metric(self, name: str, potential_metric: Union[bool, Tensor, None, Any]):
|
||||||
|
if potential_metric is not None and not isinstance(potential_metric, bool):
|
||||||
|
assert isinstance(potential_metric, Tensor), f'{name} must be a torch.Tensor'
|
||||||
|
|
||||||
|
def _assert_grad_tensor_metric(self, name: str, x: Union[torch.Tensor, Any], additional_err: str = ''):
|
||||||
|
if x is not None:
|
||||||
|
assert isinstance(x, Tensor), f'{name} must be a torch.Tensor'
|
||||||
|
m = f'{name} must have a computational graph.'
|
||||||
|
|
||||||
|
if additional_err:
|
||||||
|
m += f' {additional_err}'
|
||||||
|
assert x.grad_fn is not None, m
|
||||||
|
|
||||||
|
def log(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
value: Any,
|
||||||
|
prog_bar: bool = False,
|
||||||
|
logger: bool = True,
|
||||||
|
on_step: bool = False,
|
||||||
|
on_epoch: bool = True,
|
||||||
|
reduce_fx: Callable = torch.mean,
|
||||||
|
enable_graph: bool = False,
|
||||||
|
):
|
||||||
|
# no metrics should be logged with graphs
|
||||||
|
if not enable_graph and isinstance(value, torch.Tensor):
|
||||||
|
value = value.detach()
|
||||||
|
|
||||||
|
if 'meta' not in self:
|
||||||
|
self.__setitem__('meta', {})
|
||||||
|
|
||||||
|
self.__set_meta(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx)
|
||||||
|
|
||||||
|
# set the value
|
||||||
|
self.__setitem__(name, value)
|
||||||
|
|
||||||
|
def __set_meta(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
value: Any,
|
||||||
|
prog_bar: bool,
|
||||||
|
logger: bool,
|
||||||
|
on_step: bool,
|
||||||
|
on_epoch: bool,
|
||||||
|
reduce_fx: Callable,
|
||||||
|
):
|
||||||
|
# set the meta for the item
|
||||||
|
meta_value = value
|
||||||
|
meta = dict(
|
||||||
|
prog_bar=prog_bar,
|
||||||
|
logger=logger,
|
||||||
|
on_step=on_step,
|
||||||
|
on_epoch=on_epoch,
|
||||||
|
reduce_fx=reduce_fx,
|
||||||
|
value=meta_value
|
||||||
|
)
|
||||||
|
self['meta'][name] = meta
|
||||||
|
|
||||||
|
# track whether any input requires reduction on epoch end
|
||||||
|
_internal = self['meta']['_internal']
|
||||||
|
_internal['_reduce_on_epoch'] = max(_internal['_reduce_on_epoch'], on_epoch)
|
||||||
|
|
||||||
|
def get_callback_metrics(self) -> dict:
|
||||||
|
result = {
|
||||||
|
'early_stop_on': self.early_stop_on,
|
||||||
|
'checkpoint_on': self.checkpoint_on
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_batch_log_metrics(self) -> dict:
|
||||||
|
"""
|
||||||
|
Gets the metrics to log at the end of the batch step
|
||||||
|
"""
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
meta = self['meta']
|
||||||
|
for k, options in meta.items():
|
||||||
|
if k == '_internal':
|
||||||
|
continue
|
||||||
|
if options['logger'] and options['on_step']:
|
||||||
|
result[k] = self[k]
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_epoch_log_metrics(self) -> dict:
|
||||||
|
"""
|
||||||
|
Gets the metrics to log at the end of the batch step
|
||||||
|
"""
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
meta = self['meta']
|
||||||
|
for k, options in meta.items():
|
||||||
|
if k == '_internal':
|
||||||
|
continue
|
||||||
|
if options['logger'] and options['on_epoch']:
|
||||||
|
result[k] = self[k]
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_epoch_pbar_metrics(self):
|
||||||
|
"""
|
||||||
|
Gets the metrics to log at the end of the batch step
|
||||||
|
"""
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
meta = self['meta']
|
||||||
|
for k, options in meta.items():
|
||||||
|
if k == '_internal':
|
||||||
|
continue
|
||||||
|
if options['prog_bar'] and options['on_epoch']:
|
||||||
|
result[k] = self[k]
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_batch_pbar_metrics(self):
|
||||||
|
"""
|
||||||
|
Gets the metrics to log at the end of the batch step
|
||||||
|
"""
|
||||||
|
result = {}
|
||||||
|
|
||||||
|
meta = self['meta']
|
||||||
|
for k, options in meta.items():
|
||||||
|
if k == '_internal':
|
||||||
|
continue
|
||||||
|
if options['prog_bar'] and options['on_step']:
|
||||||
|
result[k] = self[k]
|
||||||
|
return result
|
||||||
|
|
||||||
|
def detach(self):
|
||||||
|
for k, v in self.items():
|
||||||
|
if isinstance(v, torch.Tensor):
|
||||||
|
self.__setitem__(k, v.detach())
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
self_copy = self.copy()
|
||||||
|
|
||||||
|
if 'meta' in self_copy:
|
||||||
|
del self_copy['meta']
|
||||||
|
|
||||||
|
return str(self_copy)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
copy = self.copy()
|
||||||
|
del copy['meta']
|
||||||
|
|
||||||
|
return str(copy)
|
||||||
|
|
||||||
|
def __copy__(self):
|
||||||
|
newone = type(self)()
|
||||||
|
for k, v in self.items():
|
||||||
|
newone[k] = copy(v)
|
||||||
|
return newone
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def gather(cls, outputs):
|
||||||
|
meta = outputs[0]['meta']
|
||||||
|
result = cls()
|
||||||
|
result = recursive_gather(outputs, result)
|
||||||
|
recursive_stack(result)
|
||||||
|
result['meta'] = meta
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def reduce_on_epoch_end(cls, outputs):
|
||||||
|
meta = outputs[0]['meta']
|
||||||
|
result = cls()
|
||||||
|
result = recursive_gather(outputs, result)
|
||||||
|
recursive_stack(result)
|
||||||
|
|
||||||
|
for k, option in meta.items():
|
||||||
|
if k == '_internal':
|
||||||
|
continue
|
||||||
|
|
||||||
|
if option['on_epoch']:
|
||||||
|
fx = option['reduce_fx']
|
||||||
|
result[k] = fx(result[k])
|
||||||
|
|
||||||
|
result['meta'] = meta
|
||||||
|
return result
|
||||||
|
|
||||||
|
@property
|
||||||
|
def should_reduce_on_epoch_end(self) -> bool:
|
||||||
|
return self['meta']['_internal']['_reduce_on_epoch']
|
||||||
|
|
||||||
|
|
||||||
|
def recursive_gather(outputs: Sequence[dict], result: Optional[MutableMapping] = None) -> Optional[MutableMapping]:
|
||||||
|
for out in outputs:
|
||||||
|
if 'meta' in out:
|
||||||
|
del out['meta']
|
||||||
|
|
||||||
|
for k, v in out.items():
|
||||||
|
if isinstance(v, dict):
|
||||||
|
v = recursive_gather([v], result)
|
||||||
|
|
||||||
|
if k not in result:
|
||||||
|
result[k] = []
|
||||||
|
|
||||||
|
result[k].append(v)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def recursive_stack(result: MutableMapping):
|
||||||
|
for k, v in result.items():
|
||||||
|
if isinstance(v, dict):
|
||||||
|
recursive_stack(v)
|
||||||
|
|
||||||
|
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], torch.Tensor):
|
||||||
|
v = torch.stack(v)
|
||||||
|
result[k] = v
|
||||||
|
|
||||||
|
|
||||||
|
class TrainResult(Result):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
minimize: Optional[Tensor] = None,
|
||||||
|
early_stop_on: Tensor = None,
|
||||||
|
checkpoint_on: Union[Tensor, bool] = None,
|
||||||
|
hiddens: Optional[Tensor] = None,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__(minimize, early_stop_on, checkpoint_on, hiddens)
|
||||||
|
|
||||||
|
def log(
|
||||||
|
self,
|
||||||
|
name,
|
||||||
|
value,
|
||||||
|
prog_bar: bool = False,
|
||||||
|
logger: bool = True,
|
||||||
|
on_step: bool = True,
|
||||||
|
on_epoch: bool = False,
|
||||||
|
reduce_fx: Callable = torch.mean,
|
||||||
|
enable_graph: bool = False,
|
||||||
|
):
|
||||||
|
super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph)
|
||||||
|
|
||||||
|
|
||||||
|
class EvalResult(Result):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
early_stop_on: Optional[Tensor] = None,
|
||||||
|
checkpoint_on: Optional[Tensor] = None,
|
||||||
|
hiddens: Optional[Tensor] = None,
|
||||||
|
):
|
||||||
|
|
||||||
|
super().__init__(None, early_stop_on, checkpoint_on, hiddens)
|
||||||
|
|
||||||
|
def log(
|
||||||
|
self,
|
||||||
|
name,
|
||||||
|
value,
|
||||||
|
prog_bar: bool = False,
|
||||||
|
logger: bool = True,
|
||||||
|
on_step: bool = False,
|
||||||
|
on_epoch: bool = True,
|
||||||
|
reduce_fx: Callable = torch.mean,
|
||||||
|
enable_graph: bool = False,
|
||||||
|
):
|
||||||
|
super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph)
|
||||||
|
|
||||||
|
|
||||||
|
# if __name__ == '__main__':
|
||||||
|
# import torch
|
||||||
|
# result = TrainResult()
|
||||||
|
# result.hiddens = torch.tensor(1)
|
||||||
|
# result.log('some', 123)
|
||||||
|
# print(result)
|
||||||
|
# result.minimize = torch.tensor(1)
|
|
@ -6,6 +6,7 @@ import torch
|
||||||
from torch.cuda._utils import _get_device_index
|
from torch.cuda._utils import _get_device_index
|
||||||
from torch.nn import DataParallel
|
from torch.nn import DataParallel
|
||||||
from torch.nn.parallel import DistributedDataParallel
|
from torch.nn.parallel import DistributedDataParallel
|
||||||
|
from pytorch_lightning.core.step_result import Result
|
||||||
|
|
||||||
|
|
||||||
def _find_tensors(obj): # pragma: no-cover
|
def _find_tensors(obj): # pragma: no-cover
|
||||||
|
@ -63,7 +64,34 @@ class LightningDataParallel(DataParallel):
|
||||||
|
|
||||||
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
|
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
|
||||||
outputs = self.parallel_apply(replicas, inputs, kwargs)
|
outputs = self.parallel_apply(replicas, inputs, kwargs)
|
||||||
return self.gather(outputs, self.output_device)
|
|
||||||
|
if isinstance(outputs[0], Result):
|
||||||
|
outputs = self.__gather_structured_result(outputs)
|
||||||
|
else:
|
||||||
|
outputs = self.gather(outputs, self.output_device)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
def __gather_structured_result(self, outputs):
|
||||||
|
prototype_output = outputs[0]
|
||||||
|
original_class = prototype_output.__class__
|
||||||
|
outputs = [dict(x) for x in outputs]
|
||||||
|
|
||||||
|
# remove all the meta info
|
||||||
|
meta = outputs[0]['meta']
|
||||||
|
for i, output in enumerate(outputs):
|
||||||
|
del output['meta']
|
||||||
|
|
||||||
|
outputs = self.gather(outputs, self.output_device)
|
||||||
|
|
||||||
|
# pass minimize to constructor for TrainResult
|
||||||
|
if 'minimize' in outputs:
|
||||||
|
result = original_class(outputs['minimize'])
|
||||||
|
else:
|
||||||
|
result = original_class()
|
||||||
|
|
||||||
|
result.update(outputs)
|
||||||
|
result['meta'] = meta
|
||||||
|
return result
|
||||||
|
|
||||||
def parallel_apply(self, replicas, inputs, kwargs):
|
def parallel_apply(self, replicas, inputs, kwargs):
|
||||||
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
|
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
|
||||||
|
@ -160,6 +188,8 @@ def parallel_apply(modules, inputs, kwargs_tup=None, devices=None): # pragma: n
|
||||||
if not isinstance(input, (list, tuple)):
|
if not isinstance(input, (list, tuple)):
|
||||||
input = (input,)
|
input = (input,)
|
||||||
|
|
||||||
|
module = module.to(device)
|
||||||
|
|
||||||
# ---------------
|
# ---------------
|
||||||
# CHANGE
|
# CHANGE
|
||||||
if module.training:
|
if module.training:
|
||||||
|
|
|
@ -51,6 +51,36 @@ class TrainerCallbackHookMixin(ABC):
|
||||||
for callback in self.callbacks:
|
for callback in self.callbacks:
|
||||||
callback.on_sanity_check_end(self, self.get_model())
|
callback.on_sanity_check_end(self, self.get_model())
|
||||||
|
|
||||||
|
def on_train_epoch_start(self):
|
||||||
|
"""Called when the epoch begins."""
|
||||||
|
for callback in self.callbacks:
|
||||||
|
callback.on_train_epoch_start(self, self.get_model())
|
||||||
|
|
||||||
|
def on_train_epoch_end(self):
|
||||||
|
"""Called when the epoch ends."""
|
||||||
|
for callback in self.callbacks:
|
||||||
|
callback.on_train_epoch_end(self, self.get_model())
|
||||||
|
|
||||||
|
def on_validation_epoch_start(self):
|
||||||
|
"""Called when the epoch begins."""
|
||||||
|
for callback in self.callbacks:
|
||||||
|
callback.on_validation_epoch_start(self, self.get_model())
|
||||||
|
|
||||||
|
def on_validation_epoch_end(self):
|
||||||
|
"""Called when the epoch ends."""
|
||||||
|
for callback in self.callbacks:
|
||||||
|
callback.on_validation_epoch_end(self, self.get_model())
|
||||||
|
|
||||||
|
def on_test_epoch_start(self):
|
||||||
|
"""Called when the epoch begins."""
|
||||||
|
for callback in self.callbacks:
|
||||||
|
callback.on_test_epoch_start(self, self.get_model())
|
||||||
|
|
||||||
|
def on_test_epoch_end(self):
|
||||||
|
"""Called when the epoch ends."""
|
||||||
|
for callback in self.callbacks:
|
||||||
|
callback.on_test_epoch_end(self, self.get_model())
|
||||||
|
|
||||||
def on_epoch_start(self):
|
def on_epoch_start(self):
|
||||||
"""Called when the epoch begins."""
|
"""Called when the epoch begins."""
|
||||||
for callback in self.callbacks:
|
for callback in self.callbacks:
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import os
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import Union, Iterable
|
from typing import Union, Iterable
|
||||||
|
|
||||||
|
@ -73,6 +74,8 @@ class TrainerLoggingMixin(ABC):
|
||||||
self.logger.agg_and_log_metrics(scalar_metrics, step=step)
|
self.logger.agg_and_log_metrics(scalar_metrics, step=step)
|
||||||
self.logger.save()
|
self.logger.save()
|
||||||
|
|
||||||
|
self.dev_debugger.track_logged_metrics_history(scalar_metrics)
|
||||||
|
|
||||||
def add_progress_bar_metrics(self, metrics):
|
def add_progress_bar_metrics(self, metrics):
|
||||||
for k, v in metrics.items():
|
for k, v in metrics.items():
|
||||||
if isinstance(v, torch.Tensor):
|
if isinstance(v, torch.Tensor):
|
||||||
|
@ -80,6 +83,8 @@ class TrainerLoggingMixin(ABC):
|
||||||
|
|
||||||
self.progress_bar_metrics[k] = v
|
self.progress_bar_metrics[k] = v
|
||||||
|
|
||||||
|
self.dev_debugger.track_pbar_metrics_history(metrics)
|
||||||
|
|
||||||
def metrics_to_scalars(self, metrics):
|
def metrics_to_scalars(self, metrics):
|
||||||
new_metrics = {}
|
new_metrics = {}
|
||||||
for k, v in metrics.items():
|
for k, v in metrics.items():
|
||||||
|
|
|
@ -76,3 +76,17 @@ class TensorRunningAccum(object):
|
||||||
return getattr(self.memory, how)()
|
return getattr(self.memory, how)()
|
||||||
else:
|
else:
|
||||||
return getattr(self.memory[:self.current_idx], how)()
|
return getattr(self.memory[:self.current_idx], how)()
|
||||||
|
|
||||||
|
|
||||||
|
class Accumulator(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.num_values = 0
|
||||||
|
self.total = 0
|
||||||
|
|
||||||
|
def accumulate(self, x):
|
||||||
|
with torch.no_grad():
|
||||||
|
self.total += x
|
||||||
|
self.num_values += 1
|
||||||
|
|
||||||
|
def mean(self):
|
||||||
|
return self.total / self.num_values
|
||||||
|
|
|
@ -33,6 +33,7 @@ from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
|
||||||
from pytorch_lightning.trainer.lr_finder import TrainerLRFinderMixin
|
from pytorch_lightning.trainer.lr_finder import TrainerLRFinderMixin
|
||||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||||
from pytorch_lightning.utilities import rank_zero_warn, parsing, rank_zero_info, rank_zero_only
|
from pytorch_lightning.utilities import rank_zero_warn, parsing, rank_zero_info, rank_zero_only
|
||||||
|
from pytorch_lightning.utilities.debugging import InternalDebugger
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
# warnings to ignore in trainer
|
# warnings to ignore in trainer
|
||||||
|
@ -616,6 +617,9 @@ class Trainer(
|
||||||
|
|
||||||
self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE')
|
self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE')
|
||||||
|
|
||||||
|
# tracks internal state for debugging
|
||||||
|
self.dev_debugger = InternalDebugger(self)
|
||||||
|
|
||||||
# Callback system
|
# Callback system
|
||||||
self.on_init_end()
|
self.on_init_end()
|
||||||
|
|
||||||
|
|
|
@ -143,7 +143,7 @@ in your model.
|
||||||
trainer = Trainer(terminate_on_nan=True)
|
trainer = Trainer(terminate_on_nan=True)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
import os
|
||||||
import subprocess
|
import subprocess
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
|
@ -153,17 +153,19 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
import torch.distributed as torch_distrib
|
import torch.distributed as torch_distrib
|
||||||
|
from copy import copy
|
||||||
|
|
||||||
from pytorch_lightning import _logger as log
|
from pytorch_lightning import _logger as log
|
||||||
from pytorch_lightning.callbacks.base import Callback
|
from pytorch_lightning.callbacks.base import Callback
|
||||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||||
from pytorch_lightning.core.lightning import LightningModule
|
from pytorch_lightning.core.lightning import LightningModule
|
||||||
from pytorch_lightning.loggers import LightningLoggerBase
|
from pytorch_lightning.loggers import LightningLoggerBase
|
||||||
from pytorch_lightning.trainer.supporters import TensorRunningAccum
|
from pytorch_lightning.trainer.supporters import TensorRunningAccum, Accumulator
|
||||||
from pytorch_lightning.utilities import rank_zero_warn, NATIVE_AMP_AVALAIBLE
|
from pytorch_lightning.utilities import rank_zero_warn, NATIVE_AMP_AVALAIBLE
|
||||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||||
from pytorch_lightning.utilities.parsing import AttributeDict
|
from pytorch_lightning.utilities.parsing import AttributeDict
|
||||||
from pytorch_lightning.utilities.memory import recursive_detach
|
from pytorch_lightning.utilities.memory import recursive_detach
|
||||||
|
from pytorch_lightning.core.step_result import EvalResult, TrainResult, Result
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from apex import amp
|
from apex import amp
|
||||||
|
@ -251,6 +253,8 @@ class TrainerTrainLoopMixin(ABC):
|
||||||
on_epoch_end: Callable
|
on_epoch_end: Callable
|
||||||
on_validation_end: Callable
|
on_validation_end: Callable
|
||||||
on_keyboard_interrupt: Callable
|
on_keyboard_interrupt: Callable
|
||||||
|
on_train_epoch_start: Callable
|
||||||
|
on_train_epoch_end: Callable
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_model(self) -> LightningModule:
|
def get_model(self) -> LightningModule:
|
||||||
|
@ -420,6 +424,15 @@ class TrainerTrainLoopMixin(ABC):
|
||||||
if self.is_function_implemented('on_epoch_start'):
|
if self.is_function_implemented('on_epoch_start'):
|
||||||
model.on_epoch_start()
|
model.on_epoch_start()
|
||||||
|
|
||||||
|
# Epoch start events
|
||||||
|
with self.profiler.profile('on_train_epoch_start'):
|
||||||
|
# callbacks
|
||||||
|
self.on_train_epoch_start()
|
||||||
|
|
||||||
|
# model hooks
|
||||||
|
if self.is_function_implemented('on_train_epoch_start'):
|
||||||
|
model.on_train_epoch_start()
|
||||||
|
|
||||||
def run_training_epoch(self):
|
def run_training_epoch(self):
|
||||||
|
|
||||||
# get model
|
# get model
|
||||||
|
@ -435,6 +448,10 @@ class TrainerTrainLoopMixin(ABC):
|
||||||
epoch_output = []
|
epoch_output = []
|
||||||
should_check_val = False
|
should_check_val = False
|
||||||
|
|
||||||
|
# structured result accumulators for callbacks
|
||||||
|
early_stopping_accumulator = Accumulator()
|
||||||
|
checkpoint_accumulator = Accumulator()
|
||||||
|
|
||||||
# run epoch
|
# run epoch
|
||||||
for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable(
|
for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable(
|
||||||
enumerate(_with_is_last(train_dataloader)), "get_train_batch"
|
enumerate(_with_is_last(train_dataloader)), "get_train_batch"
|
||||||
|
@ -453,7 +470,15 @@ class TrainerTrainLoopMixin(ABC):
|
||||||
|
|
||||||
# only track outputs when user implements training_epoch_end
|
# only track outputs when user implements training_epoch_end
|
||||||
# otherwise we will build up unnecessary memory
|
# otherwise we will build up unnecessary memory
|
||||||
if self.is_overridden('training_epoch_end', model=self.get_model()):
|
step_out = batch_output.training_step_output_for_epoch_end
|
||||||
|
should_auto_reduce_train_result = isinstance(step_out, Result) and step_out.should_reduce_on_epoch_end
|
||||||
|
if isinstance(step_out, dict) and 'early_stop_on' in step_out:
|
||||||
|
early_stopping_accumulator.accumulate(step_out['early_stop_on'])
|
||||||
|
|
||||||
|
if isinstance(step_out, dict) and 'checkpoint_on' in step_out:
|
||||||
|
checkpoint_accumulator.accumulate(step_out['checkpoint_on'])
|
||||||
|
|
||||||
|
if self.is_overridden('training_epoch_end', model=self.get_model()) or should_auto_reduce_train_result:
|
||||||
epoch_output.append(batch_output.training_step_output_for_epoch_end)
|
epoch_output.append(batch_output.training_step_output_for_epoch_end)
|
||||||
|
|
||||||
# update LR schedulers
|
# update LR schedulers
|
||||||
|
@ -496,7 +521,7 @@ class TrainerTrainLoopMixin(ABC):
|
||||||
self.sync_horovod()
|
self.sync_horovod()
|
||||||
|
|
||||||
# process epoch outputs
|
# process epoch outputs
|
||||||
self.run_training_epoch_end(epoch_output)
|
self.run_training_epoch_end(epoch_output, checkpoint_accumulator, early_stopping_accumulator)
|
||||||
|
|
||||||
# checkpoint callback
|
# checkpoint callback
|
||||||
self.check_checkpoint_callback(should_check_val)
|
self.check_checkpoint_callback(should_check_val)
|
||||||
|
@ -525,23 +550,74 @@ class TrainerTrainLoopMixin(ABC):
|
||||||
if self.is_function_implemented('on_epoch_end'):
|
if self.is_function_implemented('on_epoch_end'):
|
||||||
model.on_epoch_end()
|
model.on_epoch_end()
|
||||||
|
|
||||||
def run_training_epoch_end(self, epoch_output):
|
with self.profiler.profile('on_train_epoch_end'):
|
||||||
|
# callbacks
|
||||||
|
self.on_train_epoch_end()
|
||||||
|
|
||||||
|
# model hooks
|
||||||
|
if self.is_function_implemented('on_train_epoch_end'):
|
||||||
|
model.on_train_epoch_end()
|
||||||
|
|
||||||
|
def run_training_epoch_end(self, epoch_output, checkpoint_accumulator, early_stopping_accumulator):
|
||||||
model = self.get_model()
|
model = self.get_model()
|
||||||
|
is_result_obj = len(epoch_output) > 0 and isinstance(epoch_output[0], Result)
|
||||||
|
|
||||||
|
epoch_log_metrics = {}
|
||||||
|
epoch_callback_metrics = {}
|
||||||
|
epoch_progress_bar_metrics = {}
|
||||||
|
|
||||||
|
# -----------------------
|
||||||
|
# Calculate epoch callback values if given
|
||||||
|
# -----------------------
|
||||||
|
if checkpoint_accumulator.num_values > 0:
|
||||||
|
epoch_callback_metrics['checkpoint_on'] = checkpoint_accumulator.mean()
|
||||||
|
|
||||||
|
if early_stopping_accumulator.num_values > 0:
|
||||||
|
epoch_callback_metrics['early_stop_on'] = early_stopping_accumulator.mean()
|
||||||
|
|
||||||
|
# --------------------------
|
||||||
|
# EPOCH END STEP IF DEFINED
|
||||||
|
# --------------------------
|
||||||
if self.is_overridden('training_epoch_end', model=model):
|
if self.is_overridden('training_epoch_end', model=model):
|
||||||
self.global_step += 1
|
self.global_step += 1
|
||||||
|
|
||||||
|
# remove the protected keys so the user doesn't have to deal with them
|
||||||
|
if is_result_obj:
|
||||||
|
epoch_output = epoch_output[0].__class__.gather(epoch_output)
|
||||||
|
|
||||||
|
# run training_epoch_end
|
||||||
epoch_output = model.training_epoch_end(epoch_output)
|
epoch_output = model.training_epoch_end(epoch_output)
|
||||||
_processed_outputs = self.process_output(epoch_output)
|
|
||||||
log_epoch_metrics = _processed_outputs[2]
|
|
||||||
callback_epoch_metrics = _processed_outputs[3]
|
|
||||||
|
|
||||||
# add the metrics to the loggers
|
if isinstance(epoch_output, Result):
|
||||||
self.log_metrics(log_epoch_metrics, {})
|
epoch_log_metrics = epoch_output.epoch_log_metrics
|
||||||
|
epoch_progress_bar_metrics = epoch_output.epoch_pbar_metrics
|
||||||
|
else:
|
||||||
|
_processed_outputs = self.process_output(epoch_output)
|
||||||
|
epoch_progress_bar_metrics = _processed_outputs[1]
|
||||||
|
epoch_log_metrics = _processed_outputs[2]
|
||||||
|
epoch_callback_metrics = _processed_outputs[3]
|
||||||
|
|
||||||
# add metrics to callbacks
|
# --------------------------
|
||||||
self.callback_metrics.update(callback_epoch_metrics)
|
# Structured Result (auto epoch end)
|
||||||
|
# --------------------------
|
||||||
|
elif is_result_obj:
|
||||||
|
epoch_output = epoch_output[0].__class__.reduce_on_epoch_end(epoch_output)
|
||||||
|
epoch_output.minimize = epoch_output.minimize.mean()
|
||||||
|
epoch_log_metrics = epoch_output.epoch_log_metrics
|
||||||
|
epoch_progress_bar_metrics = epoch_output.epoch_pbar_metrics
|
||||||
|
|
||||||
# add metrics to progress_bar
|
# --------------------------
|
||||||
self.add_progress_bar_metrics(_processed_outputs[1])
|
# track results
|
||||||
|
# --------------------------
|
||||||
|
# add the metrics to the loggers
|
||||||
|
if epoch_log_metrics and len(epoch_log_metrics) > 0:
|
||||||
|
self.log_metrics(epoch_log_metrics, {})
|
||||||
|
|
||||||
|
# add metrics to callbacks
|
||||||
|
self.callback_metrics.update(epoch_callback_metrics)
|
||||||
|
|
||||||
|
# add metrics to progress_bar
|
||||||
|
self.add_progress_bar_metrics(epoch_progress_bar_metrics)
|
||||||
|
|
||||||
def sync_horovod(self):
|
def sync_horovod(self):
|
||||||
if self.use_horovod:
|
if self.use_horovod:
|
||||||
|
@ -558,7 +634,10 @@ class TrainerTrainLoopMixin(ABC):
|
||||||
should_log_metrics = batch_idx % self.row_log_interval == 0 or self.should_stop
|
should_log_metrics = batch_idx % self.row_log_interval == 0 or self.should_stop
|
||||||
if should_log_metrics or self.fast_dev_run:
|
if should_log_metrics or self.fast_dev_run:
|
||||||
# logs user requested information to logger
|
# logs user requested information to logger
|
||||||
self.log_metrics(batch_output.batch_log_metrics, batch_output.grad_norm_dic)
|
metrics = batch_output.batch_log_metrics
|
||||||
|
grad_norm_dic = batch_output.grad_norm_dic
|
||||||
|
if len(metrics) > 0 or len(grad_norm_dic) > 0:
|
||||||
|
self.log_metrics(metrics, grad_norm_dic)
|
||||||
|
|
||||||
def save_loggers_in_training_loop(self, batch_idx):
|
def save_loggers_in_training_loop(self, batch_idx):
|
||||||
# when loggers should save to disk
|
# when loggers should save to disk
|
||||||
|
@ -588,6 +667,8 @@ class TrainerTrainLoopMixin(ABC):
|
||||||
# track metrics to log
|
# track metrics to log
|
||||||
batch_log_metrics = []
|
batch_log_metrics = []
|
||||||
|
|
||||||
|
using_results_obj = False
|
||||||
|
|
||||||
if batch is None:
|
if batch is None:
|
||||||
return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic)
|
return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic)
|
||||||
|
|
||||||
|
@ -622,7 +703,7 @@ class TrainerTrainLoopMixin(ABC):
|
||||||
param.requires_grad = True
|
param.requires_grad = True
|
||||||
|
|
||||||
# -------------------
|
# -------------------
|
||||||
# calculate loss
|
# calculate loss (train step + train step end)
|
||||||
# -------------------
|
# -------------------
|
||||||
opt_closure_result = self.optimizer_closure(
|
opt_closure_result = self.optimizer_closure(
|
||||||
split_batch,
|
split_batch,
|
||||||
|
@ -631,14 +712,26 @@ class TrainerTrainLoopMixin(ABC):
|
||||||
optimizer,
|
optimizer,
|
||||||
self.hiddens
|
self.hiddens
|
||||||
)
|
)
|
||||||
|
using_results_obj = isinstance(opt_closure_result.training_step_output, Result)
|
||||||
|
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
# POST forward bookkeeping
|
# POST forward bookkeeping
|
||||||
# ------------------------------
|
# ------------------------------
|
||||||
batch_callback_metrics.append(opt_closure_result.training_step_output.callback_metrics)
|
batch_callback_metrics.append(opt_closure_result.training_step_output.callback_metrics)
|
||||||
batch_log_metrics.append(opt_closure_result.training_step_output.log_metrics)
|
|
||||||
|
|
||||||
self.add_progress_bar_metrics(opt_closure_result.training_step_output.pbar_on_batch_end)
|
# add metrics to loggers
|
||||||
|
if using_results_obj:
|
||||||
|
metrics_to_log = opt_closure_result.training_step_output.batch_log_metrics
|
||||||
|
else:
|
||||||
|
metrics_to_log = opt_closure_result.training_step_output.log_metrics
|
||||||
|
batch_log_metrics.append(metrics_to_log)
|
||||||
|
|
||||||
|
# add metrics to progress bar
|
||||||
|
if using_results_obj:
|
||||||
|
metrics_for_pbar = opt_closure_result.training_step_output.batch_pbar_metrics
|
||||||
|
else:
|
||||||
|
metrics_for_pbar = opt_closure_result.training_step_output.pbar_on_batch_end
|
||||||
|
self.add_progress_bar_metrics(metrics_for_pbar)
|
||||||
|
|
||||||
# track hiddens
|
# track hiddens
|
||||||
self.hiddens = opt_closure_result.hiddens
|
self.hiddens = opt_closure_result.hiddens
|
||||||
|
@ -677,7 +770,8 @@ class TrainerTrainLoopMixin(ABC):
|
||||||
batch_log_metrics = {k: v for d in batch_log_metrics for k, v in d.items()}
|
batch_log_metrics = {k: v for d in batch_log_metrics for k, v in d.items()}
|
||||||
|
|
||||||
# track all metrics for callbacks
|
# track all metrics for callbacks
|
||||||
self.callback_metrics.update({k: v for d in batch_callback_metrics for k, v in d.items()})
|
if not using_results_obj:
|
||||||
|
self.callback_metrics.update({k: v for d in batch_callback_metrics for k, v in d.items()})
|
||||||
|
|
||||||
result = AttributeDict(
|
result = AttributeDict(
|
||||||
signal=0,
|
signal=0,
|
||||||
|
@ -764,7 +858,7 @@ class TrainerTrainLoopMixin(ABC):
|
||||||
wrap the forward step in a closure so second order methods work
|
wrap the forward step in a closure so second order methods work
|
||||||
"""
|
"""
|
||||||
# ---------------------------
|
# ---------------------------
|
||||||
# FORWARD
|
# FORWARD (TRAINING STEP + TRAIN STEP END)
|
||||||
# ---------------------------
|
# ---------------------------
|
||||||
with self.profiler.profile('model_forward'):
|
with self.profiler.profile('model_forward'):
|
||||||
if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu:
|
if self.use_amp and NATIVE_AMP_AVALAIBLE and not self.use_tpu:
|
||||||
|
@ -780,26 +874,38 @@ class TrainerTrainLoopMixin(ABC):
|
||||||
# ----------------------------
|
# ----------------------------
|
||||||
# format and reduce outputs accordingly
|
# format and reduce outputs accordingly
|
||||||
training_step_output_for_epoch_end = training_step_output
|
training_step_output_for_epoch_end = training_step_output
|
||||||
training_step_output = self.process_output(training_step_output, train=True)
|
is_result_obj = isinstance(training_step_output, Result)
|
||||||
|
|
||||||
# TODO: temporary part of structured results PR
|
# don't allow EvalResult in the training_step
|
||||||
training_step_output = AttributeDict(
|
if isinstance(training_step_output, EvalResult):
|
||||||
batch_loss=training_step_output[0],
|
raise MisconfigurationException('training_step cannot return EvalResult, '
|
||||||
pbar_on_batch_end=training_step_output[1],
|
'use a dict or TrainResult instead')
|
||||||
log_metrics=training_step_output[2],
|
|
||||||
callback_metrics=training_step_output[3],
|
# handle regular dicts
|
||||||
hiddens=training_step_output[4],
|
if not is_result_obj:
|
||||||
)
|
training_step_output = self.process_output(training_step_output, train=True)
|
||||||
|
|
||||||
|
training_step_output = AttributeDict(
|
||||||
|
batch_loss=training_step_output[0],
|
||||||
|
pbar_on_batch_end=training_step_output[1],
|
||||||
|
log_metrics=training_step_output[2],
|
||||||
|
callback_metrics=training_step_output[3],
|
||||||
|
hiddens=training_step_output[4],
|
||||||
|
)
|
||||||
|
|
||||||
# if the user decides to finally reduce things in epoch_end, save raw output without graphs
|
# if the user decides to finally reduce things in epoch_end, save raw output without graphs
|
||||||
if isinstance(training_step_output_for_epoch_end, torch.Tensor):
|
if isinstance(training_step_output_for_epoch_end, torch.Tensor):
|
||||||
training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach()
|
training_step_output_for_epoch_end = training_step_output_for_epoch_end.detach()
|
||||||
|
elif is_result_obj:
|
||||||
|
training_step_output_for_epoch_end = copy(training_step_output)
|
||||||
|
training_step_output_for_epoch_end.detach()
|
||||||
else:
|
else:
|
||||||
training_step_output_for_epoch_end = recursive_detach(training_step_output_for_epoch_end)
|
training_step_output_for_epoch_end = recursive_detach(training_step_output_for_epoch_end)
|
||||||
|
|
||||||
# accumulate loss
|
# accumulate loss
|
||||||
# (if accumulate_grad_batches = 1 no effect)
|
# (if accumulate_grad_batches = 1 no effect)
|
||||||
closure_loss = training_step_output.batch_loss / self.accumulate_grad_batches
|
closure_loss = training_step_output.minimize if is_result_obj else training_step_output.batch_loss
|
||||||
|
closure_loss = closure_loss / self.accumulate_grad_batches
|
||||||
|
|
||||||
# the loss will get scaled for amp. avoid any modifications to it
|
# the loss will get scaled for amp. avoid any modifications to it
|
||||||
untouched_loss = closure_loss.detach().clone()
|
untouched_loss = closure_loss.detach().clone()
|
||||||
|
@ -829,7 +935,11 @@ class TrainerTrainLoopMixin(ABC):
|
||||||
|
|
||||||
# once backward has been applied, release graph
|
# once backward has been applied, release graph
|
||||||
closure_loss = closure_loss.detach()
|
closure_loss = closure_loss.detach()
|
||||||
training_step_output.batch_loss = training_step_output.batch_loss.detach()
|
|
||||||
|
if is_result_obj:
|
||||||
|
training_step_output.detach()
|
||||||
|
else:
|
||||||
|
training_step_output.batch_loss = training_step_output.batch_loss.detach()
|
||||||
|
|
||||||
if self.use_horovod:
|
if self.use_horovod:
|
||||||
# Synchronize Horovod to ensure gradient manipulations (e.g., loss scaling) are valid
|
# Synchronize Horovod to ensure gradient manipulations (e.g., loss scaling) are valid
|
||||||
|
@ -841,6 +951,9 @@ class TrainerTrainLoopMixin(ABC):
|
||||||
with self.profiler.profile('on_after_backward'):
|
with self.profiler.profile('on_after_backward'):
|
||||||
model_ref.on_after_backward()
|
model_ref.on_after_backward()
|
||||||
|
|
||||||
|
# when in dev debugging track the losses
|
||||||
|
self.dev_debugger.track_train_loss_history(batch_idx, untouched_loss.detach())
|
||||||
|
|
||||||
result = AttributeDict(
|
result = AttributeDict(
|
||||||
loss=untouched_loss,
|
loss=untouched_loss,
|
||||||
training_step_output=training_step_output,
|
training_step_output=training_step_output,
|
||||||
|
@ -963,6 +1076,7 @@ class TrainerTrainLoopMixin(ABC):
|
||||||
if self.is_overridden('training_step_end'):
|
if self.is_overridden('training_step_end'):
|
||||||
model_ref = self.get_model()
|
model_ref = self.get_model()
|
||||||
with self.profiler.profile('training_step_end'):
|
with self.profiler.profile('training_step_end'):
|
||||||
|
# TODO: modify when using result obj
|
||||||
output = model_ref.training_step_end(output)
|
output = model_ref.training_step_end(output)
|
||||||
|
|
||||||
# allow any mode to define training_end
|
# allow any mode to define training_end
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class InternalDebugger(object):
|
||||||
|
|
||||||
|
def __init__(self, trainer):
|
||||||
|
|
||||||
|
self.enabled = 'PL_DEV_DEBUG' in os.environ
|
||||||
|
self.trainer = trainer
|
||||||
|
self.logged_metrics = []
|
||||||
|
self.pbar_added_metrics = []
|
||||||
|
self.saved_losses = []
|
||||||
|
self.early_stopping_history = []
|
||||||
|
self.checkpoint_callback_history = []
|
||||||
|
|
||||||
|
def track_logged_metrics_history(self, scalar_metrics):
|
||||||
|
if self.enabled:
|
||||||
|
scalar_metrics['global_step'] = self.trainer.global_step
|
||||||
|
self.logged_metrics.append(scalar_metrics)
|
||||||
|
|
||||||
|
def track_train_loss_history(self, batch_idx, loss):
|
||||||
|
if self.enabled:
|
||||||
|
loss_dict = {'batch_idx': batch_idx, 'epoch': self.trainer.current_epoch, 'loss': loss.detach()}
|
||||||
|
self.saved_losses.append(loss_dict)
|
||||||
|
|
||||||
|
def track_pbar_metrics_history(self, metrics):
|
||||||
|
if self.enabled:
|
||||||
|
metrics['debug_epoch'] = self.trainer.current_epoch
|
||||||
|
self.pbar_added_metrics.append(metrics)
|
||||||
|
|
||||||
|
def track_early_stopping_history(self, current):
|
||||||
|
if self.enabled:
|
||||||
|
es = self.trainer.early_stop_callback
|
||||||
|
debug_dict = {
|
||||||
|
'epoch': self.trainer.current_epoch,
|
||||||
|
'global_step': self.trainer.global_step,
|
||||||
|
'rank': self.trainer.global_rank,
|
||||||
|
'current': current,
|
||||||
|
'best': es.best_score,
|
||||||
|
'patience': es.wait_count
|
||||||
|
}
|
||||||
|
self.early_stopping_history.append(debug_dict)
|
||||||
|
|
||||||
|
def track_checkpointing_history(self, filepath):
|
||||||
|
if self.enabled:
|
||||||
|
cb = self.trainer.checkpoint_callback
|
||||||
|
debug_dict = {
|
||||||
|
'epoch': self.trainer.current_epoch,
|
||||||
|
'global_step': self.trainer.global_step,
|
||||||
|
'monitor': cb.monitor,
|
||||||
|
'rank': self.trainer.global_rank,
|
||||||
|
'filepath': filepath
|
||||||
|
}
|
||||||
|
self.checkpoint_callback_history.append(debug_dict)
|
|
@ -1,5 +1,6 @@
|
||||||
import inspect
|
import inspect
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
|
||||||
def str_to_bool(val):
|
def str_to_bool(val):
|
||||||
|
@ -93,7 +94,7 @@ def collect_init_args(frame, path_args: list, inside: bool = False) -> list:
|
||||||
return path_args
|
return path_args
|
||||||
|
|
||||||
|
|
||||||
class AttributeDict(dict):
|
class AttributeDict(Dict):
|
||||||
"""Extended dictionary accesisable with dot notation.
|
"""Extended dictionary accesisable with dot notation.
|
||||||
|
|
||||||
>>> ad = AttributeDict({'key1': 1, 'key2': 'abc'})
|
>>> ad = AttributeDict({'key1': 1, 'key2': 'abc'})
|
||||||
|
|
|
@ -2,6 +2,7 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils.data import Dataset, DataLoader
|
from torch.utils.data import Dataset, DataLoader
|
||||||
|
from pytorch_lightning import TrainResult
|
||||||
|
|
||||||
from pytorch_lightning.core.lightning import LightningModule
|
from pytorch_lightning.core.lightning import LightningModule
|
||||||
|
|
||||||
|
@ -19,6 +20,8 @@ class DeterministicModel(LightningModule):
|
||||||
self.validation_step_end_called = False
|
self.validation_step_end_called = False
|
||||||
self.validation_epoch_end_called = False
|
self.validation_epoch_end_called = False
|
||||||
|
|
||||||
|
self.assert_backward = True
|
||||||
|
|
||||||
self.l1 = nn.Linear(2, 3, bias=False)
|
self.l1 = nn.Linear(2, 3, bias=False)
|
||||||
if weights is None:
|
if weights is None:
|
||||||
weights = torch.tensor([
|
weights = torch.tensor([
|
||||||
|
@ -33,13 +36,15 @@ class DeterministicModel(LightningModule):
|
||||||
|
|
||||||
def step(self, batch, batch_idx):
|
def step(self, batch, batch_idx):
|
||||||
x = batch
|
x = batch
|
||||||
y_hat = self(x)
|
bs = x.size(0)
|
||||||
|
y_hat = self.l1(x)
|
||||||
|
print(x.device, self.device, self.l1.weight.device)
|
||||||
|
|
||||||
test_hat = y_hat.cpu().detach()
|
test_hat = y_hat.cpu().detach()
|
||||||
assert torch.all(test_hat[:, 0] == 15.0)
|
assert torch.all(test_hat[:, 0] == 15.0)
|
||||||
assert torch.all(test_hat[:, 1] == 42.0)
|
assert torch.all(test_hat[:, 1] == 42.0)
|
||||||
out = y_hat.sum()
|
out = y_hat.sum()
|
||||||
assert out == (42.0 * 3) + (15.0 * 3)
|
assert out == (42.0 * bs) + (15.0 * bs)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@ -97,6 +102,105 @@ class DeterministicModel(LightningModule):
|
||||||
prototype_loss = outputs[0]
|
prototype_loss = outputs[0]
|
||||||
return prototype_loss
|
return prototype_loss
|
||||||
|
|
||||||
|
def training_step_no_default_callbacks_for_train_loop(self, batch, batch_idx):
|
||||||
|
"""
|
||||||
|
Early stop and checkpoint only on these values
|
||||||
|
"""
|
||||||
|
acc = self.step(batch, batch_idx)
|
||||||
|
result = TrainResult(minimize=acc)
|
||||||
|
assert 'early_step_on' not in result
|
||||||
|
assert 'checkpoint_on' in result
|
||||||
|
return result
|
||||||
|
|
||||||
|
def training_step_no_callbacks_result_obj(self, batch, batch_idx):
|
||||||
|
"""
|
||||||
|
Early stop and checkpoint only on these values
|
||||||
|
"""
|
||||||
|
acc = self.step(batch, batch_idx)
|
||||||
|
result = TrainResult(minimize=acc, checkpoint_on=False)
|
||||||
|
assert 'early_step_on' not in result
|
||||||
|
assert 'checkpoint_on' not in result
|
||||||
|
return result
|
||||||
|
|
||||||
|
def training_step_result_log_epoch_and_step_for_callbacks(self, batch, batch_idx):
|
||||||
|
"""
|
||||||
|
Early stop and checkpoint only on these values
|
||||||
|
"""
|
||||||
|
acc = self.step(batch, batch_idx)
|
||||||
|
|
||||||
|
self.assert_backward = False
|
||||||
|
losses = [20, 19, 18, 10, 15, 14, 9, 11, 11, 20]
|
||||||
|
idx = self.current_epoch
|
||||||
|
loss = acc + losses[idx]
|
||||||
|
result = TrainResult(minimize=loss, early_stop_on=loss, checkpoint_on=loss)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def training_step_result_log_step_only(self, batch, batch_idx):
|
||||||
|
acc = self.step(batch, batch_idx)
|
||||||
|
result = TrainResult(minimize=acc)
|
||||||
|
|
||||||
|
# step only metrics
|
||||||
|
result.log(f'step_log_and_pbar_acc1_b{batch_idx}', torch.tensor(11).type_as(acc), prog_bar=True)
|
||||||
|
result.log(f'step_log_acc2_b{batch_idx}', torch.tensor(12).type_as(acc))
|
||||||
|
result.log(f'step_pbar_acc3_b{batch_idx}', torch.tensor(13).type_as(acc), logger=False, prog_bar=True)
|
||||||
|
|
||||||
|
self.training_step_called = True
|
||||||
|
return result
|
||||||
|
|
||||||
|
def training_step_result_log_epoch_only(self, batch, batch_idx):
|
||||||
|
acc = self.step(batch, batch_idx)
|
||||||
|
result = TrainResult(minimize=acc)
|
||||||
|
|
||||||
|
result.log(f'epoch_log_and_pbar_acc1_e{self.current_epoch}', torch.tensor(14).type_as(acc),
|
||||||
|
on_epoch=True, prog_bar=True, on_step=False)
|
||||||
|
result.log(f'epoch_log_acc2_e{self.current_epoch}', torch.tensor(15).type_as(acc),
|
||||||
|
on_epoch=True, on_step=False)
|
||||||
|
result.log(f'epoch_pbar_acc3_e{self.current_epoch}', torch.tensor(16).type_as(acc),
|
||||||
|
on_epoch=True, logger=False, prog_bar=True, on_step=False)
|
||||||
|
|
||||||
|
self.training_step_called = True
|
||||||
|
return result
|
||||||
|
|
||||||
|
def training_step_result_log_epoch_and_step(self, batch, batch_idx):
|
||||||
|
acc = self.step(batch, batch_idx)
|
||||||
|
result = TrainResult(minimize=acc)
|
||||||
|
|
||||||
|
val_1 = (5 + batch_idx) * (self.current_epoch + 1)
|
||||||
|
val_2 = (6 + batch_idx) * (self.current_epoch + 1)
|
||||||
|
val_3 = (7 + batch_idx) * (self.current_epoch + 1)
|
||||||
|
result.log(f'step_epoch_log_and_pbar_acc1', torch.tensor(val_1).type_as(acc),
|
||||||
|
on_epoch=True, prog_bar=True)
|
||||||
|
result.log(f'step_epoch_log_acc2', torch.tensor(val_2).type_as(acc),
|
||||||
|
on_epoch=True)
|
||||||
|
result.log(f'step_epoch_pbar_acc3', torch.tensor(val_3).type_as(acc),
|
||||||
|
on_epoch=True, logger=False, prog_bar=True)
|
||||||
|
|
||||||
|
self.training_step_called = True
|
||||||
|
return result
|
||||||
|
|
||||||
|
def training_epoch_end_return_for_log_epoch_and_step(self, result):
|
||||||
|
"""
|
||||||
|
There should be an array of scalars without graphs that are all 171 (4 of them)
|
||||||
|
"""
|
||||||
|
self.training_epoch_end_called = True
|
||||||
|
|
||||||
|
if self.use_dp or self.use_ddp2:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# only saw 4 batches
|
||||||
|
assert isinstance(result, TrainResult)
|
||||||
|
|
||||||
|
result.step_epoch_log_and_pbar_acc1 = result.step_epoch_log_and_pbar_acc1.prod()
|
||||||
|
result.step_epoch_log_acc2 = result.step_epoch_log_acc2.prod()
|
||||||
|
result.step_epoch_pbar_acc3 = result.step_epoch_pbar_acc3.prod()
|
||||||
|
result.log('epoch_end_log_acc', torch.tensor(1212).type_as(result.step_epoch_log_acc2),
|
||||||
|
logger=True, on_epoch=True)
|
||||||
|
result.log('epoch_end_pbar_acc', torch.tensor(1213).type_as(result.step_epoch_log_acc2),
|
||||||
|
logger=False, prog_bar=True, on_epoch=True)
|
||||||
|
result.log('epoch_end_log_pbar_acc', torch.tensor(1214).type_as(result.step_epoch_log_acc2),
|
||||||
|
logger=True, prog_bar=True, on_epoch=True)
|
||||||
|
return result
|
||||||
|
|
||||||
# --------------------------
|
# --------------------------
|
||||||
# dictionary returns
|
# dictionary returns
|
||||||
# --------------------------
|
# --------------------------
|
||||||
|
@ -231,10 +335,12 @@ class DeterministicModel(LightningModule):
|
||||||
return torch.optim.Adam(self.parameters(), lr=0)
|
return torch.optim.Adam(self.parameters(), lr=0)
|
||||||
|
|
||||||
def backward(self, trainer, loss, optimizer, optimizer_idx):
|
def backward(self, trainer, loss, optimizer, optimizer_idx):
|
||||||
if self.trainer.precision == 16:
|
if self.assert_backward:
|
||||||
assert loss > 171 * 1000
|
if self.trainer.precision == 16:
|
||||||
else:
|
assert loss > 171 * 1000
|
||||||
assert loss == 171.0
|
else:
|
||||||
|
assert loss == 171.0
|
||||||
|
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -63,6 +63,9 @@ class EvalModelTemplate(
|
||||||
self.hidden_dim = hidden_dim
|
self.hidden_dim = hidden_dim
|
||||||
self.b1 = b1
|
self.b1 = b1
|
||||||
self.b2 = b2
|
self.b2 = b2
|
||||||
|
self.training_step_called = False
|
||||||
|
self.training_step_end_called = False
|
||||||
|
self.training_epoch_end_called = False
|
||||||
|
|
||||||
# if you specify an example input, the summary will show input/output for each layer
|
# if you specify an example input, the summary will show input/output for each layer
|
||||||
# TODO: to be fixed in #1773
|
# TODO: to be fixed in #1773
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import math
|
import math
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from pytorch_lightning import TrainResult
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -38,3 +39,35 @@ class TrainingStepVariations(ABC):
|
||||||
else:
|
else:
|
||||||
output /= 0
|
output /= 0
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def training_step_full_loop_result_obj_dp(self, batch, batch_idx, optimizer_idx=None):
|
||||||
|
"""
|
||||||
|
Full loop flow train step (result obj + dp)
|
||||||
|
"""
|
||||||
|
x, y = batch
|
||||||
|
x = x.view(x.size(0), -1)
|
||||||
|
y_hat = self(x.to(self.device))
|
||||||
|
loss_val = y_hat.sum()
|
||||||
|
result = TrainResult(minimize=loss_val)
|
||||||
|
result.log('train_step_metric', loss_val + 1)
|
||||||
|
self.training_step_called = True
|
||||||
|
return result
|
||||||
|
|
||||||
|
def training_step_end_full_loop_result_obj_dp(self, result):
|
||||||
|
"""
|
||||||
|
Full loop flow train step (result obj + dp)
|
||||||
|
"""
|
||||||
|
result.minimize = result.minimize.mean()
|
||||||
|
result.checkpoint_on = result.checkpoint_on.mean()
|
||||||
|
result.train_step_metric = result.train_step_metric.mean()
|
||||||
|
result.log('train_step_end_metric', 1)
|
||||||
|
self.training_step_end_called = True
|
||||||
|
return result
|
||||||
|
|
||||||
|
def training_epoch_end_full_loop_result_obj_dp(self, result):
|
||||||
|
"""
|
||||||
|
Full loop flow train step (result obj + dp)
|
||||||
|
"""
|
||||||
|
result.log('train_epoch_end_metric', 1, on_epoch=True)
|
||||||
|
self.training_epoch_end_called = True
|
||||||
|
return result
|
||||||
|
|
|
@ -35,7 +35,6 @@ class ValidationEpochEndVariations(ABC):
|
||||||
Args:
|
Args:
|
||||||
outputs: list of individual outputs of each validation step
|
outputs: list of individual outputs of each validation step
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# if returned a scalar from validation_step, outputs is a list of tensor scalars
|
# if returned a scalar from validation_step, outputs is a list of tensor scalars
|
||||||
# we return just the average in this case (if we want)
|
# we return just the average in this case (if we want)
|
||||||
def _mean(res, key):
|
def _mean(res, key):
|
||||||
|
|
|
@ -78,11 +78,11 @@ class ModelCheckpointTestInvocations(ModelCheckpoint):
|
||||||
self.count = 0
|
self.count = 0
|
||||||
self.expected_count = expected_count
|
self.expected_count = expected_count
|
||||||
|
|
||||||
def _save_model(self, filepath):
|
def _save_model(self, filepath, trainer, pl_module):
|
||||||
# make sure we don't save twice
|
# make sure we don't save twice
|
||||||
assert not os.path.isfile(filepath)
|
assert not os.path.isfile(filepath)
|
||||||
self.count += 1
|
self.count += 1
|
||||||
super()._save_model(filepath)
|
super()._save_model(filepath, trainer, pl_module)
|
||||||
|
|
||||||
def on_train_end(self, trainer, pl_module):
|
def on_train_end(self, trainer, pl_module):
|
||||||
super().on_train_end(trainer, pl_module)
|
super().on_train_end(trainer, pl_module)
|
||||||
|
|
|
@ -1,43 +1,12 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
import os
|
||||||
|
|
||||||
from pytorch_lightning import Trainer
|
from pytorch_lightning import Trainer
|
||||||
from pytorch_lightning.loggers import LightningLoggerBase
|
|
||||||
from pytorch_lightning.utilities import rank_zero_only
|
|
||||||
from tests.base import EvalModelTemplate
|
from tests.base import EvalModelTemplate
|
||||||
from tests.base.develop_utils import reset_seed
|
from tests.base.develop_utils import reset_seed
|
||||||
|
|
||||||
|
|
||||||
class OnlyMetricsListLogger(LightningLoggerBase):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.metrics = []
|
|
||||||
|
|
||||||
@rank_zero_only
|
|
||||||
def log_metrics(self, metrics, step):
|
|
||||||
self.metrics.append(metrics)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def experiment(self):
|
|
||||||
return 'test'
|
|
||||||
|
|
||||||
@rank_zero_only
|
|
||||||
def log_hyperparams(self, params):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@rank_zero_only
|
|
||||||
def finalize(self, status):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self):
|
|
||||||
return 'name'
|
|
||||||
|
|
||||||
@property
|
|
||||||
def version(self):
|
|
||||||
return '1'
|
|
||||||
|
|
||||||
|
|
||||||
class ModelWithManualGradTracker(EvalModelTemplate):
|
class ModelWithManualGradTracker(EvalModelTemplate):
|
||||||
def __init__(self, norm_type, *args, **kwargs):
|
def __init__(self, norm_type, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
@ -75,28 +44,29 @@ class ModelWithManualGradTracker(EvalModelTemplate):
|
||||||
|
|
||||||
@pytest.mark.parametrize("norm_type", [1., 1.25, 1.5, 2, 3, 5, 10, 'inf'])
|
@pytest.mark.parametrize("norm_type", [1., 1.25, 1.5, 2, 3, 5, 10, 'inf'])
|
||||||
def test_grad_tracking(tmpdir, norm_type, rtol=5e-3):
|
def test_grad_tracking(tmpdir, norm_type, rtol=5e-3):
|
||||||
# rtol=5e-3 respects the 3 decmials rounding in `.grad_norms` and above
|
os.environ['PL_DEV_DEBUG'] = '1'
|
||||||
|
|
||||||
|
# rtol=5e-3 respects the 3 decimals rounding in `.grad_norms` and above
|
||||||
|
|
||||||
reset_seed()
|
reset_seed()
|
||||||
|
|
||||||
# use a custom grad tracking module and a list logger
|
# use a custom grad tracking module and a list logger
|
||||||
model = ModelWithManualGradTracker(norm_type)
|
model = ModelWithManualGradTracker(norm_type)
|
||||||
logger = OnlyMetricsListLogger()
|
|
||||||
|
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
default_root_dir=tmpdir,
|
default_root_dir=tmpdir,
|
||||||
max_epochs=3,
|
max_epochs=3,
|
||||||
logger=logger,
|
|
||||||
track_grad_norm=norm_type,
|
track_grad_norm=norm_type,
|
||||||
row_log_interval=1, # request grad_norms every batch
|
row_log_interval=1, # request grad_norms every batch
|
||||||
)
|
)
|
||||||
result = trainer.fit(model)
|
result = trainer.fit(model)
|
||||||
|
|
||||||
assert result == 1, "Training failed"
|
assert result == 1, "Training failed"
|
||||||
assert len(logger.metrics) == len(model.stored_grad_norms)
|
logged_metrics = trainer.dev_debugger.logged_metrics
|
||||||
|
assert len(logged_metrics) == len(model.stored_grad_norms)
|
||||||
|
|
||||||
# compare the logged metrics against tracked norms on `.backward`
|
# compare the logged metrics against tracked norms on `.backward`
|
||||||
for mod, log in zip(model.stored_grad_norms, logger.metrics):
|
for mod, log in zip(model.stored_grad_norms, logged_metrics):
|
||||||
common = mod.keys() & log.keys()
|
common = mod.keys() & log.keys()
|
||||||
|
|
||||||
log, mod = [log[k] for k in common], [mod[k] for k in common]
|
log, mod = [log[k] for k in common], [mod[k] for k in common]
|
||||||
|
|
|
@ -589,7 +589,7 @@ def test_test_checkpoint_path(tmpdir, ckpt_path, save_top_k):
|
||||||
with pytest.raises(FileNotFoundError):
|
with pytest.raises(FileNotFoundError):
|
||||||
trainer.test(ckpt_path='random.ckpt')
|
trainer.test(ckpt_path='random.ckpt')
|
||||||
else:
|
else:
|
||||||
ckpt_path = str(list((Path(tmpdir) / 'lightning_logs/version_0/checkpoints').iterdir())[0].absolute())
|
ckpt_path = str(list((Path(tmpdir) / f'lightning_logs/version_{trainer.logger.version}/checkpoints').iterdir())[0].absolute())
|
||||||
trainer.test(ckpt_path=ckpt_path)
|
trainer.test(ckpt_path=ckpt_path)
|
||||||
assert trainer.tested_ckpt_path == ckpt_path
|
assert trainer.tested_ckpt_path == ckpt_path
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,518 @@
|
||||||
|
"""
|
||||||
|
Tests to ensure that the training loop works with a dict
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import torch
|
||||||
|
from pytorch_lightning import Trainer
|
||||||
|
from tests.base.deterministic_model import DeterministicModel
|
||||||
|
from pytorch_lightning.core.step_result import Result, TrainResult, EvalResult
|
||||||
|
from tests.base import EvalModelTemplate
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
# test with train_step_end
|
||||||
|
# add logging + row interval tests
|
||||||
|
|
||||||
|
def test_training_step_result_log_step_only(tmpdir):
|
||||||
|
"""
|
||||||
|
Tests that only training_step can be used with TrainResult
|
||||||
|
Makes sure that things are routed to pbar, loggers and loss accordingly
|
||||||
|
|
||||||
|
Makes sure pbar and logs happen on step only when requested
|
||||||
|
"""
|
||||||
|
# enable internal debugging actions
|
||||||
|
os.environ['PL_DEV_DEBUG'] = '1'
|
||||||
|
|
||||||
|
model = DeterministicModel()
|
||||||
|
model.training_step = model.training_step_result_log_step_only
|
||||||
|
model.training_step_end = None
|
||||||
|
model.training_epoch_end = None
|
||||||
|
model.val_dataloader = None
|
||||||
|
|
||||||
|
batches = 3
|
||||||
|
trainer = Trainer(
|
||||||
|
default_root_dir=tmpdir,
|
||||||
|
limit_train_batches=batches,
|
||||||
|
limit_val_batches=batches,
|
||||||
|
row_log_interval=1,
|
||||||
|
max_epochs=1,
|
||||||
|
weights_summary=None,
|
||||||
|
)
|
||||||
|
trainer.fit(model)
|
||||||
|
|
||||||
|
# make sure correct steps were called
|
||||||
|
assert model.training_step_called
|
||||||
|
assert not model.training_step_end_called
|
||||||
|
assert not model.training_epoch_end_called
|
||||||
|
|
||||||
|
# make sure correct metrics are logged (one per batch step as requested)
|
||||||
|
assert len(trainer.dev_debugger.logged_metrics) == batches
|
||||||
|
for batch_idx, logged_metrics in enumerate(trainer.dev_debugger.logged_metrics):
|
||||||
|
assert logged_metrics[f'step_log_and_pbar_acc1_b{batch_idx}'] == 11.0
|
||||||
|
assert logged_metrics[f'step_log_acc2_b{batch_idx}'] == 12.0
|
||||||
|
assert f'step_pbar_acc3_b{batch_idx}' not in logged_metrics
|
||||||
|
assert len(logged_metrics) == 4
|
||||||
|
|
||||||
|
# make sure we are using the correct metrics for callbacks
|
||||||
|
assert trainer.callback_metrics['checkpoint_on'] == 171
|
||||||
|
|
||||||
|
# make sure pbar metrics are correct ang log metrics did not leak
|
||||||
|
for batch_idx in range(batches):
|
||||||
|
assert trainer.progress_bar_metrics[f'step_log_and_pbar_acc1_b{batch_idx}'] == 11
|
||||||
|
assert trainer.progress_bar_metrics[f'step_pbar_acc3_b{batch_idx}'] == 13
|
||||||
|
assert f'step_log_acc2_b{batch_idx}' not in trainer.progress_bar_metrics
|
||||||
|
|
||||||
|
# make sure training outputs what is expected
|
||||||
|
for batch_idx, batch in enumerate(model.train_dataloader()):
|
||||||
|
break
|
||||||
|
|
||||||
|
out = trainer.run_training_batch(batch, batch_idx)
|
||||||
|
assert out.signal == 0
|
||||||
|
assert out.batch_log_metrics[f'step_log_and_pbar_acc1_b{batch_idx}'] == 11.0
|
||||||
|
assert out.batch_log_metrics[f'step_log_acc2_b{batch_idx}'] == 12.0
|
||||||
|
|
||||||
|
train_step_out = out.training_step_output_for_epoch_end
|
||||||
|
assert isinstance(train_step_out, TrainResult)
|
||||||
|
|
||||||
|
assert 'minimize' in train_step_out
|
||||||
|
assert f'step_log_and_pbar_acc1_b{batch_idx}' in train_step_out
|
||||||
|
assert f'step_log_acc2_b{batch_idx}' in train_step_out
|
||||||
|
|
||||||
|
# make sure the optimizer closure returns the correct things
|
||||||
|
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
|
||||||
|
assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_training_step_result_log_epoch_only(tmpdir):
|
||||||
|
"""
|
||||||
|
Tests that only training_step can be used with TrainResult
|
||||||
|
Makes sure that things are routed to pbar, loggers and loss accordingly
|
||||||
|
|
||||||
|
Makes sure pbar and logs happen on epoch only when requested
|
||||||
|
"""
|
||||||
|
# enable internal debugging actions
|
||||||
|
os.environ['PL_DEV_DEBUG'] = '1'
|
||||||
|
|
||||||
|
model = DeterministicModel()
|
||||||
|
model.training_step = model.training_step_result_log_epoch_only
|
||||||
|
model.training_step_end = None
|
||||||
|
model.training_epoch_end = None
|
||||||
|
model.val_dataloader = None
|
||||||
|
|
||||||
|
epochs = 3
|
||||||
|
batches = 2
|
||||||
|
trainer = Trainer(
|
||||||
|
default_root_dir=tmpdir,
|
||||||
|
limit_train_batches=batches,
|
||||||
|
limit_val_batches=batches,
|
||||||
|
row_log_interval=1,
|
||||||
|
max_epochs=epochs,
|
||||||
|
weights_summary=None,
|
||||||
|
)
|
||||||
|
trainer.fit(model)
|
||||||
|
|
||||||
|
# make sure correct steps were called
|
||||||
|
assert model.training_step_called
|
||||||
|
assert not model.training_step_end_called
|
||||||
|
assert not model.training_epoch_end_called
|
||||||
|
|
||||||
|
# make sure correct metrics are logged (one per batch step as requested)
|
||||||
|
assert len(trainer.dev_debugger.logged_metrics) == epochs
|
||||||
|
epoch_metrics = trainer.dev_debugger.logged_metrics
|
||||||
|
assert len(epoch_metrics) == epochs
|
||||||
|
for batch_idx, logged_metrics in enumerate(epoch_metrics):
|
||||||
|
assert logged_metrics[f'epoch_log_and_pbar_acc1_e{batch_idx}'] == 14.0
|
||||||
|
assert logged_metrics[f'epoch_log_acc2_e{batch_idx}'] == 15.0
|
||||||
|
assert f'epoch_pbar_acc3_e{batch_idx}' not in logged_metrics
|
||||||
|
assert len(logged_metrics) == 4
|
||||||
|
|
||||||
|
# make sure we are using the correct metrics for callbacks
|
||||||
|
assert trainer.callback_metrics['checkpoint_on'] == 171
|
||||||
|
|
||||||
|
# make sure pbar metrics are correct ang log metrics did not leak
|
||||||
|
for epoch_idx in range(epochs):
|
||||||
|
assert trainer.progress_bar_metrics[f'epoch_log_and_pbar_acc1_e{epoch_idx}'] == 14
|
||||||
|
assert trainer.progress_bar_metrics[f'epoch_pbar_acc3_e{epoch_idx}'] == 16
|
||||||
|
assert f'epoch_log_acc2_e{epoch_idx}' not in trainer.progress_bar_metrics
|
||||||
|
|
||||||
|
# make sure training outputs what is expected
|
||||||
|
for batch_idx, batch in enumerate(model.train_dataloader()):
|
||||||
|
break
|
||||||
|
|
||||||
|
out = trainer.run_training_batch(batch, batch_idx)
|
||||||
|
assert out.signal == 0
|
||||||
|
assert len(out.batch_log_metrics) == 0
|
||||||
|
|
||||||
|
train_step_out = out.training_step_output_for_epoch_end
|
||||||
|
assert isinstance(train_step_out, TrainResult)
|
||||||
|
|
||||||
|
assert 'minimize' in train_step_out
|
||||||
|
assert f'epoch_log_and_pbar_acc1_e{trainer.current_epoch}' in train_step_out
|
||||||
|
assert f'epoch_log_acc2_e{trainer.current_epoch}' in train_step_out
|
||||||
|
|
||||||
|
# make sure the optimizer closure returns the correct things
|
||||||
|
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
|
||||||
|
assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_training_step_result_log_step_and_epoch(tmpdir):
|
||||||
|
"""
|
||||||
|
Tests that only training_step can be used with TrainResult
|
||||||
|
Makes sure that things are routed to pbar, loggers and loss accordingly
|
||||||
|
|
||||||
|
Makes sure pbar and logs happen on epoch only when requested
|
||||||
|
"""
|
||||||
|
# enable internal debugging actions
|
||||||
|
os.environ['PL_DEV_DEBUG'] = '1'
|
||||||
|
|
||||||
|
model = DeterministicModel()
|
||||||
|
model.training_step = model.training_step_result_log_epoch_and_step
|
||||||
|
model.training_step_end = None
|
||||||
|
model.training_epoch_end = None
|
||||||
|
model.val_dataloader = None
|
||||||
|
|
||||||
|
epochs = 3
|
||||||
|
batches = 2
|
||||||
|
trainer = Trainer(
|
||||||
|
default_root_dir=tmpdir,
|
||||||
|
limit_train_batches=batches,
|
||||||
|
limit_val_batches=batches,
|
||||||
|
row_log_interval=1,
|
||||||
|
max_epochs=epochs,
|
||||||
|
weights_summary=None,
|
||||||
|
)
|
||||||
|
trainer.fit(model)
|
||||||
|
|
||||||
|
# make sure correct steps were called
|
||||||
|
assert model.training_step_called
|
||||||
|
assert not model.training_step_end_called
|
||||||
|
assert not model.training_epoch_end_called
|
||||||
|
|
||||||
|
# make sure correct metrics are logged (one per batch step as requested)
|
||||||
|
assert len(trainer.dev_debugger.logged_metrics) == (epochs * batches) + epochs
|
||||||
|
epoch_metrics = trainer.dev_debugger.logged_metrics
|
||||||
|
epoch_idx = -1
|
||||||
|
for i_start in range(0, len(epoch_metrics), batches + 1):
|
||||||
|
epoch_idx += 1
|
||||||
|
epoch_outputs = epoch_metrics[i_start: i_start + batches + 1]
|
||||||
|
mean_vals = {
|
||||||
|
'step_epoch_log_and_pbar_acc1': [],
|
||||||
|
'step_epoch_log_acc2': []
|
||||||
|
}
|
||||||
|
|
||||||
|
# make sure each batch logged the expected value
|
||||||
|
for batch_idx in range(len(epoch_outputs) - 1):
|
||||||
|
logged_metrics = epoch_outputs[batch_idx]
|
||||||
|
|
||||||
|
expected_val_1 = (5 + batch_idx) * (epoch_idx + 1)
|
||||||
|
expected_val_2 = (6 + batch_idx) * (epoch_idx + 1)
|
||||||
|
mean_vals['step_epoch_log_and_pbar_acc1'].append(torch.tensor(expected_val_1).float())
|
||||||
|
mean_vals['step_epoch_log_acc2'].append(torch.tensor(expected_val_2).float())
|
||||||
|
assert logged_metrics['step_epoch_log_and_pbar_acc1'] == expected_val_1
|
||||||
|
assert logged_metrics['step_epoch_log_acc2'] == expected_val_2
|
||||||
|
assert 'step_epoch_pbar_acc3' not in logged_metrics
|
||||||
|
assert len(logged_metrics) == 4
|
||||||
|
|
||||||
|
# make sure the metrics for the epoch end are actual means (the default reduce fx) or all the batches
|
||||||
|
epoch_end_metrics = epoch_outputs[-1]
|
||||||
|
eval_1 = torch.stack(mean_vals['step_epoch_log_and_pbar_acc1']).mean()
|
||||||
|
eval_2 = torch.stack(mean_vals['step_epoch_log_acc2']).mean()
|
||||||
|
assert epoch_end_metrics['step_epoch_log_and_pbar_acc1'] == eval_1
|
||||||
|
assert epoch_end_metrics['step_epoch_log_acc2'] == eval_2
|
||||||
|
assert 'step_epoch_pbar_acc3' not in epoch_end_metrics
|
||||||
|
assert len(logged_metrics) == 4
|
||||||
|
|
||||||
|
# make sure we are using the correct metrics for callbacks
|
||||||
|
assert trainer.callback_metrics['checkpoint_on'] == 171
|
||||||
|
|
||||||
|
# -------------------------------
|
||||||
|
# VERIFY PBAR METRICS
|
||||||
|
# -------------------------------
|
||||||
|
# make sure pbar metrics are correct ang log metrics did not leak
|
||||||
|
all_pbar_metrics = trainer.dev_debugger.pbar_added_metrics
|
||||||
|
assert len(all_pbar_metrics) == (epochs * batches) + epochs
|
||||||
|
|
||||||
|
epoch_idx = -1
|
||||||
|
for i_start in range(0, len(all_pbar_metrics), batches + 1):
|
||||||
|
epoch_idx += 1
|
||||||
|
epoch_outputs = all_pbar_metrics[i_start: i_start + batches + 1]
|
||||||
|
mean_vals = {
|
||||||
|
'step_epoch_log_and_pbar_acc1': [],
|
||||||
|
'step_epoch_pbar_acc3': []
|
||||||
|
}
|
||||||
|
|
||||||
|
# make sure each batch logged the expected value
|
||||||
|
for batch_idx in range(len(epoch_outputs) - 1):
|
||||||
|
logged_metrics = epoch_outputs[batch_idx]
|
||||||
|
|
||||||
|
expected_val_1 = (5 + batch_idx) * (epoch_idx + 1)
|
||||||
|
expected_val_2 = (7 + batch_idx) * (epoch_idx + 1)
|
||||||
|
mean_vals['step_epoch_log_and_pbar_acc1'].append(torch.tensor(expected_val_1).float())
|
||||||
|
mean_vals['step_epoch_pbar_acc3'].append(torch.tensor(expected_val_2).float())
|
||||||
|
assert logged_metrics['step_epoch_log_and_pbar_acc1'] == expected_val_1
|
||||||
|
assert logged_metrics['step_epoch_pbar_acc3'] == expected_val_2
|
||||||
|
assert 'step_epoch_log_acc2' not in logged_metrics
|
||||||
|
assert len(logged_metrics) == 3
|
||||||
|
|
||||||
|
# make sure the metrics for the epoch end are actual means (the default reduce fx) or all the batches
|
||||||
|
epoch_end_metrics = epoch_outputs[-1]
|
||||||
|
eval_1 = torch.stack(mean_vals['step_epoch_log_and_pbar_acc1']).mean()
|
||||||
|
eval_2 = torch.stack(mean_vals['step_epoch_pbar_acc3']).mean()
|
||||||
|
assert epoch_end_metrics['step_epoch_log_and_pbar_acc1'] == eval_1
|
||||||
|
assert epoch_end_metrics['step_epoch_pbar_acc3'] == eval_2
|
||||||
|
assert 'step_epoch_log_acc2' not in epoch_end_metrics
|
||||||
|
assert len(logged_metrics) == 3
|
||||||
|
|
||||||
|
# -----------------------------------------
|
||||||
|
# make sure training outputs what is expected
|
||||||
|
# -----------------------------------------
|
||||||
|
for batch_idx, batch in enumerate(model.train_dataloader()):
|
||||||
|
break
|
||||||
|
|
||||||
|
out = trainer.run_training_batch(batch, batch_idx)
|
||||||
|
assert out.signal == 0
|
||||||
|
assert len(out.batch_log_metrics) == 2
|
||||||
|
|
||||||
|
train_step_out = out.training_step_output_for_epoch_end
|
||||||
|
assert isinstance(train_step_out, TrainResult)
|
||||||
|
|
||||||
|
assert 'minimize' in train_step_out
|
||||||
|
assert 'step_epoch_log_and_pbar_acc1' in train_step_out
|
||||||
|
assert 'step_epoch_log_acc2' in train_step_out
|
||||||
|
|
||||||
|
# make sure the optimizer closure returns the correct things
|
||||||
|
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
|
||||||
|
assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_training_step_epoch_end_result(tmpdir):
|
||||||
|
"""
|
||||||
|
Makes sure training_step and epoch_end can be used with Results (without batch_end)
|
||||||
|
"""
|
||||||
|
os.environ['PL_DEV_DEBUG'] = '1'
|
||||||
|
|
||||||
|
model = DeterministicModel()
|
||||||
|
model.training_step = model.training_step_result_log_epoch_and_step
|
||||||
|
model.training_epoch_end = model.training_epoch_end_return_for_log_epoch_and_step
|
||||||
|
model.val_dataloader = None
|
||||||
|
|
||||||
|
batches = 3
|
||||||
|
epochs = 1
|
||||||
|
trainer = Trainer(
|
||||||
|
default_root_dir=tmpdir,
|
||||||
|
max_epochs=epochs,
|
||||||
|
row_log_interval=1,
|
||||||
|
limit_train_batches=batches,
|
||||||
|
weights_summary=None,
|
||||||
|
)
|
||||||
|
trainer.fit(model)
|
||||||
|
|
||||||
|
# make sure correct steps were called
|
||||||
|
assert model.training_step_called
|
||||||
|
assert not model.training_step_end_called
|
||||||
|
assert model.training_epoch_end_called
|
||||||
|
|
||||||
|
# make sure correct metrics were logged
|
||||||
|
logged_metrics = trainer.dev_debugger.logged_metrics
|
||||||
|
assert len(logged_metrics) == (epochs * batches) + epochs
|
||||||
|
last_logged = logged_metrics[-1]
|
||||||
|
|
||||||
|
assert last_logged['step_epoch_log_and_pbar_acc1'] == 210.0
|
||||||
|
assert last_logged['step_epoch_log_acc2'] == 336.0
|
||||||
|
assert last_logged['epoch_end_log_acc'] == 1212.0
|
||||||
|
assert last_logged['epoch_end_log_pbar_acc'] == 1214.0
|
||||||
|
assert 'epoch_end_pbar_acc' not in last_logged
|
||||||
|
|
||||||
|
# make sure pbar metrics are correct
|
||||||
|
logged_pbar = trainer.dev_debugger.pbar_added_metrics
|
||||||
|
assert len(logged_pbar) == (epochs * batches) + epochs
|
||||||
|
|
||||||
|
assert trainer.progress_bar_metrics['step_epoch_log_and_pbar_acc1'] == 210.0
|
||||||
|
assert trainer.progress_bar_metrics['step_epoch_pbar_acc3'] == 504.0
|
||||||
|
assert trainer.progress_bar_metrics['epoch_end_pbar_acc'] == 1213.0
|
||||||
|
assert trainer.progress_bar_metrics['epoch_end_log_pbar_acc'] == 1214.0
|
||||||
|
assert 'epoch_end_log_acc' not in trainer.progress_bar_metrics
|
||||||
|
assert 'log_acc2' not in trainer.progress_bar_metrics
|
||||||
|
|
||||||
|
# make sure callback metrics didn't change
|
||||||
|
assert trainer.callback_metrics['checkpoint_on'] == 171
|
||||||
|
|
||||||
|
# -----------------------------------------
|
||||||
|
# make sure training outputs what is expected
|
||||||
|
# -----------------------------------------
|
||||||
|
for batch_idx, batch in enumerate(model.train_dataloader()):
|
||||||
|
break
|
||||||
|
|
||||||
|
out = trainer.run_training_batch(batch, batch_idx)
|
||||||
|
assert out.signal == 0
|
||||||
|
assert len(out.batch_log_metrics) == 2
|
||||||
|
|
||||||
|
train_step_out = out.training_step_output_for_epoch_end
|
||||||
|
assert isinstance(train_step_out, TrainResult)
|
||||||
|
|
||||||
|
assert 'minimize' in train_step_out
|
||||||
|
assert 'step_epoch_log_and_pbar_acc1' in train_step_out
|
||||||
|
assert 'step_epoch_log_acc2' in train_step_out
|
||||||
|
|
||||||
|
# make sure the optimizer closure returns the correct things
|
||||||
|
opt_closure_result = trainer.optimizer_closure(batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens)
|
||||||
|
assert opt_closure_result['loss'] == (42.0 * 3) + (15.0 * 3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_auto_callbacks_with_train_loop_only(tmpdir):
|
||||||
|
"""
|
||||||
|
Make sure early stop + checkpoint work with only a train loop
|
||||||
|
"""
|
||||||
|
os.environ['PL_DEV_DEBUG'] = '1'
|
||||||
|
|
||||||
|
model = DeterministicModel()
|
||||||
|
model.training_step = model.training_step_no_default_callbacks_for_train_loop
|
||||||
|
model.training_epoch_end = None
|
||||||
|
model.val_dataloader = None
|
||||||
|
|
||||||
|
batches = 3
|
||||||
|
epochs = 3
|
||||||
|
trainer = Trainer(
|
||||||
|
default_root_dir=tmpdir,
|
||||||
|
max_epochs=epochs,
|
||||||
|
row_log_interval=1,
|
||||||
|
limit_train_batches=batches,
|
||||||
|
weights_summary=None,
|
||||||
|
)
|
||||||
|
trainer.fit(model)
|
||||||
|
|
||||||
|
all_losses = trainer.dev_debugger.saved_losses
|
||||||
|
assert len(all_losses) == batches * epochs
|
||||||
|
|
||||||
|
assert trainer.checkpoint_callback.monitor == 'checkpoint_on'
|
||||||
|
assert trainer.early_stop_callback is None
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
default_root_dir=tmpdir,
|
||||||
|
early_stop_callback=True,
|
||||||
|
max_epochs=epochs,
|
||||||
|
row_log_interval=1,
|
||||||
|
limit_train_batches=batches,
|
||||||
|
weights_summary=None,
|
||||||
|
)
|
||||||
|
trainer.fit(model)
|
||||||
|
|
||||||
|
assert trainer.early_stop_callback.monitor == 'val_loss'
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_callbacks_with_train_loop_only(tmpdir):
|
||||||
|
"""
|
||||||
|
Make sure early stop + checkpoint work with only a train loop
|
||||||
|
"""
|
||||||
|
os.environ['PL_DEV_DEBUG'] = '1'
|
||||||
|
|
||||||
|
model = DeterministicModel()
|
||||||
|
model.training_step = model.training_step_no_callbacks_result_obj
|
||||||
|
model.training_epoch_end = None
|
||||||
|
model.val_dataloader = None
|
||||||
|
|
||||||
|
batches = 3
|
||||||
|
epochs = 3
|
||||||
|
trainer = Trainer(
|
||||||
|
default_root_dir=tmpdir,
|
||||||
|
max_epochs=epochs,
|
||||||
|
row_log_interval=1,
|
||||||
|
limit_train_batches=batches,
|
||||||
|
weights_summary=None,
|
||||||
|
)
|
||||||
|
trainer.fit(model)
|
||||||
|
|
||||||
|
all_losses = trainer.dev_debugger.saved_losses
|
||||||
|
assert len(all_losses) == batches * epochs
|
||||||
|
|
||||||
|
assert trainer.early_stop_callback is None
|
||||||
|
|
||||||
|
assert len(trainer.dev_debugger.checkpoint_callback_history) == 0
|
||||||
|
assert len(trainer.dev_debugger.early_stopping_history) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_use_callbacks_with_train_loop_only(tmpdir):
|
||||||
|
os.environ['PL_DEV_DEBUG'] = '1'
|
||||||
|
|
||||||
|
model = DeterministicModel()
|
||||||
|
model.training_step = model.training_step_result_log_epoch_and_step_for_callbacks
|
||||||
|
model.training_epoch_end = None
|
||||||
|
model.val_dataloader = None
|
||||||
|
|
||||||
|
batches = 3
|
||||||
|
epochs = 300
|
||||||
|
trainer = Trainer(
|
||||||
|
default_root_dir=tmpdir,
|
||||||
|
max_epochs=epochs,
|
||||||
|
early_stop_callback=True,
|
||||||
|
row_log_interval=1,
|
||||||
|
limit_train_batches=batches,
|
||||||
|
weights_summary=None,
|
||||||
|
)
|
||||||
|
trainer.fit(model)
|
||||||
|
|
||||||
|
num_expected_epochs = 10
|
||||||
|
|
||||||
|
# ----------------------------------
|
||||||
|
# VERIFY EARLY STOPPING BEHAVIOR
|
||||||
|
# ----------------------------------
|
||||||
|
# with train loop only it happens on every epoch
|
||||||
|
early_stop_vals = trainer.dev_debugger.early_stopping_history
|
||||||
|
assert len(early_stop_vals) == num_expected_epochs
|
||||||
|
min_val = min([x['best'] for x in early_stop_vals])
|
||||||
|
assert min_val == 171 + 9
|
||||||
|
all_losses = trainer.dev_debugger.saved_losses
|
||||||
|
|
||||||
|
from collections import Counter
|
||||||
|
batch_idxs = Counter([x['batch_idx'] for x in all_losses])
|
||||||
|
for i, val in batch_idxs.items():
|
||||||
|
assert val == num_expected_epochs
|
||||||
|
assert i in [0, 1, 2]
|
||||||
|
|
||||||
|
# ----------------------------------
|
||||||
|
# VERIFY CHECKPOINTING BEHAVIOR
|
||||||
|
# ----------------------------------
|
||||||
|
ckpt_vals = trainer.dev_debugger.checkpoint_callback_history
|
||||||
|
assert len(ckpt_vals) == 5, '5 ckpts should have been saved'
|
||||||
|
for ckpt_val, expected_epoch in zip(ckpt_vals, [0, 1, 2, 3, 6]):
|
||||||
|
assert ckpt_val['epoch'] == expected_epoch
|
||||||
|
assert ckpt_val['monitor'] == 'checkpoint_on'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||||
|
def test_full_train_loop_with_results_obj_dp(tmpdir):
|
||||||
|
os.environ['PL_DEV_DEBUG'] = '1'
|
||||||
|
|
||||||
|
batches = 10
|
||||||
|
epochs = 3
|
||||||
|
|
||||||
|
model = EvalModelTemplate()
|
||||||
|
model.validation_step = None
|
||||||
|
model.test_step = None
|
||||||
|
model.training_step = model.training_step_full_loop_result_obj_dp
|
||||||
|
model.training_step_end = model.training_step_end_full_loop_result_obj_dp
|
||||||
|
model.training_epoch_end = model.training_epoch_end_full_loop_result_obj_dp
|
||||||
|
model.val_dataloader = None
|
||||||
|
model.test_dataloader = None
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
default_root_dir=tmpdir,
|
||||||
|
distributed_backend='dp',
|
||||||
|
gpus=[0, 1],
|
||||||
|
max_epochs=epochs,
|
||||||
|
early_stop_callback=True,
|
||||||
|
row_log_interval=2,
|
||||||
|
limit_train_batches=batches,
|
||||||
|
weights_summary=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.fit(model)
|
||||||
|
|
||||||
|
# make sure we saw all the correct keys
|
||||||
|
seen_keys = set()
|
||||||
|
for metric in trainer.dev_debugger.logged_metrics:
|
||||||
|
seen_keys.update(metric.keys())
|
||||||
|
|
||||||
|
assert 'train_step_metric' in seen_keys
|
||||||
|
assert 'train_step_end_metric' in seen_keys
|
||||||
|
assert 'train_epoch_end_metric' in seen_keys
|
Loading…
Reference in New Issue