2021-06-11 11:47:00 +00:00
|
|
|
# Copyright The Lightning AI team.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
import pytest
|
2023-02-02 10:06:45 +00:00
|
|
|
from lightning.pytorch import LightningDataModule
|
|
|
|
from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel
|
2023-09-20 17:45:12 +00:00
|
|
|
from lightning.pytorch.utilities.model_helpers import _restricted_classmethod, is_overridden
|
ruff: replace isort with ruff +TPU (#17684)
* ruff: replace isort with ruff
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fixing & imports
* lines in warning test
* docs
* fix enum import
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fixing
* import
* fix lines
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* type ClusterEnvironment
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2023-09-26 15:54:55 +00:00
|
|
|
from lightning_utilities import module_available
|
2021-06-11 11:47:00 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_is_overridden():
|
|
|
|
# edge cases
|
|
|
|
assert not is_overridden("whatever", None)
|
|
|
|
with pytest.raises(ValueError, match="Expected a parent"):
|
|
|
|
is_overridden("whatever", object())
|
2022-09-12 13:16:57 +00:00
|
|
|
model = BoringModel()
|
2021-06-11 11:47:00 +00:00
|
|
|
assert not is_overridden("whatever", model)
|
|
|
|
assert not is_overridden("whatever", model, parent=LightningDataModule)
|
|
|
|
# normal usage
|
|
|
|
assert is_overridden("training_step", model)
|
2022-09-12 13:16:57 +00:00
|
|
|
datamodule = BoringDataModule()
|
2021-06-11 11:47:00 +00:00
|
|
|
assert is_overridden("train_dataloader", datamodule)
|
2023-05-05 11:16:19 +00:00
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skipif(
|
|
|
|
not module_available("lightning") or not module_available("pytorch_lightning"),
|
|
|
|
reason="This test is ONLY relevant for the UNIFIED package",
|
|
|
|
)
|
|
|
|
def test_mixed_imports_unified():
|
|
|
|
from lightning.pytorch.utilities.compile import _maybe_unwrap_optimized as new_unwrap
|
|
|
|
from lightning.pytorch.utilities.model_helpers import is_overridden as new_is_overridden
|
|
|
|
from pytorch_lightning.callbacks import EarlyStopping as OldEarlyStopping
|
|
|
|
from pytorch_lightning.demos.boring_classes import BoringModel as OldBoringModel
|
|
|
|
|
|
|
|
model = OldBoringModel()
|
|
|
|
with pytest.raises(TypeError, match=r"`pytorch_lightning` object \(BoringModel\) to a `lightning.pytorch`"):
|
|
|
|
new_unwrap(model)
|
|
|
|
|
|
|
|
with pytest.raises(TypeError, match=r"`pytorch_lightning` object \(EarlyStopping\) to a `lightning.pytorch`"):
|
|
|
|
new_is_overridden("on_fit_start", OldEarlyStopping("foo"))
|
2023-09-20 17:45:12 +00:00
|
|
|
|
|
|
|
|
|
|
|
class RestrictedClass:
|
|
|
|
@_restricted_classmethod
|
|
|
|
def restricted_cmethod(cls):
|
|
|
|
# Can only be called on the class type
|
|
|
|
pass
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def cmethod(cls):
|
|
|
|
# Can be called on instance or class type
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
def test_restricted_classmethod():
|
|
|
|
with pytest.raises(TypeError, match="cannot be called on an instance"):
|
|
|
|
RestrictedClass().restricted_cmethod()
|
|
|
|
|
|
|
|
RestrictedClass.restricted_cmethod() # no exception
|