External callback registry through entry points for Fabric (#17756)

This commit is contained in:
M. Fox 2023-06-06 13:53:19 +02:00 committed by GitHub
parent d23c772f3c
commit e2986fab14
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 127 additions and 47 deletions

View File

@ -58,6 +58,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for loading a full-state checkpoint file into a sharded model ([#17623](https://github.com/Lightning-AI/lightning/pull/17623))
- Added support for `Callback` registration through entry points ([#17756](https://github.com/Lightning-AI/lightning/pull/17756))
### Changed
- Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331))

View File

@ -45,6 +45,7 @@ from lightning.fabric.utilities.data import (
has_iterable_dataset,
)
from lightning.fabric.utilities.distributed import DistributedSamplerWrapper
from lightning.fabric.utilities.registry import _load_external_callbacks
from lightning.fabric.utilities.seed import seed_everything
from lightning.fabric.utilities.types import ReduceOp
from lightning.fabric.utilities.warnings import PossibleUserWarning
@ -111,8 +112,7 @@ class Fabric:
self._strategy: Strategy = self._connector.strategy
self._accelerator: Accelerator = self._connector.accelerator
self._precision: Precision = self._strategy.precision
callbacks = callbacks if callbacks is not None else []
self._callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
self._callbacks = self._configure_callbacks(callbacks)
loggers = loggers if loggers is not None else []
self._loggers = loggers if isinstance(loggers, list) else [loggers]
self._models_setup: int = 0
@ -908,6 +908,13 @@ class Fabric:
if any(not isinstance(dl, DataLoader) for dl in dataloaders):
raise TypeError("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.")
@staticmethod
def _configure_callbacks(callbacks: Optional[Union[List[Any], Any]]) -> List[Any]:
callbacks = callbacks if callbacks is not None else []
callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
callbacks.extend(_load_external_callbacks("lightning.fabric.callbacks_factory"))
return callbacks
@contextmanager
def _old_sharded_model_context(strategy: Strategy) -> Generator:

View File

@ -30,3 +30,6 @@ _TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0")
_TORCH_GREATER_EQUAL_2_0 = compare_version("torch", operator.ge, "2.0.0")
_TORCH_GREATER_EQUAL_2_1 = compare_version("torch", operator.ge, "2.1.0", use_base_version=True)
_TORCH_EQUAL_2_0 = _TORCH_GREATER_EQUAL_2_0 and not _TORCH_GREATER_EQUAL_2_1
_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)

View File

@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any
import logging
from typing import Any, List, Union
from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
_log = logging.getLogger(__name__)
def _is_register_method_overridden(mod: type, base_cls: Any, method: str) -> bool:
@ -25,3 +30,40 @@ def _is_register_method_overridden(mod: type, base_cls: Any, method: str) -> boo
return False
return mod_attr.__code__ is not super_attr.__code__
def _load_external_callbacks(group: str) -> List[Any]:
"""Collect external callbacks registered through entry points.
The entry points are expected to be functions returning a list of callbacks.
Args:
group: The entry point group name to load callbacks from.
Return:
A list of all callbacks collected from external factories.
"""
if _PYTHON_GREATER_EQUAL_3_8_0:
from importlib.metadata import entry_points
factories = (
entry_points(group=group)
if _PYTHON_GREATER_EQUAL_3_10_0
else entry_points().get(group, {}) # type: ignore[arg-type]
)
else:
from pkg_resources import iter_entry_points
factories = iter_entry_points(group) # type: ignore[assignment]
external_callbacks: List[Any] = []
for factory in factories:
callback_factory = factory.load()
callbacks_list: Union[List[Any], Any] = callback_factory()
callbacks_list = [callbacks_list] if not isinstance(callbacks_list, list) else callbacks_list
_log.info(
f"Adding {len(callbacks_list)} callbacks from entry point '{factory.name}':"
f" {', '.join(type(cb).__name__ for cb in callbacks_list)}"
)
external_callbacks.extend(callbacks_list)
return external_callbacks

View File

@ -18,6 +18,7 @@ from datetime import timedelta
from typing import Dict, List, Optional, Sequence, Union
import lightning.pytorch as pl
from lightning.fabric.utilities.registry import _load_external_callbacks
from lightning.pytorch.callbacks import (
Callback,
Checkpoint,
@ -33,7 +34,6 @@ from lightning.pytorch.callbacks.rich_model_summary import RichModelSummary
from lightning.pytorch.callbacks.timer import Timer
from lightning.pytorch.trainer import call
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import rank_zero_info
@ -75,7 +75,7 @@ class _CallbackConnector:
# configure the ModelSummary callback
self._configure_model_summary_callback(enable_model_summary)
self.trainer.callbacks.extend(_configure_external_callbacks())
self.trainer.callbacks.extend(_load_external_callbacks("lightning.pytorch.callbacks_factory"))
_validate_callbacks_list(self.trainer.callbacks)
# push all model checkpoint callbacks to the end
@ -213,42 +213,6 @@ class _CallbackConnector:
return tuner_callbacks + other_callbacks + checkpoint_callbacks
def _configure_external_callbacks() -> List[Callback]:
"""Collect external callbacks registered through entry points.
The entry points are expected to be functions returning a list of callbacks.
Return:
A list of all callbacks collected from external factories.
"""
group = "lightning.pytorch.callbacks_factory"
if _PYTHON_GREATER_EQUAL_3_8_0:
from importlib.metadata import entry_points
factories = (
entry_points(group=group)
if _PYTHON_GREATER_EQUAL_3_10_0
else entry_points().get(group, {}) # type: ignore[arg-type]
)
else:
from pkg_resources import iter_entry_points
factories = iter_entry_points(group) # type: ignore[assignment]
external_callbacks: List[Callback] = []
for factory in factories:
callback_factory = factory.load()
callbacks_list: Union[List[Callback], Callback] = callback_factory()
callbacks_list = [callbacks_list] if isinstance(callbacks_list, Callback) else callbacks_list
_log.info(
f"Adding {len(callbacks_list)} callbacks from entry point '{factory.name}':"
f" {', '.join(type(cb).__name__ for cb in callbacks_list)}"
)
external_callbacks.extend(callbacks_list)
return external_callbacks
def _validate_callbacks_list(callbacks: List[Callback]) -> None:
stateful_callbacks = [cb for cb in callbacks if is_overridden("state_dict", instance=cb)]
seen_callbacks = set()

View File

@ -11,8 +11,7 @@ from lightning_utilities.core.rank_zero import rank_prefixed_message
import lightning.pytorch as pl
from lightning.fabric.plugins.environments import SLURMEnvironment
from lightning.fabric.utilities.imports import _IS_WINDOWS
from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0
from lightning.fabric.utilities.imports import _IS_WINDOWS, _PYTHON_GREATER_EQUAL_3_8_0
from lightning.pytorch.utilities.rank_zero import rank_zero_info
# copied from signal.pyi

View File

@ -17,8 +17,6 @@ import sys
import torch
from lightning_utilities.core.imports import package_available, RequirementCache
_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
_PYTHON_GREATER_EQUAL_3_11_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 11)
_TORCHMETRICS_GREATER_EQUAL_0_9_1 = RequirementCache("torchmetrics>=0.9.1")
_TORCHMETRICS_GREATER_EQUAL_0_11 = RequirementCache("torchmetrics>=0.11.0") # using new API with task

View File

@ -0,0 +1,64 @@
import contextlib
from unittest import mock
from unittest.mock import Mock
from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
from lightning.fabric.utilities.registry import _load_external_callbacks
class ExternalCallback:
"""A callback in another library that gets registered through entry points."""
pass
def test_load_external_callbacks():
"""Test that the connector collects Callback instances from factories registered through entry points."""
def factory_no_callback():
return []
def factory_one_callback():
return ExternalCallback()
def factory_one_callback_list():
return [ExternalCallback()]
def factory_multiple_callbacks_list():
return [ExternalCallback(), ExternalCallback()]
with _make_entry_point_query_mock(factory_no_callback):
callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory")
assert callbacks == []
with _make_entry_point_query_mock(factory_one_callback):
callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory")
assert isinstance(callbacks[0], ExternalCallback)
with _make_entry_point_query_mock(factory_one_callback_list):
callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory")
assert isinstance(callbacks[0], ExternalCallback)
with _make_entry_point_query_mock(factory_multiple_callbacks_list):
callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory")
assert isinstance(callbacks[0], ExternalCallback)
assert isinstance(callbacks[1], ExternalCallback)
@contextlib.contextmanager
def _make_entry_point_query_mock(callback_factory):
query_mock = Mock()
entry_point = Mock()
entry_point.name = "mocked"
entry_point.load.return_value = callback_factory
if _PYTHON_GREATER_EQUAL_3_10_0:
query_mock.return_value = [entry_point]
import_path = "importlib.metadata.entry_points"
elif _PYTHON_GREATER_EQUAL_3_8_0:
query_mock().get.return_value = [entry_point]
import_path = "importlib.metadata.entry_points"
else:
query_mock.return_value = [entry_point]
import_path = "pkg_resources.iter_entry_points"
with mock.patch(import_path, query_mock):
yield

View File

@ -19,6 +19,7 @@ from unittest.mock import Mock
import pytest
import torch
from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
from lightning.pytorch import Callback, LightningModule, Trainer
from lightning.pytorch.callbacks import (
EarlyStopping,
@ -32,7 +33,6 @@ from lightning.pytorch.callbacks import (
from lightning.pytorch.callbacks.batch_size_finder import BatchSizeFinder
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.trainer.connectors.callback_connector import _CallbackConnector
from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
def test_checkpoint_callbacks_are_last(tmpdir):

View File

@ -25,6 +25,7 @@ import pytest
import torch
from torch import Tensor
from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0
from lightning.pytorch import callbacks, Trainer
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
@ -32,7 +33,6 @@ from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.loops import _EvaluationLoop
from lightning.pytorch.trainer.states import RunningStage
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0
from tests_pytorch.helpers.runif import RunIf
if _RICH_AVAILABLE: