2023-02-10 12:45:40 +00:00
|
|
|
# Copyright The Lightning AI team.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
2023-02-28 17:38:40 +00:00
|
|
|
import sys
|
2023-02-10 12:45:40 +00:00
|
|
|
|
|
|
|
import pytest
|
|
|
|
import torch
|
2023-02-16 05:12:08 +00:00
|
|
|
from lightning_utilities.core import module_available
|
2023-02-10 12:45:40 +00:00
|
|
|
|
2023-02-28 17:38:40 +00:00
|
|
|
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_1
|
2023-02-10 12:45:40 +00:00
|
|
|
from lightning.pytorch import LightningModule, Trainer
|
|
|
|
from lightning.pytorch.demos.boring_classes import BoringModel
|
|
|
|
from lightning.pytorch.utilities.compile import from_compiled, to_uncompiled
|
|
|
|
from tests_pytorch.conftest import mock_cuda_count
|
|
|
|
from tests_pytorch.helpers.runif import RunIf
|
|
|
|
|
|
|
|
|
2023-02-28 17:38:40 +00:00
|
|
|
def skip_if_unsupported():
|
|
|
|
if _TORCH_GREATER_EQUAL_2_1:
|
|
|
|
from torch._dynamo.eval_frame import is_dynamo_supported
|
|
|
|
|
|
|
|
if not is_dynamo_supported():
|
|
|
|
pytest.skip("TorchDynamo unsupported")
|
|
|
|
elif sys.platform == "win32" or sys.version_info >= (3, 11):
|
|
|
|
pytest.skip("TorchDynamo unsupported")
|
|
|
|
|
|
|
|
|
2023-02-10 12:45:40 +00:00
|
|
|
@RunIf(min_torch="2.0.0")
|
2023-02-28 17:38:40 +00:00
|
|
|
@pytest.mark.skipif(sys.platform == "darwin", reason="https://github.com/pytorch/pytorch/issues/95708")
|
2023-02-10 12:45:40 +00:00
|
|
|
def test_trainer_compiled_model(tmp_path, monkeypatch):
|
2023-02-28 17:38:40 +00:00
|
|
|
skip_if_unsupported()
|
|
|
|
|
2023-02-10 12:45:40 +00:00
|
|
|
trainer_kwargs = {
|
|
|
|
"default_root_dir": tmp_path,
|
|
|
|
"fast_dev_run": True,
|
|
|
|
"logger": False,
|
|
|
|
"enable_checkpointing": False,
|
|
|
|
"enable_model_summary": False,
|
|
|
|
"enable_progress_bar": False,
|
|
|
|
}
|
|
|
|
|
|
|
|
model = BoringModel()
|
|
|
|
compiled_model = torch.compile(model)
|
|
|
|
assert model._compiler_ctx is compiled_model._compiler_ctx # shared reference
|
|
|
|
|
|
|
|
# can train with compiled model
|
|
|
|
trainer = Trainer(**trainer_kwargs)
|
|
|
|
trainer.fit(compiled_model)
|
|
|
|
assert trainer.model._compiler_ctx["compiler"] == "dynamo"
|
|
|
|
|
|
|
|
# the compiled model can be uncompiled
|
|
|
|
to_uncompiled_model = to_uncompiled(compiled_model)
|
|
|
|
assert model._compiler_ctx is None
|
|
|
|
assert compiled_model._compiler_ctx is None
|
|
|
|
assert to_uncompiled_model._compiler_ctx is None
|
|
|
|
|
|
|
|
# the compiled model needs to be passed
|
|
|
|
with pytest.raises(ValueError, match="required to be a compiled LightningModule"):
|
|
|
|
to_uncompiled(to_uncompiled_model)
|
|
|
|
|
|
|
|
# the uncompiled model can be fitted
|
|
|
|
trainer = Trainer(**trainer_kwargs)
|
|
|
|
trainer.fit(model)
|
|
|
|
assert trainer.model._compiler_ctx is None
|
|
|
|
|
|
|
|
# some strategies do not support it
|
2023-02-16 05:12:08 +00:00
|
|
|
if module_available("deepspeed"):
|
|
|
|
compiled_model = torch.compile(model)
|
|
|
|
mock_cuda_count(monkeypatch, 2)
|
|
|
|
trainer = Trainer(strategy="deepspeed", accelerator="cuda", **trainer_kwargs)
|
|
|
|
with pytest.raises(RuntimeError, match="Using a compiled model is incompatible with the current strategy.*"):
|
|
|
|
trainer.fit(compiled_model)
|
2023-02-10 12:45:40 +00:00
|
|
|
|
|
|
|
# ddp does
|
|
|
|
trainer = Trainer(strategy="ddp", **trainer_kwargs)
|
|
|
|
trainer.fit(compiled_model)
|
|
|
|
|
|
|
|
# an exception is raised
|
|
|
|
trainer = Trainer(**trainer_kwargs)
|
|
|
|
with pytest.raises(TypeError, match="must be a `Light"):
|
|
|
|
trainer.fit(object())
|
|
|
|
|
|
|
|
|
|
|
|
@RunIf(min_torch="2.0.0")
|
|
|
|
def test_compile_uncompile():
|
2023-02-28 17:38:40 +00:00
|
|
|
skip_if_unsupported()
|
|
|
|
|
2023-02-10 12:45:40 +00:00
|
|
|
model = BoringModel()
|
|
|
|
compiled_model = torch.compile(model)
|
|
|
|
|
|
|
|
def has_dynamo(fn):
|
|
|
|
return any(el for el in dir(fn) if el.startswith("_torchdynamo"))
|
|
|
|
|
|
|
|
from_compiled_model = from_compiled(compiled_model)
|
|
|
|
assert isinstance(from_compiled_model, LightningModule)
|
|
|
|
assert from_compiled_model._compiler_ctx is not None
|
|
|
|
assert has_dynamo(from_compiled_model.forward)
|
|
|
|
assert has_dynamo(from_compiled_model.training_step)
|
|
|
|
assert has_dynamo(from_compiled_model.validation_step)
|
|
|
|
assert has_dynamo(from_compiled_model.test_step)
|
|
|
|
assert has_dynamo(from_compiled_model.predict_step)
|
|
|
|
|
|
|
|
to_uncompiled_model = to_uncompiled(model)
|
|
|
|
assert to_uncompiled_model._compiler_ctx is None
|
|
|
|
assert to_uncompiled_model.forward == model.forward
|
|
|
|
assert to_uncompiled_model.training_step == model.training_step
|
|
|
|
assert to_uncompiled_model.validation_step == model.validation_step
|
|
|
|
assert to_uncompiled_model.test_step == model.test_step
|
|
|
|
assert to_uncompiled_model.predict_step == model.predict_step
|
|
|
|
assert not has_dynamo(to_uncompiled_model.forward)
|
|
|
|
assert not has_dynamo(to_uncompiled_model.training_step)
|
|
|
|
assert not has_dynamo(to_uncompiled_model.validation_step)
|
|
|
|
assert not has_dynamo(to_uncompiled_model.test_step)
|
|
|
|
assert not has_dynamo(to_uncompiled_model.predict_step)
|