External callback registry through entry points for Fabric (#17756)
This commit is contained in:
parent
d23c772f3c
commit
e2986fab14
|
@ -58,6 +58,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
||||||
- Added support for loading a full-state checkpoint file into a sharded model ([#17623](https://github.com/Lightning-AI/lightning/pull/17623))
|
- Added support for loading a full-state checkpoint file into a sharded model ([#17623](https://github.com/Lightning-AI/lightning/pull/17623))
|
||||||
|
|
||||||
|
|
||||||
|
- Added support for `Callback` registration through entry points ([#17756](https://github.com/Lightning-AI/lightning/pull/17756))
|
||||||
|
|
||||||
|
|
||||||
### Changed
|
### Changed
|
||||||
|
|
||||||
- Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331))
|
- Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331))
|
||||||
|
|
|
@ -45,6 +45,7 @@ from lightning.fabric.utilities.data import (
|
||||||
has_iterable_dataset,
|
has_iterable_dataset,
|
||||||
)
|
)
|
||||||
from lightning.fabric.utilities.distributed import DistributedSamplerWrapper
|
from lightning.fabric.utilities.distributed import DistributedSamplerWrapper
|
||||||
|
from lightning.fabric.utilities.registry import _load_external_callbacks
|
||||||
from lightning.fabric.utilities.seed import seed_everything
|
from lightning.fabric.utilities.seed import seed_everything
|
||||||
from lightning.fabric.utilities.types import ReduceOp
|
from lightning.fabric.utilities.types import ReduceOp
|
||||||
from lightning.fabric.utilities.warnings import PossibleUserWarning
|
from lightning.fabric.utilities.warnings import PossibleUserWarning
|
||||||
|
@ -111,8 +112,7 @@ class Fabric:
|
||||||
self._strategy: Strategy = self._connector.strategy
|
self._strategy: Strategy = self._connector.strategy
|
||||||
self._accelerator: Accelerator = self._connector.accelerator
|
self._accelerator: Accelerator = self._connector.accelerator
|
||||||
self._precision: Precision = self._strategy.precision
|
self._precision: Precision = self._strategy.precision
|
||||||
callbacks = callbacks if callbacks is not None else []
|
self._callbacks = self._configure_callbacks(callbacks)
|
||||||
self._callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
|
|
||||||
loggers = loggers if loggers is not None else []
|
loggers = loggers if loggers is not None else []
|
||||||
self._loggers = loggers if isinstance(loggers, list) else [loggers]
|
self._loggers = loggers if isinstance(loggers, list) else [loggers]
|
||||||
self._models_setup: int = 0
|
self._models_setup: int = 0
|
||||||
|
@ -908,6 +908,13 @@ class Fabric:
|
||||||
if any(not isinstance(dl, DataLoader) for dl in dataloaders):
|
if any(not isinstance(dl, DataLoader) for dl in dataloaders):
|
||||||
raise TypeError("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.")
|
raise TypeError("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _configure_callbacks(callbacks: Optional[Union[List[Any], Any]]) -> List[Any]:
|
||||||
|
callbacks = callbacks if callbacks is not None else []
|
||||||
|
callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
|
||||||
|
callbacks.extend(_load_external_callbacks("lightning.fabric.callbacks_factory"))
|
||||||
|
return callbacks
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _old_sharded_model_context(strategy: Strategy) -> Generator:
|
def _old_sharded_model_context(strategy: Strategy) -> Generator:
|
||||||
|
|
|
@ -30,3 +30,6 @@ _TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0")
|
||||||
_TORCH_GREATER_EQUAL_2_0 = compare_version("torch", operator.ge, "2.0.0")
|
_TORCH_GREATER_EQUAL_2_0 = compare_version("torch", operator.ge, "2.0.0")
|
||||||
_TORCH_GREATER_EQUAL_2_1 = compare_version("torch", operator.ge, "2.1.0", use_base_version=True)
|
_TORCH_GREATER_EQUAL_2_1 = compare_version("torch", operator.ge, "2.1.0", use_base_version=True)
|
||||||
_TORCH_EQUAL_2_0 = _TORCH_GREATER_EQUAL_2_0 and not _TORCH_GREATER_EQUAL_2_1
|
_TORCH_EQUAL_2_0 = _TORCH_GREATER_EQUAL_2_0 and not _TORCH_GREATER_EQUAL_2_1
|
||||||
|
|
||||||
|
_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)
|
||||||
|
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
|
||||||
|
|
|
@ -12,7 +12,12 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import inspect
|
import inspect
|
||||||
from typing import Any
|
import logging
|
||||||
|
from typing import Any, List, Union
|
||||||
|
|
||||||
|
from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
|
||||||
|
|
||||||
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _is_register_method_overridden(mod: type, base_cls: Any, method: str) -> bool:
|
def _is_register_method_overridden(mod: type, base_cls: Any, method: str) -> bool:
|
||||||
|
@ -25,3 +30,40 @@ def _is_register_method_overridden(mod: type, base_cls: Any, method: str) -> boo
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return mod_attr.__code__ is not super_attr.__code__
|
return mod_attr.__code__ is not super_attr.__code__
|
||||||
|
|
||||||
|
|
||||||
|
def _load_external_callbacks(group: str) -> List[Any]:
|
||||||
|
"""Collect external callbacks registered through entry points.
|
||||||
|
|
||||||
|
The entry points are expected to be functions returning a list of callbacks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
group: The entry point group name to load callbacks from.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
A list of all callbacks collected from external factories.
|
||||||
|
"""
|
||||||
|
if _PYTHON_GREATER_EQUAL_3_8_0:
|
||||||
|
from importlib.metadata import entry_points
|
||||||
|
|
||||||
|
factories = (
|
||||||
|
entry_points(group=group)
|
||||||
|
if _PYTHON_GREATER_EQUAL_3_10_0
|
||||||
|
else entry_points().get(group, {}) # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from pkg_resources import iter_entry_points
|
||||||
|
|
||||||
|
factories = iter_entry_points(group) # type: ignore[assignment]
|
||||||
|
|
||||||
|
external_callbacks: List[Any] = []
|
||||||
|
for factory in factories:
|
||||||
|
callback_factory = factory.load()
|
||||||
|
callbacks_list: Union[List[Any], Any] = callback_factory()
|
||||||
|
callbacks_list = [callbacks_list] if not isinstance(callbacks_list, list) else callbacks_list
|
||||||
|
_log.info(
|
||||||
|
f"Adding {len(callbacks_list)} callbacks from entry point '{factory.name}':"
|
||||||
|
f" {', '.join(type(cb).__name__ for cb in callbacks_list)}"
|
||||||
|
)
|
||||||
|
external_callbacks.extend(callbacks_list)
|
||||||
|
return external_callbacks
|
||||||
|
|
|
@ -18,6 +18,7 @@ from datetime import timedelta
|
||||||
from typing import Dict, List, Optional, Sequence, Union
|
from typing import Dict, List, Optional, Sequence, Union
|
||||||
|
|
||||||
import lightning.pytorch as pl
|
import lightning.pytorch as pl
|
||||||
|
from lightning.fabric.utilities.registry import _load_external_callbacks
|
||||||
from lightning.pytorch.callbacks import (
|
from lightning.pytorch.callbacks import (
|
||||||
Callback,
|
Callback,
|
||||||
Checkpoint,
|
Checkpoint,
|
||||||
|
@ -33,7 +34,6 @@ from lightning.pytorch.callbacks.rich_model_summary import RichModelSummary
|
||||||
from lightning.pytorch.callbacks.timer import Timer
|
from lightning.pytorch.callbacks.timer import Timer
|
||||||
from lightning.pytorch.trainer import call
|
from lightning.pytorch.trainer import call
|
||||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||||
from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
|
|
||||||
from lightning.pytorch.utilities.model_helpers import is_overridden
|
from lightning.pytorch.utilities.model_helpers import is_overridden
|
||||||
from lightning.pytorch.utilities.rank_zero import rank_zero_info
|
from lightning.pytorch.utilities.rank_zero import rank_zero_info
|
||||||
|
|
||||||
|
@ -75,7 +75,7 @@ class _CallbackConnector:
|
||||||
# configure the ModelSummary callback
|
# configure the ModelSummary callback
|
||||||
self._configure_model_summary_callback(enable_model_summary)
|
self._configure_model_summary_callback(enable_model_summary)
|
||||||
|
|
||||||
self.trainer.callbacks.extend(_configure_external_callbacks())
|
self.trainer.callbacks.extend(_load_external_callbacks("lightning.pytorch.callbacks_factory"))
|
||||||
_validate_callbacks_list(self.trainer.callbacks)
|
_validate_callbacks_list(self.trainer.callbacks)
|
||||||
|
|
||||||
# push all model checkpoint callbacks to the end
|
# push all model checkpoint callbacks to the end
|
||||||
|
@ -213,42 +213,6 @@ class _CallbackConnector:
|
||||||
return tuner_callbacks + other_callbacks + checkpoint_callbacks
|
return tuner_callbacks + other_callbacks + checkpoint_callbacks
|
||||||
|
|
||||||
|
|
||||||
def _configure_external_callbacks() -> List[Callback]:
|
|
||||||
"""Collect external callbacks registered through entry points.
|
|
||||||
|
|
||||||
The entry points are expected to be functions returning a list of callbacks.
|
|
||||||
|
|
||||||
Return:
|
|
||||||
A list of all callbacks collected from external factories.
|
|
||||||
"""
|
|
||||||
group = "lightning.pytorch.callbacks_factory"
|
|
||||||
|
|
||||||
if _PYTHON_GREATER_EQUAL_3_8_0:
|
|
||||||
from importlib.metadata import entry_points
|
|
||||||
|
|
||||||
factories = (
|
|
||||||
entry_points(group=group)
|
|
||||||
if _PYTHON_GREATER_EQUAL_3_10_0
|
|
||||||
else entry_points().get(group, {}) # type: ignore[arg-type]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
from pkg_resources import iter_entry_points
|
|
||||||
|
|
||||||
factories = iter_entry_points(group) # type: ignore[assignment]
|
|
||||||
|
|
||||||
external_callbacks: List[Callback] = []
|
|
||||||
for factory in factories:
|
|
||||||
callback_factory = factory.load()
|
|
||||||
callbacks_list: Union[List[Callback], Callback] = callback_factory()
|
|
||||||
callbacks_list = [callbacks_list] if isinstance(callbacks_list, Callback) else callbacks_list
|
|
||||||
_log.info(
|
|
||||||
f"Adding {len(callbacks_list)} callbacks from entry point '{factory.name}':"
|
|
||||||
f" {', '.join(type(cb).__name__ for cb in callbacks_list)}"
|
|
||||||
)
|
|
||||||
external_callbacks.extend(callbacks_list)
|
|
||||||
return external_callbacks
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_callbacks_list(callbacks: List[Callback]) -> None:
|
def _validate_callbacks_list(callbacks: List[Callback]) -> None:
|
||||||
stateful_callbacks = [cb for cb in callbacks if is_overridden("state_dict", instance=cb)]
|
stateful_callbacks = [cb for cb in callbacks if is_overridden("state_dict", instance=cb)]
|
||||||
seen_callbacks = set()
|
seen_callbacks = set()
|
||||||
|
|
|
@ -11,8 +11,7 @@ from lightning_utilities.core.rank_zero import rank_prefixed_message
|
||||||
|
|
||||||
import lightning.pytorch as pl
|
import lightning.pytorch as pl
|
||||||
from lightning.fabric.plugins.environments import SLURMEnvironment
|
from lightning.fabric.plugins.environments import SLURMEnvironment
|
||||||
from lightning.fabric.utilities.imports import _IS_WINDOWS
|
from lightning.fabric.utilities.imports import _IS_WINDOWS, _PYTHON_GREATER_EQUAL_3_8_0
|
||||||
from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0
|
|
||||||
from lightning.pytorch.utilities.rank_zero import rank_zero_info
|
from lightning.pytorch.utilities.rank_zero import rank_zero_info
|
||||||
|
|
||||||
# copied from signal.pyi
|
# copied from signal.pyi
|
||||||
|
|
|
@ -17,8 +17,6 @@ import sys
|
||||||
import torch
|
import torch
|
||||||
from lightning_utilities.core.imports import package_available, RequirementCache
|
from lightning_utilities.core.imports import package_available, RequirementCache
|
||||||
|
|
||||||
_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)
|
|
||||||
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
|
|
||||||
_PYTHON_GREATER_EQUAL_3_11_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 11)
|
_PYTHON_GREATER_EQUAL_3_11_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 11)
|
||||||
_TORCHMETRICS_GREATER_EQUAL_0_9_1 = RequirementCache("torchmetrics>=0.9.1")
|
_TORCHMETRICS_GREATER_EQUAL_0_9_1 = RequirementCache("torchmetrics>=0.9.1")
|
||||||
_TORCHMETRICS_GREATER_EQUAL_0_11 = RequirementCache("torchmetrics>=0.11.0") # using new API with task
|
_TORCHMETRICS_GREATER_EQUAL_0_11 = RequirementCache("torchmetrics>=0.11.0") # using new API with task
|
||||||
|
|
|
@ -0,0 +1,64 @@
|
||||||
|
import contextlib
|
||||||
|
from unittest import mock
|
||||||
|
from unittest.mock import Mock
|
||||||
|
|
||||||
|
from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
|
||||||
|
from lightning.fabric.utilities.registry import _load_external_callbacks
|
||||||
|
|
||||||
|
|
||||||
|
class ExternalCallback:
|
||||||
|
"""A callback in another library that gets registered through entry points."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_external_callbacks():
|
||||||
|
"""Test that the connector collects Callback instances from factories registered through entry points."""
|
||||||
|
|
||||||
|
def factory_no_callback():
|
||||||
|
return []
|
||||||
|
|
||||||
|
def factory_one_callback():
|
||||||
|
return ExternalCallback()
|
||||||
|
|
||||||
|
def factory_one_callback_list():
|
||||||
|
return [ExternalCallback()]
|
||||||
|
|
||||||
|
def factory_multiple_callbacks_list():
|
||||||
|
return [ExternalCallback(), ExternalCallback()]
|
||||||
|
|
||||||
|
with _make_entry_point_query_mock(factory_no_callback):
|
||||||
|
callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory")
|
||||||
|
assert callbacks == []
|
||||||
|
|
||||||
|
with _make_entry_point_query_mock(factory_one_callback):
|
||||||
|
callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory")
|
||||||
|
assert isinstance(callbacks[0], ExternalCallback)
|
||||||
|
|
||||||
|
with _make_entry_point_query_mock(factory_one_callback_list):
|
||||||
|
callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory")
|
||||||
|
assert isinstance(callbacks[0], ExternalCallback)
|
||||||
|
|
||||||
|
with _make_entry_point_query_mock(factory_multiple_callbacks_list):
|
||||||
|
callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory")
|
||||||
|
assert isinstance(callbacks[0], ExternalCallback)
|
||||||
|
assert isinstance(callbacks[1], ExternalCallback)
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def _make_entry_point_query_mock(callback_factory):
|
||||||
|
query_mock = Mock()
|
||||||
|
entry_point = Mock()
|
||||||
|
entry_point.name = "mocked"
|
||||||
|
entry_point.load.return_value = callback_factory
|
||||||
|
if _PYTHON_GREATER_EQUAL_3_10_0:
|
||||||
|
query_mock.return_value = [entry_point]
|
||||||
|
import_path = "importlib.metadata.entry_points"
|
||||||
|
elif _PYTHON_GREATER_EQUAL_3_8_0:
|
||||||
|
query_mock().get.return_value = [entry_point]
|
||||||
|
import_path = "importlib.metadata.entry_points"
|
||||||
|
else:
|
||||||
|
query_mock.return_value = [entry_point]
|
||||||
|
import_path = "pkg_resources.iter_entry_points"
|
||||||
|
with mock.patch(import_path, query_mock):
|
||||||
|
yield
|
|
@ -19,6 +19,7 @@ from unittest.mock import Mock
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
|
||||||
from lightning.pytorch import Callback, LightningModule, Trainer
|
from lightning.pytorch import Callback, LightningModule, Trainer
|
||||||
from lightning.pytorch.callbacks import (
|
from lightning.pytorch.callbacks import (
|
||||||
EarlyStopping,
|
EarlyStopping,
|
||||||
|
@ -32,7 +33,6 @@ from lightning.pytorch.callbacks import (
|
||||||
from lightning.pytorch.callbacks.batch_size_finder import BatchSizeFinder
|
from lightning.pytorch.callbacks.batch_size_finder import BatchSizeFinder
|
||||||
from lightning.pytorch.demos.boring_classes import BoringModel
|
from lightning.pytorch.demos.boring_classes import BoringModel
|
||||||
from lightning.pytorch.trainer.connectors.callback_connector import _CallbackConnector
|
from lightning.pytorch.trainer.connectors.callback_connector import _CallbackConnector
|
||||||
from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
|
|
||||||
|
|
||||||
|
|
||||||
def test_checkpoint_callbacks_are_last(tmpdir):
|
def test_checkpoint_callbacks_are_last(tmpdir):
|
||||||
|
|
|
@ -25,6 +25,7 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0
|
||||||
from lightning.pytorch import callbacks, Trainer
|
from lightning.pytorch import callbacks, Trainer
|
||||||
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
|
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
|
||||||
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
|
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
|
||||||
|
@ -32,7 +33,6 @@ from lightning.pytorch.loggers import TensorBoardLogger
|
||||||
from lightning.pytorch.loops import _EvaluationLoop
|
from lightning.pytorch.loops import _EvaluationLoop
|
||||||
from lightning.pytorch.trainer.states import RunningStage
|
from lightning.pytorch.trainer.states import RunningStage
|
||||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||||
from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0
|
|
||||||
from tests_pytorch.helpers.runif import RunIf
|
from tests_pytorch.helpers.runif import RunIf
|
||||||
|
|
||||||
if _RICH_AVAILABLE:
|
if _RICH_AVAILABLE:
|
||||||
|
|
Loading…
Reference in New Issue