External callback registry through entry points for Fabric (#17756)
This commit is contained in:
parent
d23c772f3c
commit
e2986fab14
|
@ -58,6 +58,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added support for loading a full-state checkpoint file into a sharded model ([#17623](https://github.com/Lightning-AI/lightning/pull/17623))
|
||||
|
||||
|
||||
- Added support for `Callback` registration through entry points ([#17756](https://github.com/Lightning-AI/lightning/pull/17756))
|
||||
|
||||
|
||||
### Changed
|
||||
|
||||
- Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331))
|
||||
|
|
|
@ -45,6 +45,7 @@ from lightning.fabric.utilities.data import (
|
|||
has_iterable_dataset,
|
||||
)
|
||||
from lightning.fabric.utilities.distributed import DistributedSamplerWrapper
|
||||
from lightning.fabric.utilities.registry import _load_external_callbacks
|
||||
from lightning.fabric.utilities.seed import seed_everything
|
||||
from lightning.fabric.utilities.types import ReduceOp
|
||||
from lightning.fabric.utilities.warnings import PossibleUserWarning
|
||||
|
@ -111,8 +112,7 @@ class Fabric:
|
|||
self._strategy: Strategy = self._connector.strategy
|
||||
self._accelerator: Accelerator = self._connector.accelerator
|
||||
self._precision: Precision = self._strategy.precision
|
||||
callbacks = callbacks if callbacks is not None else []
|
||||
self._callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
|
||||
self._callbacks = self._configure_callbacks(callbacks)
|
||||
loggers = loggers if loggers is not None else []
|
||||
self._loggers = loggers if isinstance(loggers, list) else [loggers]
|
||||
self._models_setup: int = 0
|
||||
|
@ -908,6 +908,13 @@ class Fabric:
|
|||
if any(not isinstance(dl, DataLoader) for dl in dataloaders):
|
||||
raise TypeError("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.")
|
||||
|
||||
@staticmethod
|
||||
def _configure_callbacks(callbacks: Optional[Union[List[Any], Any]]) -> List[Any]:
|
||||
callbacks = callbacks if callbacks is not None else []
|
||||
callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
|
||||
callbacks.extend(_load_external_callbacks("lightning.fabric.callbacks_factory"))
|
||||
return callbacks
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _old_sharded_model_context(strategy: Strategy) -> Generator:
|
||||
|
|
|
@ -30,3 +30,6 @@ _TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0")
|
|||
_TORCH_GREATER_EQUAL_2_0 = compare_version("torch", operator.ge, "2.0.0")
|
||||
_TORCH_GREATER_EQUAL_2_1 = compare_version("torch", operator.ge, "2.1.0", use_base_version=True)
|
||||
_TORCH_EQUAL_2_0 = _TORCH_GREATER_EQUAL_2_0 and not _TORCH_GREATER_EQUAL_2_1
|
||||
|
||||
_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)
|
||||
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
|
||||
|
|
|
@ -12,7 +12,12 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
from typing import Any
|
||||
import logging
|
||||
from typing import Any, List, Union
|
||||
|
||||
from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _is_register_method_overridden(mod: type, base_cls: Any, method: str) -> bool:
|
||||
|
@ -25,3 +30,40 @@ def _is_register_method_overridden(mod: type, base_cls: Any, method: str) -> boo
|
|||
return False
|
||||
|
||||
return mod_attr.__code__ is not super_attr.__code__
|
||||
|
||||
|
||||
def _load_external_callbacks(group: str) -> List[Any]:
|
||||
"""Collect external callbacks registered through entry points.
|
||||
|
||||
The entry points are expected to be functions returning a list of callbacks.
|
||||
|
||||
Args:
|
||||
group: The entry point group name to load callbacks from.
|
||||
|
||||
Return:
|
||||
A list of all callbacks collected from external factories.
|
||||
"""
|
||||
if _PYTHON_GREATER_EQUAL_3_8_0:
|
||||
from importlib.metadata import entry_points
|
||||
|
||||
factories = (
|
||||
entry_points(group=group)
|
||||
if _PYTHON_GREATER_EQUAL_3_10_0
|
||||
else entry_points().get(group, {}) # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
from pkg_resources import iter_entry_points
|
||||
|
||||
factories = iter_entry_points(group) # type: ignore[assignment]
|
||||
|
||||
external_callbacks: List[Any] = []
|
||||
for factory in factories:
|
||||
callback_factory = factory.load()
|
||||
callbacks_list: Union[List[Any], Any] = callback_factory()
|
||||
callbacks_list = [callbacks_list] if not isinstance(callbacks_list, list) else callbacks_list
|
||||
_log.info(
|
||||
f"Adding {len(callbacks_list)} callbacks from entry point '{factory.name}':"
|
||||
f" {', '.join(type(cb).__name__ for cb in callbacks_list)}"
|
||||
)
|
||||
external_callbacks.extend(callbacks_list)
|
||||
return external_callbacks
|
||||
|
|
|
@ -18,6 +18,7 @@ from datetime import timedelta
|
|||
from typing import Dict, List, Optional, Sequence, Union
|
||||
|
||||
import lightning.pytorch as pl
|
||||
from lightning.fabric.utilities.registry import _load_external_callbacks
|
||||
from lightning.pytorch.callbacks import (
|
||||
Callback,
|
||||
Checkpoint,
|
||||
|
@ -33,7 +34,6 @@ from lightning.pytorch.callbacks.rich_model_summary import RichModelSummary
|
|||
from lightning.pytorch.callbacks.timer import Timer
|
||||
from lightning.pytorch.trainer import call
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
|
||||
from lightning.pytorch.utilities.model_helpers import is_overridden
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_info
|
||||
|
||||
|
@ -75,7 +75,7 @@ class _CallbackConnector:
|
|||
# configure the ModelSummary callback
|
||||
self._configure_model_summary_callback(enable_model_summary)
|
||||
|
||||
self.trainer.callbacks.extend(_configure_external_callbacks())
|
||||
self.trainer.callbacks.extend(_load_external_callbacks("lightning.pytorch.callbacks_factory"))
|
||||
_validate_callbacks_list(self.trainer.callbacks)
|
||||
|
||||
# push all model checkpoint callbacks to the end
|
||||
|
@ -213,42 +213,6 @@ class _CallbackConnector:
|
|||
return tuner_callbacks + other_callbacks + checkpoint_callbacks
|
||||
|
||||
|
||||
def _configure_external_callbacks() -> List[Callback]:
|
||||
"""Collect external callbacks registered through entry points.
|
||||
|
||||
The entry points are expected to be functions returning a list of callbacks.
|
||||
|
||||
Return:
|
||||
A list of all callbacks collected from external factories.
|
||||
"""
|
||||
group = "lightning.pytorch.callbacks_factory"
|
||||
|
||||
if _PYTHON_GREATER_EQUAL_3_8_0:
|
||||
from importlib.metadata import entry_points
|
||||
|
||||
factories = (
|
||||
entry_points(group=group)
|
||||
if _PYTHON_GREATER_EQUAL_3_10_0
|
||||
else entry_points().get(group, {}) # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
from pkg_resources import iter_entry_points
|
||||
|
||||
factories = iter_entry_points(group) # type: ignore[assignment]
|
||||
|
||||
external_callbacks: List[Callback] = []
|
||||
for factory in factories:
|
||||
callback_factory = factory.load()
|
||||
callbacks_list: Union[List[Callback], Callback] = callback_factory()
|
||||
callbacks_list = [callbacks_list] if isinstance(callbacks_list, Callback) else callbacks_list
|
||||
_log.info(
|
||||
f"Adding {len(callbacks_list)} callbacks from entry point '{factory.name}':"
|
||||
f" {', '.join(type(cb).__name__ for cb in callbacks_list)}"
|
||||
)
|
||||
external_callbacks.extend(callbacks_list)
|
||||
return external_callbacks
|
||||
|
||||
|
||||
def _validate_callbacks_list(callbacks: List[Callback]) -> None:
|
||||
stateful_callbacks = [cb for cb in callbacks if is_overridden("state_dict", instance=cb)]
|
||||
seen_callbacks = set()
|
||||
|
|
|
@ -11,8 +11,7 @@ from lightning_utilities.core.rank_zero import rank_prefixed_message
|
|||
|
||||
import lightning.pytorch as pl
|
||||
from lightning.fabric.plugins.environments import SLURMEnvironment
|
||||
from lightning.fabric.utilities.imports import _IS_WINDOWS
|
||||
from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0
|
||||
from lightning.fabric.utilities.imports import _IS_WINDOWS, _PYTHON_GREATER_EQUAL_3_8_0
|
||||
from lightning.pytorch.utilities.rank_zero import rank_zero_info
|
||||
|
||||
# copied from signal.pyi
|
||||
|
|
|
@ -17,8 +17,6 @@ import sys
|
|||
import torch
|
||||
from lightning_utilities.core.imports import package_available, RequirementCache
|
||||
|
||||
_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)
|
||||
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
|
||||
_PYTHON_GREATER_EQUAL_3_11_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 11)
|
||||
_TORCHMETRICS_GREATER_EQUAL_0_9_1 = RequirementCache("torchmetrics>=0.9.1")
|
||||
_TORCHMETRICS_GREATER_EQUAL_0_11 = RequirementCache("torchmetrics>=0.11.0") # using new API with task
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
import contextlib
|
||||
from unittest import mock
|
||||
from unittest.mock import Mock
|
||||
|
||||
from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
|
||||
from lightning.fabric.utilities.registry import _load_external_callbacks
|
||||
|
||||
|
||||
class ExternalCallback:
|
||||
"""A callback in another library that gets registered through entry points."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def test_load_external_callbacks():
|
||||
"""Test that the connector collects Callback instances from factories registered through entry points."""
|
||||
|
||||
def factory_no_callback():
|
||||
return []
|
||||
|
||||
def factory_one_callback():
|
||||
return ExternalCallback()
|
||||
|
||||
def factory_one_callback_list():
|
||||
return [ExternalCallback()]
|
||||
|
||||
def factory_multiple_callbacks_list():
|
||||
return [ExternalCallback(), ExternalCallback()]
|
||||
|
||||
with _make_entry_point_query_mock(factory_no_callback):
|
||||
callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory")
|
||||
assert callbacks == []
|
||||
|
||||
with _make_entry_point_query_mock(factory_one_callback):
|
||||
callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory")
|
||||
assert isinstance(callbacks[0], ExternalCallback)
|
||||
|
||||
with _make_entry_point_query_mock(factory_one_callback_list):
|
||||
callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory")
|
||||
assert isinstance(callbacks[0], ExternalCallback)
|
||||
|
||||
with _make_entry_point_query_mock(factory_multiple_callbacks_list):
|
||||
callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory")
|
||||
assert isinstance(callbacks[0], ExternalCallback)
|
||||
assert isinstance(callbacks[1], ExternalCallback)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _make_entry_point_query_mock(callback_factory):
|
||||
query_mock = Mock()
|
||||
entry_point = Mock()
|
||||
entry_point.name = "mocked"
|
||||
entry_point.load.return_value = callback_factory
|
||||
if _PYTHON_GREATER_EQUAL_3_10_0:
|
||||
query_mock.return_value = [entry_point]
|
||||
import_path = "importlib.metadata.entry_points"
|
||||
elif _PYTHON_GREATER_EQUAL_3_8_0:
|
||||
query_mock().get.return_value = [entry_point]
|
||||
import_path = "importlib.metadata.entry_points"
|
||||
else:
|
||||
query_mock.return_value = [entry_point]
|
||||
import_path = "pkg_resources.iter_entry_points"
|
||||
with mock.patch(import_path, query_mock):
|
||||
yield
|
|
@ -19,6 +19,7 @@ from unittest.mock import Mock
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
|
||||
from lightning.pytorch import Callback, LightningModule, Trainer
|
||||
from lightning.pytorch.callbacks import (
|
||||
EarlyStopping,
|
||||
|
@ -32,7 +33,6 @@ from lightning.pytorch.callbacks import (
|
|||
from lightning.pytorch.callbacks.batch_size_finder import BatchSizeFinder
|
||||
from lightning.pytorch.demos.boring_classes import BoringModel
|
||||
from lightning.pytorch.trainer.connectors.callback_connector import _CallbackConnector
|
||||
from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
|
||||
|
||||
|
||||
def test_checkpoint_callbacks_are_last(tmpdir):
|
||||
|
|
|
@ -25,6 +25,7 @@ import pytest
|
|||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0
|
||||
from lightning.pytorch import callbacks, Trainer
|
||||
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
|
||||
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
|
||||
|
@ -32,7 +33,6 @@ from lightning.pytorch.loggers import TensorBoardLogger
|
|||
from lightning.pytorch.loops import _EvaluationLoop
|
||||
from lightning.pytorch.trainer.states import RunningStage
|
||||
from lightning.pytorch.utilities.exceptions import MisconfigurationException
|
||||
from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0
|
||||
from tests_pytorch.helpers.runif import RunIf
|
||||
|
||||
if _RICH_AVAILABLE:
|
||||
|
|
Loading…
Reference in New Issue