External callback registry through entry points for Fabric (#17756)

This commit is contained in:
M. Fox 2023-06-06 13:53:19 +02:00 committed by GitHub
parent d23c772f3c
commit e2986fab14
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 127 additions and 47 deletions

View File

@ -58,6 +58,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for loading a full-state checkpoint file into a sharded model ([#17623](https://github.com/Lightning-AI/lightning/pull/17623)) - Added support for loading a full-state checkpoint file into a sharded model ([#17623](https://github.com/Lightning-AI/lightning/pull/17623))
- Added support for `Callback` registration through entry points ([#17756](https://github.com/Lightning-AI/lightning/pull/17756))
### Changed ### Changed
- Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331)) - Allow using iterable-style datasets with TPUs ([#17331](https://github.com/Lightning-AI/lightning/pull/17331))

View File

@ -45,6 +45,7 @@ from lightning.fabric.utilities.data import (
has_iterable_dataset, has_iterable_dataset,
) )
from lightning.fabric.utilities.distributed import DistributedSamplerWrapper from lightning.fabric.utilities.distributed import DistributedSamplerWrapper
from lightning.fabric.utilities.registry import _load_external_callbacks
from lightning.fabric.utilities.seed import seed_everything from lightning.fabric.utilities.seed import seed_everything
from lightning.fabric.utilities.types import ReduceOp from lightning.fabric.utilities.types import ReduceOp
from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.fabric.utilities.warnings import PossibleUserWarning
@ -111,8 +112,7 @@ class Fabric:
self._strategy: Strategy = self._connector.strategy self._strategy: Strategy = self._connector.strategy
self._accelerator: Accelerator = self._connector.accelerator self._accelerator: Accelerator = self._connector.accelerator
self._precision: Precision = self._strategy.precision self._precision: Precision = self._strategy.precision
callbacks = callbacks if callbacks is not None else [] self._callbacks = self._configure_callbacks(callbacks)
self._callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
loggers = loggers if loggers is not None else [] loggers = loggers if loggers is not None else []
self._loggers = loggers if isinstance(loggers, list) else [loggers] self._loggers = loggers if isinstance(loggers, list) else [loggers]
self._models_setup: int = 0 self._models_setup: int = 0
@ -908,6 +908,13 @@ class Fabric:
if any(not isinstance(dl, DataLoader) for dl in dataloaders): if any(not isinstance(dl, DataLoader) for dl in dataloaders):
raise TypeError("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.") raise TypeError("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.")
@staticmethod
def _configure_callbacks(callbacks: Optional[Union[List[Any], Any]]) -> List[Any]:
callbacks = callbacks if callbacks is not None else []
callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
callbacks.extend(_load_external_callbacks("lightning.fabric.callbacks_factory"))
return callbacks
@contextmanager @contextmanager
def _old_sharded_model_context(strategy: Strategy) -> Generator: def _old_sharded_model_context(strategy: Strategy) -> Generator:

View File

@ -30,3 +30,6 @@ _TORCH_GREATER_EQUAL_1_13 = compare_version("torch", operator.ge, "1.13.0")
_TORCH_GREATER_EQUAL_2_0 = compare_version("torch", operator.ge, "2.0.0") _TORCH_GREATER_EQUAL_2_0 = compare_version("torch", operator.ge, "2.0.0")
_TORCH_GREATER_EQUAL_2_1 = compare_version("torch", operator.ge, "2.1.0", use_base_version=True) _TORCH_GREATER_EQUAL_2_1 = compare_version("torch", operator.ge, "2.1.0", use_base_version=True)
_TORCH_EQUAL_2_0 = _TORCH_GREATER_EQUAL_2_0 and not _TORCH_GREATER_EQUAL_2_1 _TORCH_EQUAL_2_0 = _TORCH_GREATER_EQUAL_2_0 and not _TORCH_GREATER_EQUAL_2_1
_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)

View File

@ -12,7 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect import inspect
from typing import Any import logging
from typing import Any, List, Union
from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
_log = logging.getLogger(__name__)
def _is_register_method_overridden(mod: type, base_cls: Any, method: str) -> bool: def _is_register_method_overridden(mod: type, base_cls: Any, method: str) -> bool:
@ -25,3 +30,40 @@ def _is_register_method_overridden(mod: type, base_cls: Any, method: str) -> boo
return False return False
return mod_attr.__code__ is not super_attr.__code__ return mod_attr.__code__ is not super_attr.__code__
def _load_external_callbacks(group: str) -> List[Any]:
"""Collect external callbacks registered through entry points.
The entry points are expected to be functions returning a list of callbacks.
Args:
group: The entry point group name to load callbacks from.
Return:
A list of all callbacks collected from external factories.
"""
if _PYTHON_GREATER_EQUAL_3_8_0:
from importlib.metadata import entry_points
factories = (
entry_points(group=group)
if _PYTHON_GREATER_EQUAL_3_10_0
else entry_points().get(group, {}) # type: ignore[arg-type]
)
else:
from pkg_resources import iter_entry_points
factories = iter_entry_points(group) # type: ignore[assignment]
external_callbacks: List[Any] = []
for factory in factories:
callback_factory = factory.load()
callbacks_list: Union[List[Any], Any] = callback_factory()
callbacks_list = [callbacks_list] if not isinstance(callbacks_list, list) else callbacks_list
_log.info(
f"Adding {len(callbacks_list)} callbacks from entry point '{factory.name}':"
f" {', '.join(type(cb).__name__ for cb in callbacks_list)}"
)
external_callbacks.extend(callbacks_list)
return external_callbacks

View File

@ -18,6 +18,7 @@ from datetime import timedelta
from typing import Dict, List, Optional, Sequence, Union from typing import Dict, List, Optional, Sequence, Union
import lightning.pytorch as pl import lightning.pytorch as pl
from lightning.fabric.utilities.registry import _load_external_callbacks
from lightning.pytorch.callbacks import ( from lightning.pytorch.callbacks import (
Callback, Callback,
Checkpoint, Checkpoint,
@ -33,7 +34,6 @@ from lightning.pytorch.callbacks.rich_model_summary import RichModelSummary
from lightning.pytorch.callbacks.timer import Timer from lightning.pytorch.callbacks.timer import Timer
from lightning.pytorch.trainer import call from lightning.pytorch.trainer import call
from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
from lightning.pytorch.utilities.model_helpers import is_overridden from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.rank_zero import rank_zero_info from lightning.pytorch.utilities.rank_zero import rank_zero_info
@ -75,7 +75,7 @@ class _CallbackConnector:
# configure the ModelSummary callback # configure the ModelSummary callback
self._configure_model_summary_callback(enable_model_summary) self._configure_model_summary_callback(enable_model_summary)
self.trainer.callbacks.extend(_configure_external_callbacks()) self.trainer.callbacks.extend(_load_external_callbacks("lightning.pytorch.callbacks_factory"))
_validate_callbacks_list(self.trainer.callbacks) _validate_callbacks_list(self.trainer.callbacks)
# push all model checkpoint callbacks to the end # push all model checkpoint callbacks to the end
@ -213,42 +213,6 @@ class _CallbackConnector:
return tuner_callbacks + other_callbacks + checkpoint_callbacks return tuner_callbacks + other_callbacks + checkpoint_callbacks
def _configure_external_callbacks() -> List[Callback]:
"""Collect external callbacks registered through entry points.
The entry points are expected to be functions returning a list of callbacks.
Return:
A list of all callbacks collected from external factories.
"""
group = "lightning.pytorch.callbacks_factory"
if _PYTHON_GREATER_EQUAL_3_8_0:
from importlib.metadata import entry_points
factories = (
entry_points(group=group)
if _PYTHON_GREATER_EQUAL_3_10_0
else entry_points().get(group, {}) # type: ignore[arg-type]
)
else:
from pkg_resources import iter_entry_points
factories = iter_entry_points(group) # type: ignore[assignment]
external_callbacks: List[Callback] = []
for factory in factories:
callback_factory = factory.load()
callbacks_list: Union[List[Callback], Callback] = callback_factory()
callbacks_list = [callbacks_list] if isinstance(callbacks_list, Callback) else callbacks_list
_log.info(
f"Adding {len(callbacks_list)} callbacks from entry point '{factory.name}':"
f" {', '.join(type(cb).__name__ for cb in callbacks_list)}"
)
external_callbacks.extend(callbacks_list)
return external_callbacks
def _validate_callbacks_list(callbacks: List[Callback]) -> None: def _validate_callbacks_list(callbacks: List[Callback]) -> None:
stateful_callbacks = [cb for cb in callbacks if is_overridden("state_dict", instance=cb)] stateful_callbacks = [cb for cb in callbacks if is_overridden("state_dict", instance=cb)]
seen_callbacks = set() seen_callbacks = set()

View File

@ -11,8 +11,7 @@ from lightning_utilities.core.rank_zero import rank_prefixed_message
import lightning.pytorch as pl import lightning.pytorch as pl
from lightning.fabric.plugins.environments import SLURMEnvironment from lightning.fabric.plugins.environments import SLURMEnvironment
from lightning.fabric.utilities.imports import _IS_WINDOWS from lightning.fabric.utilities.imports import _IS_WINDOWS, _PYTHON_GREATER_EQUAL_3_8_0
from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0
from lightning.pytorch.utilities.rank_zero import rank_zero_info from lightning.pytorch.utilities.rank_zero import rank_zero_info
# copied from signal.pyi # copied from signal.pyi

View File

@ -17,8 +17,6 @@ import sys
import torch import torch
from lightning_utilities.core.imports import package_available, RequirementCache from lightning_utilities.core.imports import package_available, RequirementCache
_PYTHON_GREATER_EQUAL_3_8_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 8)
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)
_PYTHON_GREATER_EQUAL_3_11_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 11) _PYTHON_GREATER_EQUAL_3_11_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 11)
_TORCHMETRICS_GREATER_EQUAL_0_9_1 = RequirementCache("torchmetrics>=0.9.1") _TORCHMETRICS_GREATER_EQUAL_0_9_1 = RequirementCache("torchmetrics>=0.9.1")
_TORCHMETRICS_GREATER_EQUAL_0_11 = RequirementCache("torchmetrics>=0.11.0") # using new API with task _TORCHMETRICS_GREATER_EQUAL_0_11 = RequirementCache("torchmetrics>=0.11.0") # using new API with task

View File

@ -0,0 +1,64 @@
import contextlib
from unittest import mock
from unittest.mock import Mock
from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
from lightning.fabric.utilities.registry import _load_external_callbacks
class ExternalCallback:
"""A callback in another library that gets registered through entry points."""
pass
def test_load_external_callbacks():
"""Test that the connector collects Callback instances from factories registered through entry points."""
def factory_no_callback():
return []
def factory_one_callback():
return ExternalCallback()
def factory_one_callback_list():
return [ExternalCallback()]
def factory_multiple_callbacks_list():
return [ExternalCallback(), ExternalCallback()]
with _make_entry_point_query_mock(factory_no_callback):
callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory")
assert callbacks == []
with _make_entry_point_query_mock(factory_one_callback):
callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory")
assert isinstance(callbacks[0], ExternalCallback)
with _make_entry_point_query_mock(factory_one_callback_list):
callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory")
assert isinstance(callbacks[0], ExternalCallback)
with _make_entry_point_query_mock(factory_multiple_callbacks_list):
callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory")
assert isinstance(callbacks[0], ExternalCallback)
assert isinstance(callbacks[1], ExternalCallback)
@contextlib.contextmanager
def _make_entry_point_query_mock(callback_factory):
query_mock = Mock()
entry_point = Mock()
entry_point.name = "mocked"
entry_point.load.return_value = callback_factory
if _PYTHON_GREATER_EQUAL_3_10_0:
query_mock.return_value = [entry_point]
import_path = "importlib.metadata.entry_points"
elif _PYTHON_GREATER_EQUAL_3_8_0:
query_mock().get.return_value = [entry_point]
import_path = "importlib.metadata.entry_points"
else:
query_mock.return_value = [entry_point]
import_path = "pkg_resources.iter_entry_points"
with mock.patch(import_path, query_mock):
yield

View File

@ -19,6 +19,7 @@ from unittest.mock import Mock
import pytest import pytest
import torch import torch
from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
from lightning.pytorch import Callback, LightningModule, Trainer from lightning.pytorch import Callback, LightningModule, Trainer
from lightning.pytorch.callbacks import ( from lightning.pytorch.callbacks import (
EarlyStopping, EarlyStopping,
@ -32,7 +33,6 @@ from lightning.pytorch.callbacks import (
from lightning.pytorch.callbacks.batch_size_finder import BatchSizeFinder from lightning.pytorch.callbacks.batch_size_finder import BatchSizeFinder
from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.trainer.connectors.callback_connector import _CallbackConnector from lightning.pytorch.trainer.connectors.callback_connector import _CallbackConnector
from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0, _PYTHON_GREATER_EQUAL_3_10_0
def test_checkpoint_callbacks_are_last(tmpdir): def test_checkpoint_callbacks_are_last(tmpdir):

View File

@ -25,6 +25,7 @@ import pytest
import torch import torch
from torch import Tensor from torch import Tensor
from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0
from lightning.pytorch import callbacks, Trainer from lightning.pytorch import callbacks, Trainer
from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset
@ -32,7 +33,6 @@ from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.loops import _EvaluationLoop from lightning.pytorch.loops import _EvaluationLoop
from lightning.pytorch.trainer.states import RunningStage from lightning.pytorch.trainer.states import RunningStage
from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.imports import _PYTHON_GREATER_EQUAL_3_8_0
from tests_pytorch.helpers.runif import RunIf from tests_pytorch.helpers.runif import RunIf
if _RICH_AVAILABLE: if _RICH_AVAILABLE: