Re-enable dynamo tests that were fixed in PyTorch 2.1 (#19038)
This commit is contained in:
parent
d3df1273b6
commit
e3be762538
|
@ -35,6 +35,7 @@ from lightning.fabric.strategies import (
|
|||
)
|
||||
from lightning.fabric.strategies.strategy import _Sharded
|
||||
from lightning.fabric.utilities.exceptions import MisconfigurationException
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
|
||||
from lightning.fabric.utilities.seed import pl_worker_init_function, seed_everything
|
||||
from lightning.fabric.utilities.warnings import PossibleUserWarning
|
||||
from lightning.fabric.wrappers import _FabricDataLoader, _FabricModule, _FabricOptimizer
|
||||
|
@ -1198,7 +1199,7 @@ def test_verify_launch_called():
|
|||
fabric._validate_launched()
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "darwin", reason="https://github.com/pytorch/pytorch/issues/95708")
|
||||
@pytest.mark.skipif(sys.platform == "darwin" and not _TORCH_GREATER_EQUAL_2_1, reason="Fix for MacOS in PyTorch 2.1")
|
||||
@RunIf(dynamo=True)
|
||||
@pytest.mark.parametrize(
|
||||
"kwargs",
|
||||
|
@ -1224,7 +1225,7 @@ def test_fabric_with_torchdynamo_fullgraph(kwargs):
|
|||
a = x * 10
|
||||
return model(a)
|
||||
|
||||
fabric = Fabric(devices=1, **kwargs)
|
||||
fabric = Fabric(devices=1, accelerator="cpu", **kwargs)
|
||||
model = MyModel()
|
||||
fmodel = fabric.setup(model)
|
||||
# we are compiling a function that calls model.forward() inside
|
||||
|
|
|
@ -16,6 +16,7 @@ from unittest import mock
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
|
||||
from lightning.pytorch import LightningModule, Trainer
|
||||
from lightning.pytorch.demos.boring_classes import BoringModel
|
||||
from lightning.pytorch.utilities.compile import from_compiled, to_uncompiled
|
||||
|
@ -25,10 +26,10 @@ from tests_pytorch.conftest import mock_cuda_count
|
|||
from tests_pytorch.helpers.runif import RunIf
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "darwin" and not _TORCH_GREATER_EQUAL_2_1, reason="Fix for MacOS in PyTorch 2.1")
|
||||
@RunIf(dynamo=True)
|
||||
@pytest.mark.skipif(sys.platform == "darwin", reason="https://github.com/pytorch/pytorch/issues/95708")
|
||||
@mock.patch("lightning.pytorch.trainer.call._call_and_handle_interrupt")
|
||||
def test_trainer_compiled_model(_, tmp_path, monkeypatch):
|
||||
def test_trainer_compiled_model(_, tmp_path, monkeypatch, mps_count_0):
|
||||
trainer_kwargs = {
|
||||
"default_root_dir": tmp_path,
|
||||
"fast_dev_run": True,
|
||||
|
@ -111,7 +112,7 @@ def test_compile_uncompile():
|
|||
assert not has_dynamo(to_uncompiled_model.predict_step)
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "darwin", reason="https://github.com/pytorch/pytorch/issues/95708")
|
||||
@pytest.mark.skipif(sys.platform == "darwin" and not _TORCH_GREATER_EQUAL_2_1, reason="Fix for MacOS in PyTorch 2.1")
|
||||
@RunIf(dynamo=True)
|
||||
def test_trainer_compiled_model_that_logs(tmp_path):
|
||||
class MyModel(BoringModel):
|
||||
|
@ -129,13 +130,14 @@ def test_trainer_compiled_model_that_logs(tmp_path):
|
|||
enable_checkpointing=False,
|
||||
enable_model_summary=False,
|
||||
enable_progress_bar=False,
|
||||
accelerator="cpu",
|
||||
)
|
||||
trainer.fit(compiled_model)
|
||||
|
||||
assert set(trainer.callback_metrics) == {"loss"}
|
||||
|
||||
|
||||
@pytest.mark.skipif(sys.platform == "darwin", reason="https://github.com/pytorch/pytorch/issues/95708")
|
||||
@pytest.mark.skipif(sys.platform == "darwin" and not _TORCH_GREATER_EQUAL_2_1, reason="Fix for MacOS in PyTorch 2.1")
|
||||
@RunIf(dynamo=True)
|
||||
def test_trainer_compiled_model_test(tmp_path):
|
||||
model = BoringModel()
|
||||
|
@ -147,5 +149,6 @@ def test_trainer_compiled_model_test(tmp_path):
|
|||
enable_checkpointing=False,
|
||||
enable_model_summary=False,
|
||||
enable_progress_bar=False,
|
||||
accelerator="cpu",
|
||||
)
|
||||
trainer.test(compiled_model)
|
||||
|
|
Loading…
Reference in New Issue