957 lines
31 KiB
Python
957 lines
31 KiB
Python
# Copyright The PyTorch Lightning team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""[Train, Eval]Result for easier logging, checkpointing, early stopping, epoch-wise reduction."""
|
|
|
|
import numbers
|
|
import os
|
|
from copy import copy
|
|
from typing import Any, Callable, Dict, Iterable, List, MutableMapping, Optional, Sequence, Tuple, Union
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
|
|
from pytorch_lightning.metrics import Metric
|
|
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
|
|
|
|
|
|
class Result(Dict):
|
|
def __init__(
|
|
self,
|
|
minimize: Optional[Tensor] = None,
|
|
early_stop_on: Optional[Tensor] = None,
|
|
checkpoint_on: Optional[Union[Tensor, bool]] = None,
|
|
hiddens: Optional[Tensor] = None,
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
# temporary until dict results are deprecated
|
|
os.environ['PL_USING_RESULT_OBJ'] = '1'
|
|
|
|
if early_stop_on is not None:
|
|
self.early_stop_on = early_stop_on
|
|
if checkpoint_on is not None and checkpoint_on:
|
|
self.checkpoint_on = checkpoint_on
|
|
if hiddens is not None:
|
|
self.hiddens = hiddens.detach()
|
|
if minimize is not None:
|
|
err = 'Minimize can only be used in training_step, training_step_end, training_epoch_end'
|
|
self._assert_grad_tensor_metric('minimize', minimize, err)
|
|
self.minimize = minimize
|
|
|
|
if minimize is not None and checkpoint_on is None:
|
|
self.checkpoint_on = minimize.detach()
|
|
|
|
self['meta'] = {'_internal': {'_reduce_on_epoch': False, 'batch_sizes': []}}
|
|
|
|
def __getitem__(self, key: Union[str, Any]) -> Any:
|
|
try:
|
|
return super().__getitem__(key)
|
|
except KeyError:
|
|
return super().__getitem__(f'{key}_step')
|
|
|
|
def __getattr__(self, key: str) -> Any:
|
|
try:
|
|
if key == 'callback_metrics':
|
|
return self.get_callback_metrics()
|
|
elif key == 'batch_log_metrics':
|
|
return self.get_batch_log_metrics()
|
|
elif key == 'batch_pbar_metrics':
|
|
return self.get_batch_pbar_metrics()
|
|
elif key == 'epoch_log_metrics':
|
|
return self.get_epoch_log_metrics()
|
|
elif key == 'epoch_pbar_metrics':
|
|
return self.get_epoch_pbar_metrics()
|
|
else:
|
|
return self[key]
|
|
except KeyError:
|
|
return None
|
|
|
|
def __setattr__(self, key: str, val: Union[Tensor, Any]):
|
|
# ensure reserve keys are tensors and detached
|
|
if key in {'checkpoint_on', 'early_stop_on'}:
|
|
self._assert_tensor_metric(key, val)
|
|
if val is not None and isinstance(val, torch.Tensor):
|
|
val = val.detach()
|
|
|
|
# ensure anything else that is a tensor is detached
|
|
elif isinstance(val, torch.Tensor) and key != 'minimize':
|
|
val = val.detach()
|
|
|
|
self[key] = val
|
|
|
|
def __getstate__(self):
|
|
return self
|
|
|
|
def __setstate__(self, d):
|
|
self.update(d)
|
|
|
|
def _assert_tensor_metric(self, name: str, potential_metric: Union[bool, Tensor, None, Any]):
|
|
if potential_metric is not None and not isinstance(potential_metric, bool):
|
|
assert isinstance(potential_metric, Tensor), f'{name} must be a torch.Tensor'
|
|
|
|
def _assert_grad_tensor_metric(self, name: str, x: Union[torch.Tensor, Any], additional_err: str = ''):
|
|
if x is not None:
|
|
assert isinstance(x, Tensor), f'{name} must be a torch.Tensor'
|
|
m = f'{name} must have a computational graph.'
|
|
|
|
if additional_err:
|
|
m += f' {additional_err}'
|
|
assert x.grad_fn is not None, m
|
|
|
|
def log(
|
|
self,
|
|
name: str,
|
|
value: Any,
|
|
prog_bar: bool = False,
|
|
logger: bool = True,
|
|
on_step: bool = False,
|
|
on_epoch: bool = True,
|
|
reduce_fx: Callable = torch.mean,
|
|
tbptt_reduce_fx: Callable = torch.mean,
|
|
tbptt_pad_token: int = 0,
|
|
enable_graph: bool = False,
|
|
sync_dist: bool = False,
|
|
sync_dist_op: Union[Any, str] = 'mean',
|
|
sync_dist_group: Optional[Any] = None,
|
|
sync_fn: Callable = None,
|
|
dataloader_idx: Optional[int] = None,
|
|
device: torch.device = None,
|
|
):
|
|
# no metrics should be logged with graphs
|
|
if not enable_graph and isinstance(value, torch.Tensor):
|
|
value = value.detach()
|
|
|
|
# sync across workers when using distributed training
|
|
sync_fn = sync_fn or sync_ddp_if_available
|
|
if sync_dist and isinstance(value, (torch.Tensor, numbers.Number)):
|
|
is_dist_initialized = torch.distributed.is_available() and torch.distributed.is_initialized()
|
|
# TODO: Find a way to make the reduction only once, so we don't need to clone.
|
|
if is_dist_initialized and isinstance(value, torch.Tensor):
|
|
value = value.clone()
|
|
else:
|
|
value = torch.tensor(value, device=device, dtype=torch.float)
|
|
value = sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op)
|
|
|
|
if 'meta' not in self:
|
|
self.__setitem__('meta', {})
|
|
|
|
# if user requests both step and epoch, then we split the metric in two automatically
|
|
# one will be logged per step. the other per epoch
|
|
was_forked = False
|
|
if on_step and on_epoch:
|
|
was_forked = True
|
|
|
|
# set step version
|
|
step_name = f'{name}_step'
|
|
|
|
self.__set_meta(
|
|
step_name,
|
|
value,
|
|
prog_bar,
|
|
logger,
|
|
on_step=True,
|
|
on_epoch=False,
|
|
reduce_fx=reduce_fx,
|
|
tbptt_reduce_fx=tbptt_reduce_fx,
|
|
tbptt_pad_token=tbptt_pad_token,
|
|
forked=False,
|
|
dataloader_idx=dataloader_idx,
|
|
)
|
|
|
|
self.__setitem__(step_name, value)
|
|
|
|
# set epoch version
|
|
epoch_name = f'{name}_epoch'
|
|
|
|
self.__set_meta(
|
|
epoch_name,
|
|
value,
|
|
prog_bar,
|
|
logger,
|
|
on_step=False,
|
|
on_epoch=True,
|
|
reduce_fx=reduce_fx,
|
|
tbptt_reduce_fx=tbptt_reduce_fx,
|
|
tbptt_pad_token=tbptt_pad_token,
|
|
forked=False,
|
|
dataloader_idx=dataloader_idx,
|
|
)
|
|
self.__setitem__(epoch_name, value)
|
|
|
|
# always log the original metric
|
|
self.__set_meta(
|
|
name,
|
|
value,
|
|
prog_bar,
|
|
logger,
|
|
on_step,
|
|
on_epoch,
|
|
reduce_fx,
|
|
tbptt_reduce_fx=tbptt_reduce_fx,
|
|
tbptt_pad_token=tbptt_pad_token,
|
|
forked=was_forked,
|
|
dataloader_idx=dataloader_idx,
|
|
)
|
|
|
|
# set the value
|
|
self.__setitem__(name, value)
|
|
|
|
def __set_meta(
|
|
self,
|
|
name: str,
|
|
value: Any,
|
|
prog_bar: bool,
|
|
logger: bool,
|
|
on_step: bool,
|
|
on_epoch: bool,
|
|
reduce_fx: Callable,
|
|
tbptt_pad_token: int,
|
|
tbptt_reduce_fx: Callable,
|
|
forked: bool,
|
|
dataloader_idx: Union[int, None]
|
|
):
|
|
# set the meta for the item
|
|
meta_value = value
|
|
meta = dict(
|
|
prog_bar=prog_bar,
|
|
logger=logger,
|
|
on_step=on_step,
|
|
on_epoch=on_epoch,
|
|
reduce_fx=reduce_fx,
|
|
value=meta_value,
|
|
tbptt_reduce_fx=tbptt_reduce_fx,
|
|
tbptt_pad_token=tbptt_pad_token,
|
|
forked=forked,
|
|
dataloader_idx=dataloader_idx,
|
|
)
|
|
|
|
self['meta'][name] = meta
|
|
|
|
# track whether any input requires reduction on epoch end
|
|
_internal = self['meta']['_internal']
|
|
_internal['_reduce_on_epoch'] = max(_internal['_reduce_on_epoch'], on_epoch)
|
|
|
|
def track_batch_size(self, batch):
|
|
batch_size = Result.extract_batch_size(batch)
|
|
Result.attach_batch_size(batch_size, self)
|
|
|
|
@staticmethod
|
|
def extract_batch_size(batch):
|
|
try:
|
|
batch_size = Result.unpack_batch_size(batch)
|
|
except RecursionError:
|
|
batch_size = 1
|
|
return batch_size
|
|
|
|
@staticmethod
|
|
def attach_batch_size(batch_size: Union[int, None], result: 'Result') -> None:
|
|
if batch_size is not None:
|
|
meta = result['meta']
|
|
meta['_internal']['batch_sizes'].append(batch_size)
|
|
|
|
def get_batch_sizes(self):
|
|
meta = self['meta']
|
|
return torch.tensor(meta['_internal']['batch_sizes'])
|
|
|
|
def get_callback_metrics(self) -> dict:
|
|
result = {'early_stop_on': self.early_stop_on, 'checkpoint_on': self.checkpoint_on}
|
|
|
|
return result
|
|
|
|
def _add_dataloader_idx(self, k: str, dataloader_idx: Union[int, None], add_dataloader_idx: bool) -> str:
|
|
if dataloader_idx is not None and add_dataloader_idx:
|
|
return f"{k}/dataloader_idx_{dataloader_idx}"
|
|
return k
|
|
|
|
def get_batch_log_metrics(self, include_forked_originals=True, add_dataloader_idx=False) -> dict:
|
|
"""
|
|
Gets the metrics to log at the end of the batch step
|
|
|
|
"""
|
|
result = {}
|
|
|
|
meta = self['meta']
|
|
for k, options in meta.items():
|
|
if k == '_internal':
|
|
continue
|
|
|
|
if options['forked'] and not include_forked_originals:
|
|
continue
|
|
|
|
dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx)
|
|
|
|
if options['logger'] and options['on_step']:
|
|
if isinstance(self[k], Metric):
|
|
result[dl_key] = self[k]._forward_cache.detach()
|
|
else:
|
|
result[dl_key] = self[k]
|
|
|
|
return result
|
|
|
|
def get_epoch_log_metrics(self, add_dataloader_idx=False) -> dict:
|
|
"""
|
|
Gets the metrics to log at the end of epoch
|
|
"""
|
|
result = {}
|
|
|
|
meta = self['meta']
|
|
for k, options in meta.items():
|
|
if k == '_internal':
|
|
continue
|
|
|
|
if options['forked']:
|
|
continue
|
|
|
|
dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx)
|
|
|
|
if options['logger'] and options['on_epoch']:
|
|
if isinstance(self[k], Metric):
|
|
result[dl_key] = self[k].compute().detach()
|
|
else:
|
|
result[dl_key] = self[k]
|
|
|
|
if k in self and not options['on_epoch'] and isinstance(self[k], Metric):
|
|
# compute metric on epoch anyway so state does not accumulate
|
|
self[k].compute()
|
|
|
|
return result
|
|
|
|
def get_epoch_pbar_metrics(self, add_dataloader_idx=False):
|
|
"""
|
|
Gets the metrics to log at the end of epoch
|
|
"""
|
|
result = {}
|
|
|
|
meta = self['meta']
|
|
for k, options in meta.items():
|
|
if k == '_internal':
|
|
continue
|
|
|
|
if options['forked']:
|
|
continue
|
|
|
|
dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx)
|
|
|
|
if options['prog_bar'] and options['on_epoch']:
|
|
if isinstance(self[k], Metric):
|
|
result[dl_key] = self[k].compute().detach()
|
|
else:
|
|
result[dl_key] = self[k]
|
|
|
|
if k in self and not options['on_epoch'] and isinstance(self[k], Metric):
|
|
# compute metric on epoch anyway so state does not accumulate
|
|
self[k].compute()
|
|
|
|
return result
|
|
|
|
def get_forked_metrics(self, add_dataloader_idx=False):
|
|
"""
|
|
Gets the metrics to log at the end of epoch
|
|
"""
|
|
result = {}
|
|
|
|
meta = self['meta']
|
|
for k, options in meta.items():
|
|
if k == '_internal':
|
|
continue
|
|
|
|
dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx)
|
|
|
|
if options['forked']:
|
|
if isinstance(self[k], Metric):
|
|
result[dl_key] = self[k].compute().detach()
|
|
else:
|
|
result[dl_key] = self[k]
|
|
|
|
return result
|
|
|
|
def get_batch_pbar_metrics(self, include_forked_originals=True, add_dataloader_idx=False):
|
|
"""
|
|
Gets the metrics to log at the end of the batch step
|
|
"""
|
|
result = {}
|
|
|
|
meta = self['meta']
|
|
for k, options in meta.items():
|
|
if k == '_internal':
|
|
continue
|
|
|
|
if options['forked'] and not include_forked_originals:
|
|
continue
|
|
|
|
dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx)
|
|
|
|
if options['prog_bar'] and options['on_step']:
|
|
if isinstance(self[k], Metric):
|
|
result[dl_key] = self[k]._forward_cache
|
|
else:
|
|
result[dl_key] = self[k]
|
|
|
|
return result
|
|
|
|
def detach(self):
|
|
for k, v in self.items():
|
|
if isinstance(v, torch.Tensor):
|
|
self.__setitem__(k, v.detach())
|
|
|
|
def to(self, *args, **kwargs):
|
|
"""Move all self attributes to the given device."""
|
|
for k, v in self.items():
|
|
if isinstance(v, torch.Tensor):
|
|
self.__setitem__(k, v.to(*args, **kwargs))
|
|
|
|
def cpu(self):
|
|
"""Move all self attributes to CPU."""
|
|
self.to(torch.device("cpu"))
|
|
|
|
def __repr__(self):
|
|
self_copy = self.copy()
|
|
|
|
if 'meta' in self_copy:
|
|
del self_copy['meta']
|
|
|
|
return str(self_copy)
|
|
|
|
def __str__(self):
|
|
copy = self.copy()
|
|
del copy['meta']
|
|
|
|
return str(copy)
|
|
|
|
def __copy__(self):
|
|
newone = type(self)()
|
|
for k, v in self.items():
|
|
if isinstance(v, torch.Tensor):
|
|
v = v.detach()
|
|
newone[k] = copy(v)
|
|
return newone
|
|
|
|
@staticmethod
|
|
def unpack_batch_size(sample):
|
|
"""
|
|
Recursively unpack sample to find a torch.Tensor.
|
|
returns len(tensor) when found, or 1 when it hits an empty or non iterable.
|
|
"""
|
|
if isinstance(sample, torch.Tensor):
|
|
size = sample.size(0)
|
|
elif isinstance(sample, str):
|
|
return len(sample)
|
|
elif isinstance(sample, dict):
|
|
sample = next(iter(sample.values()), 1)
|
|
size = Result.unpack_batch_size(sample)
|
|
elif isinstance(sample, Iterable):
|
|
sample = next(iter(sample), 1)
|
|
size = Result.unpack_batch_size(sample)
|
|
else:
|
|
size = 1
|
|
return size
|
|
|
|
@classmethod
|
|
def gather(cls, outputs):
|
|
meta = outputs[0].get('meta')
|
|
result = cls()
|
|
result = recursive_gather(outputs, result)
|
|
recursive_stack(result)
|
|
|
|
if meta:
|
|
result['meta'] = meta
|
|
return result
|
|
|
|
@classmethod
|
|
def padded_gather(cls, outputs):
|
|
meta = outputs[0].get('meta')
|
|
result = cls()
|
|
result = recursive_gather(outputs, result)
|
|
|
|
# find the padding used for other values
|
|
default_padding_idx = 0
|
|
for name, value in result.items():
|
|
if isinstance(value, list) and len(value) > 0 and isinstance(value[0], torch.Tensor):
|
|
if name not in {'checkpoint_on', 'early_stop_on', 'minimize'}:
|
|
default_padding_idx = meta[name]['tbptt_pad_token']
|
|
break
|
|
|
|
# pad across each key individually
|
|
for name, value in result.items():
|
|
is_reserved = name in {'checkpoint_on', 'early_stop_on', 'minimize'}
|
|
if isinstance(value, list) and len(value) > 0 and isinstance(value[0], torch.Tensor):
|
|
|
|
if is_reserved:
|
|
padding_key = default_padding_idx
|
|
else:
|
|
padding_key = meta[name]['tbptt_pad_token']
|
|
padded = torch.nn.utils.rnn.pad_sequence(value, batch_first=True, padding_value=padding_key)
|
|
result[name] = padded
|
|
|
|
# also update the result
|
|
if meta and not is_reserved:
|
|
meta[name]['value'] = padded
|
|
if meta:
|
|
result['meta'] = meta
|
|
return result
|
|
|
|
@classmethod
|
|
def reduce_on_epoch_end(cls, outputs):
|
|
# get the batch sizes for all outputs
|
|
batch_sizes = []
|
|
meta = {}
|
|
for x in outputs:
|
|
batch_sizes.append(x.get_batch_sizes())
|
|
meta.update(x['meta'])
|
|
|
|
batch_sizes = torch.stack(batch_sizes).view(-1)
|
|
|
|
result = cls()
|
|
result = recursive_gather(outputs, result)
|
|
recursive_stack(result)
|
|
|
|
for k, option in meta.items():
|
|
if k == '_internal' or isinstance(result[k], Metric):
|
|
continue
|
|
|
|
# for forked metrics don't reduce, just take the last val
|
|
if option['forked']:
|
|
result[k] = choose_last(result[k])
|
|
continue
|
|
|
|
if option['on_epoch']:
|
|
fx = option['reduce_fx']
|
|
if fx == torch.mean:
|
|
if isinstance(result[k], list):
|
|
result[k] = torch.tensor(result[k]).float()
|
|
try:
|
|
reduced_val = weighted_mean(result[k], batch_sizes)
|
|
# todo: specify the expected Exceptions to come
|
|
except Exception:
|
|
reduced_val = torch.mean(result[k])
|
|
else:
|
|
reduced_val = fx(result[k])
|
|
|
|
result[k] = reduced_val
|
|
else:
|
|
del result[k]
|
|
|
|
result['meta'] = meta
|
|
return result
|
|
|
|
@classmethod
|
|
def reduce_across_time(cls, time_outputs):
|
|
# auto-reduce across time for tbptt
|
|
meta = time_outputs[0]['meta']
|
|
|
|
# in 1.0 the results have 'extra'. Once we deprecate 0.10.0 we may not need this
|
|
if 'extra' in time_outputs[0]:
|
|
[x.pop('extra', None) for x in time_outputs]
|
|
|
|
result = cls()
|
|
result = recursive_gather(time_outputs, result)
|
|
recursive_stack(result)
|
|
|
|
for k, value in result.items():
|
|
if k in ['meta', 'extra'] or isinstance(value, Metric):
|
|
continue
|
|
|
|
# pick the reduce fx
|
|
if k in ['checkpoint_on', 'early_stop_on', 'minimize']:
|
|
tbptt_reduce_fx = torch.mean
|
|
else:
|
|
tbptt_reduce_fx = meta[k]['tbptt_reduce_fx']
|
|
|
|
if isinstance(value, list):
|
|
value = torch.tensor(value)
|
|
|
|
if isinstance(value, dict):
|
|
# TODO: recursive reduce:
|
|
_recursive_fx_apply(value, tbptt_reduce_fx)
|
|
else:
|
|
result[k] = tbptt_reduce_fx(value.float())
|
|
|
|
result['meta'] = meta
|
|
return result
|
|
|
|
def dp_reduce(self):
|
|
for k, value in self.items():
|
|
if k == 'meta' or isinstance(value, Metric):
|
|
continue
|
|
|
|
if isinstance(value, list):
|
|
value = torch.tensor(value)
|
|
|
|
self[k] = value.mean(dim=-1)
|
|
|
|
@property
|
|
def should_reduce_on_epoch_end(self) -> bool:
|
|
return self['meta']['_internal']['_reduce_on_epoch']
|
|
|
|
def drop_hiddens(self):
|
|
if 'hiddens' in self:
|
|
del self['hiddens']
|
|
|
|
def rename_keys(self, map_dict: dict):
|
|
"""
|
|
Maps key values to the target values. Useful when renaming variables in mass.
|
|
|
|
Args:
|
|
map_dict:
|
|
"""
|
|
meta = self.meta
|
|
for source, dest in map_dict.items():
|
|
# map the main keys
|
|
self[dest] = self[source]
|
|
del self[source]
|
|
|
|
# map meta
|
|
meta[dest] = meta[source]
|
|
del meta[source]
|
|
|
|
|
|
def choose_last(x):
|
|
if isinstance(x, (torch.Tensor, list)):
|
|
return x[-1]
|
|
if isinstance(x, dict):
|
|
for k, v in x.items():
|
|
x[k] = x[k][-1]
|
|
|
|
|
|
def recursive_gather(outputs: Sequence[dict], result: Optional[MutableMapping] = None) -> Optional[MutableMapping]:
|
|
for out in outputs:
|
|
if 'meta' in out:
|
|
del out['meta']
|
|
|
|
for k, v in out.items():
|
|
# support manual opt where the user does not return a minimize key
|
|
if k == 'minimize' and v is None:
|
|
continue
|
|
|
|
if isinstance(v, dict):
|
|
in_d = result.get(k, {})
|
|
v = recursive_gather([v], in_d)
|
|
result[k] = v
|
|
else:
|
|
if isinstance(v, Metric):
|
|
# if v is a metric, just keep one of them,
|
|
# don't keep on adding a list of them
|
|
result[k] = v
|
|
else:
|
|
if k not in result:
|
|
result[k] = []
|
|
result[k].append(v)
|
|
|
|
return result
|
|
|
|
|
|
def recursive_stack(result: MutableMapping):
|
|
for k, v in result.items():
|
|
if isinstance(v, dict):
|
|
recursive_stack(v)
|
|
|
|
result[k] = collate_tensors(v)
|
|
|
|
|
|
def _recursive_fx_apply(input: dict, fx):
|
|
for k, v in input.items():
|
|
if isinstance(v, list):
|
|
v = torch.tensor(v)
|
|
|
|
if isinstance(v, torch.Tensor):
|
|
v = fx(v.float())
|
|
input[k] = v
|
|
else:
|
|
_recursive_fx_apply(v, fx)
|
|
|
|
|
|
def collate_tensors(items: Union[List, Tuple]) -> Union[Tensor, List, Tuple]:
|
|
if not items or not isinstance(items, (list, tuple)) or any(not isinstance(item, Tensor) for item in items):
|
|
# items is not a sequence, empty, or contains non-tensors
|
|
return items
|
|
|
|
if all(item.ndim == 0 for item in items):
|
|
# all tensors are scalars, we need to stack
|
|
return torch.stack(items)
|
|
|
|
if all(item.ndim >= 1 and item.shape[1:] == items[0].shape[1:] for item in items):
|
|
# we can concatenate along the first dimension
|
|
return torch.cat(items)
|
|
|
|
return items
|
|
|
|
|
|
class EvalResult(Result):
|
|
def __init__(
|
|
self,
|
|
early_stop_on: Optional[Tensor] = None,
|
|
checkpoint_on: Optional[Tensor] = None,
|
|
hiddens: Optional[Tensor] = None,
|
|
):
|
|
"""
|
|
Used in val/train loop to auto-log to a logger or progress bar without needing to define
|
|
a _step_end or _epoch_end method
|
|
|
|
Example::
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
loss = ...
|
|
result = EvalResult()
|
|
result.log('val_loss', loss)
|
|
return result
|
|
|
|
def test_step(self, batch, batch_idx):
|
|
loss = ...
|
|
result = EvalResult()
|
|
result.log('val_loss', loss)
|
|
return result
|
|
|
|
Args:
|
|
early_stop_on: Metric to early stop on.
|
|
Should be a one element tensor if combined with default
|
|
:class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping`.
|
|
If this result is returned by
|
|
:meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step`,
|
|
the specified value will be averaged across all steps.
|
|
checkpoint_on: Metric to checkpoint on.
|
|
Should be a one element tensor if combined with default checkpoint callback.
|
|
If this result is returned by
|
|
:meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step`,
|
|
the specified value will be averaged across all steps.
|
|
hiddens:
|
|
"""
|
|
|
|
super().__init__(None, early_stop_on, checkpoint_on, hiddens)
|
|
|
|
def log(
|
|
self,
|
|
name,
|
|
value,
|
|
prog_bar: bool = False,
|
|
logger: bool = True,
|
|
on_step: bool = False,
|
|
on_epoch: bool = True,
|
|
reduce_fx: Callable = torch.mean,
|
|
tbptt_reduce_fx: Callable = torch.mean,
|
|
tbptt_pad_token: int = 0,
|
|
enable_graph: bool = False,
|
|
sync_dist: bool = False,
|
|
sync_dist_op: Union[Any, str] = 'mean',
|
|
sync_dist_group: Optional[Any] = None,
|
|
):
|
|
"""
|
|
Log a key, value
|
|
|
|
Example::
|
|
|
|
result.log('val_loss', loss)
|
|
|
|
# defaults used
|
|
result.log(
|
|
name,
|
|
value,
|
|
on_step=False,
|
|
on_epoch=True,
|
|
logger=True,
|
|
prog_bar=False,
|
|
reduce_fx=torch.mean
|
|
)
|
|
|
|
|
|
Args:
|
|
name: key name
|
|
value: value name
|
|
prog_bar: if True logs to the progress base
|
|
logger: if True logs to the logger
|
|
on_step: if True logs the output of validation_step or test_step
|
|
on_epoch: if True, logs the output of the training loop aggregated
|
|
reduce_fx: Torch.mean by default
|
|
tbptt_reduce_fx: function to reduce on truncated back prop
|
|
tbptt_pad_token: token to use for padding
|
|
enable_graph: if True, will not auto detach the graph
|
|
sync_dist: if True, reduces the metric across GPUs/TPUs
|
|
sync_dist_op: the op to sync across
|
|
sync_dist_group: the ddp group
|
|
"""
|
|
super().log(
|
|
name=name,
|
|
value=value,
|
|
prog_bar=prog_bar,
|
|
logger=logger,
|
|
on_step=on_step,
|
|
on_epoch=on_epoch,
|
|
reduce_fx=reduce_fx,
|
|
enable_graph=enable_graph,
|
|
sync_dist=sync_dist,
|
|
sync_dist_group=sync_dist_group,
|
|
sync_dist_op=sync_dist_op,
|
|
tbptt_pad_token=tbptt_pad_token,
|
|
tbptt_reduce_fx=tbptt_reduce_fx,
|
|
)
|
|
|
|
def log_dict(
|
|
self,
|
|
dictionary: dict,
|
|
prog_bar: bool = False,
|
|
logger: bool = True,
|
|
on_step: bool = False,
|
|
on_epoch: bool = True,
|
|
reduce_fx: Callable = torch.mean,
|
|
tbptt_reduce_fx: Callable = torch.mean,
|
|
tbptt_pad_token: int = 0,
|
|
enable_graph: bool = False,
|
|
sync_dist: bool = False,
|
|
sync_dist_op: Union[Any, str] = 'mean',
|
|
sync_dist_group: Optional[Any] = None,
|
|
):
|
|
"""
|
|
Log a dictonary of values at once
|
|
|
|
Example::
|
|
|
|
values = {'loss': loss, 'acc': acc, ..., 'metric_n': metric_n}
|
|
result.log_dict(values)
|
|
|
|
Args:
|
|
dictionary: key value pairs (str, tensors)
|
|
prog_bar: if True logs to the progress base
|
|
logger: if True logs to the logger
|
|
on_step: if True logs the output of validation_step or test_step
|
|
on_epoch: if True, logs the output of the training loop aggregated
|
|
reduce_fx: Torch.mean by default
|
|
tbptt_reduce_fx: function to reduce on truncated back prop
|
|
tbptt_pad_token: token to use for padding
|
|
enable_graph: if True, will not auto detach the graph
|
|
sync_dist: if True, reduces the metric across GPUs/TPUs
|
|
sync_dist_op: the op to sync across
|
|
sync_dist_group: the ddp group
|
|
"""
|
|
for k, v in dictionary.items():
|
|
self.log(
|
|
name=k,
|
|
value=v,
|
|
prog_bar=prog_bar,
|
|
logger=logger,
|
|
on_step=on_step,
|
|
on_epoch=on_epoch,
|
|
reduce_fx=reduce_fx,
|
|
enable_graph=enable_graph,
|
|
sync_dist=sync_dist,
|
|
sync_dist_group=sync_dist_group,
|
|
sync_dist_op=sync_dist_op,
|
|
tbptt_pad_token=tbptt_pad_token,
|
|
tbptt_reduce_fx=tbptt_reduce_fx,
|
|
)
|
|
|
|
def get_callback_metrics(self) -> dict:
|
|
result = {}
|
|
if self.early_stop_on:
|
|
result['early_stop_on'] = self.early_stop_on
|
|
if self.checkpoint_on:
|
|
result['checkpoint_on'] = self.checkpoint_on
|
|
return result
|
|
|
|
def write(self, name: str, values: Union[Tensor, list], filename: str = 'predictions.pt'):
|
|
"""Add feature name and value pair to collection of predictions that will be written to disk on
|
|
`validation_end` or `test_end`. If running on multiple GPUs, you will get separate `n_gpu`
|
|
prediction files with the rank prepended onto filename.
|
|
|
|
Example::
|
|
|
|
result = pl.EvalResult()
|
|
result.write('ids', [0, 1, 2])
|
|
result.write('preds', ['cat', 'dog', 'dog'])
|
|
|
|
Args:
|
|
name: Feature name that will turn into column header of predictions file
|
|
values: Flat tensor or list of row values for given feature column 'name'.
|
|
filename: Filepath where your predictions will be saved. Defaults to 'predictions.pt'.
|
|
"""
|
|
# Type check the incoming arguments
|
|
if not isinstance(name, str):
|
|
raise ValueError(f"Expected str for 'name' but got {type(name)}")
|
|
if not isinstance(filename, str):
|
|
raise ValueError(f"Expected str for 'filename' but got {type(name)}")
|
|
|
|
if isinstance(values, Tensor):
|
|
values = values.detach()
|
|
|
|
preds = getattr(self, 'predictions', None)
|
|
if preds is None:
|
|
self.predictions = {filename: {name: values}}
|
|
elif filename not in preds:
|
|
preds[filename] = {name: values}
|
|
elif name not in preds[filename]:
|
|
preds[filename][name] = values
|
|
elif isinstance(values, Tensor):
|
|
preds[filename][name] = torch.cat((preds[filename][name], values))
|
|
elif isinstance(values, list):
|
|
preds[filename][name].extend(values)
|
|
|
|
def write_dict(self, predictions_dict, filename='predictions.pt'):
|
|
"""Calls EvalResult.write() for each key-value pair in predictions_dict.
|
|
|
|
It is recommended that you use this function call instead of .write if you need to
|
|
store more than one column of predictions in your output file.
|
|
|
|
Example::
|
|
|
|
predictions_to_write = {'preds': ['cat', 'dog'], 'ids': tensor([0, 1])}
|
|
result.write_dict(predictions_to_write)
|
|
|
|
Args:
|
|
predictions_dict ([type]): Dict of predictions to store and then write to filename at eval end.
|
|
filename (str, optional): File where your predictions will be stored. Defaults to './predictions.pt'.
|
|
"""
|
|
for k, v in predictions_dict.items():
|
|
self.write(k, v, filename)
|
|
|
|
|
|
def weighted_mean(result, weights):
|
|
|
|
if isinstance(result, dict):
|
|
_process_dataloader_aggregated_steps(result, weights)
|
|
else:
|
|
if isinstance(result, list):
|
|
result = torch.tensor(result)
|
|
|
|
weights = weights.to(result.device)[:result.size(0)]
|
|
numerator = torch.dot(result.float(), weights.transpose(-1, 0).float())
|
|
result = numerator / weights.sum().float()
|
|
return result
|
|
|
|
|
|
def _process_dataloader_aggregated_steps(result, weights):
|
|
internal_keys = {'meta'}
|
|
|
|
moved = False
|
|
|
|
for k, v in result.items():
|
|
if k in internal_keys:
|
|
continue
|
|
|
|
# make sure v is a tensor
|
|
if not isinstance(v, torch.Tensor):
|
|
v = torch.tensor(v)
|
|
|
|
# move to memory only once
|
|
if not moved:
|
|
weights = weights.to(v.device)
|
|
moved = True
|
|
|
|
# move weights to same device as value to reduce
|
|
weights_t = weights[:v.size(0)]
|
|
|
|
# weighted mean
|
|
numerator = torch.dot(v.float(), weights_t.transpose(-1, 0).float())
|
|
v = numerator / weights.sum().float()
|
|
result[k] = v
|