lightning/pytorch_lightning/core/step_result.py

518 lines
15 KiB
Python

from typing import Optional, Dict, Union, Sequence, Callable, MutableMapping, Any
from torch import Tensor
import torch
from copy import copy
class Result(Dict):
def __init__(
self,
minimize: Optional[Tensor] = None,
early_stop_on: Optional[Tensor] = None,
checkpoint_on: Union[Tensor, bool, None] = None,
hiddens: Optional[Tensor] = None,
):
super().__init__()
if early_stop_on is not None:
self.early_stop_on = early_stop_on
if checkpoint_on is not None and checkpoint_on:
self.checkpoint_on = checkpoint_on
if hiddens is not None:
self.hiddens = hiddens
if minimize is not None:
err = 'Minimize can only be used in training_step, training_step_end, training_epoch_end'
self._assert_grad_tensor_metric('minimize', minimize, err)
self.minimize = minimize
if minimize is not None and checkpoint_on is None:
self.checkpoint_on = minimize.detach()
self['meta'] = {
'_internal': {
'_reduce_on_epoch': False
}
}
def __getattr__(self, key: str) -> Any:
try:
if key == 'callback_metrics':
return self.get_callback_metrics()
elif key == 'batch_log_metrics':
return self.get_batch_log_metrics()
elif key == 'batch_pbar_metrics':
return self.get_batch_pbar_metrics()
elif key == 'epoch_log_metrics':
return self.get_epoch_log_metrics()
elif key == 'epoch_pbar_metrics':
return self.get_epoch_pbar_metrics()
else:
return self[key]
except KeyError:
return None
def __setattr__(self, key: str, val: Union[Tensor, Any]):
# ensure reserve keys are tensors and detached
if key in {'hiddens', 'checkpoint_on', 'early_stop_on'}:
self._assert_tensor_metric(key, val)
if val is not None and isinstance(val, torch.Tensor):
val = val.detach()
# ensure anything else that is a tensor is detached
elif isinstance(val, torch.Tensor) and key != 'minimize':
val = val.detach()
self[key] = val
def _assert_tensor_metric(self, name: str, potential_metric: Union[bool, Tensor, None, Any]):
if potential_metric is not None and not isinstance(potential_metric, bool):
assert isinstance(potential_metric, Tensor), f'{name} must be a torch.Tensor'
def _assert_grad_tensor_metric(self, name: str, x: Union[torch.Tensor, Any], additional_err: str = ''):
if x is not None:
assert isinstance(x, Tensor), f'{name} must be a torch.Tensor'
m = f'{name} must have a computational graph.'
if additional_err:
m += f' {additional_err}'
assert x.grad_fn is not None, m
def log(
self,
name: str,
value: Any,
prog_bar: bool = False,
logger: bool = True,
on_step: bool = False,
on_epoch: bool = True,
reduce_fx: Callable = torch.mean,
enable_graph: bool = False,
):
# no metrics should be logged with graphs
if not enable_graph and isinstance(value, torch.Tensor):
value = value.detach()
if 'meta' not in self:
self.__setitem__('meta', {})
# if user requests both step and epoch, then we split the metric in two automatically
# one will be logged per step. the other per epoch
if on_step and on_epoch:
# set step version
step_name = f'step_{name}'
self.__set_meta(step_name, value, prog_bar, logger, on_step=True, on_epoch=False, reduce_fx=reduce_fx)
self.__setitem__(step_name, value)
# set epoch version
epoch_name = f'epoch_{name}'
self.__set_meta(epoch_name, value, prog_bar, logger, on_step=False, on_epoch=True, reduce_fx=reduce_fx)
self.__setitem__(epoch_name, value)
else:
self.__set_meta(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx)
# set the value
self.__setitem__(name, value)
def __set_meta(
self,
name: str,
value: Any,
prog_bar: bool,
logger: bool,
on_step: bool,
on_epoch: bool,
reduce_fx: Callable,
):
# set the meta for the item
meta_value = value
meta = dict(
prog_bar=prog_bar,
logger=logger,
on_step=on_step,
on_epoch=on_epoch,
reduce_fx=reduce_fx,
value=meta_value
)
self['meta'][name] = meta
# track whether any input requires reduction on epoch end
_internal = self['meta']['_internal']
_internal['_reduce_on_epoch'] = max(_internal['_reduce_on_epoch'], on_epoch)
def get_callback_metrics(self) -> dict:
result = {
'early_stop_on': self.early_stop_on,
'checkpoint_on': self.checkpoint_on
}
return result
def get_batch_log_metrics(self) -> dict:
"""
Gets the metrics to log at the end of the batch step
"""
result = {}
meta = self['meta']
for k, options in meta.items():
if k == '_internal':
continue
if options['logger'] and options['on_step']:
result[k] = self[k]
return result
def get_epoch_log_metrics(self) -> dict:
"""
Gets the metrics to log at the end of the batch step
"""
result = {}
meta = self['meta']
for k, options in meta.items():
if k == '_internal':
continue
if options['logger'] and options['on_epoch']:
result[k] = self[k]
return result
def get_epoch_pbar_metrics(self):
"""
Gets the metrics to log at the end of the batch step
"""
result = {}
meta = self['meta']
for k, options in meta.items():
if k == '_internal':
continue
if options['prog_bar'] and options['on_epoch']:
result[k] = self[k]
return result
def get_batch_pbar_metrics(self):
"""
Gets the metrics to log at the end of the batch step
"""
result = {}
meta = self['meta']
for k, options in meta.items():
if k == '_internal':
continue
if options['prog_bar'] and options['on_step']:
result[k] = self[k]
return result
def detach(self):
for k, v in self.items():
if isinstance(v, torch.Tensor):
self.__setitem__(k, v.detach())
def __repr__(self):
self_copy = self.copy()
if 'meta' in self_copy:
del self_copy['meta']
return str(self_copy)
def __str__(self):
copy = self.copy()
del copy['meta']
return str(copy)
def __copy__(self):
newone = type(self)()
for k, v in self.items():
newone[k] = copy(v)
return newone
@classmethod
def gather(cls, outputs):
meta = outputs[0].get('meta')
result = cls()
result = recursive_gather(outputs, result)
recursive_stack(result)
if meta:
result['meta'] = meta
return result
@classmethod
def reduce_on_epoch_end(cls, outputs):
meta = outputs[0]['meta']
result = cls()
result = recursive_gather(outputs, result)
recursive_stack(result)
for k, option in meta.items():
if k == '_internal':
continue
if option['on_epoch']:
fx = option['reduce_fx']
result[k] = fx(result[k])
result['meta'] = meta
return result
@property
def should_reduce_on_epoch_end(self) -> bool:
return self['meta']['_internal']['_reduce_on_epoch']
def recursive_gather(outputs: Sequence[dict], result: Optional[MutableMapping] = None) -> Optional[MutableMapping]:
for out in outputs:
if 'meta' in out:
del out['meta']
for k, v in out.items():
if isinstance(v, dict):
v = recursive_gather([v], result)
if k not in result:
result[k] = []
result[k].append(v)
return result
def recursive_stack(result: MutableMapping):
for k, v in result.items():
if isinstance(v, dict):
recursive_stack(v)
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], torch.Tensor):
v = torch.stack(v)
result[k] = v
class TrainResult(Result):
def __init__(
self,
minimize: Optional[Tensor] = None,
early_stop_on: Tensor = None,
checkpoint_on: Union[Tensor, bool] = None,
hiddens: Optional[Tensor] = None,
):
"""
Used in train loop to auto-log to a logger or progress bar without needing to define
a train_step_end or train_epoch_end method
Example::
def training_step(self, batch, batch_idx):
loss = ...
result = pl.TrainResult(loss)
result.log('train_loss', loss)
return result
# without val/test loop can model checkpoint or early stop
def training_step(self, batch, batch_idx):
loss = ...
result = pl.TrainResult(loss, early_stop_on=loss, checkpoint_on=loss)
result.log('train_loss', loss)
return result
Args:
early_stop_on:
checkpoint_on:
hiddens:
"""
super().__init__(minimize, early_stop_on, checkpoint_on, hiddens)
def log(
self,
name,
value,
prog_bar: bool = False,
logger: bool = True,
on_step: bool = True,
on_epoch: bool = False,
reduce_fx: Callable = torch.mean,
enable_graph: bool = False,
):
"""
Log a key, value
Example::
result.log('train_loss', loss)
# defaults used
result.log(
name,
value,
on_step=True,
on_epoch=False,
logger=True,
prog_bar=False,
reduce_fx=torch.mean,
enable_graph=False
)
Args:
name: key name
value: value name
prog_bar: if True logs to the progress base
logger: if True logs to the logger
on_step: if True logs the output of validation_step or test_step
on_epoch: if True, logs the output of the training loop aggregated
reduce_fx: Torch.mean by default
enable_graph: if True, will not auto detach the graph
"""
super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph)
def log_dict(
self,
dictionary: dict,
prog_bar: bool = False,
logger: bool = True,
on_step: bool = False,
on_epoch: bool = True,
reduce_fx: Callable = torch.mean,
enable_graph: bool = False,
):
"""
Log a dictonary of values at once
Example::
values = {'loss': loss, 'acc': acc, ..., 'metric_n': metric_n}
result.log_dict(values)
Args:
dictionary:
prog_bar:
logger:
on_step:
on_epoch:
reduce_fx:
enable_graph:
"""
for k, v in dictionary.items():
self.log(k, v, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph)
class EvalResult(Result):
def __init__(
self,
early_stop_on: Optional[Tensor] = None,
checkpoint_on: Optional[Tensor] = None,
hiddens: Optional[Tensor] = None,
):
"""
Used in val/train loop to auto-log to a logger or progress bar without needing to define
a _step_end or _epoch_end method
Example::
def validation_step(self, batch, batch_idx):
loss = ...
result = EvalResult()
result.log('val_loss', loss)
return result
def test_step(self, batch, batch_idx):
loss = ...
result = EvalResult()
result.log('val_loss', loss)
return result
Args:
early_stop_on:
checkpoint_on:
hiddens:
"""
super().__init__(None, early_stop_on, checkpoint_on, hiddens)
def log(
self,
name,
value,
prog_bar: bool = False,
logger: bool = True,
on_step: bool = False,
on_epoch: bool = True,
reduce_fx: Callable = torch.mean,
enable_graph: bool = False,
):
"""
Log a key, value
Example::
result.log('val_loss', loss)
# defaults used
result.log(
name,
value,
on_step=False,
on_epoch=True,
logger=True,
prog_bar=False,
reduce_fx=torch.mean
)
Args:
name: key name
value: value name
prog_bar: if True logs to the progress base
logger: if True logs to the logger
on_step: if True logs the output of validation_step or test_step
on_epoch: if True, logs the output of the validation loop or test loop aggregated
reduce_fx: Torch.mean by default
enable_graph: if True, will not auto detach the graph :
"""
super().log(name, value, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph)
def log_dict(
self,
dictionary: dict,
prog_bar: bool = False,
logger: bool = True,
on_step: bool = False,
on_epoch: bool = True,
reduce_fx: Callable = torch.mean,
enable_graph: bool = False,
):
"""
Log a dictonary of values at once
Example::
values = {'loss': loss, 'acc': acc, ..., 'metric_n': metric_n}
result.log_dict(values)
Args:
dictionary:
prog_bar:
logger:
on_step:
on_epoch:
reduce_fx:
enable_graph:
"""
for k, v in dictionary.items():
self.log(k, v, prog_bar, logger, on_step, on_epoch, reduce_fx, enable_graph)
def get_callback_metrics(self) -> dict:
result = {
'val_early_stop_on': self.early_stop_on,
'val_checkpoint_on': self.checkpoint_on
}
return result