2023-02-10 12:45:40 +00:00
|
|
|
# Copyright The Lightning AI team.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
2023-02-28 17:38:40 +00:00
|
|
|
import sys
|
2023-07-26 19:41:58 +00:00
|
|
|
from unittest import mock
|
2023-02-10 12:45:40 +00:00
|
|
|
|
|
|
|
import pytest
|
|
|
|
import torch
|
|
|
|
from lightning.pytorch import LightningModule, Trainer
|
|
|
|
from lightning.pytorch.demos.boring_classes import BoringModel
|
|
|
|
from lightning.pytorch.utilities.compile import from_compiled, to_uncompiled
|
ruff: replace isort with ruff +TPU (#17684)
* ruff: replace isort with ruff
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fixing & imports
* lines in warning test
* docs
* fix enum import
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fixing
* import
* fix lines
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* type ClusterEnvironment
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2023-09-26 15:54:55 +00:00
|
|
|
from lightning_utilities.core import module_available
|
|
|
|
|
2023-02-10 12:45:40 +00:00
|
|
|
from tests_pytorch.conftest import mock_cuda_count
|
|
|
|
from tests_pytorch.helpers.runif import RunIf
|
|
|
|
|
|
|
|
|
2023-04-18 23:09:42 +00:00
|
|
|
@RunIf(dynamo=True)
|
2023-02-28 17:38:40 +00:00
|
|
|
@pytest.mark.skipif(sys.platform == "darwin", reason="https://github.com/pytorch/pytorch/issues/95708")
|
2023-07-26 19:41:58 +00:00
|
|
|
@mock.patch("lightning.pytorch.trainer.call._call_and_handle_interrupt")
|
|
|
|
def test_trainer_compiled_model(_, tmp_path, monkeypatch):
|
2023-02-10 12:45:40 +00:00
|
|
|
trainer_kwargs = {
|
|
|
|
"default_root_dir": tmp_path,
|
|
|
|
"fast_dev_run": True,
|
|
|
|
"logger": False,
|
|
|
|
"enable_checkpointing": False,
|
|
|
|
"enable_model_summary": False,
|
|
|
|
"enable_progress_bar": False,
|
|
|
|
}
|
|
|
|
|
|
|
|
model = BoringModel()
|
|
|
|
compiled_model = torch.compile(model)
|
|
|
|
assert model._compiler_ctx is compiled_model._compiler_ctx # shared reference
|
|
|
|
|
|
|
|
# can train with compiled model
|
|
|
|
trainer = Trainer(**trainer_kwargs)
|
|
|
|
trainer.fit(compiled_model)
|
|
|
|
assert trainer.model._compiler_ctx["compiler"] == "dynamo"
|
|
|
|
|
|
|
|
# the compiled model can be uncompiled
|
|
|
|
to_uncompiled_model = to_uncompiled(compiled_model)
|
|
|
|
assert model._compiler_ctx is None
|
|
|
|
assert compiled_model._compiler_ctx is None
|
|
|
|
assert to_uncompiled_model._compiler_ctx is None
|
|
|
|
|
|
|
|
# the compiled model needs to be passed
|
|
|
|
with pytest.raises(ValueError, match="required to be a compiled LightningModule"):
|
|
|
|
to_uncompiled(to_uncompiled_model)
|
|
|
|
|
|
|
|
# the uncompiled model can be fitted
|
|
|
|
trainer = Trainer(**trainer_kwargs)
|
|
|
|
trainer.fit(model)
|
|
|
|
assert trainer.model._compiler_ctx is None
|
|
|
|
|
|
|
|
# some strategies do not support it
|
2023-02-16 05:12:08 +00:00
|
|
|
if module_available("deepspeed"):
|
|
|
|
compiled_model = torch.compile(model)
|
|
|
|
mock_cuda_count(monkeypatch, 2)
|
|
|
|
trainer = Trainer(strategy="deepspeed", accelerator="cuda", **trainer_kwargs)
|
|
|
|
with pytest.raises(RuntimeError, match="Using a compiled model is incompatible with the current strategy.*"):
|
|
|
|
trainer.fit(compiled_model)
|
2023-02-10 12:45:40 +00:00
|
|
|
|
|
|
|
# ddp does
|
|
|
|
trainer = Trainer(strategy="ddp", **trainer_kwargs)
|
|
|
|
trainer.fit(compiled_model)
|
|
|
|
|
|
|
|
# an exception is raised
|
|
|
|
trainer = Trainer(**trainer_kwargs)
|
|
|
|
with pytest.raises(TypeError, match="must be a `Light"):
|
|
|
|
trainer.fit(object())
|
|
|
|
|
|
|
|
|
2023-04-18 23:09:42 +00:00
|
|
|
@RunIf(dynamo=True)
|
2023-02-10 12:45:40 +00:00
|
|
|
def test_compile_uncompile():
|
|
|
|
model = BoringModel()
|
|
|
|
compiled_model = torch.compile(model)
|
|
|
|
|
|
|
|
def has_dynamo(fn):
|
|
|
|
return any(el for el in dir(fn) if el.startswith("_torchdynamo"))
|
|
|
|
|
|
|
|
from_compiled_model = from_compiled(compiled_model)
|
|
|
|
assert isinstance(from_compiled_model, LightningModule)
|
|
|
|
assert from_compiled_model._compiler_ctx is not None
|
|
|
|
assert has_dynamo(from_compiled_model.forward)
|
|
|
|
assert has_dynamo(from_compiled_model.training_step)
|
|
|
|
assert has_dynamo(from_compiled_model.validation_step)
|
|
|
|
assert has_dynamo(from_compiled_model.test_step)
|
|
|
|
assert has_dynamo(from_compiled_model.predict_step)
|
|
|
|
|
|
|
|
to_uncompiled_model = to_uncompiled(model)
|
|
|
|
assert to_uncompiled_model._compiler_ctx is None
|
|
|
|
assert to_uncompiled_model.forward == model.forward
|
|
|
|
assert to_uncompiled_model.training_step == model.training_step
|
|
|
|
assert to_uncompiled_model.validation_step == model.validation_step
|
|
|
|
assert to_uncompiled_model.test_step == model.test_step
|
|
|
|
assert to_uncompiled_model.predict_step == model.predict_step
|
|
|
|
assert not has_dynamo(to_uncompiled_model.forward)
|
|
|
|
assert not has_dynamo(to_uncompiled_model.training_step)
|
|
|
|
assert not has_dynamo(to_uncompiled_model.validation_step)
|
|
|
|
assert not has_dynamo(to_uncompiled_model.test_step)
|
|
|
|
assert not has_dynamo(to_uncompiled_model.predict_step)
|
2023-03-09 00:57:31 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(sys.platform == "darwin", reason="https://github.com/pytorch/pytorch/issues/95708")
|
2023-04-18 23:09:42 +00:00
|
|
|
@RunIf(dynamo=True)
|
2023-03-09 00:57:31 +00:00
|
|
|
def test_trainer_compiled_model_that_logs(tmp_path):
|
|
|
|
class MyModel(BoringModel):
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
|
|
loss = self.step(batch)
|
|
|
|
self.log("loss", loss)
|
|
|
|
return loss
|
|
|
|
|
|
|
|
model = MyModel()
|
|
|
|
compiled_model = torch.compile(model)
|
|
|
|
|
|
|
|
trainer = Trainer(
|
|
|
|
default_root_dir=tmp_path,
|
|
|
|
fast_dev_run=True,
|
|
|
|
enable_checkpointing=False,
|
|
|
|
enable_model_summary=False,
|
|
|
|
enable_progress_bar=False,
|
|
|
|
)
|
|
|
|
trainer.fit(compiled_model)
|
|
|
|
|
|
|
|
assert set(trainer.callback_metrics) == {"loss"}
|
2023-03-29 19:43:28 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(sys.platform == "darwin", reason="https://github.com/pytorch/pytorch/issues/95708")
|
2023-04-18 23:09:42 +00:00
|
|
|
@RunIf(dynamo=True)
|
2023-03-29 19:43:28 +00:00
|
|
|
def test_trainer_compiled_model_test(tmp_path):
|
|
|
|
model = BoringModel()
|
|
|
|
compiled_model = torch.compile(model)
|
|
|
|
|
|
|
|
trainer = Trainer(
|
|
|
|
default_root_dir=tmp_path,
|
|
|
|
fast_dev_run=True,
|
|
|
|
enable_checkpointing=False,
|
|
|
|
enable_model_summary=False,
|
|
|
|
enable_progress_bar=False,
|
|
|
|
)
|
|
|
|
trainer.test(compiled_model)
|