diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index cc1ec88618..0fe4b68d3d 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -35,6 +35,7 @@ from lightning.fabric.strategies import ( ) from lightning.fabric.strategies.strategy import _Sharded from lightning.fabric.utilities.exceptions import MisconfigurationException +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from lightning.fabric.utilities.seed import pl_worker_init_function, seed_everything from lightning.fabric.utilities.warnings import PossibleUserWarning from lightning.fabric.wrappers import _FabricDataLoader, _FabricModule, _FabricOptimizer @@ -1198,7 +1199,7 @@ def test_verify_launch_called(): fabric._validate_launched() -@pytest.mark.skipif(sys.platform == "darwin", reason="https://github.com/pytorch/pytorch/issues/95708") +@pytest.mark.skipif(sys.platform == "darwin" and not _TORCH_GREATER_EQUAL_2_1, reason="Fix for MacOS in PyTorch 2.1") @RunIf(dynamo=True) @pytest.mark.parametrize( "kwargs", @@ -1224,7 +1225,7 @@ def test_fabric_with_torchdynamo_fullgraph(kwargs): a = x * 10 return model(a) - fabric = Fabric(devices=1, **kwargs) + fabric = Fabric(devices=1, accelerator="cpu", **kwargs) model = MyModel() fmodel = fabric.setup(model) # we are compiling a function that calls model.forward() inside diff --git a/tests/tests_pytorch/utilities/test_compile.py b/tests/tests_pytorch/utilities/test_compile.py index 9a97d2d74b..7fe03ddf51 100644 --- a/tests/tests_pytorch/utilities/test_compile.py +++ b/tests/tests_pytorch/utilities/test_compile.py @@ -16,6 +16,7 @@ from unittest import mock import pytest import torch +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1 from lightning.pytorch import LightningModule, Trainer from lightning.pytorch.demos.boring_classes import BoringModel from lightning.pytorch.utilities.compile import from_compiled, to_uncompiled @@ -25,10 +26,10 @@ from tests_pytorch.conftest import mock_cuda_count from tests_pytorch.helpers.runif import RunIf +@pytest.mark.skipif(sys.platform == "darwin" and not _TORCH_GREATER_EQUAL_2_1, reason="Fix for MacOS in PyTorch 2.1") @RunIf(dynamo=True) -@pytest.mark.skipif(sys.platform == "darwin", reason="https://github.com/pytorch/pytorch/issues/95708") @mock.patch("lightning.pytorch.trainer.call._call_and_handle_interrupt") -def test_trainer_compiled_model(_, tmp_path, monkeypatch): +def test_trainer_compiled_model(_, tmp_path, monkeypatch, mps_count_0): trainer_kwargs = { "default_root_dir": tmp_path, "fast_dev_run": True, @@ -111,7 +112,7 @@ def test_compile_uncompile(): assert not has_dynamo(to_uncompiled_model.predict_step) -@pytest.mark.skipif(sys.platform == "darwin", reason="https://github.com/pytorch/pytorch/issues/95708") +@pytest.mark.skipif(sys.platform == "darwin" and not _TORCH_GREATER_EQUAL_2_1, reason="Fix for MacOS in PyTorch 2.1") @RunIf(dynamo=True) def test_trainer_compiled_model_that_logs(tmp_path): class MyModel(BoringModel): @@ -129,13 +130,14 @@ def test_trainer_compiled_model_that_logs(tmp_path): enable_checkpointing=False, enable_model_summary=False, enable_progress_bar=False, + accelerator="cpu", ) trainer.fit(compiled_model) assert set(trainer.callback_metrics) == {"loss"} -@pytest.mark.skipif(sys.platform == "darwin", reason="https://github.com/pytorch/pytorch/issues/95708") +@pytest.mark.skipif(sys.platform == "darwin" and not _TORCH_GREATER_EQUAL_2_1, reason="Fix for MacOS in PyTorch 2.1") @RunIf(dynamo=True) def test_trainer_compiled_model_test(tmp_path): model = BoringModel() @@ -147,5 +149,6 @@ def test_trainer_compiled_model_test(tmp_path): enable_checkpointing=False, enable_model_summary=False, enable_progress_bar=False, + accelerator="cpu", ) trainer.test(compiled_model)