Re-enable dynamo tests that were fixed in PyTorch 2.1 (#19038)

This commit is contained in:
Adrian Wälchli 2023-11-21 22:30:20 +01:00 committed by GitHub
parent d3df1273b6
commit e3be762538
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 6 deletions

View File

@ -35,6 +35,7 @@ from lightning.fabric.strategies import (
)
from lightning.fabric.strategies.strategy import _Sharded
from lightning.fabric.utilities.exceptions import MisconfigurationException
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
from lightning.fabric.utilities.seed import pl_worker_init_function, seed_everything
from lightning.fabric.utilities.warnings import PossibleUserWarning
from lightning.fabric.wrappers import _FabricDataLoader, _FabricModule, _FabricOptimizer
@ -1198,7 +1199,7 @@ def test_verify_launch_called():
fabric._validate_launched()
@pytest.mark.skipif(sys.platform == "darwin", reason="https://github.com/pytorch/pytorch/issues/95708")
@pytest.mark.skipif(sys.platform == "darwin" and not _TORCH_GREATER_EQUAL_2_1, reason="Fix for MacOS in PyTorch 2.1")
@RunIf(dynamo=True)
@pytest.mark.parametrize(
"kwargs",
@ -1224,7 +1225,7 @@ def test_fabric_with_torchdynamo_fullgraph(kwargs):
a = x * 10
return model(a)
fabric = Fabric(devices=1, **kwargs)
fabric = Fabric(devices=1, accelerator="cpu", **kwargs)
model = MyModel()
fmodel = fabric.setup(model)
# we are compiling a function that calls model.forward() inside

View File

@ -16,6 +16,7 @@ from unittest import mock
import pytest
import torch
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.utilities.compile import from_compiled, to_uncompiled
@ -25,10 +26,10 @@ from tests_pytorch.conftest import mock_cuda_count
from tests_pytorch.helpers.runif import RunIf
@pytest.mark.skipif(sys.platform == "darwin" and not _TORCH_GREATER_EQUAL_2_1, reason="Fix for MacOS in PyTorch 2.1")
@RunIf(dynamo=True)
@pytest.mark.skipif(sys.platform == "darwin", reason="https://github.com/pytorch/pytorch/issues/95708")
@mock.patch("lightning.pytorch.trainer.call._call_and_handle_interrupt")
def test_trainer_compiled_model(_, tmp_path, monkeypatch):
def test_trainer_compiled_model(_, tmp_path, monkeypatch, mps_count_0):
trainer_kwargs = {
"default_root_dir": tmp_path,
"fast_dev_run": True,
@ -111,7 +112,7 @@ def test_compile_uncompile():
assert not has_dynamo(to_uncompiled_model.predict_step)
@pytest.mark.skipif(sys.platform == "darwin", reason="https://github.com/pytorch/pytorch/issues/95708")
@pytest.mark.skipif(sys.platform == "darwin" and not _TORCH_GREATER_EQUAL_2_1, reason="Fix for MacOS in PyTorch 2.1")
@RunIf(dynamo=True)
def test_trainer_compiled_model_that_logs(tmp_path):
class MyModel(BoringModel):
@ -129,13 +130,14 @@ def test_trainer_compiled_model_that_logs(tmp_path):
enable_checkpointing=False,
enable_model_summary=False,
enable_progress_bar=False,
accelerator="cpu",
)
trainer.fit(compiled_model)
assert set(trainer.callback_metrics) == {"loss"}
@pytest.mark.skipif(sys.platform == "darwin", reason="https://github.com/pytorch/pytorch/issues/95708")
@pytest.mark.skipif(sys.platform == "darwin" and not _TORCH_GREATER_EQUAL_2_1, reason="Fix for MacOS in PyTorch 2.1")
@RunIf(dynamo=True)
def test_trainer_compiled_model_test(tmp_path):
model = BoringModel()
@ -147,5 +149,6 @@ def test_trainer_compiled_model_test(tmp_path):
enable_checkpointing=False,
enable_model_summary=False,
enable_progress_bar=False,
accelerator="cpu",
)
trainer.test(compiled_model)