Mark all result classes as protected (#11130)
This commit is contained in:
parent
860959fb3f
commit
8508cce37d
|
@ -129,6 +129,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Moved ownership of the `Accelerator` instance to the `TrainingTypePlugin`; all training-type plugins now take an optional parameter `accelerator` ([#11022](https://github.com/PyTorchLightning/pytorch-lightning/pull/11022))
|
||||
|
||||
|
||||
- Marked the `ResultCollection`, `ResultMetric`, and `ResultMetricCollection` classes as protected ([#11130](https://github.com/PyTorchLightning/pytorch-lightning/pull/11130))
|
||||
|
||||
|
||||
- DeepSpeed does not require lightning module zero 3 partitioning ([#10655](https://github.com/PyTorchLightning/pytorch-lightning/pull/10655))
|
||||
|
||||
|
||||
|
|
|
@ -351,7 +351,7 @@ class LightningModule(
|
|||
results = self.trainer._results
|
||||
if results is None:
|
||||
raise MisconfigurationException(
|
||||
"You are trying to `self.log()` but the loop `ResultCollection` is not registered"
|
||||
"You are trying to `self.log()` but the loop's result collection is not registered"
|
||||
" yet. This is most likely because you are trying to log in a `predict` hook,"
|
||||
" but it doesn't support logging"
|
||||
)
|
||||
|
|
|
@ -19,7 +19,7 @@ from deprecate import void
|
|||
from torchmetrics import Metric
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
|
||||
from pytorch_lightning.trainer.progress import BaseProgress
|
||||
from pytorch_lightning.utilities.enums import _FaultTolerantMode
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
@ -282,7 +282,7 @@ class Loop(ABC, Generic[T]):
|
|||
destination[key] = v.state_dict()
|
||||
elif isinstance(v, Loop):
|
||||
v.state_dict(destination, key + ".")
|
||||
elif isinstance(v, ResultCollection):
|
||||
elif isinstance(v, _ResultCollection):
|
||||
# sync / unsync metrics
|
||||
v.sync()
|
||||
destination[key] = v.state_dict()
|
||||
|
@ -312,7 +312,7 @@ class Loop(ABC, Generic[T]):
|
|||
if isinstance(v, BaseProgress):
|
||||
v.load_state_dict(state_dict[key])
|
||||
elif (
|
||||
isinstance(v, ResultCollection)
|
||||
isinstance(v, _ResultCollection)
|
||||
and self.trainer is not None
|
||||
and self.trainer.lightning_module is not None
|
||||
):
|
||||
|
@ -324,10 +324,10 @@ class Loop(ABC, Generic[T]):
|
|||
if metrics:
|
||||
metric_attributes.update(metrics)
|
||||
|
||||
# The `ResultCollection` objects have 2 types of metrics: `Tensor` and `torchmetrics.Metric`.
|
||||
# The `_ResultCollection` objects have 2 types of metrics: `Tensor` and `torchmetrics.Metric`.
|
||||
# When creating a checkpoint, the `Metric`s are dropped from the loop `state_dict` to serialize only
|
||||
# Python primitives. However, their states are saved with the model's `state_dict`.
|
||||
# On reload, we need to re-attach the `Metric`s back to the `ResultCollection`.
|
||||
# On reload, we need to re-attach the `Metric`s back to the `_ResultCollection`.
|
||||
# The references are provided through the `metric_attributes` dictionary.
|
||||
v.load_state_dict(
|
||||
state_dict[key], metrics=metric_attributes, sync_fn=self.trainer.training_type_plugin.reduce
|
||||
|
|
|
@ -19,8 +19,8 @@ from torch.utils.data.dataloader import DataLoader
|
|||
|
||||
from pytorch_lightning.loops.dataloader import DataLoaderLoop
|
||||
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, ResultCollection
|
||||
from pytorch_lightning.trainer.states import RunningStage
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, _ResultCollection
|
||||
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
|
||||
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
|
||||
|
||||
|
||||
|
@ -32,7 +32,7 @@ class EvaluationLoop(DataLoaderLoop):
|
|||
self.epoch_loop = EvaluationEpochLoop()
|
||||
self.verbose = verbose
|
||||
|
||||
self._results = ResultCollection(training=False)
|
||||
self._results = _ResultCollection(training=False)
|
||||
self._outputs: List[EPOCH_OUTPUT] = []
|
||||
self._logged_outputs: List[_OUT_DICT] = []
|
||||
self._max_batches: List[int] = []
|
||||
|
|
|
@ -21,7 +21,7 @@ from pytorch_lightning import loops # import as loops to avoid circular imports
|
|||
from pytorch_lightning.loops.batch import TrainingBatchLoop
|
||||
from pytorch_lightning.loops.batch.training_batch_loop import _OUTPUTS_TYPE as _BATCH_OUTPUTS_TYPE
|
||||
from pytorch_lightning.loops.utilities import _get_active_optimizers, _is_max_limit_reached, _update_dataloader_iter
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
|
||||
from pytorch_lightning.trainer.progress import BatchProgress, SchedulerProgress
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.apply_func import apply_to_collection
|
||||
|
@ -65,7 +65,7 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):
|
|||
self.batch_loop = TrainingBatchLoop()
|
||||
self.val_loop = loops.EvaluationLoop(verbose=False)
|
||||
|
||||
self._results = ResultCollection(training=True)
|
||||
self._results = _ResultCollection(training=True)
|
||||
self._outputs: _OUTPUTS_TYPE = []
|
||||
self._warning_cache = WarningCache()
|
||||
self._dataloader_iter: Optional[Iterator] = None
|
||||
|
|
|
@ -17,7 +17,7 @@ from typing import Optional
|
|||
from pytorch_lightning.loops import Loop
|
||||
from pytorch_lightning.loops.epoch import TrainingEpochLoop
|
||||
from pytorch_lightning.loops.utilities import _is_max_limit_reached
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
|
||||
from pytorch_lightning.trainer.progress import Progress
|
||||
from pytorch_lightning.trainer.supporters import TensorRunningAccum
|
||||
from pytorch_lightning.utilities import rank_zero_deprecation
|
||||
|
@ -136,7 +136,7 @@ class FitLoop(Loop):
|
|||
self.epoch_loop.batch_loop.optimizer_loop._skip_backward = value
|
||||
|
||||
@property
|
||||
def _results(self) -> ResultCollection:
|
||||
def _results(self) -> _ResultCollection:
|
||||
if self.trainer.training:
|
||||
return self.epoch_loop._results
|
||||
if self.trainer.validating:
|
||||
|
|
|
@ -199,7 +199,7 @@ class _Metadata:
|
|||
return meta
|
||||
|
||||
|
||||
class ResultMetric(Metric, DeviceDtypeModuleMixin):
|
||||
class _ResultMetric(Metric, DeviceDtypeModuleMixin):
|
||||
"""Wraps the value provided to `:meth:`~pytorch_lightning.core.lightning.LightningModule.log`"""
|
||||
|
||||
def __init__(self, metadata: _Metadata, is_tensor: bool) -> None:
|
||||
|
@ -316,25 +316,25 @@ class ResultMetric(Metric, DeviceDtypeModuleMixin):
|
|||
super().__setstate__(d)
|
||||
|
||||
@classmethod
|
||||
def _reconstruct(cls, state: dict, sync_fn: Optional[Callable] = None) -> "ResultMetric":
|
||||
def _reconstruct(cls, state: dict, sync_fn: Optional[Callable] = None) -> "_ResultMetric":
|
||||
# need to reconstruct twice because `meta` is used in `__init__`
|
||||
meta = _Metadata._reconstruct(state["meta"])
|
||||
result_metric = cls(meta, state["is_tensor"])
|
||||
result_metric.__setstate__(state, sync_fn=sync_fn)
|
||||
return result_metric
|
||||
|
||||
def to(self, *args: Any, **kwargs: Any) -> "ResultMetric":
|
||||
def to(self, *args: Any, **kwargs: Any) -> "_ResultMetric":
|
||||
self.__dict__.update(
|
||||
apply_to_collection(self.__dict__, (torch.Tensor, Metric), move_data_to_device, *args, **kwargs)
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class ResultMetricCollection(dict):
|
||||
class _ResultMetricCollection(dict):
|
||||
"""Dict wrapper for easy access to metadata.
|
||||
|
||||
All of the leaf items should be instances of
|
||||
:class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetric`
|
||||
:class:`~pytorch_lightning.trainer.connectors.logger_connector.result._ResultMetric`
|
||||
with the same metadata.
|
||||
"""
|
||||
|
||||
|
@ -347,36 +347,36 @@ class ResultMetricCollection(dict):
|
|||
return any(v.is_tensor for v in self.values())
|
||||
|
||||
def __getstate__(self, drop_value: bool = False) -> dict:
|
||||
def getstate(item: ResultMetric) -> dict:
|
||||
def getstate(item: _ResultMetric) -> dict:
|
||||
return item.__getstate__(drop_value=drop_value)
|
||||
|
||||
items = apply_to_collection(dict(self), ResultMetric, getstate)
|
||||
items = apply_to_collection(dict(self), _ResultMetric, getstate)
|
||||
return {"items": items, "meta": self.meta.__getstate__(), "_class": self.__class__.__name__}
|
||||
|
||||
def __setstate__(self, state: dict, sync_fn: Optional[Callable] = None) -> None:
|
||||
# can't use `apply_to_collection` as it does not recurse items of the same type
|
||||
items = {k: ResultMetric._reconstruct(v, sync_fn=sync_fn) for k, v in state["items"].items()}
|
||||
items = {k: _ResultMetric._reconstruct(v, sync_fn=sync_fn) for k, v in state["items"].items()}
|
||||
self.update(items)
|
||||
|
||||
@classmethod
|
||||
def _reconstruct(cls, state: dict, sync_fn: Optional[Callable] = None) -> "ResultMetricCollection":
|
||||
def _reconstruct(cls, state: dict, sync_fn: Optional[Callable] = None) -> "_ResultMetricCollection":
|
||||
rmc = cls()
|
||||
rmc.__setstate__(state, sync_fn=sync_fn)
|
||||
return rmc
|
||||
|
||||
|
||||
_METRIC_COLLECTION = Union[_IN_METRIC, ResultMetricCollection]
|
||||
_METRIC_COLLECTION = Union[_IN_METRIC, _ResultMetricCollection]
|
||||
|
||||
|
||||
class ResultCollection(dict):
|
||||
class _ResultCollection(dict):
|
||||
"""
|
||||
Collection (dictionary) of :class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetric` or
|
||||
:class:`~pytorch_lightning.trainer.connectors.logger_connector.result.ResultMetricCollection`
|
||||
Collection (dictionary) of :class:`~pytorch_lightning.trainer.connectors.logger_connector.result._ResultMetric` or
|
||||
:class:`~pytorch_lightning.trainer.connectors.logger_connector.result._ResultMetricCollection`
|
||||
|
||||
Example:
|
||||
|
||||
# `device` needs to be provided before logging
|
||||
result = ResultCollection(training=True, torch.device("cpu"))
|
||||
result = _ResultCollection(training=True, torch.device("cpu"))
|
||||
|
||||
# you can log to a specific collection.
|
||||
# arguments: fx, key, value, metadata
|
||||
|
@ -395,14 +395,14 @@ class ResultCollection(dict):
|
|||
self.dataloader_idx: Optional[int] = None
|
||||
|
||||
@property
|
||||
def result_metrics(self) -> List[ResultMetric]:
|
||||
def result_metrics(self) -> List[_ResultMetric]:
|
||||
o = []
|
||||
|
||||
def append_fn(v: ResultMetric) -> None:
|
||||
def append_fn(v: _ResultMetric) -> None:
|
||||
nonlocal o
|
||||
o.append(v)
|
||||
|
||||
apply_to_collection(list(self.values()), ResultMetric, append_fn)
|
||||
apply_to_collection(list(self.values()), _ResultMetric, append_fn)
|
||||
return o
|
||||
|
||||
def _extract_batch_size(self, value: _METRIC_COLLECTION, batch_size: Optional[int], meta: _Metadata) -> int:
|
||||
|
@ -414,7 +414,7 @@ class ResultCollection(dict):
|
|||
return batch_size
|
||||
|
||||
batch_size = 1
|
||||
is_tensor = value.is_tensor if isinstance(value, ResultMetric) else value.has_tensor
|
||||
is_tensor = value.is_tensor if isinstance(value, _ResultMetric) else value.has_tensor
|
||||
if self.batch is not None and is_tensor and meta.on_epoch and meta.is_mean_reduction:
|
||||
batch_size = extract_batch_size(self.batch)
|
||||
self.batch_size = batch_size
|
||||
|
@ -485,30 +485,30 @@ class ResultCollection(dict):
|
|||
self.update_metrics(key, value, batch_size)
|
||||
|
||||
def register_key(self, key: str, meta: _Metadata, value: _METRIC_COLLECTION) -> None:
|
||||
"""Create one ResultMetric object per value.
|
||||
"""Create one _ResultMetric object per value.
|
||||
|
||||
Value can be provided as a nested collection
|
||||
"""
|
||||
|
||||
def fn(v: _IN_METRIC) -> ResultMetric:
|
||||
metric = ResultMetric(meta, isinstance(v, torch.Tensor))
|
||||
def fn(v: _IN_METRIC) -> _ResultMetric:
|
||||
metric = _ResultMetric(meta, isinstance(v, torch.Tensor))
|
||||
return metric.to(self.device)
|
||||
|
||||
value = apply_to_collection(value, (torch.Tensor, Metric), fn)
|
||||
if isinstance(value, dict):
|
||||
value = ResultMetricCollection(value)
|
||||
value = _ResultMetricCollection(value)
|
||||
self[key] = value
|
||||
|
||||
def update_metrics(self, key: str, value: _METRIC_COLLECTION, batch_size: int) -> None:
|
||||
def fn(result_metric: ResultMetric, v: torch.Tensor) -> None:
|
||||
def fn(result_metric: _ResultMetric, v: torch.Tensor) -> None:
|
||||
# performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl`
|
||||
result_metric.forward(v.to(self.device), batch_size)
|
||||
result_metric.has_reset = False
|
||||
|
||||
apply_to_collections(self[key], value, ResultMetric, fn)
|
||||
apply_to_collections(self[key], value, _ResultMetric, fn)
|
||||
|
||||
@staticmethod
|
||||
def _get_cache(result_metric: ResultMetric, on_step: bool) -> Optional[torch.Tensor]:
|
||||
def _get_cache(result_metric: _ResultMetric, on_step: bool) -> Optional[torch.Tensor]:
|
||||
cache = None
|
||||
if on_step and result_metric.meta.on_step:
|
||||
cache = result_metric._forward_cache
|
||||
|
@ -529,11 +529,11 @@ class ResultCollection(dict):
|
|||
return (
|
||||
(k, v)
|
||||
for k, v in self.items()
|
||||
if not (isinstance(v, ResultMetric) and v.has_reset)
|
||||
if not (isinstance(v, _ResultMetric) and v.has_reset)
|
||||
and self.dataloader_idx in (None, v.meta.dataloader_idx)
|
||||
)
|
||||
|
||||
def _forked_name(self, result_metric: ResultMetric, on_step: bool) -> Tuple[str, str]:
|
||||
def _forked_name(self, result_metric: _ResultMetric, on_step: bool) -> Tuple[str, str]:
|
||||
name = result_metric.meta.name
|
||||
forked_name = result_metric.meta.forked_name(on_step)
|
||||
add_dataloader_idx = result_metric.meta.add_dataloader_idx
|
||||
|
@ -549,11 +549,11 @@ class ResultCollection(dict):
|
|||
|
||||
for _, result_metric in self.valid_items():
|
||||
|
||||
# extract forward_cache or computed from the ResultMetric. ignore when the output is None
|
||||
value = apply_to_collection(result_metric, ResultMetric, self._get_cache, on_step, include_none=False)
|
||||
# extract forward_cache or computed from the _ResultMetric. ignore when the output is None
|
||||
value = apply_to_collection(result_metric, _ResultMetric, self._get_cache, on_step, include_none=False)
|
||||
|
||||
# convert metric collection to dict container.
|
||||
if isinstance(value, ResultMetricCollection):
|
||||
if isinstance(value, _ResultMetricCollection):
|
||||
value = dict(value.items())
|
||||
|
||||
# check if the collection is empty
|
||||
|
@ -594,15 +594,15 @@ class ResultCollection(dict):
|
|||
fx: Function to reset
|
||||
"""
|
||||
|
||||
def fn(item: ResultMetric) -> None:
|
||||
def fn(item: _ResultMetric) -> None:
|
||||
requested_type = metrics is None or metrics ^ item.is_tensor
|
||||
same_fx = fx is None or fx == item.meta.fx
|
||||
if requested_type and same_fx:
|
||||
item.reset()
|
||||
|
||||
apply_to_collection(self, ResultMetric, fn)
|
||||
apply_to_collection(self, _ResultMetric, fn)
|
||||
|
||||
def to(self, *args: Any, **kwargs: Any) -> "ResultCollection":
|
||||
def to(self, *args: Any, **kwargs: Any) -> "_ResultCollection":
|
||||
"""Move all data to the given device."""
|
||||
self.update(apply_to_collection(dict(self), (torch.Tensor, Metric), move_data_to_device, *args, **kwargs))
|
||||
|
||||
|
@ -610,7 +610,7 @@ class ResultCollection(dict):
|
|||
self.device = kwargs["device"]
|
||||
return self
|
||||
|
||||
def cpu(self) -> "ResultCollection":
|
||||
def cpu(self) -> "_ResultCollection":
|
||||
"""Move all data to CPU."""
|
||||
return self.to(device="cpu")
|
||||
|
||||
|
@ -634,7 +634,7 @@ class ResultCollection(dict):
|
|||
|
||||
def __getstate__(self, drop_value: bool = True) -> dict:
|
||||
d = self.__dict__.copy()
|
||||
# all the items should be either `ResultMetric`s or `ResultMetricCollection`s
|
||||
# all the items should be either `_ResultMetric`s or `_ResultMetricCollection`s
|
||||
items = {k: v.__getstate__(drop_value=drop_value) for k, v in self.items()}
|
||||
return {**d, "items": items}
|
||||
|
||||
|
@ -643,14 +643,14 @@ class ResultCollection(dict):
|
|||
) -> None:
|
||||
self.__dict__.update({k: v for k, v in state.items() if k != "items"})
|
||||
|
||||
def setstate(k: str, item: dict) -> Union[ResultMetric, ResultMetricCollection]:
|
||||
def setstate(k: str, item: dict) -> Union[_ResultMetric, _ResultMetricCollection]:
|
||||
if not isinstance(item, dict):
|
||||
raise ValueError(f"Unexpected value: {item}")
|
||||
cls = item["_class"]
|
||||
if cls == ResultMetric.__name__:
|
||||
cls = ResultMetric
|
||||
elif cls == ResultMetricCollection.__name__:
|
||||
cls = ResultMetricCollection
|
||||
if cls == _ResultMetric.__name__:
|
||||
cls = _ResultMetric
|
||||
elif cls == _ResultMetricCollection.__name__:
|
||||
cls = _ResultMetricCollection
|
||||
else:
|
||||
raise ValueError(f"Unexpected class name: {cls}")
|
||||
_sync_fn = sync_fn or (self[k].meta.sync.fn if k in self else None)
|
||||
|
|
|
@ -65,7 +65,7 @@ from pytorch_lightning.trainer.connectors.callback_connector import CallbackConn
|
|||
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
|
||||
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
|
||||
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
|
||||
from pytorch_lightning.trainer.connectors.signal_connector import SignalConnector
|
||||
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
|
||||
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
|
||||
|
@ -2242,7 +2242,7 @@ class Trainer(
|
|||
return self.logger_connector.progress_bar_metrics
|
||||
|
||||
@property
|
||||
def _results(self) -> Optional[ResultCollection]:
|
||||
def _results(self) -> Optional[_ResultCollection]:
|
||||
active_loop = self._active_loop
|
||||
if active_loop is not None:
|
||||
return active_loop._results
|
||||
|
|
|
@ -29,9 +29,9 @@ from pytorch_lightning import Trainer
|
|||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import (
|
||||
_Metadata,
|
||||
_ResultCollection,
|
||||
_ResultMetric,
|
||||
_Sync,
|
||||
ResultCollection,
|
||||
ResultMetric,
|
||||
)
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_training
|
||||
from tests.helpers import BoringModel
|
||||
|
@ -71,7 +71,7 @@ def _ddp_test_fn(rank, worldsize):
|
|||
metric_b = metric_b.to(f"cuda:{rank}")
|
||||
metric_c = metric_c.to(f"cuda:{rank}")
|
||||
|
||||
result = ResultCollection(True, torch.device(f"cuda:{rank}"))
|
||||
result = _ResultCollection(True, torch.device(f"cuda:{rank}"))
|
||||
|
||||
for _ in range(3):
|
||||
cumulative_sum = 0
|
||||
|
@ -114,7 +114,7 @@ def test_result_metric_integration():
|
|||
metric_b = DummyMetric()
|
||||
metric_c = DummyMetric()
|
||||
|
||||
result = ResultCollection(True, torch.device("cpu"))
|
||||
result = _ResultCollection(True, torch.device("cpu"))
|
||||
|
||||
for _ in range(3):
|
||||
cumulative_sum = 0
|
||||
|
@ -145,26 +145,26 @@ def test_result_metric_integration():
|
|||
result.minimize = torch.tensor(1.0)
|
||||
result.extra = {}
|
||||
assert str(result) == (
|
||||
"ResultCollection("
|
||||
"_ResultCollection("
|
||||
"{"
|
||||
"'h.a': ResultMetric('a', value=DummyMetric()), "
|
||||
"'h.b': ResultMetric('b', value=DummyMetric()), "
|
||||
"'h.c': ResultMetric('c', value=DummyMetric())"
|
||||
"'h.a': _ResultMetric('a', value=DummyMetric()), "
|
||||
"'h.b': _ResultMetric('b', value=DummyMetric()), "
|
||||
"'h.c': _ResultMetric('c', value=DummyMetric())"
|
||||
"})"
|
||||
)
|
||||
assert repr(result) == (
|
||||
"{"
|
||||
"True, "
|
||||
"device(type='cpu'), "
|
||||
"{'h.a': ResultMetric('a', value=DummyMetric()), "
|
||||
"'h.b': ResultMetric('b', value=DummyMetric()), "
|
||||
"'h.c': ResultMetric('c', value=DummyMetric())"
|
||||
"{'h.a': _ResultMetric('a', value=DummyMetric()), "
|
||||
"'h.b': _ResultMetric('b', value=DummyMetric()), "
|
||||
"'h.c': _ResultMetric('c', value=DummyMetric())"
|
||||
"}}"
|
||||
)
|
||||
|
||||
|
||||
def test_result_collection_simple_loop():
|
||||
result = ResultCollection(True, torch.device("cpu"))
|
||||
result = _ResultCollection(True, torch.device("cpu"))
|
||||
current_fx_name = None
|
||||
batch_idx = None
|
||||
|
||||
|
@ -212,7 +212,7 @@ def my_sync_dist(x, *_, **__):
|
|||
def test_result_collection_restoration(tmpdir):
|
||||
"""This test make sure metrics are properly reloaded on failure."""
|
||||
|
||||
result = ResultCollection(True, torch.device("cpu"))
|
||||
result = _ResultCollection(True, torch.device("cpu"))
|
||||
metric_a = DummyMetric()
|
||||
metric_b = DummyMetric()
|
||||
metric_c = DummyMetric()
|
||||
|
@ -253,7 +253,7 @@ def test_result_collection_restoration(tmpdir):
|
|||
assert set(batch_log["c_1"]) == {"1", "2"}
|
||||
|
||||
result_copy = deepcopy(result)
|
||||
new_result = ResultCollection(True, torch.device("cpu"))
|
||||
new_result = _ResultCollection(True, torch.device("cpu"))
|
||||
state_dict = result.state_dict()
|
||||
# check the sync fn was dropped
|
||||
assert "fn" not in state_dict["items"]["training_step.a"]["meta"]["_sync"]
|
||||
|
@ -334,7 +334,7 @@ def test_lightning_module_logging_result_collection(tmpdir, device):
|
|||
assert results["validation_step.v"].meta.sync.fn == self.trainer.training_type_plugin.reduce
|
||||
|
||||
# default sync fn
|
||||
new_results = ResultCollection(False, device)
|
||||
new_results = _ResultCollection(False, device)
|
||||
new_results.load_state_dict(state_dict, map_location="cpu")
|
||||
assert new_results["validation_step.v"].meta.sync.fn is None
|
||||
|
||||
|
@ -373,7 +373,7 @@ class DummyMeanMetric(Metric):
|
|||
|
||||
def result_collection_reload(**kwargs):
|
||||
|
||||
"""This test is going to validate ResultCollection is properly being reload and final accumulation with Fault
|
||||
"""This test is going to validate _ResultCollection is properly being reload and final accumulation with Fault
|
||||
Tolerant Training is correct."""
|
||||
|
||||
if not _fault_tolerant_training():
|
||||
|
@ -551,10 +551,10 @@ def test_metric_result_computed_check():
|
|||
"""Unittest ``_get_cache`` with multielement tensors."""
|
||||
metadata = _Metadata("foo", "bar", on_epoch=True, enable_graph=True)
|
||||
metadata.sync = _Sync()
|
||||
rm = ResultMetric(metadata, is_tensor=True)
|
||||
rm = _ResultMetric(metadata, is_tensor=True)
|
||||
computed_value = torch.tensor([1, 2, 3])
|
||||
rm._computed = computed_value
|
||||
cache = ResultCollection._get_cache(rm, on_step=False)
|
||||
cache = _ResultCollection._get_cache(rm, on_step=False)
|
||||
# `enable_graph=True` so no detach, identity works
|
||||
assert cache is computed_value
|
||||
|
||||
|
@ -566,7 +566,7 @@ def test_metric_result_respects_dtype(floating_dtype):
|
|||
|
||||
metadata = _Metadata("foo", "bar")
|
||||
metadata.sync = _Sync()
|
||||
rm = ResultMetric(metadata, is_tensor=True)
|
||||
rm = _ResultMetric(metadata, is_tensor=True)
|
||||
|
||||
assert rm.value.dtype == floating_dtype
|
||||
assert rm.cumulated_batch_size.dtype == fixed_dtype
|
||||
|
|
|
@ -23,7 +23,7 @@ from pytorch_lightning import LightningModule
|
|||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.trainer import Trainer
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.helpers.boring_model import BoringModel, RandomDataset
|
||||
from tests.helpers.runif import RunIf
|
||||
|
@ -267,19 +267,19 @@ def test_fx_validator_integration(tmpdir):
|
|||
with pytest.deprecated_call(match="on_test_dataloader` is deprecated in v1.5"):
|
||||
trainer.test(model, verbose=False)
|
||||
|
||||
not_supported.update({k: "ResultCollection` is not registered yet" for k in not_supported})
|
||||
not_supported.update({k: "result collection is not registered yet" for k in not_supported})
|
||||
not_supported.update(
|
||||
{
|
||||
"on_predict_dataloader": "ResultCollection` is not registered yet",
|
||||
"predict_dataloader": "ResultCollection` is not registered yet",
|
||||
"on_predict_model_eval": "ResultCollection` is not registered yet",
|
||||
"on_predict_start": "ResultCollection` is not registered yet",
|
||||
"on_predict_epoch_start": "ResultCollection` is not registered yet",
|
||||
"on_predict_batch_start": "ResultCollection` is not registered yet",
|
||||
"predict_step": "ResultCollection` is not registered yet",
|
||||
"on_predict_batch_end": "ResultCollection` is not registered yet",
|
||||
"on_predict_epoch_end": "ResultCollection` is not registered yet",
|
||||
"on_predict_end": "ResultCollection` is not registered yet",
|
||||
"on_predict_dataloader": "result collection is not registered yet",
|
||||
"predict_dataloader": "result collection is not registered yet",
|
||||
"on_predict_model_eval": "result collection is not registered yet",
|
||||
"on_predict_start": "result collection is not registered yet",
|
||||
"on_predict_epoch_start": "result collection is not registered yet",
|
||||
"on_predict_batch_start": "result collection is not registered yet",
|
||||
"predict_step": "result collection is not registered yet",
|
||||
"on_predict_batch_end": "result collection is not registered yet",
|
||||
"on_predict_epoch_end": "result collection is not registered yet",
|
||||
"on_predict_end": "result collection is not registered yet",
|
||||
}
|
||||
)
|
||||
with pytest.deprecated_call(match="on_predict_dataloader` is deprecated in v1.5"):
|
||||
|
@ -531,7 +531,7 @@ def test_metrics_reset(tmpdir):
|
|||
|
||||
|
||||
def test_result_collection_on_tensor_with_mean_reduction():
|
||||
result_collection = ResultCollection(True)
|
||||
result_collection = _ResultCollection(True)
|
||||
product = [(True, True), (False, True), (True, False), (False, False)]
|
||||
values = torch.arange(1, 10)
|
||||
batches = values * values
|
||||
|
@ -647,7 +647,7 @@ def test_result_collection_batch_size_extraction():
|
|||
fx_name = "training_step"
|
||||
log_val = torch.tensor(7.0)
|
||||
|
||||
results = ResultCollection(training=True, device="cpu")
|
||||
results = _ResultCollection(training=True, device="cpu")
|
||||
results.batch = torch.randn(1, 4)
|
||||
train_mse = MeanSquaredError()
|
||||
train_mse(torch.randn(4, 5), torch.randn(4, 5))
|
||||
|
@ -656,7 +656,7 @@ def test_result_collection_batch_size_extraction():
|
|||
assert isinstance(results["training_step.train_logs"]["mse"].value, MeanSquaredError)
|
||||
assert results["training_step.train_logs"]["log_val"].value == log_val
|
||||
|
||||
results = ResultCollection(training=True, device="cpu")
|
||||
results = _ResultCollection(training=True, device="cpu")
|
||||
results.batch = torch.randn(1, 4)
|
||||
results.log(fx_name, "train_log", log_val, on_step=False, on_epoch=True)
|
||||
assert results.batch_size == 1
|
||||
|
@ -665,7 +665,7 @@ def test_result_collection_batch_size_extraction():
|
|||
|
||||
|
||||
def test_result_collection_no_batch_size_extraction():
|
||||
results = ResultCollection(training=True, device="cpu")
|
||||
results = _ResultCollection(training=True, device="cpu")
|
||||
results.batch = torch.randn(1, 4)
|
||||
fx_name = "training_step"
|
||||
batch_size = 10
|
||||
|
|
|
@ -20,7 +20,7 @@ import torch
|
|||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.fx_validator import _FxValidator
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import _ResultCollection
|
||||
from pytorch_lightning.trainer.states import RunningStage, TrainerFn
|
||||
from tests.helpers.boring_model import BoringModel
|
||||
|
||||
|
@ -39,13 +39,13 @@ def test_default_level_for_hooks_that_support_logging():
|
|||
model.trainer = trainer
|
||||
extra_kwargs = {
|
||||
k: ANY
|
||||
for k in inspect.signature(ResultCollection.log).parameters
|
||||
for k in inspect.signature(_ResultCollection.log).parameters
|
||||
if k not in ["self", "fx", "name", "value", "on_step", "on_epoch"]
|
||||
}
|
||||
all_logging_hooks = {k for k in _FxValidator.functions if _FxValidator.functions[k]}
|
||||
|
||||
with mock.patch(
|
||||
"pytorch_lightning.trainer.connectors.logger_connector.result.ResultCollection.log", return_value=None
|
||||
"pytorch_lightning.trainer.connectors.logger_connector.result._ResultCollection.log", return_value=None
|
||||
) as result_mock:
|
||||
trainer.state.stage = RunningStage.TRAINING
|
||||
hooks = [
|
||||
|
|
Loading…
Reference in New Issue