Mark all result classes as protected (#11130)

This commit is contained in:
Carlos Mocholí 2021-12-17 20:35:17 +01:00 committed by GitHub
parent 860959fb3f
commit 8508cce37d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 96 additions and 93 deletions

View File

@ -129,6 +129,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Moved ownership of the `Accelerator` instance to the `TrainingTypePlugin`; all training-type plugins now take an optional parameter `accelerator` ([#11022](https://github.com/PyTorchLightning/pytorch-lightning/pull/11022))
- Marked the `ResultCollection`, `ResultMetric`, and `ResultMetricCollection` classes as protected ([#11130](https://github.com/PyTorchLightning/pytorch-lightning/pull/11130))
- DeepSpeed does not require lightning module zero 3 partitioning ([#10655](https://github.com/PyTorchLightning/pytorch-lightning/pull/10655))

View File

@ -351,7 +351,7 @@ class LightningModule(
results = self.trainer._results
if results is None:
raise MisconfigurationException(
"You are trying to `self.log()` but the loop `ResultCollection` is not registered"
"You are trying to `self.log()` but the loop's result collection is not registered"
" yet. This is most likely because you are trying to log in a `predict` hook,"
" but it doesn't support logging"
)

View File

@ -19,7 +19,7 @@ from deprecate import void
from torchmetrics import Metric
import pytorch_lightning as pl
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
from pytorch_lightning.trainer.progress import BaseProgress
from pytorch_lightning.utilities.enums import _FaultTolerantMode
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@ -282,7 +282,7 @@ class Loop(ABC, Generic[T]):
destination[key] = v.state_dict()
elif isinstance(v, Loop):
v.state_dict(destination, key + ".")
elif isinstance(v, ResultCollection):
elif isinstance(v, _ResultCollection):
# sync / unsync metrics
v.sync()
destination[key] = v.state_dict()
@ -312,7 +312,7 @@ class Loop(ABC, Generic[T]):
if isinstance(v, BaseProgress):
v.load_state_dict(state_dict[key])
elif (
isinstance(v, ResultCollection)
isinstance(v, _ResultCollection)
and self.trainer is not None
and self.trainer.lightning_module is not None
):
@ -324,10 +324,10 @@ class Loop(ABC, Generic[T]):
if metrics:
metric_attributes.update(metrics)
# The `ResultCollection` objects have 2 types of metrics: `Tensor` and `torchmetrics.Metric`.
# The `_ResultCollection` objects have 2 types of metrics: `Tensor` and `torchmetrics.Metric`.
# When creating a checkpoint, the `Metric`s are dropped from the loop `state_dict` to serialize only
# Python primitives. However, their states are saved with the model's `state_dict`.
# On reload, we need to re-attach the `Metric`s back to the `ResultCollection`.
# On reload, we need to re-attach the `Metric`s back to the `_ResultCollection`.
# The references are provided through the `metric_attributes` dictionary.
v.load_state_dict(
state_dict[key], metrics=metric_attributes, sync_fn=self.trainer.training_type_plugin.reduce

View File

@ -19,8 +19,8 @@ from torch.utils.data.dataloader import DataLoader
from pytorch_lightning.loops.dataloader import DataLoaderLoop
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, ResultCollection
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
@ -32,7 +32,7 @@ class EvaluationLoop(DataLoaderLoop):
self.epoch_loop = EvaluationEpochLoop()
self.verbose = verbose
self._results = ResultCollection(training=False)
self._results = _ResultCollection(training=False)
self._outputs: List[EPOCH_OUTPUT] = []
self._logged_outputs: List[_OUT_DICT] = []
self._max_batches: List[int] = []

View File

@ -21,7 +21,7 @@ from pytorch_lightning import loops # import as loops to avoid circular imports
from pytorch_lightning.loops.batch import TrainingBatchLoop
from pytorch_lightning.loops.batch.training_batch_loop import _OUTPUTS_TYPE as _BATCH_OUTPUTS_TYPE
from pytorch_lightning.loops.utilities import _get_active_optimizers, _is_max_limit_reached, _update_dataloader_iter
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
from pytorch_lightning.trainer.progress import BatchProgress, SchedulerProgress
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
@ -65,7 +65,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
self.batch_loop = TrainingBatchLoop()
self.val_loop = loops.EvaluationLoop(verbose=False)
self._results = ResultCollection(training=True)
self._results = _ResultCollection(training=True)
self._outputs: _OUTPUTS_TYPE = []
self._warning_cache = WarningCache()
self._dataloader_iter: Optional[Iterator] = None

View File

@ -17,7 +17,7 @@ from typing import Optional
from pytorch_lightning.loops import Loop
from pytorch_lightning.loops.epoch import TrainingEpochLoop
from pytorch_lightning.loops.utilities import _is_max_limit_reached
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
from pytorch_lightning.trainer.progress import Progress
from pytorch_lightning.trainer.supporters import TensorRunningAccum
from pytorch_lightning.utilities import rank_zero_deprecation
@ -136,7 +136,7 @@ class FitLoop(Loop):
self.epoch_loop.batch_loop.optimizer_loop._skip_backward = value
@property
def _results(self) -> ResultCollection:
def _results(self) -> _ResultCollection:
if self.trainer.training:
return self.epoch_loop._results
if self.trainer.validating:

View File

@ -199,7 +199,7 @@ class _Metadata:
return meta
class ResultMetric(Metric, DeviceDtypeModuleMixin):
class _ResultMetric(Metric, DeviceDtypeModuleMixin):
"""Wraps the value provided to `:meth:`~pytorch_lightning.core.lightning.LightningModule.log`"""
def __init__(self, metadata: _Metadata, is_tensor: bool) -> None:
@ -316,25 +316,25 @@ class ResultMetric(Metric, DeviceDtypeModuleMixin):
super().__setstate__(d)
@classmethod
def _reconstruct(cls, state: dict, sync_fn: Optional[Callable] = None) -> "ResultMetric":
def _reconstruct(cls, state: dict, sync_fn: Optional[Callable] = None) -> "_ResultMetric":
# need to reconstruct twice because `meta` is used in `__init__`
meta = _Metadata._reconstruct(state["meta"])
result_metric = cls(meta, state["is_tensor"])
result_metric.__setstate__(state, sync_fn=sync_fn)
return result_metric
def to(self, *args: Any, **kwargs: Any) -> "ResultMetric":
def to(self, *args: Any, **kwargs: Any) -> "_ResultMetric":
self.__dict__.update(
apply_to_collection(self.__dict__, (torch.Tensor, Metric), move_data_to_device, *args, **kwargs)
)
return self
class ResultMetricCollection(dict):
class _ResultMetricCollection(dict):
"""Dict wrapper for easy access to metadata.
All of the leaf items should be instances of
:class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetric`
:class:`~pytorch_lightning.trainer.connectors.logger_connector.result._ResultMetric`
with the same metadata.
"""
@ -347,36 +347,36 @@ class ResultMetricCollection(dict):
return any(v.is_tensor for v in self.values())
def __getstate__(self, drop_value: bool = False) -> dict:
def getstate(item: ResultMetric) -> dict:
def getstate(item: _ResultMetric) -> dict:
return item.__getstate__(drop_value=drop_value)
items = apply_to_collection(dict(self), ResultMetric, getstate)
items = apply_to_collection(dict(self), _ResultMetric, getstate)
return {"items": items, "meta": self.meta.__getstate__(), "_class": self.__class__.__name__}
def __setstate__(self, state: dict, sync_fn: Optional[Callable] = None) -> None:
# can't use `apply_to_collection` as it does not recurse items of the same type
items = {k: ResultMetric._reconstruct(v, sync_fn=sync_fn) for k, v in state["items"].items()}
items = {k: _ResultMetric._reconstruct(v, sync_fn=sync_fn) for k, v in state["items"].items()}
self.update(items)
@classmethod
def _reconstruct(cls, state: dict, sync_fn: Optional[Callable] = None) -> "ResultMetricCollection":
def _reconstruct(cls, state: dict, sync_fn: Optional[Callable] = None) -> "_ResultMetricCollection":
rmc = cls()
rmc.__setstate__(state, sync_fn=sync_fn)
return rmc
_METRIC_COLLECTION = Union[_IN_METRIC, ResultMetricCollection]
_METRIC_COLLECTION = Union[_IN_METRIC, _ResultMetricCollection]
class ResultCollection(dict):
class _ResultCollection(dict):
"""
Collection (dictionary) of :class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetric` or
:class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetricCollection`
Collection (dictionary) of :class:`~pytorch_lightning.trainer.connectors.logger_connector.result._ResultMetric` or
:class:`~pytorch_lightning.trainer.connectors.logger_connector.result._ResultMetricCollection`
Example:
# `device` needs to be provided before logging
result = ResultCollection(training=True, torch.device("cpu"))
result = _ResultCollection(training=True, torch.device("cpu"))
# you can log to a specific collection.
# arguments: fx, key, value, metadata
@ -395,14 +395,14 @@ class ResultCollection(dict):
self.dataloader_idx: Optional[int] = None
@property
def result_metrics(self) -> List[ResultMetric]:
def result_metrics(self) -> List[_ResultMetric]:
o = []
def append_fn(v: ResultMetric) -> None:
def append_fn(v: _ResultMetric) -> None:
nonlocal o
o.append(v)
apply_to_collection(list(self.values()), ResultMetric, append_fn)
apply_to_collection(list(self.values()), _ResultMetric, append_fn)
return o
def _extract_batch_size(self, value: _METRIC_COLLECTION, batch_size: Optional[int], meta: _Metadata) -> int:
@ -414,7 +414,7 @@ class ResultCollection(dict):
return batch_size
batch_size = 1
is_tensor = value.is_tensor if isinstance(value, ResultMetric) else value.has_tensor
is_tensor = value.is_tensor if isinstance(value, _ResultMetric) else value.has_tensor
if self.batch is not None and is_tensor and meta.on_epoch and meta.is_mean_reduction:
batch_size = extract_batch_size(self.batch)
self.batch_size = batch_size
@ -485,30 +485,30 @@ class ResultCollection(dict):
self.update_metrics(key, value, batch_size)
def register_key(self, key: str, meta: _Metadata, value: _METRIC_COLLECTION) -> None:
"""Create one ResultMetric object per value.
"""Create one _ResultMetric object per value.
Value can be provided as a nested collection
"""
def fn(v: _IN_METRIC) -> ResultMetric:
metric = ResultMetric(meta, isinstance(v, torch.Tensor))
def fn(v: _IN_METRIC) -> _ResultMetric:
metric = _ResultMetric(meta, isinstance(v, torch.Tensor))
return metric.to(self.device)
value = apply_to_collection(value, (torch.Tensor, Metric), fn)
if isinstance(value, dict):
value = ResultMetricCollection(value)
value = _ResultMetricCollection(value)
self[key] = value
def update_metrics(self, key: str, value: _METRIC_COLLECTION, batch_size: int) -> None:
def fn(result_metric: ResultMetric, v: torch.Tensor) -> None:
def fn(result_metric: _ResultMetric, v: torch.Tensor) -> None:
# performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl`
result_metric.forward(v.to(self.device), batch_size)
result_metric.has_reset = False
apply_to_collections(self[key], value, ResultMetric, fn)
apply_to_collections(self[key], value, _ResultMetric, fn)
@staticmethod
def _get_cache(result_metric: ResultMetric, on_step: bool) -> Optional[torch.Tensor]:
def _get_cache(result_metric: _ResultMetric, on_step: bool) -> Optional[torch.Tensor]:
cache = None
if on_step and result_metric.meta.on_step:
cache = result_metric._forward_cache
@ -529,11 +529,11 @@ class ResultCollection(dict):
return (
(k, v)
for k, v in self.items()
if not (isinstance(v, ResultMetric) and v.has_reset)
if not (isinstance(v, _ResultMetric) and v.has_reset)
and self.dataloader_idx in (None, v.meta.dataloader_idx)
)
def _forked_name(self, result_metric: ResultMetric, on_step: bool) -> Tuple[str, str]:
def _forked_name(self, result_metric: _ResultMetric, on_step: bool) -> Tuple[str, str]:
name = result_metric.meta.name
forked_name = result_metric.meta.forked_name(on_step)
add_dataloader_idx = result_metric.meta.add_dataloader_idx
@ -549,11 +549,11 @@ class ResultCollection(dict):
for _, result_metric in self.valid_items():
# extract forward_cache or computed from the ResultMetric. ignore when the output is None
value = apply_to_collection(result_metric, ResultMetric, self._get_cache, on_step, include_none=False)
# extract forward_cache or computed from the _ResultMetric. ignore when the output is None
value = apply_to_collection(result_metric, _ResultMetric, self._get_cache, on_step, include_none=False)
# convert metric collection to dict container.
if isinstance(value, ResultMetricCollection):
if isinstance(value, _ResultMetricCollection):
value = dict(value.items())
# check if the collection is empty
@ -594,15 +594,15 @@ class ResultCollection(dict):
fx: Function to reset
"""
def fn(item: ResultMetric) -> None:
def fn(item: _ResultMetric) -> None:
requested_type = metrics is None or metrics ^ item.is_tensor
same_fx = fx is None or fx == item.meta.fx
if requested_type and same_fx:
item.reset()
apply_to_collection(self, ResultMetric, fn)
apply_to_collection(self, _ResultMetric, fn)
def to(self, *args: Any, **kwargs: Any) -> "ResultCollection":
def to(self, *args: Any, **kwargs: Any) -> "_ResultCollection":
"""Move all data to the given device."""
self.update(apply_to_collection(dict(self), (torch.Tensor, Metric), move_data_to_device, *args, **kwargs))
@ -610,7 +610,7 @@ class ResultCollection(dict):
self.device = kwargs["device"]
return self
def cpu(self) -> "ResultCollection":
def cpu(self) -> "_ResultCollection":
"""Move all data to CPU."""
return self.to(device="cpu")
@ -634,7 +634,7 @@ class ResultCollection(dict):
def __getstate__(self, drop_value: bool = True) -> dict:
d = self.__dict__.copy()
# all the items should be either `ResultMetric`s or `ResultMetricCollection`s
# all the items should be either `_ResultMetric`s or `_ResultMetricCollection`s
items = {k: v.__getstate__(drop_value=drop_value) for k, v in self.items()}
return {**d, "items": items}
@ -643,14 +643,14 @@ class ResultCollection(dict):
) -> None:
self.__dict__.update({k: v for k, v in state.items() if k != "items"})
def setstate(k: str, item: dict) -> Union[ResultMetric, ResultMetricCollection]:
def setstate(k: str, item: dict) -> Union[_ResultMetric, _ResultMetricCollection]:
if not isinstance(item, dict):
raise ValueError(f"Unexpected value: {item}")
cls = item["_class"]
if cls == ResultMetric.__name__:
cls = ResultMetric
elif cls == ResultMetricCollection.__name__:
cls = ResultMetricCollection
if cls == _ResultMetric.__name__:
cls = _ResultMetric
elif cls == _ResultMetricCollection.__name__:
cls = _ResultMetricCollection
else:
raise ValueError(f"Unexpected class name: {cls}")
_sync_fn = sync_fn or (self[k].meta.sync.fn if k in self else None)

View File

@ -65,7 +65,7 @@ from pytorch_lightning.trainer.connectors.callback_connector import CallbackConn
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
@ -2242,7 +2242,7 @@ class Trainer(
return self.logger_connector.progress_bar_metrics
@property
def _results(self) -> Optional[ResultCollection]:
def _results(self) -> Optional[_ResultCollection]:
active_loop = self._active_loop
if active_loop is not None:
return active_loop._results

View File

@ -29,9 +29,9 @@ from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.trainer.connectors.logger_connector.result import (
_Metadata,
_ResultCollection,
_ResultMetric,
_Sync,
ResultCollection,
ResultMetric,
)
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from tests.helpers import BoringModel
@ -71,7 +71,7 @@ def _ddp_test_fn(rank, worldsize):
metric_b = metric_b.to(f"cuda:{rank}")
metric_c = metric_c.to(f"cuda:{rank}")
result = ResultCollection(True, torch.device(f"cuda:{rank}"))
result = _ResultCollection(True, torch.device(f"cuda:{rank}"))
for _ in range(3):
cumulative_sum = 0
@ -114,7 +114,7 @@ def test_result_metric_integration():
metric_b = DummyMetric()
metric_c = DummyMetric()
result = ResultCollection(True, torch.device("cpu"))
result = _ResultCollection(True, torch.device("cpu"))
for _ in range(3):
cumulative_sum = 0
@ -145,26 +145,26 @@ def test_result_metric_integration():
result.minimize = torch.tensor(1.0)
result.extra = {}
assert str(result) == (
"ResultCollection("
"_ResultCollection("
"{"
"'h.a': ResultMetric('a', value=DummyMetric()), "
"'h.b': ResultMetric('b', value=DummyMetric()), "
"'h.c': ResultMetric('c', value=DummyMetric())"
"'h.a': _ResultMetric('a', value=DummyMetric()), "
"'h.b': _ResultMetric('b', value=DummyMetric()), "
"'h.c': _ResultMetric('c', value=DummyMetric())"
"})"
)
assert repr(result) == (
"{"
"True, "
"device(type='cpu'), "
"{'h.a': ResultMetric('a', value=DummyMetric()), "
"'h.b': ResultMetric('b', value=DummyMetric()), "
"'h.c': ResultMetric('c', value=DummyMetric())"
"{'h.a': _ResultMetric('a', value=DummyMetric()), "
"'h.b': _ResultMetric('b', value=DummyMetric()), "
"'h.c': _ResultMetric('c', value=DummyMetric())"
"}}"
)
def test_result_collection_simple_loop():
result = ResultCollection(True, torch.device("cpu"))
result = _ResultCollection(True, torch.device("cpu"))
current_fx_name = None
batch_idx = None
@ -212,7 +212,7 @@ def my_sync_dist(x, *_, **__):
def test_result_collection_restoration(tmpdir):
"""This test make sure metrics are properly reloaded on failure."""
result = ResultCollection(True, torch.device("cpu"))
result = _ResultCollection(True, torch.device("cpu"))
metric_a = DummyMetric()
metric_b = DummyMetric()
metric_c = DummyMetric()
@ -253,7 +253,7 @@ def test_result_collection_restoration(tmpdir):
assert set(batch_log["c_1"]) == {"1", "2"}
result_copy = deepcopy(result)
new_result = ResultCollection(True, torch.device("cpu"))
new_result = _ResultCollection(True, torch.device("cpu"))
state_dict = result.state_dict()
# check the sync fn was dropped
assert "fn" not in state_dict["items"]["training_step.a"]["meta"]["_sync"]
@ -334,7 +334,7 @@ def test_lightning_module_logging_result_collection(tmpdir, device):
assert results["validation_step.v"].meta.sync.fn == self.trainer.training_type_plugin.reduce
# default sync fn
new_results = ResultCollection(False, device)
new_results = _ResultCollection(False, device)
new_results.load_state_dict(state_dict, map_location="cpu")
assert new_results["validation_step.v"].meta.sync.fn is None
@ -373,7 +373,7 @@ class DummyMeanMetric(Metric):
def result_collection_reload(**kwargs):
"""This test is going to validate ResultCollection is properly being reload and final accumulation with Fault
"""This test is going to validate _ResultCollection is properly being reload and final accumulation with Fault
Tolerant Training is correct."""
if not _fault_tolerant_training():
@ -551,10 +551,10 @@ def test_metric_result_computed_check():
"""Unittest ``_get_cache`` with multielement tensors."""
metadata = _Metadata("foo", "bar", on_epoch=True, enable_graph=True)
metadata.sync = _Sync()
rm = ResultMetric(metadata, is_tensor=True)
rm = _ResultMetric(metadata, is_tensor=True)
computed_value = torch.tensor([1, 2, 3])
rm._computed = computed_value
cache = ResultCollection._get_cache(rm, on_step=False)
cache = _ResultCollection._get_cache(rm, on_step=False)
# `enable_graph=True` so no detach, identity works
assert cache is computed_value
@ -566,7 +566,7 @@ def test_metric_result_respects_dtype(floating_dtype):
metadata = _Metadata("foo", "bar")
metadata.sync = _Sync()
rm = ResultMetric(metadata, is_tensor=True)
rm = _ResultMetric(metadata, is_tensor=True)
assert rm.value.dtype == floating_dtype
assert rm.cumulated_batch_size.dtype == fixed_dtype

View File

@ -23,7 +23,7 @@ from pytorch_lightning import LightningModule
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel, RandomDataset
from tests.helpers.runif import RunIf
@ -267,19 +267,19 @@ def test_fx_validator_integration(tmpdir):
with pytest.deprecated_call(match="on_test_dataloader` is deprecated in v1.5"):
trainer.test(model, verbose=False)
not_supported.update({k: "ResultCollection` is not registered yet" for k in not_supported})
not_supported.update({k: "result collection is not registered yet" for k in not_supported})
not_supported.update(
{
"on_predict_dataloader": "ResultCollection` is not registered yet",
"predict_dataloader": "ResultCollection` is not registered yet",
"on_predict_model_eval": "ResultCollection` is not registered yet",
"on_predict_start": "ResultCollection` is not registered yet",
"on_predict_epoch_start": "ResultCollection` is not registered yet",
"on_predict_batch_start": "ResultCollection` is not registered yet",
"predict_step": "ResultCollection` is not registered yet",
"on_predict_batch_end": "ResultCollection` is not registered yet",
"on_predict_epoch_end": "ResultCollection` is not registered yet",
"on_predict_end": "ResultCollection` is not registered yet",
"on_predict_dataloader": "result collection is not registered yet",
"predict_dataloader": "result collection is not registered yet",
"on_predict_model_eval": "result collection is not registered yet",
"on_predict_start": "result collection is not registered yet",
"on_predict_epoch_start": "result collection is not registered yet",
"on_predict_batch_start": "result collection is not registered yet",
"predict_step": "result collection is not registered yet",
"on_predict_batch_end": "result collection is not registered yet",
"on_predict_epoch_end": "result collection is not registered yet",
"on_predict_end": "result collection is not registered yet",
}
)
with pytest.deprecated_call(match="on_predict_dataloader` is deprecated in v1.5"):
@ -531,7 +531,7 @@ def test_metrics_reset(tmpdir):
def test_result_collection_on_tensor_with_mean_reduction():
result_collection = ResultCollection(True)
result_collection = _ResultCollection(True)
product = [(True, True), (False, True), (True, False), (False, False)]
values = torch.arange(1, 10)
batches = values * values
@ -647,7 +647,7 @@ def test_result_collection_batch_size_extraction():
fx_name = "training_step"
log_val = torch.tensor(7.0)
results = ResultCollection(training=True, device="cpu")
results = _ResultCollection(training=True, device="cpu")
results.batch = torch.randn(1, 4)
train_mse = MeanSquaredError()
train_mse(torch.randn(4, 5), torch.randn(4, 5))
@ -656,7 +656,7 @@ def test_result_collection_batch_size_extraction():
assert isinstance(results["training_step.train_logs"]["mse"].value, MeanSquaredError)
assert results["training_step.train_logs"]["log_val"].value == log_val
results = ResultCollection(training=True, device="cpu")
results = _ResultCollection(training=True, device="cpu")
results.batch = torch.randn(1, 4)
results.log(fx_name, "train_log", log_val, on_step=False, on_epoch=True)
assert results.batch_size == 1
@ -665,7 +665,7 @@ def test_result_collection_batch_size_extraction():
def test_result_collection_no_batch_size_extraction():
results = ResultCollection(training=True, device="cpu")
results = _ResultCollection(training=True, device="cpu")
results.batch = torch.randn(1, 4)
fx_name = "training_step"
batch_size = 10

View File

@ -20,7 +20,7 @@ import torch
from pytorch_lightning import Trainer
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
from tests.helpers.boring_model import BoringModel
@ -39,13 +39,13 @@ def test_default_level_for_hooks_that_support_logging():
model.trainer = trainer
extra_kwargs = {
k: ANY
for k in inspect.signature(ResultCollection.log).parameters
for k in inspect.signature(_ResultCollection.log).parameters
if k not in ["self", "fx", "name", "value", "on_step", "on_epoch"]
}
all_logging_hooks = {k for k in _FxValidator.functions if _FxValidator.functions[k]}
with mock.patch(
"pytorch_lightning.trainer.connectors.logger_connector.result.ResultCollection.log", return_value=None
"pytorch_lightning.trainer.connectors.logger_connector.result._ResultCollection.log", return_value=None
) as result_mock:
trainer.state.stage = RunningStage.TRAINING
hooks = [