lightning/tests/tests_pytorch/utilities/test_imports.py

173 lines
5.9 KiB
Python

# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import subprocess
import sys
from textwrap import dedent
from unittest import mock
import pytest
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE
from lightning_utilities.core.imports import RequirementCache
from torch.distributed import is_available
from tests_pytorch.helpers.runif import RunIf
def test_imports():
try:
import omegaconf # noqa
except ModuleNotFoundError:
assert not _OMEGACONF_AVAILABLE
else:
assert _OMEGACONF_AVAILABLE
def _shortcut_patch(orig_fn, shortcut_case, attr_names=None):
"""Patch a function to return False value in the shortcut case, otherwise return original value."""
def new_fn(*args, **kwargs):
if attr_names is not None:
# We assume that the first argument is the object to check and therefore values are stored in attributes.
self = args[0]
values = tuple(getattr(self, attr_name) for attr_name in attr_names)
else:
values = args
match = True
for value, case in zip(values, shortcut_case):
# Go through values passed to the original function and compare them to the shortcut case.
# We are iterating in case the `shortcut_case` and `args` lengths are different.
if value != case:
match = False
break
if match: # If all values match the shortcut case, return False to simulate the module not being available.
return False
# otherwise return the original value
return orig_fn(*args, **kwargs)
return new_fn
@pytest.fixture()
def clean_import():
"""This fixture allows test to import {pytorch_}lightning* modules completely cleanly, regardless of the current
state of the imported modules.
Afterwards, it restores the original state of the modules.
"""
import sys
# copy modules to avoid modifying the original
old_sys_modules = sys.modules
# remove all *lightning* modules
new_sys_modules = {key: value for key, value in sys.modules.items() if "lightning" not in key}
sys.modules = new_sys_modules # replace sys.modules with the new one
yield
sys.modules = old_sys_modules # restore original modules
@pytest.mark.parametrize(
("patch_name", "new_fn", "to_import"),
[
pytest.param(
"torch.distributed.is_available", _shortcut_patch(is_available, ()), "lightning.pytorch", id="ProcessGroup"
),
pytest.param(
"lightning_utilities.core.imports.RequirementCache.__bool__",
_shortcut_patch(RequirementCache.__bool__, ("neptune",), ("requirement",)),
"lightning.pytorch.loggers.neptune",
id="neptune",
),
pytest.param(
"lightning_utilities.core.imports.RequirementCache.__bool__",
_shortcut_patch(RequirementCache.__bool__, ("jsonargparse[signatures]>=4.12.0",), ("requirement",)),
"lightning.pytorch.cli",
id="cli",
),
],
)
def test_import_with_unavailable_dependencies(patch_name, new_fn, to_import, clean_import):
"""This tests simulates unavailability of certain modules by patching the functions that check for their
availability.
When the patch is applied and the module is imported, it should not raise any errors. The list of cases to check was
compiled by finding else branches of top-level if statements checking for the availability of the module and
performing imports.
"""
with mock.patch(patch_name, new=new_fn):
importlib.import_module(to_import)
def test_import_pytorch_lightning_with_torch_dist_unavailable():
"""Test that the package can be imported regardless of whether torch.distributed is available."""
code = dedent(
"""
import torch
# pretend torch.distributed not available
for name in list(torch.distributed.__dict__.keys()):
if not name.startswith("__"):
delattr(torch.distributed, name)
torch.distributed.is_available = lambda: False
import lightning.pytorch
"""
)
# run in complete isolation
assert subprocess.call([sys.executable, "-c", code]) == 0
@RunIf(deepspeed=True)
def test_import_deepspeed_lazily():
"""Test that we are importing deepspeed only when necessary."""
code = dedent(
"""
import lightning.pytorch
import sys
assert 'deepspeed' not in sys.modules
from lightning.pytorch.strategies import DeepSpeedStrategy
from lightning.pytorch.plugins import DeepSpeedPrecision
assert 'deepspeed' not in sys.modules
import deepspeed
assert 'deepspeed' in sys.modules
"""
)
# run in complete isolation
assert subprocess.call([sys.executable, "-c", code]) == 0
@RunIf(min_python="3.9")
def test_import_lightning_multiprocessing_start_method_not_set():
"""Regression test for avoiding the lightning import to set the multiprocessing context."""
package_name = "pytorch_lightning" if "lightning.pytorch" == "pytorch_lightning" else "lightning"
# The following would fail with "context has already been set"
code = dedent(
f"""
import sys
import multiprocessing as mp
import {package_name}
mp.set_start_method("spawn")
"""
)
# run in complete isolation
assert subprocess.call([sys.executable, "-c", code]) == 0