New logger connector code (#7882)
* New logger connector code * Update CHANGELOG * Update requirements * Fix import path * Add new suffix * Tests * Minor changes * Rename and reorder * Fix import * Formatting * Fix with seed_everything? * Fix test? * Fix test? * Minor change * Minor changes * Minor changes * Force float * Fix minimal bug * Fix minimal bug * Update with latest changes * Fix import * bad merge * update typing Co-authored-by: tchaton <thomas@grid.ai>
This commit is contained in:
parent
b74f8ac149
commit
b214442e74
|
@ -107,7 +107,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Refactored logging
|
||||
* Renamed and moved `core/step_result.py` to `trainer/connectors/logger_connector/result.py` ([#7736](https://github.com/PyTorchLightning/pytorch-lightning/pull/7736))
|
||||
|
||||
* Dramatically simplify the `LoggerConnector` ([#7882](https://github.com/PyTorchLightning/pytorch-lightning/pull/7882))
|
||||
* `trainer.{logged,progress_bar,callback}_metrics` are now updated on-demand ([#7882](https://github.com/PyTorchLightning/pytorch-lightning/pull/7882))
|
||||
* Completely overhaul the `Result` object in favor of `ResultMetric` ([#7882](https://github.com/PyTorchLightning/pytorch-lightning/pull/7882))
|
||||
* Improve epoch-level reduction time and overall memory usage ([#7882](https://github.com/PyTorchLightning/pytorch-lightning/pull/7882))
|
||||
|
||||
- Moved `ignore_scalar_return_in_dp` warning suppression to the DataParallelPlugin class ([#7421](https://github.com/PyTorchLightning/pytorch-lightning/pull/7421/))
|
||||
|
||||
|
|
|
@ -141,7 +141,7 @@ class LoggerConnector:
|
|||
should_log_every_n_steps = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0
|
||||
return should_log_every_n_steps or self.trainer.should_stop
|
||||
|
||||
def configure_logger(self, logger: LightningLoggerBase) -> None:
|
||||
def configure_logger(self, logger: Union[bool, Iterable, LightningLoggerBase]) -> None:
|
||||
if logger is True:
|
||||
version = os.environ.get('PL_EXP_VERSION', self.trainer.slurm_job_id)
|
||||
|
||||
|
|
|
@ -0,0 +1,311 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from pprint import pprint
|
||||
from typing import Any, Dict, Iterable, Mapping, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.core import memory
|
||||
from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result_new import _METRIC, MetricSource
|
||||
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
|
||||
from pytorch_lightning.utilities import DeviceType
|
||||
from pytorch_lightning.utilities.metrics import metrics_to_scalars
|
||||
from pytorch_lightning.utilities.types import _EVALUATE_OUTPUT
|
||||
|
||||
|
||||
# TODO(@carmocca): Remove `New` suffix
|
||||
class LoggerConnectorNew:
|
||||
|
||||
def __init__(self, trainer: 'pl.Trainer', log_gpu_memory: Optional[str] = None) -> None:
|
||||
self.trainer = trainer
|
||||
self.log_gpu_memory = log_gpu_memory
|
||||
self.eval_loop_results = []
|
||||
self._val_log_step: int = 0
|
||||
self._test_log_step: int = 0
|
||||
self._progress_bar_metrics: Dict[str, float] = {}
|
||||
self._logged_metrics: Dict[str, _METRIC] = {}
|
||||
self._callback_metrics: Dict[str, _METRIC] = {}
|
||||
self._epoch_end_reached = False
|
||||
self._current_fx: Optional[str] = None
|
||||
self._batch_idx: Optional[int] = None
|
||||
self._split_idx: Optional[int] = None
|
||||
|
||||
def on_trainer_init(
|
||||
self,
|
||||
logger: LightningLoggerBase,
|
||||
flush_logs_every_n_steps: int,
|
||||
log_every_n_steps: int,
|
||||
move_metrics_to_cpu: bool,
|
||||
) -> None:
|
||||
self.configure_logger(logger)
|
||||
self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps
|
||||
self.trainer.log_every_n_steps = log_every_n_steps
|
||||
self.trainer.move_metrics_to_cpu = move_metrics_to_cpu
|
||||
|
||||
@property
|
||||
def should_flush_logs(self) -> bool:
|
||||
should_flush = (self.trainer.global_step + 1) % self.trainer.flush_logs_every_n_steps == 0
|
||||
return should_flush or self.trainer.should_stop
|
||||
|
||||
@property
|
||||
def should_update_logs(self) -> bool:
|
||||
should_log_every_n_steps = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0
|
||||
return should_log_every_n_steps or self.trainer.should_stop
|
||||
|
||||
def configure_logger(self, logger: Union[bool, Iterable, LightningLoggerBase]) -> None:
|
||||
if logger is True:
|
||||
version = os.environ.get('PL_EXP_VERSION', self.trainer.slurm_job_id)
|
||||
|
||||
# default logger
|
||||
self.trainer.logger = TensorBoardLogger(
|
||||
save_dir=self.trainer.default_root_dir, version=version, name='lightning_logs'
|
||||
)
|
||||
elif logger is False:
|
||||
self.trainer.logger = None
|
||||
else:
|
||||
if isinstance(logger, Iterable):
|
||||
self.trainer.logger = LoggerCollection(logger)
|
||||
else:
|
||||
self.trainer.logger = logger
|
||||
|
||||
def log_metrics(self, metrics: Dict[str, _METRIC], step: Optional[int] = None) -> None:
|
||||
"""Logs the metric dict passed in.
|
||||
If `step` parameter is None and `step` key is presented is metrics,
|
||||
uses metrics["step"] as a step
|
||||
|
||||
Args:
|
||||
metrics: Metric values
|
||||
step: Step for which metrics should be logged. Default value is `self.global_step` during training or
|
||||
the total validation / test log step count during validation and testing.
|
||||
"""
|
||||
# add gpu memory
|
||||
if self.trainer._device_type == DeviceType.GPU and self.log_gpu_memory:
|
||||
mem_map = memory.get_memory_profile(self.log_gpu_memory)
|
||||
metrics.update(mem_map)
|
||||
|
||||
# turn all tensors to scalars
|
||||
scalar_metrics = metrics_to_scalars(metrics)
|
||||
|
||||
if "step" in scalar_metrics and step is None:
|
||||
step = scalar_metrics.pop("step")
|
||||
|
||||
elif step is None:
|
||||
# added metrics by Lightning for convenience
|
||||
scalar_metrics['epoch'] = self.trainer.current_epoch
|
||||
step = self.trainer.global_step
|
||||
|
||||
# log actual metrics
|
||||
if self.trainer.logger is not None:
|
||||
if self.trainer.is_global_zero:
|
||||
self.trainer.logger.agg_and_log_metrics(scalar_metrics, step=step)
|
||||
self.trainer.logger.save()
|
||||
|
||||
self._logged_metrics.update(scalar_metrics)
|
||||
|
||||
"""
|
||||
Evaluation metric updates
|
||||
"""
|
||||
|
||||
@property
|
||||
def _eval_log_step(self) -> Optional[int]:
|
||||
if self.trainer.state.stage is RunningStage.VALIDATING:
|
||||
return self._val_log_step
|
||||
elif self.trainer.state.stage is RunningStage.TESTING:
|
||||
return self._test_log_step
|
||||
else:
|
||||
return None
|
||||
|
||||
def _increment_eval_log_step(self) -> None:
|
||||
if self.trainer.state.stage is RunningStage.VALIDATING:
|
||||
self._val_log_step += 1
|
||||
elif self.trainer.state.stage is RunningStage.TESTING:
|
||||
self._test_log_step += 1
|
||||
|
||||
def on_evaluation_batch_start(self, batch: Any, batch_idx: int, dataloader_idx: int, num_dataloaders: int) -> None:
|
||||
model = self.trainer.lightning_module
|
||||
# set dataloader_idx only if multiple ones
|
||||
model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None
|
||||
|
||||
# track batch_size
|
||||
self.trainer.result_collection.extract_batch_size(batch)
|
||||
self._batch_idx = batch_idx
|
||||
|
||||
def update_eval_step_metrics(self) -> None:
|
||||
if self.trainer.sanity_checking:
|
||||
return
|
||||
|
||||
# logs user requested information to logger
|
||||
assert not self._epoch_end_reached
|
||||
metrics = self.metrics[MetricSource.LOG]
|
||||
if metrics:
|
||||
self.log_metrics(metrics, step=self._eval_log_step)
|
||||
|
||||
# increment the step even if nothing was logged
|
||||
self._increment_eval_log_step()
|
||||
|
||||
def _prepare_eval_loop_results(self, metrics: Mapping[str, _METRIC]) -> None:
|
||||
if self.trainer.sanity_checking:
|
||||
return
|
||||
|
||||
num_dataloaders = self.trainer.evaluation_loop.num_dataloaders
|
||||
has_been_initialized = len(self.eval_loop_results) == num_dataloaders
|
||||
for dl_idx in range(self.trainer.evaluation_loop.num_dataloaders):
|
||||
# remove callback metrics that don't belong to this dataloader
|
||||
callback_metrics = {
|
||||
k: v
|
||||
for k, v in metrics.items() if "dataloader_idx" not in k or f"dataloader_idx_{dl_idx}" in k
|
||||
}
|
||||
if has_been_initialized:
|
||||
self.eval_loop_results[dl_idx].update(callback_metrics)
|
||||
else:
|
||||
self.eval_loop_results.append(callback_metrics)
|
||||
|
||||
def update_eval_epoch_metrics(self) -> _EVALUATE_OUTPUT:
|
||||
assert self._epoch_end_reached
|
||||
metrics = self.metrics
|
||||
|
||||
if not self.trainer.sanity_checking:
|
||||
# log all the metrics as a single dict
|
||||
log_metrics = metrics[MetricSource.LOG]
|
||||
if log_metrics:
|
||||
self.log_metrics(log_metrics)
|
||||
|
||||
self._prepare_eval_loop_results(metrics[MetricSource.CALLBACK])
|
||||
|
||||
# log results of evaluation
|
||||
if (
|
||||
self.trainer.state.fn != TrainerFn.FITTING and self.trainer.evaluating and self.trainer.is_global_zero
|
||||
and self.trainer.verbose_evaluate
|
||||
):
|
||||
print('-' * 80)
|
||||
for result_idx, results in enumerate(self.eval_loop_results):
|
||||
print(f'DATALOADER:{result_idx} {self.trainer.state.stage.upper()} RESULTS')
|
||||
pprint({
|
||||
k: (v.item() if v.numel() == 1 else v.tolist()) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in results.items()
|
||||
})
|
||||
print('-' * 80)
|
||||
|
||||
results = self.eval_loop_results
|
||||
# clear mem
|
||||
self.eval_loop_results = []
|
||||
return results
|
||||
|
||||
"""
|
||||
Train metric updates
|
||||
"""
|
||||
|
||||
def on_train_split_start(self, batch_idx: int, split_idx: int, split_batch: Any) -> None:
|
||||
self.trainer.results.extract_batch_size(split_batch)
|
||||
self._batch_idx = batch_idx
|
||||
self._split_idx = split_idx
|
||||
|
||||
def update_train_step_metrics(self) -> None:
|
||||
if self.trainer.train_loop.should_accumulate() and self.trainer.lightning_module.automatic_optimization:
|
||||
return
|
||||
|
||||
# when metrics should be logged
|
||||
assert not self._epoch_end_reached
|
||||
metrics = self.metrics[MetricSource.LOG]
|
||||
if self.should_update_logs or self.trainer.fast_dev_run is True and metrics:
|
||||
self.log_metrics(metrics)
|
||||
|
||||
def update_train_epoch_metrics(self) -> None:
|
||||
# add the metrics to the loggers
|
||||
assert self._epoch_end_reached
|
||||
metrics = self.metrics[MetricSource.LOG]
|
||||
if metrics:
|
||||
self.log_metrics(metrics)
|
||||
|
||||
# reset result collection for next epoch
|
||||
self.trainer.results.reset(metrics=True)
|
||||
|
||||
"""
|
||||
Utilities and properties
|
||||
"""
|
||||
|
||||
def on_epoch_start(self) -> None:
|
||||
self._epoch_end_reached = False
|
||||
|
||||
def on_batch_start(self) -> None:
|
||||
self._epoch_end_reached = False
|
||||
|
||||
def epoch_end_reached(self):
|
||||
self.trainer.logger_connector._epoch_end_reached = True
|
||||
self.trainer.logger_connector._batch_idx = None
|
||||
self.trainer.logger_connector._split_idx = None
|
||||
|
||||
def on_epoch_end(self) -> None:
|
||||
assert self._epoch_end_reached
|
||||
metrics = self.metrics
|
||||
self._progress_bar_metrics.update(metrics[MetricSource.PBAR])
|
||||
self._callback_metrics.update(metrics[MetricSource.CALLBACK])
|
||||
self._logged_metrics.update(metrics[MetricSource.LOG])
|
||||
self._current_fx = None
|
||||
|
||||
def on_batch_end(self) -> None:
|
||||
assert not self._epoch_end_reached
|
||||
metrics = self.metrics
|
||||
self._progress_bar_metrics.update(metrics[MetricSource.PBAR])
|
||||
self._callback_metrics.update(metrics[MetricSource.CALLBACK])
|
||||
self._logged_metrics.update(metrics[MetricSource.LOG])
|
||||
|
||||
def should_reset_tensors(self, fx: str) -> bool:
|
||||
is_different_fx = self._current_fx != fx
|
||||
if self._split_idx is None:
|
||||
is_first_batch = self._batch_idx in (None, 0)
|
||||
else:
|
||||
is_first_batch = self._batch_idx + self._split_idx == 0
|
||||
return is_different_fx and is_first_batch
|
||||
|
||||
def reset(self, metrics: Optional[bool] = None) -> None:
|
||||
self.trainer.results.reset(metrics=metrics)
|
||||
self._batch_idx = None
|
||||
self._split_idx = None
|
||||
self._current_fx = None
|
||||
|
||||
@property
|
||||
def metrics(self) -> Dict[MetricSource, Dict[str, _METRIC]]:
|
||||
"""This function returns either batch or epoch metrics depending on ``_epoch_end_reached``."""
|
||||
on_step = not self._epoch_end_reached
|
||||
return self.trainer.results.metrics(on_step)
|
||||
|
||||
@property
|
||||
def callback_metrics(self) -> Dict[str, _METRIC]:
|
||||
if self.trainer.results:
|
||||
metrics = self.metrics[MetricSource.CALLBACK]
|
||||
self._callback_metrics.update(metrics)
|
||||
return self._callback_metrics
|
||||
|
||||
@property
|
||||
def logged_metrics(self) -> Dict[str, _METRIC]:
|
||||
if self.trainer.results:
|
||||
metrics = self.metrics[MetricSource.LOG]
|
||||
self._logged_metrics.update(metrics)
|
||||
return self._logged_metrics
|
||||
|
||||
@property
|
||||
def progress_bar_metrics(self) -> Dict[str, float]:
|
||||
if self.trainer.results:
|
||||
metrics = self.metrics[MetricSource.PBAR]
|
||||
self._progress_bar_metrics.update(metrics)
|
||||
return self._progress_bar_metrics
|
||||
|
||||
def teardown(self):
|
||||
self.trainer.train_loop.results.cpu()
|
||||
self.trainer.evaluation_loop._val_results.cpu()
|
||||
self.trainer.evaluation_loop._test_results.cpu()
|
|
@ -0,0 +1,499 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial, wraps
|
||||
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torchmetrics import Metric
|
||||
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections
|
||||
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
|
||||
from pytorch_lightning.utilities.enums import LightningEnum
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.metrics import metrics_to_scalars
|
||||
|
||||
# re-define the ones from pytorch_lightning.utilities.types without the `Number` type
|
||||
_METRIC = Union[Metric, torch.Tensor]
|
||||
_METRIC_COLLECTION = Union[_METRIC, Mapping[str, _METRIC]]
|
||||
|
||||
|
||||
class MetricSource(LightningEnum):
|
||||
CALLBACK = "callback"
|
||||
PBAR = "pbar"
|
||||
LOG = "log"
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Sync:
|
||||
fn: Callable
|
||||
should: bool = False
|
||||
op: Union[Any, str] = 'mean'
|
||||
group: Optional[Any] = None
|
||||
|
||||
@property
|
||||
def __call__(self) -> Any:
|
||||
return partial(self.fn, reduce_op=self.op, group=self.group) if self.should else self.no_op
|
||||
|
||||
@staticmethod
|
||||
def no_op(value: Any, *_, **__) -> Any:
|
||||
return value
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Metadata:
|
||||
fx: str
|
||||
name: str
|
||||
prog_bar: bool = False
|
||||
logger: bool = True
|
||||
on_step: bool = False
|
||||
on_epoch: bool = True
|
||||
reduce_fx: Callable = torch.mean
|
||||
enable_graph: bool = False
|
||||
dataloader_idx: Optional[int] = None
|
||||
metric_attribute: Optional[str] = None
|
||||
sync: _Sync = field(default_factory=_Sync)
|
||||
|
||||
@property
|
||||
def forked(self) -> bool:
|
||||
return self.on_step and self.on_epoch
|
||||
|
||||
def forked_name(self, on_step: bool) -> str:
|
||||
if self.forked:
|
||||
return f'{self.name}_{"step" if on_step else "epoch"}'
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def is_mean_reduction(self) -> bool:
|
||||
return self.reduce_fx == torch.mean
|
||||
|
||||
@property
|
||||
def is_max_reduction(self) -> bool:
|
||||
return self.reduce_fx in (torch.max, max)
|
||||
|
||||
@property
|
||||
def is_min_reduction(self) -> bool:
|
||||
return self.reduce_fx in (torch.min, min)
|
||||
|
||||
@property
|
||||
def is_custom_reduction(self) -> bool:
|
||||
return not (self.is_mean_reduction or self.is_max_reduction or self.is_min_reduction)
|
||||
|
||||
|
||||
class ResultMetric(Metric, DeviceDtypeModuleMixin):
|
||||
"""Wraps the value provided to `:meth:`~pytorch_lightning.core.lightning.LightningModule.log`"""
|
||||
|
||||
def __init__(self, metadata: _Metadata, is_tensor: bool) -> None:
|
||||
super().__init__()
|
||||
self.is_tensor = is_tensor
|
||||
self.meta = metadata
|
||||
self.has_reset = False
|
||||
if is_tensor:
|
||||
self.add_state("value", torch.tensor(0, dtype=torch.float))
|
||||
if self.meta.is_mean_reduction:
|
||||
self.add_state("cumulated_batch_size", torch.tensor(0, dtype=torch.float))
|
||||
|
||||
def update(self, value: _METRIC, batch_size: torch.Tensor) -> None:
|
||||
if self.is_tensor:
|
||||
value = value.float()
|
||||
self._forward_cache = value
|
||||
# performance: no need to accumulate on values only logged on_step
|
||||
if self.meta.on_step and not self.meta.on_epoch:
|
||||
self.value = self.meta.sync(value)
|
||||
return
|
||||
# perform accumulation with reduction
|
||||
if self.meta.is_mean_reduction:
|
||||
self.value += value.mean() * batch_size
|
||||
self.cumulated_batch_size += batch_size
|
||||
elif self.meta.is_max_reduction or self.meta.is_min_reduction:
|
||||
self.value = self.meta.reduce_fx(self.value, value.mean())
|
||||
else:
|
||||
self.value = value # noqa: attribute-defined-outside-init
|
||||
self._forward_cache = value._forward_cache
|
||||
|
||||
def compute(self) -> torch.Tensor:
|
||||
if self.is_tensor:
|
||||
value = self.meta.sync(self.value)
|
||||
if self.meta.is_mean_reduction:
|
||||
cumulated_batch_size = self.meta.sync(self.cumulated_batch_size)
|
||||
return value / cumulated_batch_size
|
||||
elif self.meta.is_max_reduction or self.meta.is_min_reduction:
|
||||
return value
|
||||
raise MisconfigurationException(
|
||||
f"Only [min, max, mean] reductions are supported. Found {self.meta.reduce_fx}"
|
||||
)
|
||||
return self.value.compute()
|
||||
|
||||
def reset(self) -> None:
|
||||
if self.is_tensor:
|
||||
super().reset()
|
||||
else:
|
||||
self.value.reset()
|
||||
self.has_reset = True
|
||||
|
||||
def forward(self, value: _METRIC, batch_size: torch.Tensor) -> None:
|
||||
if self.meta.enable_graph:
|
||||
with torch.no_grad():
|
||||
self.update(value, batch_size)
|
||||
else:
|
||||
# performance: skip the `torch.no_grad` context manager by calling `update` directly
|
||||
self.update(value, batch_size)
|
||||
|
||||
def _wrap_compute(self, compute: Any) -> Any:
|
||||
# Override to avoid syncing - we handle it ourselves.
|
||||
@wraps(compute)
|
||||
def wrapped_func(*args, **kwargs):
|
||||
if not self._update_called:
|
||||
rank_zero_warn(
|
||||
f"The ``compute`` method of metric {self.__class__.__name__}"
|
||||
" was called before the ``update`` method which may lead to errors,"
|
||||
" as metric states have not yet been updated.", UserWarning
|
||||
)
|
||||
|
||||
# return cached value
|
||||
if self._computed is not None:
|
||||
return self._computed
|
||||
self._computed = compute(*args, **kwargs)
|
||||
return self._computed
|
||||
|
||||
return wrapped_func
|
||||
|
||||
def __setattr__(self, key: str, value: Any) -> None:
|
||||
# performance: skip the `torch.nn.Module.__setattr__` checks
|
||||
object.__setattr__(self, key, value)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
state = f"value={self.value}"
|
||||
if self.is_tensor and self.meta.is_mean_reduction:
|
||||
state += f", cumulated_batch_size={self.cumulated_batch_size}"
|
||||
return f"{self.__class__.__name__}({state})"
|
||||
|
||||
|
||||
class ResultMetricCollection(dict):
|
||||
"""
|
||||
Dict wrapper for easy access to metadata.
|
||||
|
||||
All of the leaf items should be instances of
|
||||
:class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetric`
|
||||
with the same metadata.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, metadata: Optional[_Metadata] = None) -> None:
|
||||
super().__init__(*args)
|
||||
self.meta = metadata
|
||||
|
||||
|
||||
class ResultCollection(dict):
|
||||
"""
|
||||
Collection (dictionary) of :class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetric` or
|
||||
:class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetricCollection`
|
||||
|
||||
Example:
|
||||
|
||||
# `device` needs to be provided before logging
|
||||
result = ResultCollection(True, torch.device("cpu"))
|
||||
|
||||
# you can log to a specific collection.
|
||||
# arguments: fx, key, value, metadata
|
||||
result.log('training_step', 'acc', torch.tensor(...), on_step=True, on_epoch=True)
|
||||
result.log('validation_step', 'recall', torch.tensor(...), on_step=True, on_epoch=True)
|
||||
"""
|
||||
|
||||
DATALOADER_SUFFIX = "/dataloader_idx_{}"
|
||||
|
||||
def __init__(self, training: bool, device: Optional[torch.device] = None) -> None:
|
||||
super().__init__()
|
||||
self.training = training
|
||||
self._minimize = None
|
||||
self._batch_size = torch.tensor(1, device=device)
|
||||
self.device: Optional[Union[str, torch.device]] = device
|
||||
self.fx_validator = FxValidator()
|
||||
|
||||
@property
|
||||
def batch_size(self) -> torch.Tensor:
|
||||
# performance: cache the `batch_size` tensor instead of re-creating it
|
||||
return self._batch_size
|
||||
|
||||
@batch_size.setter
|
||||
def batch_size(self, value: int) -> None:
|
||||
self._batch_size = torch.tensor(value, device=self.device)
|
||||
|
||||
@property
|
||||
def minimize(self) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
The :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` loss
|
||||
will be saved as the ``minimize`` attribute.
|
||||
"""
|
||||
return self._minimize
|
||||
|
||||
@minimize.setter
|
||||
def minimize(self, loss: Optional[torch.Tensor]) -> None:
|
||||
if loss is not None:
|
||||
if not isinstance(loss, torch.Tensor):
|
||||
raise ValueError(f"`Result.minimize` must be a `torch.Tensor`, found: {loss}")
|
||||
if loss.grad_fn is None:
|
||||
raise RuntimeError("`Result.minimize` must have a `grad_fn`")
|
||||
self._minimize = loss
|
||||
|
||||
@property
|
||||
def extra(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Extras are any keys other than the loss returned by
|
||||
:meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`
|
||||
"""
|
||||
return self.get('_extra', {})
|
||||
|
||||
@extra.setter
|
||||
def extra(self, extra: Mapping[str, Any]) -> None:
|
||||
|
||||
def check_fn(v):
|
||||
if v.grad_fn is not None:
|
||||
raise MisconfigurationException(f'You returned a tensor with `grad_fn`. The extra values are {extra}')
|
||||
|
||||
apply_to_collection(extra, torch.Tensor, check_fn)
|
||||
self['_extra'] = extra
|
||||
|
||||
def log(
|
||||
self,
|
||||
fx: str,
|
||||
name: str,
|
||||
value: _METRIC_COLLECTION,
|
||||
prog_bar: bool = False,
|
||||
logger: bool = True,
|
||||
on_step: bool = False,
|
||||
on_epoch: bool = True,
|
||||
reduce_fx: Callable = torch.mean,
|
||||
enable_graph: bool = False,
|
||||
sync_dist: bool = False,
|
||||
sync_dist_fn: Callable = _Sync.no_op,
|
||||
sync_dist_op: Union[Any, str] = 'mean',
|
||||
sync_dist_group: Optional[Any] = None,
|
||||
dataloader_idx: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
metric_attribute: Optional[str] = None,
|
||||
) -> None:
|
||||
"""See :meth:`~pytorch_lightning.core.lightning.LightningModule.log`"""
|
||||
# no metrics should be logged with graphs
|
||||
if not enable_graph and isinstance(value, torch.Tensor):
|
||||
value = value.detach()
|
||||
|
||||
# move metrics to cpu on TPU.
|
||||
if isinstance(value, torch.Tensor) and value.device.type == "xla":
|
||||
value = value.cpu()
|
||||
|
||||
# storage key
|
||||
key = f"{fx}.{name}"
|
||||
# add dataloader_suffix to both key and fx
|
||||
if dataloader_idx is not None:
|
||||
key += f'.{dataloader_idx}'
|
||||
fx += f'.{dataloader_idx}'
|
||||
|
||||
meta = _Metadata(
|
||||
fx=fx,
|
||||
name=name,
|
||||
prog_bar=prog_bar,
|
||||
logger=logger,
|
||||
on_step=on_step,
|
||||
on_epoch=on_epoch,
|
||||
reduce_fx=reduce_fx,
|
||||
enable_graph=enable_graph,
|
||||
dataloader_idx=dataloader_idx,
|
||||
metric_attribute=metric_attribute,
|
||||
sync=_Sync(
|
||||
should=sync_dist,
|
||||
fn=sync_dist_fn,
|
||||
op=sync_dist_op,
|
||||
group=sync_dist_group,
|
||||
)
|
||||
)
|
||||
if key not in self:
|
||||
if meta.is_custom_reduction:
|
||||
raise MisconfigurationException(
|
||||
'Only `self.log(..., reduce_fx={min,max,mean})` are currently supported.'
|
||||
' Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`'
|
||||
)
|
||||
self.register_key(key, meta, value)
|
||||
elif meta != self[key].meta:
|
||||
raise MisconfigurationException(
|
||||
f'You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed'
|
||||
)
|
||||
|
||||
if batch_size is not None:
|
||||
self.batch_size = batch_size
|
||||
|
||||
self.update_metrics(key, value)
|
||||
|
||||
def register_key(self, key: str, meta: _Metadata, value: _METRIC_COLLECTION) -> None:
|
||||
"""Create one ResultMetric object per value. Value can be provided as a nested collection"""
|
||||
|
||||
def fn(v: _METRIC) -> ResultMetric:
|
||||
metric = ResultMetric(meta, isinstance(v, torch.Tensor))
|
||||
return metric.to(self.device)
|
||||
|
||||
value = apply_to_collection(value, (torch.Tensor, Metric), fn)
|
||||
if isinstance(value, dict):
|
||||
value = ResultMetricCollection(value, metadata=meta)
|
||||
self[key] = value
|
||||
|
||||
def update_metrics(self, key: str, value: _METRIC_COLLECTION) -> None:
|
||||
|
||||
def fn(result_metric, v):
|
||||
# performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl`
|
||||
result_metric.forward(v.to(self.device), self.batch_size)
|
||||
result_metric.has_reset = False
|
||||
|
||||
apply_to_collections(self[key], value, ResultMetric, fn)
|
||||
|
||||
@staticmethod
|
||||
def _get_cache(result_metric: ResultMetric, on_step: bool) -> Optional[torch.Tensor]:
|
||||
cache = None
|
||||
if on_step and result_metric.meta.on_step:
|
||||
cache = result_metric._forward_cache
|
||||
elif not on_step and result_metric.meta.on_epoch:
|
||||
if not result_metric._computed:
|
||||
result_metric.compute()
|
||||
cache = result_metric._computed
|
||||
if cache is not None and not result_metric.meta.enable_graph:
|
||||
return cache.detach()
|
||||
return cache
|
||||
|
||||
def valid_items(self) -> Generator:
|
||||
"""This function is used to iterate over current valid metrics."""
|
||||
return ((k, v) for k, v in self.items()
|
||||
if not k == "_extra" and not (isinstance(v, ResultMetric) and v.has_reset))
|
||||
|
||||
def _forked_name(self, result_metric: ResultMetric, on_step: bool) -> Tuple[str, str]:
|
||||
name = result_metric.meta.name
|
||||
forked_name = result_metric.meta.forked_name(on_step)
|
||||
dl_idx = result_metric.meta.dataloader_idx
|
||||
if dl_idx is not None:
|
||||
dataloader_suffix = self.DATALOADER_SUFFIX.format(dl_idx)
|
||||
name += dataloader_suffix
|
||||
forked_name += dataloader_suffix
|
||||
return name, forked_name
|
||||
|
||||
def metrics(self, on_step: bool) -> Dict[MetricSource, Dict[str, _METRIC]]:
|
||||
metrics = {k: {} for k in MetricSource}
|
||||
|
||||
for key, result_metric in self.valid_items():
|
||||
|
||||
# extract forward_cache or computed from the ResultMetric. ignore when the output is None
|
||||
value = apply_to_collection(result_metric, ResultMetric, self._get_cache, on_step, include_none=False)
|
||||
|
||||
# check if the collection is empty
|
||||
has_tensor = False
|
||||
|
||||
def any_tensor(_):
|
||||
nonlocal has_tensor
|
||||
has_tensor = True
|
||||
|
||||
apply_to_collection(value, torch.Tensor, any_tensor)
|
||||
if not has_tensor:
|
||||
continue
|
||||
|
||||
name, forked_name = self._forked_name(result_metric, on_step)
|
||||
|
||||
# populate logging metrics
|
||||
if result_metric.meta.logger:
|
||||
metrics[MetricSource.LOG][forked_name] = value
|
||||
|
||||
# populate callback metrics. callback metrics don't take `_step` forked metrics
|
||||
if self.training or result_metric.meta.on_epoch and not on_step:
|
||||
metrics[MetricSource.CALLBACK][name] = value
|
||||
metrics[MetricSource.CALLBACK][forked_name] = value
|
||||
|
||||
# populate progress_bar metrics. convert tensors to numbers
|
||||
if result_metric.meta.prog_bar:
|
||||
metrics[MetricSource.PBAR][forked_name] = metrics_to_scalars(value)
|
||||
|
||||
return metrics
|
||||
|
||||
def reset(self, metrics: Optional[bool] = None, fx: Optional[str] = None) -> None:
|
||||
"""
|
||||
Reset the result collection
|
||||
|
||||
Args:
|
||||
metrics: If True, only ``torchmetrics.Metric`` results are reset,
|
||||
if False, only ``torch.Tensors`` are reset,
|
||||
if ``None``, both are.
|
||||
fx: Function to reset
|
||||
"""
|
||||
|
||||
def fn(item: ResultMetric) -> None:
|
||||
requested_type = metrics is None or metrics ^ item.is_tensor
|
||||
same_fx = fx is None or fx == item.meta.fx
|
||||
if requested_type and same_fx:
|
||||
item.reset()
|
||||
|
||||
apply_to_collection(self, ResultMetric, fn)
|
||||
|
||||
def extract_batch_size(self, batch: Any) -> None:
|
||||
try:
|
||||
self.batch_size = self._extract_batch_size(batch)
|
||||
except RecursionError:
|
||||
self.batch_size = 1
|
||||
|
||||
def _extract_batch_size(self, batch: Any) -> int:
|
||||
"""
|
||||
Recursively unpack a batch to find a torch.Tensor.
|
||||
|
||||
Returns:
|
||||
``len(tensor)`` when found, or ``1`` when it hits an empty or non iterable.
|
||||
"""
|
||||
if isinstance(batch, torch.Tensor):
|
||||
size = batch.size(0)
|
||||
elif isinstance(batch, str):
|
||||
return len(batch)
|
||||
elif isinstance(batch, dict):
|
||||
sample = next(iter(batch.values()), 1)
|
||||
size = self._extract_batch_size(sample)
|
||||
elif isinstance(batch, Iterable):
|
||||
sample = next(iter(batch), 1)
|
||||
size = self._extract_batch_size(sample)
|
||||
else:
|
||||
size = 1
|
||||
return size
|
||||
|
||||
def to(self, *args, **kwargs) -> 'ResultCollection':
|
||||
"""Move all data to the given device."""
|
||||
|
||||
def to_(item: Union[torch.Tensor, Metric], *args: Any, **kwargs: Any) -> Union[torch.Tensor, Metric]:
|
||||
return item.to(*args, **kwargs)
|
||||
|
||||
apply_to_collection(self, (torch.Tensor, Metric), to_, *args, **kwargs)
|
||||
|
||||
if self.minimize is not None:
|
||||
self.minimize = self.minimize.to(*args, **kwargs)
|
||||
self._batch_size = self._batch_size.to(*args, **kwargs)
|
||||
if 'device' in kwargs:
|
||||
self.device = kwargs['device']
|
||||
return self
|
||||
|
||||
def cpu(self) -> 'ResultCollection':
|
||||
"""Move all data to CPU."""
|
||||
return self.to(device="cpu")
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'{self.__class__.__name__}({self.training}, {self.device}, {repr(self)})'
|
||||
|
||||
def __getstate__(self) -> dict:
|
||||
d = self.__dict__.copy()
|
||||
# can't deepcopy tensors with grad_fn
|
||||
minimize = d.get('_minimize')
|
||||
if minimize is not None:
|
||||
d['_minimize'] = minimize.detach()
|
||||
return d
|
|
@ -7,7 +7,7 @@ tqdm>=4.41.0
|
|||
PyYAML>=5.1,<=5.4.1
|
||||
fsspec[http]>=2021.05.0, !=2021.06.0
|
||||
tensorboard>=2.2.0, !=2.5.0 # 2.5.0 GPU CI error: 'Couldn't build proto file into descriptor pool!'
|
||||
torchmetrics>=0.2.0
|
||||
torchmetrics>=0.3.2
|
||||
pyDeprecate==0.3.1
|
||||
packaging
|
||||
typing-extensions # TypedDict support for python<3.8
|
||||
|
|
|
@ -11,14 +11,13 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from torchmetrics import Metric
|
||||
|
||||
import tests.helpers.utils as tutils
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import Result
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result_new import MetricSource, ResultCollection
|
||||
from tests.helpers.runif import RunIf
|
||||
|
||||
|
||||
|
@ -52,12 +51,14 @@ def _ddp_test_fn(rank, worldsize):
|
|||
metric_b = DummyMetric()
|
||||
metric_c = DummyMetric()
|
||||
|
||||
# dist_sync_on_step is False by default
|
||||
result = Result()
|
||||
metric_a = metric_a.to(f"cuda:{rank}")
|
||||
metric_b = metric_b.to(f"cuda:{rank}")
|
||||
metric_c = metric_c.to(f"cuda:{rank}")
|
||||
|
||||
for epoch in range(3):
|
||||
result = ResultCollection(True, torch.device(f"cuda:{rank}"))
|
||||
|
||||
for _ in range(3):
|
||||
cumulative_sum = 0
|
||||
|
||||
for i in range(5):
|
||||
metric_a(i)
|
||||
metric_b(i)
|
||||
|
@ -65,32 +66,25 @@ def _ddp_test_fn(rank, worldsize):
|
|||
|
||||
cumulative_sum += i
|
||||
|
||||
result.log('a', metric_a, on_step=True, on_epoch=True)
|
||||
result.log('b', metric_b, on_step=False, on_epoch=True)
|
||||
result.log('c', metric_c, on_step=True, on_epoch=False)
|
||||
result.log('h', 'a', metric_a, on_step=True, on_epoch=True, metric_attribute="metric_a")
|
||||
result.log('h', 'b', metric_b, on_step=False, on_epoch=True, metric_attribute="metric_b")
|
||||
result.log('h', 'c', metric_c, on_step=True, on_epoch=False, metric_attribute="metric_c")
|
||||
|
||||
batch_log = result.get_batch_log_metrics()
|
||||
batch_expected = {"a_step": i, "a": i, "c": i}
|
||||
assert set(batch_log.keys()) == set(batch_expected.keys())
|
||||
for k in batch_expected.keys():
|
||||
assert batch_expected[k] == batch_log[k]
|
||||
batch_log = result.metrics(True)[MetricSource.LOG]
|
||||
assert batch_log == {"a_step": i, "c": i}
|
||||
|
||||
epoch_log = result.get_epoch_log_metrics()
|
||||
epoch_log = result.metrics(False)[MetricSource.LOG]
|
||||
result.reset()
|
||||
|
||||
# assert metric state reset to default values
|
||||
assert metric_a.x == metric_a._defaults['x']
|
||||
assert metric_a.x == metric_a._defaults['x'], (metric_a.x, metric_a._defaults['x'])
|
||||
assert metric_b.x == metric_b._defaults['x']
|
||||
assert metric_c.x == metric_c._defaults['x']
|
||||
|
||||
epoch_expected = {"b": cumulative_sum * worldsize, "a_epoch": cumulative_sum * worldsize}
|
||||
|
||||
assert set(epoch_log.keys()) == set(epoch_expected.keys())
|
||||
for k in epoch_expected.keys():
|
||||
assert epoch_expected[k] == epoch_log[k]
|
||||
assert epoch_log == {"b": cumulative_sum * worldsize, "a_epoch": cumulative_sum * worldsize}
|
||||
|
||||
|
||||
@RunIf(skip_windows=True)
|
||||
@RunIf(skip_windows=True, min_gpus=2)
|
||||
def test_result_reduce_ddp():
|
||||
"""Make sure result logging works with DDP"""
|
||||
tutils.set_random_master_port()
|
||||
|
@ -104,11 +98,10 @@ def test_result_metric_integration():
|
|||
metric_b = DummyMetric()
|
||||
metric_c = DummyMetric()
|
||||
|
||||
result = Result()
|
||||
result = ResultCollection(True, torch.device("cpu"))
|
||||
|
||||
for epoch in range(3):
|
||||
for _ in range(3):
|
||||
cumulative_sum = 0
|
||||
|
||||
for i in range(5):
|
||||
metric_a(i)
|
||||
metric_b(i)
|
||||
|
@ -116,17 +109,14 @@ def test_result_metric_integration():
|
|||
|
||||
cumulative_sum += i
|
||||
|
||||
result.log('a', metric_a, on_step=True, on_epoch=True)
|
||||
result.log('b', metric_b, on_step=False, on_epoch=True)
|
||||
result.log('c', metric_c, on_step=True, on_epoch=False)
|
||||
result.log('h', 'a', metric_a, on_step=True, on_epoch=True, metric_attribute="metric_a")
|
||||
result.log('h', 'b', metric_b, on_step=False, on_epoch=True, metric_attribute="metric_b")
|
||||
result.log('h', 'c', metric_c, on_step=True, on_epoch=False, metric_attribute="metric_c")
|
||||
|
||||
batch_log = result.get_batch_log_metrics()
|
||||
batch_expected = {"a_step": i, "a": i, "c": i}
|
||||
assert set(batch_log.keys()) == set(batch_expected.keys())
|
||||
for k in batch_expected.keys():
|
||||
assert batch_expected[k] == batch_log[k]
|
||||
batch_log = result.metrics(True)[MetricSource.LOG]
|
||||
assert batch_log == {"a_step": i, "c": i}
|
||||
|
||||
epoch_log = result.get_epoch_log_metrics()
|
||||
epoch_log = result.metrics(False)[MetricSource.LOG]
|
||||
result.reset()
|
||||
|
||||
# assert metric state reset to default values
|
||||
|
@ -134,8 +124,54 @@ def test_result_metric_integration():
|
|||
assert metric_b.x == metric_b._defaults['x']
|
||||
assert metric_c.x == metric_c._defaults['x']
|
||||
|
||||
epoch_expected = {"b": cumulative_sum, "a_epoch": cumulative_sum}
|
||||
assert epoch_log == {"b": cumulative_sum, "a_epoch": cumulative_sum}
|
||||
|
||||
assert set(epoch_log.keys()) == set(epoch_expected.keys())
|
||||
for k in epoch_expected.keys():
|
||||
assert epoch_expected[k] == epoch_log[k]
|
||||
assert str(result) == (
|
||||
"ResultCollection(True, cpu, {"
|
||||
"'h.a': ResultMetric(value=DummyMetric()), "
|
||||
"'h.b': ResultMetric(value=DummyMetric()), "
|
||||
"'h.c': ResultMetric(value=DummyMetric())"
|
||||
"})"
|
||||
)
|
||||
|
||||
|
||||
def test_result_collection_simple_loop():
|
||||
result = ResultCollection(True, torch.device("cpu"))
|
||||
current_fx_name = None
|
||||
batch_idx = None
|
||||
|
||||
def lightning_log(fx, *args, **kwargs):
|
||||
nonlocal current_fx_name
|
||||
if current_fx_name != fx and batch_idx in (None, 0):
|
||||
result.reset(metrics=False, fx=fx)
|
||||
result.log(fx, *args, **kwargs)
|
||||
current_fx_name = fx
|
||||
|
||||
lightning_log('a0', 'a', torch.tensor(0.), on_step=True, on_epoch=True)
|
||||
lightning_log('a1', 'a', torch.tensor(0.), on_step=True, on_epoch=True)
|
||||
for epoch in range(2):
|
||||
lightning_log('b0', 'a', torch.tensor(1.) + epoch, on_step=True, on_epoch=True)
|
||||
lightning_log('b1', 'a', torch.tensor(1.) + epoch, on_step=True, on_epoch=True)
|
||||
for batch_idx in range(2):
|
||||
lightning_log('c0', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True)
|
||||
lightning_log('c1', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True)
|
||||
lightning_log('c2', 'a', torch.tensor(2.) + epoch, on_step=True, on_epoch=True)
|
||||
batch_idx = None
|
||||
lightning_log('d0', 'a', torch.tensor(3.) + epoch, on_step=False, on_epoch=True)
|
||||
lightning_log('d1', 'a', torch.tensor(3.) + epoch, on_step=False, on_epoch=True)
|
||||
|
||||
for k in ('a0.a', 'a1.a'):
|
||||
assert result[k].value == torch.tensor(0.), k
|
||||
assert result[k].cumulated_batch_size == torch.tensor(1.), k
|
||||
|
||||
for k in ('b0.a', 'b1.a'):
|
||||
assert result[k].value == torch.tensor(1.) + epoch, k
|
||||
assert result[k].cumulated_batch_size == torch.tensor(1.), k
|
||||
|
||||
for k in ('c0.a', 'c1.a', 'c2.a'):
|
||||
assert result[k].value == torch.tensor(4.) + epoch * 2, k
|
||||
assert result[k].cumulated_batch_size == torch.tensor(2.), k
|
||||
|
||||
for k in ('d0.a', 'd1.a'):
|
||||
assert result[k].value == torch.tensor(3.) + epoch, k
|
||||
assert result[k].cumulated_batch_size == torch.tensor(1.), k
|
||||
|
|
|
@ -20,7 +20,9 @@ import torch.distributed as dist
|
|||
import torch.multiprocessing as mp
|
||||
|
||||
import tests.helpers.utils as tutils
|
||||
from pytorch_lightning import LightningModule, Trainer
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result_new import _Sync
|
||||
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
|
||||
from tests.helpers import BoringDataModule, BoringModel
|
||||
from tests.helpers.runif import RunIf
|
||||
|
||||
|
@ -37,7 +39,8 @@ def _setup_ddp(rank, worldsize):
|
|||
def _ddp_test_fn(rank, worldsize):
|
||||
_setup_ddp(rank, worldsize)
|
||||
tensor = torch.tensor([1.0])
|
||||
actual = LightningModule._LightningModule__sync(tensor, sync_dist=True, sync_dist_op=torch.distributed.ReduceOp.SUM)
|
||||
sync = _Sync(sync_ddp_if_available, should=True, op=torch.distributed.ReduceOp.SUM)
|
||||
actual = sync(tensor)
|
||||
assert actual.item() == dist.get_world_size(), "Result-Log does not work properly with DDP and Tensors"
|
||||
|
||||
|
||||
|
|
|
@ -21,10 +21,11 @@ from torch.utils.data import DataLoader
|
|||
|
||||
import tests.helpers.pipelines as tpipes
|
||||
import tests.helpers.utils as tutils
|
||||
from pytorch_lightning import LightningModule, Trainer
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.accelerators import TPUAccelerator
|
||||
from pytorch_lightning.callbacks import EarlyStopping
|
||||
from pytorch_lightning.plugins import TPUSpawnPlugin
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result_new import _Sync
|
||||
from pytorch_lightning.utilities import _TPU_AVAILABLE
|
||||
from pytorch_lightning.utilities.distributed import ReduceOp
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
@ -424,12 +425,9 @@ def test_tpu_sync_dist():
|
|||
"""Test tpu spawn sync dist operation """
|
||||
|
||||
def test_sync_dist(_):
|
||||
value = LightningModule._LightningModule__sync(
|
||||
torch.tensor([1.0]),
|
||||
sync_fn=TPUSpawnPlugin().reduce,
|
||||
sync_dist=True,
|
||||
sync_dist_op=torch.distributed.ReduceOp.SUM
|
||||
)
|
||||
sync = _Sync(TPUSpawnPlugin().reduce, should=True, op=torch.distributed.ReduceOp.SUM)
|
||||
value = torch.tensor([1.0])
|
||||
value = sync(value),
|
||||
assert value.item() == 8
|
||||
|
||||
xmp.spawn(test_sync_dist, nprocs=8, start_method='fork')
|
||||
|
|
|
@ -29,6 +29,7 @@ from pytorch_lightning.trainer import Trainer
|
|||
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import FxValidator
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import Result
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result_new import MetricSource, ResultCollection
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.helpers.boring_model import BoringModel, RandomDataset
|
||||
from tests.helpers.runif import RunIf
|
||||
|
@ -108,8 +109,8 @@ def test__logger_connector__epoch_result_store__train(tmpdir):
|
|||
assert train_results.has_reduced is True
|
||||
|
||||
generated = train_results(fx_name="training_step", dl_idx=0, opt_idx=0, reduced=True)['train_loss_epoch'].item()
|
||||
excepted = torch.stack(model.train_losses).mean().item()
|
||||
assert generated == excepted
|
||||
expected = torch.stack(model.train_losses).mean().item()
|
||||
assert generated == expected
|
||||
|
||||
|
||||
def test__logger_connector__epoch_result_store__train__tbptt(tmpdir):
|
||||
|
@ -453,7 +454,7 @@ def test_metrics_holder(to_float, tmpdir):
|
|||
def is_float(value: Any) -> bool:
|
||||
return isinstance(value, float)
|
||||
|
||||
excepted_function = is_float if to_float else torch.is_tensor
|
||||
expected_function = is_float if to_float else torch.is_tensor
|
||||
targets = torch.tensor([1], device=device)
|
||||
acc = Accuracy().to(device)
|
||||
metric_holder = MetricsHolder(to_float=to_float)
|
||||
|
@ -464,9 +465,9 @@ def test_metrics_holder(to_float, tmpdir):
|
|||
})
|
||||
metric_holder.convert(device)
|
||||
metrics = metric_holder.metrics
|
||||
assert excepted_function(metrics["x"])
|
||||
assert excepted_function(metrics["y"])
|
||||
assert excepted_function(metrics["z"])
|
||||
assert expected_function(metrics["x"])
|
||||
assert expected_function(metrics["y"])
|
||||
assert expected_function(metrics["z"])
|
||||
|
||||
|
||||
def test_metric_holder_raises(tmpdir):
|
||||
|
@ -686,3 +687,97 @@ def test_metrics_reset(tmpdir):
|
|||
|
||||
trainer.test(model)
|
||||
_assert_called(model, 'test')
|
||||
|
||||
|
||||
def test_result_collection_on_tensor_with_mean_reduction():
|
||||
result_collection = ResultCollection(True, torch.device("cpu"))
|
||||
product = [(True, True), (False, True), (True, False), (False, False)]
|
||||
values = torch.arange(1, 10).float() # need to convert to float() due to precision issues using torch 1.4
|
||||
batches = values * values
|
||||
|
||||
for i, v in enumerate(values):
|
||||
for prog_bar in [False, True]:
|
||||
for logger in [False, True]:
|
||||
for on_step, on_epoch in product:
|
||||
name = "loss"
|
||||
if on_step:
|
||||
name += "_on_step"
|
||||
if on_epoch:
|
||||
name += "_on_epoch"
|
||||
if prog_bar:
|
||||
name += "_prog_bar"
|
||||
if logger:
|
||||
name += "_logger"
|
||||
result_collection.log(
|
||||
"training_step",
|
||||
name,
|
||||
v,
|
||||
on_step=on_step,
|
||||
on_epoch=on_epoch,
|
||||
batch_size=batches[i],
|
||||
prog_bar=prog_bar,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
total_value = sum(values * batches)
|
||||
total_batches = sum(batches)
|
||||
assert result_collection["training_step.loss_on_step_on_epoch"].value == total_value
|
||||
assert result_collection["training_step.loss_on_step_on_epoch"].cumulated_batch_size == total_batches
|
||||
|
||||
batch_metrics = result_collection.metrics(True)
|
||||
max_ = max(values)
|
||||
assert batch_metrics[MetricSource.PBAR] == {
|
||||
'loss_on_step_on_epoch_prog_bar_step': max_,
|
||||
'loss_on_step_on_epoch_prog_bar_logger_step': max_,
|
||||
'loss_on_step_prog_bar': max_,
|
||||
'loss_on_step_prog_bar_logger': max_,
|
||||
}
|
||||
assert batch_metrics[MetricSource.LOG] == {
|
||||
'loss_on_step_on_epoch_logger_step': max_,
|
||||
'loss_on_step_logger': max_,
|
||||
'loss_on_step_on_epoch_prog_bar_logger_step': max_,
|
||||
'loss_on_step_prog_bar_logger': max_,
|
||||
}
|
||||
assert batch_metrics[MetricSource.CALLBACK] == {
|
||||
'loss_on_step': max_,
|
||||
'loss_on_step_logger': max_,
|
||||
'loss_on_step_on_epoch': max_,
|
||||
'loss_on_step_on_epoch_logger': max_,
|
||||
'loss_on_step_on_epoch_logger_step': max_,
|
||||
'loss_on_step_on_epoch_prog_bar': max_,
|
||||
'loss_on_step_on_epoch_prog_bar_logger': max_,
|
||||
'loss_on_step_on_epoch_prog_bar_logger_step': max_,
|
||||
'loss_on_step_on_epoch_prog_bar_step': max_,
|
||||
'loss_on_step_on_epoch_step': max_,
|
||||
'loss_on_step_prog_bar': max_,
|
||||
'loss_on_step_prog_bar_logger': max_,
|
||||
}
|
||||
|
||||
epoch_metrics = result_collection.metrics(False)
|
||||
mean = total_value / total_batches
|
||||
assert epoch_metrics[MetricSource.PBAR] == {
|
||||
'loss_on_epoch_prog_bar': mean,
|
||||
'loss_on_epoch_prog_bar_logger': mean,
|
||||
'loss_on_step_on_epoch_prog_bar_epoch': mean,
|
||||
'loss_on_step_on_epoch_prog_bar_logger_epoch': mean,
|
||||
}
|
||||
assert epoch_metrics[MetricSource.LOG] == {
|
||||
'loss_on_epoch_logger': mean,
|
||||
'loss_on_epoch_prog_bar_logger': mean,
|
||||
'loss_on_step_on_epoch_logger_epoch': mean,
|
||||
'loss_on_step_on_epoch_prog_bar_logger_epoch': mean
|
||||
}
|
||||
assert epoch_metrics[MetricSource.CALLBACK] == {
|
||||
'loss_on_epoch': mean,
|
||||
'loss_on_epoch_logger': mean,
|
||||
'loss_on_epoch_prog_bar': mean,
|
||||
'loss_on_epoch_prog_bar_logger': mean,
|
||||
'loss_on_step_on_epoch': mean,
|
||||
'loss_on_step_on_epoch_epoch': mean,
|
||||
'loss_on_step_on_epoch_logger': mean,
|
||||
'loss_on_step_on_epoch_logger_epoch': mean,
|
||||
'loss_on_step_on_epoch_prog_bar': mean,
|
||||
'loss_on_step_on_epoch_prog_bar_epoch': mean,
|
||||
'loss_on_step_on_epoch_prog_bar_logger': mean,
|
||||
'loss_on_step_on_epoch_prog_bar_logger_epoch': mean
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue