lightning/tests/tests_pytorch/utilities/test_compile.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

152 lines
5.5 KiB
Python
Raw Normal View History

# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from unittest import mock
import pytest
import torch
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.demos.boring_classes import BoringModel
from lightning.pytorch.utilities.compile import from_compiled, to_uncompiled
ruff: replace isort with ruff +TPU (#17684) * ruff: replace isort with ruff * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixing & imports * lines in warning test * docs * fix enum import * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixing * import * fix lines * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * type ClusterEnvironment * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2023-09-26 15:54:55 +00:00
from lightning_utilities.core import module_available
from tests_pytorch.conftest import mock_cuda_count
from tests_pytorch.helpers.runif import RunIf
@RunIf(dynamo=True)
@pytest.mark.skipif(sys.platform == "darwin", reason="https://github.com/pytorch/pytorch/issues/95708")
@mock.patch("lightning.pytorch.trainer.call._call_and_handle_interrupt")
def test_trainer_compiled_model(_, tmp_path, monkeypatch):
trainer_kwargs = {
"default_root_dir": tmp_path,
"fast_dev_run": True,
"logger": False,
"enable_checkpointing": False,
"enable_model_summary": False,
"enable_progress_bar": False,
}
model = BoringModel()
compiled_model = torch.compile(model)
assert model._compiler_ctx is compiled_model._compiler_ctx # shared reference
# can train with compiled model
trainer = Trainer(**trainer_kwargs)
trainer.fit(compiled_model)
assert trainer.model._compiler_ctx["compiler"] == "dynamo"
# the compiled model can be uncompiled
to_uncompiled_model = to_uncompiled(compiled_model)
assert model._compiler_ctx is None
assert compiled_model._compiler_ctx is None
assert to_uncompiled_model._compiler_ctx is None
# the compiled model needs to be passed
with pytest.raises(ValueError, match="required to be a compiled LightningModule"):
to_uncompiled(to_uncompiled_model)
# the uncompiled model can be fitted
trainer = Trainer(**trainer_kwargs)
trainer.fit(model)
assert trainer.model._compiler_ctx is None
# some strategies do not support it
if module_available("deepspeed"):
compiled_model = torch.compile(model)
mock_cuda_count(monkeypatch, 2)
trainer = Trainer(strategy="deepspeed", accelerator="cuda", **trainer_kwargs)
with pytest.raises(RuntimeError, match="Using a compiled model is incompatible with the current strategy.*"):
trainer.fit(compiled_model)
# ddp does
trainer = Trainer(strategy="ddp", **trainer_kwargs)
trainer.fit(compiled_model)
# an exception is raised
trainer = Trainer(**trainer_kwargs)
with pytest.raises(TypeError, match="must be a `Light"):
trainer.fit(object())
@RunIf(dynamo=True)
def test_compile_uncompile():
model = BoringModel()
compiled_model = torch.compile(model)
def has_dynamo(fn):
return any(el for el in dir(fn) if el.startswith("_torchdynamo"))
from_compiled_model = from_compiled(compiled_model)
assert isinstance(from_compiled_model, LightningModule)
assert from_compiled_model._compiler_ctx is not None
assert has_dynamo(from_compiled_model.forward)
assert has_dynamo(from_compiled_model.training_step)
assert has_dynamo(from_compiled_model.validation_step)
assert has_dynamo(from_compiled_model.test_step)
assert has_dynamo(from_compiled_model.predict_step)
to_uncompiled_model = to_uncompiled(model)
assert to_uncompiled_model._compiler_ctx is None
assert to_uncompiled_model.forward == model.forward
assert to_uncompiled_model.training_step == model.training_step
assert to_uncompiled_model.validation_step == model.validation_step
assert to_uncompiled_model.test_step == model.test_step
assert to_uncompiled_model.predict_step == model.predict_step
assert not has_dynamo(to_uncompiled_model.forward)
assert not has_dynamo(to_uncompiled_model.training_step)
assert not has_dynamo(to_uncompiled_model.validation_step)
assert not has_dynamo(to_uncompiled_model.test_step)
assert not has_dynamo(to_uncompiled_model.predict_step)
@pytest.mark.skipif(sys.platform == "darwin", reason="https://github.com/pytorch/pytorch/issues/95708")
@RunIf(dynamo=True)
def test_trainer_compiled_model_that_logs(tmp_path):
class MyModel(BoringModel):
def training_step(self, batch, batch_idx):
loss = self.step(batch)
self.log("loss", loss)
return loss
model = MyModel()
compiled_model = torch.compile(model)
trainer = Trainer(
default_root_dir=tmp_path,
fast_dev_run=True,
enable_checkpointing=False,
enable_model_summary=False,
enable_progress_bar=False,
)
trainer.fit(compiled_model)
assert set(trainer.callback_metrics) == {"loss"}
@pytest.mark.skipif(sys.platform == "darwin", reason="https://github.com/pytorch/pytorch/issues/95708")
@RunIf(dynamo=True)
def test_trainer_compiled_model_test(tmp_path):
model = BoringModel()
compiled_model = torch.compile(model)
trainer = Trainer(
default_root_dir=tmp_path,
fast_dev_run=True,
enable_checkpointing=False,
enable_model_summary=False,
enable_progress_bar=False,
)
trainer.test(compiled_model)