From 8508cce37d4f18266715d6bcebed4b382b704380 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 17 Dec 2021 20:35:17 +0100 Subject: [PATCH] Mark all result classes as protected (#11130) --- CHANGELOG.md | 3 + pytorch_lightning/core/lightning.py | 2 +- pytorch_lightning/loops/base.py | 10 +-- .../loops/dataloader/evaluation_loop.py | 6 +- .../loops/epoch/training_epoch_loop.py | 4 +- pytorch_lightning/loops/fit_loop.py | 4 +- .../connectors/logger_connector/result.py | 80 +++++++++---------- pytorch_lightning/trainer/trainer.py | 4 +- tests/core/test_metric_result_integration.py | 38 ++++----- .../trainer/logging_/test_logger_connector.py | 32 ++++---- tests/trainer/logging_/test_loop_logging.py | 6 +- 11 files changed, 96 insertions(+), 93 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d769406c54..9990bff5f2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -129,6 +129,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Moved ownership of the `Accelerator` instance to the `TrainingTypePlugin`; all training-type plugins now take an optional parameter `accelerator` ([#11022](https://github.com/PyTorchLightning/pytorch-lightning/pull/11022)) +- Marked the `ResultCollection`, `ResultMetric`, and `ResultMetricCollection` classes as protected ([#11130](https://github.com/PyTorchLightning/pytorch-lightning/pull/11130)) + + - DeepSpeed does not require lightning module zero 3 partitioning ([#10655](https://github.com/PyTorchLightning/pytorch-lightning/pull/10655)) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index c7b6d1ced3..8c516fc060 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -351,7 +351,7 @@ class LightningModule( results = self.trainer._results if results is None: raise MisconfigurationException( - "You are trying to `self.log()` but the loop `ResultCollection` is not registered" + "You are trying to `self.log()` but the loop's result collection is not registered" " yet. This is most likely because you are trying to log in a `predict` hook," " but it doesn't support logging" ) diff --git a/pytorch_lightning/loops/base.py b/pytorch_lightning/loops/base.py index 6114320bfc..77139186dc 100644 --- a/pytorch_lightning/loops/base.py +++ b/pytorch_lightning/loops/base.py @@ -19,7 +19,7 @@ from deprecate import void from torchmetrics import Metric import pytorch_lightning as pl -from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection +from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection from pytorch_lightning.trainer.progress import BaseProgress from pytorch_lightning.utilities.enums import _FaultTolerantMode from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -282,7 +282,7 @@ class Loop(ABC, Generic[T]): destination[key] = v.state_dict() elif isinstance(v, Loop): v.state_dict(destination, key + ".") - elif isinstance(v, ResultCollection): + elif isinstance(v, _ResultCollection): # sync / unsync metrics v.sync() destination[key] = v.state_dict() @@ -312,7 +312,7 @@ class Loop(ABC, Generic[T]): if isinstance(v, BaseProgress): v.load_state_dict(state_dict[key]) elif ( - isinstance(v, ResultCollection) + isinstance(v, _ResultCollection) and self.trainer is not None and self.trainer.lightning_module is not None ): @@ -324,10 +324,10 @@ class Loop(ABC, Generic[T]): if metrics: metric_attributes.update(metrics) - # The `ResultCollection` objects have 2 types of metrics: `Tensor` and `torchmetrics.Metric`. + # The `_ResultCollection` objects have 2 types of metrics: `Tensor` and `torchmetrics.Metric`. # When creating a checkpoint, the `Metric`s are dropped from the loop `state_dict` to serialize only # Python primitives. However, their states are saved with the model's `state_dict`. - # On reload, we need to re-attach the `Metric`s back to the `ResultCollection`. + # On reload, we need to re-attach the `Metric`s back to the `_ResultCollection`. # The references are provided through the `metric_attributes` dictionary. v.load_state_dict( state_dict[key], metrics=metric_attributes, sync_fn=self.trainer.training_type_plugin.reduce diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index 844c994e68..2954927196 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -19,8 +19,8 @@ from torch.utils.data.dataloader import DataLoader from pytorch_lightning.loops.dataloader import DataLoaderLoop from pytorch_lightning.loops.epoch import EvaluationEpochLoop -from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, ResultCollection -from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection +from pytorch_lightning.trainer.states import RunningStage, TrainerFn from pytorch_lightning.utilities.types import EPOCH_OUTPUT @@ -32,7 +32,7 @@ class EvaluationLoop(DataLoaderLoop): self.epoch_loop = EvaluationEpochLoop() self.verbose = verbose - self._results = ResultCollection(training=False) + self._results = _ResultCollection(training=False) self._outputs: List[EPOCH_OUTPUT] = [] self._logged_outputs: List[_OUT_DICT] = [] self._max_batches: List[int] = [] diff --git a/pytorch_lightning/loops/epoch/training_epoch_loop.py b/pytorch_lightning/loops/epoch/training_epoch_loop.py index c25278988c..2689f12088 100644 --- a/pytorch_lightning/loops/epoch/training_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/training_epoch_loop.py @@ -21,7 +21,7 @@ from pytorch_lightning import loops # import as loops to avoid circular imports from pytorch_lightning.loops.batch import TrainingBatchLoop from pytorch_lightning.loops.batch.training_batch_loop import _OUTPUTS_TYPE as _BATCH_OUTPUTS_TYPE from pytorch_lightning.loops.utilities import _get_active_optimizers, _is_max_limit_reached, _update_dataloader_iter -from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection +from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection from pytorch_lightning.trainer.progress import BatchProgress, SchedulerProgress from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection @@ -65,7 +65,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]): self.batch_loop = TrainingBatchLoop() self.val_loop = loops.EvaluationLoop(verbose=False) - self._results = ResultCollection(training=True) + self._results = _ResultCollection(training=True) self._outputs: _OUTPUTS_TYPE = [] self._warning_cache = WarningCache() self._dataloader_iter: Optional[Iterator] = None diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 0bcd10f916..8f48696926 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -17,7 +17,7 @@ from typing import Optional from pytorch_lightning.loops import Loop from pytorch_lightning.loops.epoch import TrainingEpochLoop from pytorch_lightning.loops.utilities import _is_max_limit_reached -from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection +from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection from pytorch_lightning.trainer.progress import Progress from pytorch_lightning.trainer.supporters import TensorRunningAccum from pytorch_lightning.utilities import rank_zero_deprecation @@ -136,7 +136,7 @@ class FitLoop(Loop): self.epoch_loop.batch_loop.optimizer_loop._skip_backward = value @property - def _results(self) -> ResultCollection: + def _results(self) -> _ResultCollection: if self.trainer.training: return self.epoch_loop._results if self.trainer.validating: diff --git a/pytorch_lightning/trainer/connectors/logger_connector/result.py b/pytorch_lightning/trainer/connectors/logger_connector/result.py index 7dfc4622ce..b92234da19 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/result.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/result.py @@ -199,7 +199,7 @@ class _Metadata: return meta -class ResultMetric(Metric, DeviceDtypeModuleMixin): +class _ResultMetric(Metric, DeviceDtypeModuleMixin): """Wraps the value provided to `:meth:`~pytorch_lightning.core.lightning.LightningModule.log`""" def __init__(self, metadata: _Metadata, is_tensor: bool) -> None: @@ -316,25 +316,25 @@ class ResultMetric(Metric, DeviceDtypeModuleMixin): super().__setstate__(d) @classmethod - def _reconstruct(cls, state: dict, sync_fn: Optional[Callable] = None) -> "ResultMetric": + def _reconstruct(cls, state: dict, sync_fn: Optional[Callable] = None) -> "_ResultMetric": # need to reconstruct twice because `meta` is used in `__init__` meta = _Metadata._reconstruct(state["meta"]) result_metric = cls(meta, state["is_tensor"]) result_metric.__setstate__(state, sync_fn=sync_fn) return result_metric - def to(self, *args: Any, **kwargs: Any) -> "ResultMetric": + def to(self, *args: Any, **kwargs: Any) -> "_ResultMetric": self.__dict__.update( apply_to_collection(self.__dict__, (torch.Tensor, Metric), move_data_to_device, *args, **kwargs) ) return self -class ResultMetricCollection(dict): +class _ResultMetricCollection(dict): """Dict wrapper for easy access to metadata. All of the leaf items should be instances of - :class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetric` + :class:`~pytorch_lightning.trainer.connectors.logger_connector.result._ResultMetric` with the same metadata. """ @@ -347,36 +347,36 @@ class ResultMetricCollection(dict): return any(v.is_tensor for v in self.values()) def __getstate__(self, drop_value: bool = False) -> dict: - def getstate(item: ResultMetric) -> dict: + def getstate(item: _ResultMetric) -> dict: return item.__getstate__(drop_value=drop_value) - items = apply_to_collection(dict(self), ResultMetric, getstate) + items = apply_to_collection(dict(self), _ResultMetric, getstate) return {"items": items, "meta": self.meta.__getstate__(), "_class": self.__class__.__name__} def __setstate__(self, state: dict, sync_fn: Optional[Callable] = None) -> None: # can't use `apply_to_collection` as it does not recurse items of the same type - items = {k: ResultMetric._reconstruct(v, sync_fn=sync_fn) for k, v in state["items"].items()} + items = {k: _ResultMetric._reconstruct(v, sync_fn=sync_fn) for k, v in state["items"].items()} self.update(items) @classmethod - def _reconstruct(cls, state: dict, sync_fn: Optional[Callable] = None) -> "ResultMetricCollection": + def _reconstruct(cls, state: dict, sync_fn: Optional[Callable] = None) -> "_ResultMetricCollection": rmc = cls() rmc.__setstate__(state, sync_fn=sync_fn) return rmc -_METRIC_COLLECTION = Union[_IN_METRIC, ResultMetricCollection] +_METRIC_COLLECTION = Union[_IN_METRIC, _ResultMetricCollection] -class ResultCollection(dict): +class _ResultCollection(dict): """ - Collection (dictionary) of :class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetric` or - :class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetricCollection` + Collection (dictionary) of :class:`~pytorch_lightning.trainer.connectors.logger_connector.result._ResultMetric` or + :class:`~pytorch_lightning.trainer.connectors.logger_connector.result._ResultMetricCollection` Example: # `device` needs to be provided before logging - result = ResultCollection(training=True, torch.device("cpu")) + result = _ResultCollection(training=True, torch.device("cpu")) # you can log to a specific collection. # arguments: fx, key, value, metadata @@ -395,14 +395,14 @@ class ResultCollection(dict): self.dataloader_idx: Optional[int] = None @property - def result_metrics(self) -> List[ResultMetric]: + def result_metrics(self) -> List[_ResultMetric]: o = [] - def append_fn(v: ResultMetric) -> None: + def append_fn(v: _ResultMetric) -> None: nonlocal o o.append(v) - apply_to_collection(list(self.values()), ResultMetric, append_fn) + apply_to_collection(list(self.values()), _ResultMetric, append_fn) return o def _extract_batch_size(self, value: _METRIC_COLLECTION, batch_size: Optional[int], meta: _Metadata) -> int: @@ -414,7 +414,7 @@ class ResultCollection(dict): return batch_size batch_size = 1 - is_tensor = value.is_tensor if isinstance(value, ResultMetric) else value.has_tensor + is_tensor = value.is_tensor if isinstance(value, _ResultMetric) else value.has_tensor if self.batch is not None and is_tensor and meta.on_epoch and meta.is_mean_reduction: batch_size = extract_batch_size(self.batch) self.batch_size = batch_size @@ -485,30 +485,30 @@ class ResultCollection(dict): self.update_metrics(key, value, batch_size) def register_key(self, key: str, meta: _Metadata, value: _METRIC_COLLECTION) -> None: - """Create one ResultMetric object per value. + """Create one _ResultMetric object per value. Value can be provided as a nested collection """ - def fn(v: _IN_METRIC) -> ResultMetric: - metric = ResultMetric(meta, isinstance(v, torch.Tensor)) + def fn(v: _IN_METRIC) -> _ResultMetric: + metric = _ResultMetric(meta, isinstance(v, torch.Tensor)) return metric.to(self.device) value = apply_to_collection(value, (torch.Tensor, Metric), fn) if isinstance(value, dict): - value = ResultMetricCollection(value) + value = _ResultMetricCollection(value) self[key] = value def update_metrics(self, key: str, value: _METRIC_COLLECTION, batch_size: int) -> None: - def fn(result_metric: ResultMetric, v: torch.Tensor) -> None: + def fn(result_metric: _ResultMetric, v: torch.Tensor) -> None: # performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl` result_metric.forward(v.to(self.device), batch_size) result_metric.has_reset = False - apply_to_collections(self[key], value, ResultMetric, fn) + apply_to_collections(self[key], value, _ResultMetric, fn) @staticmethod - def _get_cache(result_metric: ResultMetric, on_step: bool) -> Optional[torch.Tensor]: + def _get_cache(result_metric: _ResultMetric, on_step: bool) -> Optional[torch.Tensor]: cache = None if on_step and result_metric.meta.on_step: cache = result_metric._forward_cache @@ -529,11 +529,11 @@ class ResultCollection(dict): return ( (k, v) for k, v in self.items() - if not (isinstance(v, ResultMetric) and v.has_reset) + if not (isinstance(v, _ResultMetric) and v.has_reset) and self.dataloader_idx in (None, v.meta.dataloader_idx) ) - def _forked_name(self, result_metric: ResultMetric, on_step: bool) -> Tuple[str, str]: + def _forked_name(self, result_metric: _ResultMetric, on_step: bool) -> Tuple[str, str]: name = result_metric.meta.name forked_name = result_metric.meta.forked_name(on_step) add_dataloader_idx = result_metric.meta.add_dataloader_idx @@ -549,11 +549,11 @@ class ResultCollection(dict): for _, result_metric in self.valid_items(): - # extract forward_cache or computed from the ResultMetric. ignore when the output is None - value = apply_to_collection(result_metric, ResultMetric, self._get_cache, on_step, include_none=False) + # extract forward_cache or computed from the _ResultMetric. ignore when the output is None + value = apply_to_collection(result_metric, _ResultMetric, self._get_cache, on_step, include_none=False) # convert metric collection to dict container. - if isinstance(value, ResultMetricCollection): + if isinstance(value, _ResultMetricCollection): value = dict(value.items()) # check if the collection is empty @@ -594,15 +594,15 @@ class ResultCollection(dict): fx: Function to reset """ - def fn(item: ResultMetric) -> None: + def fn(item: _ResultMetric) -> None: requested_type = metrics is None or metrics ^ item.is_tensor same_fx = fx is None or fx == item.meta.fx if requested_type and same_fx: item.reset() - apply_to_collection(self, ResultMetric, fn) + apply_to_collection(self, _ResultMetric, fn) - def to(self, *args: Any, **kwargs: Any) -> "ResultCollection": + def to(self, *args: Any, **kwargs: Any) -> "_ResultCollection": """Move all data to the given device.""" self.update(apply_to_collection(dict(self), (torch.Tensor, Metric), move_data_to_device, *args, **kwargs)) @@ -610,7 +610,7 @@ class ResultCollection(dict): self.device = kwargs["device"] return self - def cpu(self) -> "ResultCollection": + def cpu(self) -> "_ResultCollection": """Move all data to CPU.""" return self.to(device="cpu") @@ -634,7 +634,7 @@ class ResultCollection(dict): def __getstate__(self, drop_value: bool = True) -> dict: d = self.__dict__.copy() - # all the items should be either `ResultMetric`s or `ResultMetricCollection`s + # all the items should be either `_ResultMetric`s or `_ResultMetricCollection`s items = {k: v.__getstate__(drop_value=drop_value) for k, v in self.items()} return {**d, "items": items} @@ -643,14 +643,14 @@ class ResultCollection(dict): ) -> None: self.__dict__.update({k: v for k, v in state.items() if k != "items"}) - def setstate(k: str, item: dict) -> Union[ResultMetric, ResultMetricCollection]: + def setstate(k: str, item: dict) -> Union[_ResultMetric, _ResultMetricCollection]: if not isinstance(item, dict): raise ValueError(f"Unexpected value: {item}") cls = item["_class"] - if cls == ResultMetric.__name__: - cls = ResultMetric - elif cls == ResultMetricCollection.__name__: - cls = ResultMetricCollection + if cls == _ResultMetric.__name__: + cls = _ResultMetric + elif cls == _ResultMetricCollection.__name__: + cls = _ResultMetricCollection else: raise ValueError(f"Unexpected class name: {cls}") _sync_fn = sync_fn or (self[k].meta.sync.fn if k in self else None) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 47459f5ba6..3740be974b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -65,7 +65,7 @@ from pytorch_lightning.trainer.connectors.callback_connector import CallbackConn from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector -from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection +from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin @@ -2242,7 +2242,7 @@ class Trainer( return self.logger_connector.progress_bar_metrics @property - def _results(self) -> Optional[ResultCollection]: + def _results(self) -> Optional[_ResultCollection]: active_loop = self._active_loop if active_loop is not None: return active_loop._results diff --git a/tests/core/test_metric_result_integration.py b/tests/core/test_metric_result_integration.py index e506fc2927..b80ec8c88d 100644 --- a/tests/core/test_metric_result_integration.py +++ b/tests/core/test_metric_result_integration.py @@ -29,9 +29,9 @@ from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.trainer.connectors.logger_connector.result import ( _Metadata, + _ResultCollection, + _ResultMetric, _Sync, - ResultCollection, - ResultMetric, ) from pytorch_lightning.utilities.imports import _fault_tolerant_training from tests.helpers import BoringModel @@ -71,7 +71,7 @@ def _ddp_test_fn(rank, worldsize): metric_b = metric_b.to(f"cuda:{rank}") metric_c = metric_c.to(f"cuda:{rank}") - result = ResultCollection(True, torch.device(f"cuda:{rank}")) + result = _ResultCollection(True, torch.device(f"cuda:{rank}")) for _ in range(3): cumulative_sum = 0 @@ -114,7 +114,7 @@ def test_result_metric_integration(): metric_b = DummyMetric() metric_c = DummyMetric() - result = ResultCollection(True, torch.device("cpu")) + result = _ResultCollection(True, torch.device("cpu")) for _ in range(3): cumulative_sum = 0 @@ -145,26 +145,26 @@ def test_result_metric_integration(): result.minimize = torch.tensor(1.0) result.extra = {} assert str(result) == ( - "ResultCollection(" + "_ResultCollection(" "{" - "'h.a': ResultMetric('a', value=DummyMetric()), " - "'h.b': ResultMetric('b', value=DummyMetric()), " - "'h.c': ResultMetric('c', value=DummyMetric())" + "'h.a': _ResultMetric('a', value=DummyMetric()), " + "'h.b': _ResultMetric('b', value=DummyMetric()), " + "'h.c': _ResultMetric('c', value=DummyMetric())" "})" ) assert repr(result) == ( "{" "True, " "device(type='cpu'), " - "{'h.a': ResultMetric('a', value=DummyMetric()), " - "'h.b': ResultMetric('b', value=DummyMetric()), " - "'h.c': ResultMetric('c', value=DummyMetric())" + "{'h.a': _ResultMetric('a', value=DummyMetric()), " + "'h.b': _ResultMetric('b', value=DummyMetric()), " + "'h.c': _ResultMetric('c', value=DummyMetric())" "}}" ) def test_result_collection_simple_loop(): - result = ResultCollection(True, torch.device("cpu")) + result = _ResultCollection(True, torch.device("cpu")) current_fx_name = None batch_idx = None @@ -212,7 +212,7 @@ def my_sync_dist(x, *_, **__): def test_result_collection_restoration(tmpdir): """This test make sure metrics are properly reloaded on failure.""" - result = ResultCollection(True, torch.device("cpu")) + result = _ResultCollection(True, torch.device("cpu")) metric_a = DummyMetric() metric_b = DummyMetric() metric_c = DummyMetric() @@ -253,7 +253,7 @@ def test_result_collection_restoration(tmpdir): assert set(batch_log["c_1"]) == {"1", "2"} result_copy = deepcopy(result) - new_result = ResultCollection(True, torch.device("cpu")) + new_result = _ResultCollection(True, torch.device("cpu")) state_dict = result.state_dict() # check the sync fn was dropped assert "fn" not in state_dict["items"]["training_step.a"]["meta"]["_sync"] @@ -334,7 +334,7 @@ def test_lightning_module_logging_result_collection(tmpdir, device): assert results["validation_step.v"].meta.sync.fn == self.trainer.training_type_plugin.reduce # default sync fn - new_results = ResultCollection(False, device) + new_results = _ResultCollection(False, device) new_results.load_state_dict(state_dict, map_location="cpu") assert new_results["validation_step.v"].meta.sync.fn is None @@ -373,7 +373,7 @@ class DummyMeanMetric(Metric): def result_collection_reload(**kwargs): - """This test is going to validate ResultCollection is properly being reload and final accumulation with Fault + """This test is going to validate _ResultCollection is properly being reload and final accumulation with Fault Tolerant Training is correct.""" if not _fault_tolerant_training(): @@ -551,10 +551,10 @@ def test_metric_result_computed_check(): """Unittest ``_get_cache`` with multielement tensors.""" metadata = _Metadata("foo", "bar", on_epoch=True, enable_graph=True) metadata.sync = _Sync() - rm = ResultMetric(metadata, is_tensor=True) + rm = _ResultMetric(metadata, is_tensor=True) computed_value = torch.tensor([1, 2, 3]) rm._computed = computed_value - cache = ResultCollection._get_cache(rm, on_step=False) + cache = _ResultCollection._get_cache(rm, on_step=False) # `enable_graph=True` so no detach, identity works assert cache is computed_value @@ -566,7 +566,7 @@ def test_metric_result_respects_dtype(floating_dtype): metadata = _Metadata("foo", "bar") metadata.sync = _Sync() - rm = ResultMetric(metadata, is_tensor=True) + rm = _ResultMetric(metadata, is_tensor=True) assert rm.value.dtype == floating_dtype assert rm.cumulated_batch_size.dtype == fixed_dtype diff --git a/tests/trainer/logging_/test_logger_connector.py b/tests/trainer/logging_/test_logger_connector.py index cebf4cb6f9..b965478684 100644 --- a/tests/trainer/logging_/test_logger_connector.py +++ b/tests/trainer/logging_/test_logger_connector.py @@ -23,7 +23,7 @@ from pytorch_lightning import LightningModule from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.trainer import Trainer from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator -from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection +from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers.boring_model import BoringModel, RandomDataset from tests.helpers.runif import RunIf @@ -267,19 +267,19 @@ def test_fx_validator_integration(tmpdir): with pytest.deprecated_call(match="on_test_dataloader` is deprecated in v1.5"): trainer.test(model, verbose=False) - not_supported.update({k: "ResultCollection` is not registered yet" for k in not_supported}) + not_supported.update({k: "result collection is not registered yet" for k in not_supported}) not_supported.update( { - "on_predict_dataloader": "ResultCollection` is not registered yet", - "predict_dataloader": "ResultCollection` is not registered yet", - "on_predict_model_eval": "ResultCollection` is not registered yet", - "on_predict_start": "ResultCollection` is not registered yet", - "on_predict_epoch_start": "ResultCollection` is not registered yet", - "on_predict_batch_start": "ResultCollection` is not registered yet", - "predict_step": "ResultCollection` is not registered yet", - "on_predict_batch_end": "ResultCollection` is not registered yet", - "on_predict_epoch_end": "ResultCollection` is not registered yet", - "on_predict_end": "ResultCollection` is not registered yet", + "on_predict_dataloader": "result collection is not registered yet", + "predict_dataloader": "result collection is not registered yet", + "on_predict_model_eval": "result collection is not registered yet", + "on_predict_start": "result collection is not registered yet", + "on_predict_epoch_start": "result collection is not registered yet", + "on_predict_batch_start": "result collection is not registered yet", + "predict_step": "result collection is not registered yet", + "on_predict_batch_end": "result collection is not registered yet", + "on_predict_epoch_end": "result collection is not registered yet", + "on_predict_end": "result collection is not registered yet", } ) with pytest.deprecated_call(match="on_predict_dataloader` is deprecated in v1.5"): @@ -531,7 +531,7 @@ def test_metrics_reset(tmpdir): def test_result_collection_on_tensor_with_mean_reduction(): - result_collection = ResultCollection(True) + result_collection = _ResultCollection(True) product = [(True, True), (False, True), (True, False), (False, False)] values = torch.arange(1, 10) batches = values * values @@ -647,7 +647,7 @@ def test_result_collection_batch_size_extraction(): fx_name = "training_step" log_val = torch.tensor(7.0) - results = ResultCollection(training=True, device="cpu") + results = _ResultCollection(training=True, device="cpu") results.batch = torch.randn(1, 4) train_mse = MeanSquaredError() train_mse(torch.randn(4, 5), torch.randn(4, 5)) @@ -656,7 +656,7 @@ def test_result_collection_batch_size_extraction(): assert isinstance(results["training_step.train_logs"]["mse"].value, MeanSquaredError) assert results["training_step.train_logs"]["log_val"].value == log_val - results = ResultCollection(training=True, device="cpu") + results = _ResultCollection(training=True, device="cpu") results.batch = torch.randn(1, 4) results.log(fx_name, "train_log", log_val, on_step=False, on_epoch=True) assert results.batch_size == 1 @@ -665,7 +665,7 @@ def test_result_collection_batch_size_extraction(): def test_result_collection_no_batch_size_extraction(): - results = ResultCollection(training=True, device="cpu") + results = _ResultCollection(training=True, device="cpu") results.batch = torch.randn(1, 4) fx_name = "training_step" batch_size = 10 diff --git a/tests/trainer/logging_/test_loop_logging.py b/tests/trainer/logging_/test_loop_logging.py index 40ad1bc48d..fca7a8546c 100644 --- a/tests/trainer/logging_/test_loop_logging.py +++ b/tests/trainer/logging_/test_loop_logging.py @@ -20,7 +20,7 @@ import torch from pytorch_lightning import Trainer from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator -from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection +from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection from pytorch_lightning.trainer.states import RunningStage, TrainerFn from tests.helpers.boring_model import BoringModel @@ -39,13 +39,13 @@ def test_default_level_for_hooks_that_support_logging(): model.trainer = trainer extra_kwargs = { k: ANY - for k in inspect.signature(ResultCollection.log).parameters + for k in inspect.signature(_ResultCollection.log).parameters if k not in ["self", "fx", "name", "value", "on_step", "on_epoch"] } all_logging_hooks = {k for k in _FxValidator.functions if _FxValidator.functions[k]} with mock.patch( - "pytorch_lightning.trainer.connectors.logger_connector.result.ResultCollection.log", return_value=None + "pytorch_lightning.trainer.connectors.logger_connector.result._ResultCollection.log", return_value=None ) as result_mock: trainer.state.stage = RunningStage.TRAINING hooks = [