New logger connector code (#7882)

* New logger connector code

* Update CHANGELOG

* Update requirements

* Fix import path

* Add new suffix

* Tests

* Minor changes

* Rename and reorder

* Fix import

* Formatting

* Fix with seed_everything?

* Fix test?

* Fix test?

* Minor change

* Minor changes

* Minor changes

* Force float

* Fix minimal bug

* Fix minimal bug

* Update with latest changes

* Fix import

* bad merge

* update typing

Co-authored-by: tchaton <thomas@grid.ai>
This commit is contained in:
Carlos Mocholí 2021-06-08 22:20:17 +02:00 committed by GitHub
parent b74f8ac149
commit b214442e74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 1001 additions and 56 deletions

View File

@ -107,7 +107,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Refactored logging
* Renamed and moved `core/step_result.py` to `trainer/connectors/logger_connector/result.py` ([#7736](https://github.com/PyTorchLightning/pytorch-lightning/pull/7736))
* Dramatically simplify the `LoggerConnector` ([#7882](https://github.com/PyTorchLightning/pytorch-lightning/pull/7882))
* `trainer.{logged,progress_bar,callback}_metrics` are now updated on-demand ([#7882](https://github.com/PyTorchLightning/pytorch-lightning/pull/7882))
* Completely overhaul the `Result` object in favor of `ResultMetric` ([#7882](https://github.com/PyTorchLightning/pytorch-lightning/pull/7882))
* Improve epoch-level reduction time and overall memory usage ([#7882](https://github.com/PyTorchLightning/pytorch-lightning/pull/7882))
- Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/))

View File

@ -141,7 +141,7 @@ class LoggerConnector:
should_log_every_n_steps = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0
return should_log_every_n_steps or self.trainer.should_stop
def configure_logger(self, logger: LightningLoggerBase) -> None:
def configure_logger(self, logger: Union[bool, Iterable, LightningLoggerBase]) -> None:
if logger is True:
version = os.environ.get('PL_EXP_VERSION', self.trainer.slurm_job_id)

View File

@ -0,0 +1,311 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from pprint import pprint
from typing import Any, Dict, Iterable, Mapping, Optional, Union
import torch
import pytorch_lightning as pl
from pytorch_lightning.core import memory
from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger
from pytorch_lightning.trainer.connectors.logger_connector.result_new import _METRIC, MetricSource
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.utilities import DeviceType
from pytorch_lightning.utilities.metrics import metrics_to_scalars
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT
# TODO(@carmocca): Remove `New` suffix
class LoggerConnectorNew:
def __init__(self, trainer: 'pl.Trainer', log_gpu_memory: Optional[str] = None) -> None:
self.trainer = trainer
self.log_gpu_memory = log_gpu_memory
self.eval_loop_results = []
self._val_log_step: int = 0
self._test_log_step: int = 0
self._progress_bar_metrics: Dict[str, float] = {}
self._logged_metrics: Dict[str, _METRIC] = {}
self._callback_metrics: Dict[str, _METRIC] = {}
self._epoch_end_reached = False
self._current_fx: Optional[str] = None
self._batch_idx: Optional[int] = None
self._split_idx: Optional[int] = None
def on_trainer_init(
self,
logger: LightningLoggerBase,
flush_logs_every_n_steps: int,
log_every_n_steps: int,
move_metrics_to_cpu: bool,
) -> None:
self.configure_logger(logger)
self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps
self.trainer.log_every_n_steps = log_every_n_steps
self.trainer.move_metrics_to_cpu = move_metrics_to_cpu
@property
def should_flush_logs(self) -> bool:
should_flush = (self.trainer.global_step + 1) % self.trainer.flush_logs_every_n_steps == 0
return should_flush or self.trainer.should_stop
@property
def should_update_logs(self) -> bool:
should_log_every_n_steps = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0
return should_log_every_n_steps or self.trainer.should_stop
def configure_logger(self, logger: Union[bool, Iterable, LightningLoggerBase]) -> None:
if logger is True:
version = os.environ.get('PL_EXP_VERSION', self.trainer.slurm_job_id)
# default logger
self.trainer.logger = TensorBoardLogger(
save_dir=self.trainer.default_root_dir, version=version, name='lightning_logs'
)
elif logger is False:
self.trainer.logger = None
else:
if isinstance(logger, Iterable):
self.trainer.logger = LoggerCollection(logger)
else:
self.trainer.logger = logger
def log_metrics(self, metrics: Dict[str, _METRIC], step: Optional[int] = None) -> None:
"""Logs the metric dict passed in.
If `step` parameter is None and `step` key is presented is metrics,
uses metrics["step"] as a step
Args:
metrics: Metric values
step: Step for which metrics should be logged. Default value is `self.global_step` during training or
the total validation / test log step count during validation and testing.
"""
# add gpu memory
if self.trainer._device_type == DeviceType.GPU and self.log_gpu_memory:
mem_map = memory.get_memory_profile(self.log_gpu_memory)
metrics.update(mem_map)
# turn all tensors to scalars
scalar_metrics = metrics_to_scalars(metrics)
if "step" in scalar_metrics and step is None:
step = scalar_metrics.pop("step")
elif step is None:
# added metrics by Lightning for convenience
scalar_metrics['epoch'] = self.trainer.current_epoch
step = self.trainer.global_step
# log actual metrics
if self.trainer.logger is not None:
if self.trainer.is_global_zero:
self.trainer.logger.agg_and_log_metrics(scalar_metrics, step=step)
self.trainer.logger.save()
self._logged_metrics.update(scalar_metrics)
"""
Evaluation metric updates
"""
@property
def _eval_log_step(self) -> Optional[int]:
if self.trainer.state.stage is RunningStage.VALIDATING:
return self._val_log_step
elif self.trainer.state.stage is RunningStage.TESTING:
return self._test_log_step
else:
return None
def _increment_eval_log_step(self) -> None:
if self.trainer.state.stage is RunningStage.VALIDATING:
self._val_log_step += 1
elif self.trainer.state.stage is RunningStage.TESTING:
self._test_log_step += 1
def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int, num_dataloaders: int) -> None:
model = self.trainer.lightning_module
# set dataloader_idx only if multiple ones
model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None
# track batch_size
self.trainer.result_collection.extract_batch_size(batch)
self._batch_idx = batch_idx
def update_eval_step_metrics(self) -> None:
if self.trainer.sanity_checking:
return
# logs user requested information to logger
assert not self._epoch_end_reached
metrics = self.metrics[MetricSource.LOG]
if metrics:
self.log_metrics(metrics, step=self._eval_log_step)
# increment the step even if nothing was logged
self._increment_eval_log_step()
def _prepare_eval_loop_results(self, metrics: Mapping[str, _METRIC]) -> None:
if self.trainer.sanity_checking:
return
num_dataloaders = self.trainer.evaluation_loop.num_dataloaders
has_been_initialized = len(self.eval_loop_results) == num_dataloaders
for dl_idx in range(self.trainer.evaluation_loop.num_dataloaders):
# remove callback metrics that don't belong to this dataloader
callback_metrics = {
k: v
for k, v in metrics.items() if "dataloader_idx" not in k or f"dataloader_idx_{dl_idx}" in k
}
if has_been_initialized:
self.eval_loop_results[dl_idx].update(callback_metrics)
else:
self.eval_loop_results.append(callback_metrics)
def update_eval_epoch_metrics(self) -> _EVALUATE_OUTPUT:
assert self._epoch_end_reached
metrics = self.metrics
if not self.trainer.sanity_checking:
# log all the metrics as a single dict
log_metrics = metrics[MetricSource.LOG]
if log_metrics:
self.log_metrics(log_metrics)
self._prepare_eval_loop_results(metrics[MetricSource.CALLBACK])
# log results of evaluation
if (
self.trainer.state.fn != TrainerFn.FITTING and self.trainer.evaluating and self.trainer.is_global_zero
and self.trainer.verbose_evaluate
):
print('-' * 80)
for result_idx, results in enumerate(self.eval_loop_results):
print(f'DATALOADER:{result_idx} {self.trainer.state.stage.upper()} RESULTS')
pprint({
k: (v.item() if v.numel() == 1 else v.tolist()) if isinstance(v, torch.Tensor) else v
for k, v in results.items()
})
print('-' * 80)
results = self.eval_loop_results
# clear mem
self.eval_loop_results = []
return results
"""
Train metric updates
"""
def on_train_split_start(self, batch_idx: int, split_idx: int, split_batch: Any) -> None:
self.trainer.results.extract_batch_size(split_batch)
self._batch_idx = batch_idx
self._split_idx = split_idx
def update_train_step_metrics(self) -> None:
if self.trainer.train_loop.should_accumulate() and self.trainer.lightning_module.automatic_optimization:
return
# when metrics should be logged
assert not self._epoch_end_reached
metrics = self.metrics[MetricSource.LOG]
if self.should_update_logs or self.trainer.fast_dev_run is True and metrics:
self.log_metrics(metrics)
def update_train_epoch_metrics(self) -> None:
# add the metrics to the loggers
assert self._epoch_end_reached
metrics = self.metrics[MetricSource.LOG]
if metrics:
self.log_metrics(metrics)
# reset result collection for next epoch
self.trainer.results.reset(metrics=True)
"""
Utilities and properties
"""
def on_epoch_start(self) -> None:
self._epoch_end_reached = False
def on_batch_start(self) -> None:
self._epoch_end_reached = False
def epoch_end_reached(self):
self.trainer.logger_connector._epoch_end_reached = True
self.trainer.logger_connector._batch_idx = None
self.trainer.logger_connector._split_idx = None
def on_epoch_end(self) -> None:
assert self._epoch_end_reached
metrics = self.metrics
self._progress_bar_metrics.update(metrics[MetricSource.PBAR])
self._callback_metrics.update(metrics[MetricSource.CALLBACK])
self._logged_metrics.update(metrics[MetricSource.LOG])
self._current_fx = None
def on_batch_end(self) -> None:
assert not self._epoch_end_reached
metrics = self.metrics
self._progress_bar_metrics.update(metrics[MetricSource.PBAR])
self._callback_metrics.update(metrics[MetricSource.CALLBACK])
self._logged_metrics.update(metrics[MetricSource.LOG])
def should_reset_tensors(self, fx: str) -> bool:
is_different_fx = self._current_fx != fx
if self._split_idx is None:
is_first_batch = self._batch_idx in (None, 0)
else:
is_first_batch = self._batch_idx + self._split_idx == 0
return is_different_fx and is_first_batch
def reset(self, metrics: Optional[bool] = None) -> None:
self.trainer.results.reset(metrics=metrics)
self._batch_idx = None
self._split_idx = None
self._current_fx = None
@property
def metrics(self) -> Dict[MetricSource, Dict[str, _METRIC]]:
"""This function returns either batch or epoch metrics depending on ``_epoch_end_reached``."""
on_step = not self._epoch_end_reached
return self.trainer.results.metrics(on_step)
@property
def callback_metrics(self) -> Dict[str, _METRIC]:
if self.trainer.results:
metrics = self.metrics[MetricSource.CALLBACK]
self._callback_metrics.update(metrics)
return self._callback_metrics
@property
def logged_metrics(self) -> Dict[str, _METRIC]:
if self.trainer.results:
metrics = self.metrics[MetricSource.LOG]
self._logged_metrics.update(metrics)
return self._logged_metrics
@property
def progress_bar_metrics(self) -> Dict[str, float]:
if self.trainer.results:
metrics = self.metrics[MetricSource.PBAR]
self._progress_bar_metrics.update(metrics)
return self._progress_bar_metrics
def teardown(self):
self.trainer.train_loop.results.cpu()
self.trainer.evaluation_loop._val_results.cpu()
self.trainer.evaluation_loop._test_results.cpu()

View File

@ -0,0 +1,499 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Generator
from dataclasses import dataclass, field
from functools import partial, wraps
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Tuple, Union
import torch
from torchmetrics import Metric
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
from pytorch_lightning.utilities.enums import LightningEnum
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.metrics import metrics_to_scalars
# re-define the ones from pytorch_lightning.utilities.types without the `Number` type
_METRIC = Union[Metric, torch.Tensor]
_METRIC_COLLECTION = Union[_METRIC, Mapping[str, _METRIC]]
class MetricSource(LightningEnum):
CALLBACK = "callback"
PBAR = "pbar"
LOG = "log"
@dataclass
class _Sync:
fn: Callable
should: bool = False
op: Union[Any, str] = 'mean'
group: Optional[Any] = None
@property
def __call__(self) -> Any:
return partial(self.fn, reduce_op=self.op, group=self.group) if self.should else self.no_op
@staticmethod
def no_op(value: Any, *_, **__) -> Any:
return value
@dataclass
class _Metadata:
fx: str
name: str
prog_bar: bool = False
logger: bool = True
on_step: bool = False
on_epoch: bool = True
reduce_fx: Callable = torch.mean
enable_graph: bool = False
dataloader_idx: Optional[int] = None
metric_attribute: Optional[str] = None
sync: _Sync = field(default_factory=_Sync)
@property
def forked(self) -> bool:
return self.on_step and self.on_epoch
def forked_name(self, on_step: bool) -> str:
if self.forked:
return f'{self.name}_{"step" if on_step else "epoch"}'
return self.name
@property
def is_mean_reduction(self) -> bool:
return self.reduce_fx == torch.mean
@property
def is_max_reduction(self) -> bool:
return self.reduce_fx in (torch.max, max)
@property
def is_min_reduction(self) -> bool:
return self.reduce_fx in (torch.min, min)
@property
def is_custom_reduction(self) -> bool:
return not (self.is_mean_reduction or self.is_max_reduction or self.is_min_reduction)
class ResultMetric(Metric, DeviceDtypeModuleMixin):
"""Wraps the value provided to `:meth:`~pytorch_lightning.core.lightning.LightningModule.log`"""
def __init__(self, metadata: _Metadata, is_tensor: bool) -> None:
super().__init__()
self.is_tensor = is_tensor
self.meta = metadata
self.has_reset = False
if is_tensor:
self.add_state("value", torch.tensor(0, dtype=torch.float))
if self.meta.is_mean_reduction:
self.add_state("cumulated_batch_size", torch.tensor(0, dtype=torch.float))
def update(self, value: _METRIC, batch_size: torch.Tensor) -> None:
if self.is_tensor:
value = value.float()
self._forward_cache = value
# performance: no need to accumulate on values only logged on_step
if self.meta.on_step and not self.meta.on_epoch:
self.value = self.meta.sync(value)
return
# perform accumulation with reduction
if self.meta.is_mean_reduction:
self.value += value.mean() * batch_size
self.cumulated_batch_size += batch_size
elif self.meta.is_max_reduction or self.meta.is_min_reduction:
self.value = self.meta.reduce_fx(self.value, value.mean())
else:
self.value = value # noqa: attribute-defined-outside-init
self._forward_cache = value._forward_cache
def compute(self) -> torch.Tensor:
if self.is_tensor:
value = self.meta.sync(self.value)
if self.meta.is_mean_reduction:
cumulated_batch_size = self.meta.sync(self.cumulated_batch_size)
return value / cumulated_batch_size
elif self.meta.is_max_reduction or self.meta.is_min_reduction:
return value
raise MisconfigurationException(
f"Only [min, max, mean] reductions are supported. Found {self.meta.reduce_fx}"
)
return self.value.compute()
def reset(self) -> None:
if self.is_tensor:
super().reset()
else:
self.value.reset()
self.has_reset = True
def forward(self, value: _METRIC, batch_size: torch.Tensor) -> None:
if self.meta.enable_graph:
with torch.no_grad():
self.update(value, batch_size)
else:
# performance: skip the `torch.no_grad` context manager by calling `update` directly
self.update(value, batch_size)
def _wrap_compute(self, compute: Any) -> Any:
# Override to avoid syncing - we handle it ourselves.
@wraps(compute)
def wrapped_func(*args, **kwargs):
if not self._update_called:
rank_zero_warn(
f"The ``compute`` method of metric {self.__class__.__name__}"
" was called before the ``update`` method which may lead to errors,"
" as metric states have not yet been updated.", UserWarning
)
# return cached value
if self._computed is not None:
return self._computed
self._computed = compute(*args, **kwargs)
return self._computed
return wrapped_func
def __setattr__(self, key: str, value: Any) -> None:
# performance: skip the `torch.nn.Module.__setattr__` checks
object.__setattr__(self, key, value)
def __repr__(self) -> str:
state = f"value={self.value}"
if self.is_tensor and self.meta.is_mean_reduction:
state += f", cumulated_batch_size={self.cumulated_batch_size}"
return f"{self.__class__.__name__}({state})"
class ResultMetricCollection(dict):
"""
Dict wrapper for easy access to metadata.
All of the leaf items should be instances of
:class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetric`
with the same metadata.
"""
def __init__(self, *args, metadata: Optional[_Metadata] = None) -> None:
super().__init__(*args)
self.meta = metadata
class ResultCollection(dict):
"""
Collection (dictionary) of :class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetric` or
:class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetricCollection`
Example:
# `device` needs to be provided before logging
result = ResultCollection(True, torch.device("cpu"))
# you can log to a specific collection.
# arguments: fx, key, value, metadata
result.log('training_step', 'acc', torch.tensor(...), on_step=True, on_epoch=True)
result.log('validation_step', 'recall', torch.tensor(...), on_step=True, on_epoch=True)
"""
DATALOADER_SUFFIX = "/dataloader_idx_{}"
def __init__(self, training: bool, device: Optional[torch.device] = None) -> None:
super().__init__()
self.training = training
self._minimize = None
self._batch_size = torch.tensor(1, device=device)
self.device: Optional[Union[str, torch.device]] = device
self.fx_validator = FxValidator()
@property
def batch_size(self) -> torch.Tensor:
# performance: cache the `batch_size` tensor instead of re-creating it
return self._batch_size
@batch_size.setter
def batch_size(self, value: int) -> None:
self._batch_size = torch.tensor(value, device=self.device)
@property
def minimize(self) -> Optional[torch.Tensor]:
"""
The :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` loss
will be saved as the ``minimize`` attribute.
"""
return self._minimize
@minimize.setter
def minimize(self, loss: Optional[torch.Tensor]) -> None:
if loss is not None:
if not isinstance(loss, torch.Tensor):
raise ValueError(f"`Result.minimize` must be a `torch.Tensor`, found: {loss}")
if loss.grad_fn is None:
raise RuntimeError("`Result.minimize` must have a `grad_fn`")
self._minimize = loss
@property
def extra(self) -> Dict[str, Any]:
"""
Extras are any keys other than the loss returned by
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`
"""
return self.get('_extra', {})
@extra.setter
def extra(self, extra: Mapping[str, Any]) -> None:
def check_fn(v):
if v.grad_fn is not None:
raise MisconfigurationException(f'You returned a tensor with `grad_fn`. The extra values are {extra}')
apply_to_collection(extra, torch.Tensor, check_fn)
self['_extra'] = extra
def log(
self,
fx: str,
name: str,
value: _METRIC_COLLECTION,
prog_bar: bool = False,
logger: bool = True,
on_step: bool = False,
on_epoch: bool = True,
reduce_fx: Callable = torch.mean,
enable_graph: bool = False,
sync_dist: bool = False,
sync_dist_fn: Callable = _Sync.no_op,
sync_dist_op: Union[Any, str] = 'mean',
sync_dist_group: Optional[Any] = None,
dataloader_idx: Optional[int] = None,
batch_size: Optional[int] = None,
metric_attribute: Optional[str] = None,
) -> None:
"""See :meth:`~pytorch_lightning.core.lightning.LightningModule.log`"""
# no metrics should be logged with graphs
if not enable_graph and isinstance(value, torch.Tensor):
value = value.detach()
# move metrics to cpu on TPU.
if isinstance(value, torch.Tensor) and value.device.type == "xla":
value = value.cpu()
# storage key
key = f"{fx}.{name}"
# add dataloader_suffix to both key and fx
if dataloader_idx is not None:
key += f'.{dataloader_idx}'
fx += f'.{dataloader_idx}'
meta = _Metadata(
fx=fx,
name=name,
prog_bar=prog_bar,
logger=logger,
on_step=on_step,
on_epoch=on_epoch,
reduce_fx=reduce_fx,
enable_graph=enable_graph,
dataloader_idx=dataloader_idx,
metric_attribute=metric_attribute,
sync=_Sync(
should=sync_dist,
fn=sync_dist_fn,
op=sync_dist_op,
group=sync_dist_group,
)
)
if key not in self:
if meta.is_custom_reduction:
raise MisconfigurationException(
'Only `self.log(..., reduce_fx={min,max,mean})` are currently supported.'
' Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`'
)
self.register_key(key, meta, value)
elif meta != self[key].meta:
raise MisconfigurationException(
f'You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed'
)
if batch_size is not None:
self.batch_size = batch_size
self.update_metrics(key, value)
def register_key(self, key: str, meta: _Metadata, value: _METRIC_COLLECTION) -> None:
"""Create one ResultMetric object per value. Value can be provided as a nested collection"""
def fn(v: _METRIC) -> ResultMetric:
metric = ResultMetric(meta, isinstance(v, torch.Tensor))
return metric.to(self.device)
value = apply_to_collection(value, (torch.Tensor, Metric), fn)
if isinstance(value, dict):
value = ResultMetricCollection(value, metadata=meta)
self[key] = value
def update_metrics(self, key: str, value: _METRIC_COLLECTION) -> None:
def fn(result_metric, v):
# performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl`
result_metric.forward(v.to(self.device), self.batch_size)
result_metric.has_reset = False
apply_to_collections(self[key], value, ResultMetric, fn)
@staticmethod
def _get_cache(result_metric: ResultMetric, on_step: bool) -> Optional[torch.Tensor]:
cache = None
if on_step and result_metric.meta.on_step:
cache = result_metric._forward_cache
elif not on_step and result_metric.meta.on_epoch:
if not result_metric._computed:
result_metric.compute()
cache = result_metric._computed
if cache is not None and not result_metric.meta.enable_graph:
return cache.detach()
return cache
def valid_items(self) -> Generator:
"""This function is used to iterate over current valid metrics."""
return ((k, v) for k, v in self.items()
if not k == "_extra" and not (isinstance(v, ResultMetric) and v.has_reset))
def _forked_name(self, result_metric: ResultMetric, on_step: bool) -> Tuple[str, str]:
name = result_metric.meta.name
forked_name = result_metric.meta.forked_name(on_step)
dl_idx = result_metric.meta.dataloader_idx
if dl_idx is not None:
dataloader_suffix = self.DATALOADER_SUFFIX.format(dl_idx)
name += dataloader_suffix
forked_name += dataloader_suffix
return name, forked_name
def metrics(self, on_step: bool) -> Dict[MetricSource, Dict[str, _METRIC]]:
metrics = {k: {} for k in MetricSource}
for key, result_metric in self.valid_items():
# extract forward_cache or computed from the ResultMetric. ignore when the output is None
value = apply_to_collection(result_metric, ResultMetric, self._get_cache, on_step, include_none=False)
# check if the collection is empty
has_tensor = False
def any_tensor(_):
nonlocal has_tensor
has_tensor = True
apply_to_collection(value, torch.Tensor, any_tensor)
if not has_tensor:
continue
name, forked_name = self._forked_name(result_metric, on_step)
# populate logging metrics
if result_metric.meta.logger:
metrics[MetricSource.LOG][forked_name] = value
# populate callback metrics. callback metrics don't take `_step` forked metrics
if self.training or result_metric.meta.on_epoch and not on_step:
metrics[MetricSource.CALLBACK][name] = value
metrics[MetricSource.CALLBACK][forked_name] = value
# populate progress_bar metrics. convert tensors to numbers
if result_metric.meta.prog_bar:
metrics[MetricSource.PBAR][forked_name] = metrics_to_scalars(value)
return metrics
def reset(self, metrics: Optional[bool] = None, fx: Optional[str] = None) -> None:
"""
Reset the result collection
Args:
metrics: If True, only ``torchmetrics.Metric`` results are reset,
if False, only ``torch.Tensors`` are reset,
if ``None``, both are.
fx: Function to reset
"""
def fn(item: ResultMetric) -> None:
requested_type = metrics is None or metrics ^ item.is_tensor
same_fx = fx is None or fx == item.meta.fx
if requested_type and same_fx:
item.reset()
apply_to_collection(self, ResultMetric, fn)
def extract_batch_size(self, batch: Any) -> None:
try:
self.batch_size = self._extract_batch_size(batch)
except RecursionError:
self.batch_size = 1
def _extract_batch_size(self, batch: Any) -> int:
"""
Recursively unpack a batch to find a torch.Tensor.
Returns:
``len(tensor)`` when found, or ``1`` when it hits an empty or non iterable.
"""
if isinstance(batch, torch.Tensor):
size = batch.size(0)
elif isinstance(batch, str):
return len(batch)
elif isinstance(batch, dict):
sample = next(iter(batch.values()), 1)
size = self._extract_batch_size(sample)
elif isinstance(batch, Iterable):
sample = next(iter(batch), 1)
size = self._extract_batch_size(sample)
else:
size = 1
return size
def to(self, *args, **kwargs) -> 'ResultCollection':
"""Move all data to the given device."""
def to_(item: Union[torch.Tensor, Metric], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Metric]:
return item.to(*args, **kwargs)
apply_to_collection(self, (torch.Tensor, Metric), to_, *args, **kwargs)
if self.minimize is not None:
self.minimize = self.minimize.to(*args, **kwargs)
self._batch_size = self._batch_size.to(*args, **kwargs)
if 'device' in kwargs:
self.device = kwargs['device']
return self
def cpu(self) -> 'ResultCollection':
"""Move all data to CPU."""
return self.to(device="cpu")
def __str__(self) -> str:
return f'{self.__class__.__name__}({self.training}, {self.device}, {repr(self)})'
def __getstate__(self) -> dict:
d = self.__dict__.copy()
# can't deepcopy tensors with grad_fn
minimize = d.get('_minimize')
if minimize is not None:
d['_minimize'] = minimize.detach()
return d

View File

@ -7,7 +7,7 @@ tqdm>=4.41.0
PyYAML>=5.1,<=5.4.1
fsspec[http]>=2021.05.0, !=2021.06.0
tensorboard>=2.2.0, !=2.5.0 # 2.5.0 GPU CI error: 'Couldn't build proto file into descriptor pool!'
torchmetrics>=0.2.0
torchmetrics>=0.3.2
pyDeprecate==0.3.1
packaging
typing-extensions # TypedDict support for python<3.8

View File

@ -11,14 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torchmetrics import Metric
import tests.helpers.utils as tutils
from pytorch_lightning.trainer.connectors.logger_connector.result import Result
from pytorch_lightning.trainer.connectors.logger_connector.result_new import MetricSource, ResultCollection
from tests.helpers.runif import RunIf
@ -52,12 +51,14 @@ def _ddp_test_fn(rank, worldsize):
metric_b = DummyMetric()
metric_c = DummyMetric()
# dist_sync_on_step is False by default
result = Result()
metric_a = metric_a.to(f"cuda:{rank}")
metric_b = metric_b.to(f"cuda:{rank}")
metric_c = metric_c.to(f"cuda:{rank}")
for epoch in range(3):
result = ResultCollection(True, torch.device(f"cuda:{rank}"))
for _ in range(3):
cumulative_sum = 0
for i in range(5):
metric_a(i)
metric_b(i)
@ -65,32 +66,25 @@ def _ddp_test_fn(rank, worldsize):
cumulative_sum += i
result.log('a', metric_a, on_step=True, on_epoch=True)
result.log('b', metric_b, on_step=False, on_epoch=True)
result.log('c', metric_c, on_step=True, on_epoch=False)
result.log('h', 'a', metric_a, on_step=True, on_epoch=True, metric_attribute="metric_a")
result.log('h', 'b', metric_b, on_step=False, on_epoch=True, metric_attribute="metric_b")
result.log('h', 'c', metric_c, on_step=True, on_epoch=False, metric_attribute="metric_c")
batch_log = result.get_batch_log_metrics()
batch_expected = {"a_step": i, "a": i, "c": i}
assert set(batch_log.keys()) == set(batch_expected.keys())
for k in batch_expected.keys():
assert batch_expected[k] == batch_log[k]
batch_log = result.metrics(True)[MetricSource.LOG]
assert batch_log == {"a_step": i, "c": i}
epoch_log = result.get_epoch_log_metrics()
epoch_log = result.metrics(False)[MetricSource.LOG]
result.reset()
# assert metric state reset to default values
assert metric_a.x == metric_a._defaults['x']
assert metric_a.x == metric_a._defaults['x'], (metric_a.x, metric_a._defaults['x'])
assert metric_b.x == metric_b._defaults['x']
assert metric_c.x == metric_c._defaults['x']
epoch_expected = {"b": cumulative_sum * worldsize, "a_epoch": cumulative_sum * worldsize}
assert set(epoch_log.keys()) == set(epoch_expected.keys())
for k in epoch_expected.keys():
assert epoch_expected[k] == epoch_log[k]
assert epoch_log == {"b": cumulative_sum * worldsize, "a_epoch": cumulative_sum * worldsize}
@RunIf(skip_windows=True)
@RunIf(skip_windows=True, min_gpus=2)
def test_result_reduce_ddp():
"""Make sure result logging works with DDP"""
tutils.set_random_master_port()
@ -104,11 +98,10 @@ def test_result_metric_integration():
metric_b = DummyMetric()
metric_c = DummyMetric()
result = Result()
result = ResultCollection(True, torch.device("cpu"))
for epoch in range(3):
for _ in range(3):
cumulative_sum = 0
for i in range(5):
metric_a(i)
metric_b(i)
@ -116,17 +109,14 @@ def test_result_metric_integration():
cumulative_sum += i
result.log('a', metric_a, on_step=True, on_epoch=True)
result.log('b', metric_b, on_step=False, on_epoch=True)
result.log('c', metric_c, on_step=True, on_epoch=False)
result.log('h', 'a', metric_a, on_step=True, on_epoch=True, metric_attribute="metric_a")
result.log('h', 'b', metric_b, on_step=False, on_epoch=True, metric_attribute="metric_b")
result.log('h', 'c', metric_c, on_step=True, on_epoch=False, metric_attribute="metric_c")
batch_log = result.get_batch_log_metrics()
batch_expected = {"a_step": i, "a": i, "c": i}
assert set(batch_log.keys()) == set(batch_expected.keys())
for k in batch_expected.keys():
assert batch_expected[k] == batch_log[k]
batch_log = result.metrics(True)[MetricSource.LOG]
assert batch_log == {"a_step": i, "c": i}
epoch_log = result.get_epoch_log_metrics()
epoch_log = result.metrics(False)[MetricSource.LOG]
result.reset()
# assert metric state reset to default values
@ -134,8 +124,54 @@ def test_result_metric_integration():
assert metric_b.x == metric_b._defaults['x']
assert metric_c.x == metric_c._defaults['x']
epoch_expected = {"b": cumulative_sum, "a_epoch": cumulative_sum}
assert epoch_log == {"b": cumulative_sum, "a_epoch": cumulative_sum}
assert set(epoch_log.keys()) == set(epoch_expected.keys())
for k in epoch_expected.keys():
assert epoch_expected[k] == epoch_log[k]
assert str(result) == (
"ResultCollection(True, cpu, {"
"'h.a': ResultMetric(value=DummyMetric()), "
"'h.b': ResultMetric(value=DummyMetric()), "
"'h.c': ResultMetric(value=DummyMetric())"
"})"
)
def test_result_collection_simple_loop():
result = ResultCollection(True, torch.device("cpu"))
current_fx_name = None
batch_idx = None
def lightning_log(fx, *args, **kwargs):
nonlocal current_fx_name
if current_fx_name != fx and batch_idx in (None, 0):
result.reset(metrics=False, fx=fx)
result.log(fx, *args, **kwargs)
current_fx_name = fx
lightning_log('a0', 'a', torch.tensor(0.), on_step=True, on_epoch=True)
lightning_log('a1', 'a', torch.tensor(0.), on_step=True, on_epoch=True)
for epoch in range(2):
lightning_log('b0', 'a', torch.tensor(1.) + epoch, on_step=True, on_epoch=True)
lightning_log('b1', 'a', torch.tensor(1.) + epoch, on_step=True, on_epoch=True)
for batch_idx in range(2):
lightning_log('c0', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True)
lightning_log('c1', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True)
lightning_log('c2', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True)
batch_idx = None
lightning_log('d0', 'a', torch.tensor(3.) + epoch, on_step=False, on_epoch=True)
lightning_log('d1', 'a', torch.tensor(3.) + epoch, on_step=False, on_epoch=True)
for k in ('a0.a', 'a1.a'):
assert result[k].value == torch.tensor(0.), k
assert result[k].cumulated_batch_size == torch.tensor(1.), k
for k in ('b0.a', 'b1.a'):
assert result[k].value == torch.tensor(1.) + epoch, k
assert result[k].cumulated_batch_size == torch.tensor(1.), k
for k in ('c0.a', 'c1.a', 'c2.a'):
assert result[k].value == torch.tensor(4.) + epoch * 2, k
assert result[k].cumulated_batch_size == torch.tensor(2.), k
for k in ('d0.a', 'd1.a'):
assert result[k].value == torch.tensor(3.) + epoch, k
assert result[k].cumulated_batch_size == torch.tensor(1.), k

View File

@ -20,7 +20,9 @@ import torch.distributed as dist
import torch.multiprocessing as mp
import tests.helpers.utils as tutils
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning import Trainer
from pytorch_lightning.trainer.connectors.logger_connector.result_new import _Sync
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
from tests.helpers import BoringDataModule, BoringModel
from tests.helpers.runif import RunIf
@ -37,7 +39,8 @@ def _setup_ddp(rank, worldsize):
def _ddp_test_fn(rank, worldsize):
_setup_ddp(rank, worldsize)
tensor = torch.tensor([1.0])
actual = LightningModule._LightningModule__sync(tensor, sync_dist=True, sync_dist_op=torch.distributed.ReduceOp.SUM)
sync = _Sync(sync_ddp_if_available, should=True, op=torch.distributed.ReduceOp.SUM)
actual = sync(tensor)
assert actual.item() == dist.get_world_size(), "Result-Log does not work properly with DDP and Tensors"

View File

@ -21,10 +21,11 @@ from torch.utils.data import DataLoader
import tests.helpers.pipelines as tpipes
import tests.helpers.utils as tutils
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning import Trainer
from pytorch_lightning.accelerators import TPUAccelerator
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.plugins import TPUSpawnPlugin
from pytorch_lightning.trainer.connectors.logger_connector.result_new import _Sync
from pytorch_lightning.utilities import _TPU_AVAILABLE
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -424,12 +425,9 @@ def test_tpu_sync_dist():
"""Test tpu spawn sync dist operation """
def test_sync_dist(_):
value = LightningModule._LightningModule__sync(
torch.tensor([1.0]),
sync_fn=TPUSpawnPlugin().reduce,
sync_dist=True,
sync_dist_op=torch.distributed.ReduceOp.SUM
)
sync = _Sync(TPUSpawnPlugin().reduce, should=True, op=torch.distributed.ReduceOp.SUM)
value = torch.tensor([1.0])
value = sync(value),
assert value.item() == 8
xmp.spawn(test_sync_dist, nprocs=8, start_method='fork')

View File

@ -29,6 +29,7 @@ from pytorch_lightning.trainer import Trainer
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator
from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder
from pytorch_lightning.trainer.connectors.logger_connector.result import Result
from pytorch_lightning.trainer.connectors.logger_connector.result_new import MetricSource, ResultCollection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel, RandomDataset
from tests.helpers.runif import RunIf
@ -108,8 +109,8 @@ def test__logger_connector__epoch_result_store__train(tmpdir):
assert train_results.has_reduced is True
generated = train_results(fx_name="training_step", dl_idx=0, opt_idx=0, reduced=True)['train_loss_epoch'].item()
excepted = torch.stack(model.train_losses).mean().item()
assert generated == excepted
expected = torch.stack(model.train_losses).mean().item()
assert generated == expected
def test__logger_connector__epoch_result_store__train__tbptt(tmpdir):
@ -453,7 +454,7 @@ def test_metrics_holder(to_float, tmpdir):
def is_float(value: Any) -> bool:
return isinstance(value, float)
excepted_function = is_float if to_float else torch.is_tensor
expected_function = is_float if to_float else torch.is_tensor
targets = torch.tensor([1], device=device)
acc = Accuracy().to(device)
metric_holder = MetricsHolder(to_float=to_float)
@ -464,9 +465,9 @@ def test_metrics_holder(to_float, tmpdir):
})
metric_holder.convert(device)
metrics = metric_holder.metrics
assert excepted_function(metrics["x"])
assert excepted_function(metrics["y"])
assert excepted_function(metrics["z"])
assert expected_function(metrics["x"])
assert expected_function(metrics["y"])
assert expected_function(metrics["z"])
def test_metric_holder_raises(tmpdir):
@ -686,3 +687,97 @@ def test_metrics_reset(tmpdir):
trainer.test(model)
_assert_called(model, 'test')
def test_result_collection_on_tensor_with_mean_reduction():
result_collection = ResultCollection(True, torch.device("cpu"))
product = [(True, True), (False, True), (True, False), (False, False)]
values = torch.arange(1, 10).float() # need to convert to float() due to precision issues using torch 1.4
batches = values * values
for i, v in enumerate(values):
for prog_bar in [False, True]:
for logger in [False, True]:
for on_step, on_epoch in product:
name = "loss"
if on_step:
name += "_on_step"
if on_epoch:
name += "_on_epoch"
if prog_bar:
name += "_prog_bar"
if logger:
name += "_logger"
result_collection.log(
"training_step",
name,
v,
on_step=on_step,
on_epoch=on_epoch,
batch_size=batches[i],
prog_bar=prog_bar,
logger=logger,
)
total_value = sum(values * batches)
total_batches = sum(batches)
assert result_collection["training_step.loss_on_step_on_epoch"].value == total_value
assert result_collection["training_step.loss_on_step_on_epoch"].cumulated_batch_size == total_batches
batch_metrics = result_collection.metrics(True)
max_ = max(values)
assert batch_metrics[MetricSource.PBAR] == {
'loss_on_step_on_epoch_prog_bar_step': max_,
'loss_on_step_on_epoch_prog_bar_logger_step': max_,
'loss_on_step_prog_bar': max_,
'loss_on_step_prog_bar_logger': max_,
}
assert batch_metrics[MetricSource.LOG] == {
'loss_on_step_on_epoch_logger_step': max_,
'loss_on_step_logger': max_,
'loss_on_step_on_epoch_prog_bar_logger_step': max_,
'loss_on_step_prog_bar_logger': max_,
}
assert batch_metrics[MetricSource.CALLBACK] == {
'loss_on_step': max_,
'loss_on_step_logger': max_,
'loss_on_step_on_epoch': max_,
'loss_on_step_on_epoch_logger': max_,
'loss_on_step_on_epoch_logger_step': max_,
'loss_on_step_on_epoch_prog_bar': max_,
'loss_on_step_on_epoch_prog_bar_logger': max_,
'loss_on_step_on_epoch_prog_bar_logger_step': max_,
'loss_on_step_on_epoch_prog_bar_step': max_,
'loss_on_step_on_epoch_step': max_,
'loss_on_step_prog_bar': max_,
'loss_on_step_prog_bar_logger': max_,
}
epoch_metrics = result_collection.metrics(False)
mean = total_value / total_batches
assert epoch_metrics[MetricSource.PBAR] == {
'loss_on_epoch_prog_bar': mean,
'loss_on_epoch_prog_bar_logger': mean,
'loss_on_step_on_epoch_prog_bar_epoch': mean,
'loss_on_step_on_epoch_prog_bar_logger_epoch': mean,
}
assert epoch_metrics[MetricSource.LOG] == {
'loss_on_epoch_logger': mean,
'loss_on_epoch_prog_bar_logger': mean,
'loss_on_step_on_epoch_logger_epoch': mean,
'loss_on_step_on_epoch_prog_bar_logger_epoch': mean
}
assert epoch_metrics[MetricSource.CALLBACK] == {
'loss_on_epoch': mean,
'loss_on_epoch_logger': mean,
'loss_on_epoch_prog_bar': mean,
'loss_on_epoch_prog_bar_logger': mean,
'loss_on_step_on_epoch': mean,
'loss_on_step_on_epoch_epoch': mean,
'loss_on_step_on_epoch_logger': mean,
'loss_on_step_on_epoch_logger_epoch': mean,
'loss_on_step_on_epoch_prog_bar': mean,
'loss_on_step_on_epoch_prog_bar_epoch': mean,
'loss_on_step_on_epoch_prog_bar_logger': mean,
'loss_on_step_on_epoch_prog_bar_logger_epoch': mean
}