lightning/tests/tests_fabric/utilities/test_registry.py

61 lines
2.1 KiB
Python

import contextlib
from unittest import mock
from unittest.mock import MagicMock, Mock
from lightning.fabric.utilities.imports import _PYTHON_GREATER_EQUAL_3_10_0
from lightning.fabric.utilities.registry import _load_external_callbacks
class ExternalCallback:
"""A callback in another library that gets registered through entry points."""
pass
def test_load_external_callbacks():
"""Test that the connector collects Callback instances from factories registered through entry points."""
def factory_no_callback():
return []
def factory_one_callback():
return ExternalCallback()
def factory_one_callback_list():
return [ExternalCallback()]
def factory_multiple_callbacks_list():
return [ExternalCallback(), ExternalCallback()]
with _make_entry_point_query_mock(factory_no_callback):
callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory")
assert callbacks == []
with _make_entry_point_query_mock(factory_one_callback):
callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory")
assert isinstance(callbacks[0], ExternalCallback)
with _make_entry_point_query_mock(factory_one_callback_list):
callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory")
assert isinstance(callbacks[0], ExternalCallback)
with _make_entry_point_query_mock(factory_multiple_callbacks_list):
callbacks = _load_external_callbacks("lightning.pytorch.callbacks_factory")
assert isinstance(callbacks[0], ExternalCallback)
assert isinstance(callbacks[1], ExternalCallback)
@contextlib.contextmanager
def _make_entry_point_query_mock(callback_factory):
query_mock = MagicMock()
entry_point = Mock()
entry_point.name = "mocked"
entry_point.load.return_value = callback_factory
if _PYTHON_GREATER_EQUAL_3_10_0:
query_mock.return_value = [entry_point]
else:
query_mock().get.return_value = [entry_point]
with mock.patch("lightning.fabric.utilities.registry.entry_points", query_mock):
yield