diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index f29f6e9159..fa08057733 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -29,7 +29,6 @@ import cloudpickle import pytest import torch import yaml -from omegaconf import Container, OmegaConf from torch import optim import pytorch_lightning as pl @@ -39,9 +38,13 @@ from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE from tests.helpers import BoringModel from tests.helpers.runif import RunIf +if _OMEGACONF_AVAILABLE: + from omegaconf import Container, OmegaConf + def test_model_checkpoint_state_key(): early_stopping = ModelCheckpoint(monitor="val_loss") @@ -1094,8 +1097,8 @@ def test_current_score_when_nan(tmpdir, mode: str): assert model_checkpoint.current_score == expected -@pytest.mark.parametrize("hparams_type", [dict, Container]) -def test_hparams_type(tmpdir, hparams_type): +@pytest.mark.parametrize("use_omegaconf", [False, pytest.param(True, marks=RunIf(omegaconf=True))]) +def test_hparams_type(tmpdir, use_omegaconf): class TestModel(BoringModel): def __init__(self, hparams): super().__init__() @@ -1113,15 +1116,15 @@ def test_hparams_type(tmpdir, hparams_type): enable_model_summary=False, ) hp = {"test_hp_0": 1, "test_hp_1": 2} - hp = OmegaConf.create(hp) if hparams_type == Container else Namespace(**hp) + hp = OmegaConf.create(hp) if use_omegaconf else Namespace(**hp) model = TestModel(hp) trainer.fit(model) ckpt = trainer.checkpoint_connector.dump_checkpoint() - if hparams_type == Container: - assert isinstance(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY], hparams_type) + if use_omegaconf: + assert isinstance(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY], Container) else: # make sure it's not AttributeDict - assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) is hparams_type + assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) is dict def test_ckpt_version_after_rerun_new_trainer(tmpdir): diff --git a/tests/core/test_datamodules.py b/tests/core/test_datamodules.py index 59b68a723e..57574c074c 100644 --- a/tests/core/test_datamodules.py +++ b/tests/core/test_datamodules.py @@ -20,12 +20,11 @@ from unittest.mock import call, PropertyMock import pytest import torch -from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities import AttributeDict +from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, AttributeDict from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden from tests.helpers import BoringDataModule, BoringModel @@ -34,6 +33,9 @@ from tests.helpers.runif import RunIf from tests.helpers.simple_models import ClassificationModel from tests.helpers.utils import reset_seed +if _OMEGACONF_AVAILABLE: + from omegaconf import OmegaConf + @mock.patch("pytorch_lightning.trainer.trainer.Trainer.node_rank", new_callable=PropertyMock) @mock.patch("pytorch_lightning.trainer.trainer.Trainer.local_rank", new_callable=PropertyMock) @@ -440,8 +442,9 @@ def test_hyperparameters_saving(): data = DataModuleWithHparams_1({"hello": "world"}, "foo", kwarg0="bar") assert data.hparams == AttributeDict({"hello": "world"}) - data = DataModuleWithHparams_1(OmegaConf.create({"hello": "world"}), "foo", kwarg0="bar") - assert data.hparams == OmegaConf.create({"hello": "world"}) + if _OMEGACONF_AVAILABLE: + data = DataModuleWithHparams_1(OmegaConf.create({"hello": "world"}), "foo", kwarg0="bar") + assert data.hparams == OmegaConf.create({"hello": "world"}) def test_define_as_dataclass(): diff --git a/tests/helpers/runif.py b/tests/helpers/runif.py index 544e86d146..8a9f707e69 100644 --- a/tests/helpers/runif.py +++ b/tests/helpers/runif.py @@ -27,6 +27,7 @@ from pytorch_lightning.utilities import ( _FAIRSCALE_FULLY_SHARDED_AVAILABLE, _HOROVOD_AVAILABLE, _IPU_AVAILABLE, + _OMEGACONF_AVAILABLE, _RICH_AVAILABLE, _TORCH_QUANTIZE_AVAILABLE, _TPU_AVAILABLE, @@ -70,6 +71,7 @@ class RunIf: deepspeed: bool = False, rich: bool = False, skip_49370: bool = False, + omegaconf: bool = False, **kwargs, ): """ @@ -89,9 +91,10 @@ class RunIf: standalone: Mark the test as standalone, our CI will run it in a separate process. fairscale: Require that facebookresearch/fairscale is installed. fairscale_fully_sharded: Require that `fairscale` fully sharded support is available. - deepspeed: Require that Microsoft/DeepSpeed is installed. + deepspeed: Require that microsoft/DeepSpeed is installed. rich: Require that willmcgugan/rich is installed. skip_49370: Skip the test as it's impacted by https://github.com/pytorch/pytorch/issues/49370. + omegaconf: Require that omry/omegaconf is installed. **kwargs: Any :class:`pytest.mark.skipif` keyword arguments. """ conditions = [] @@ -177,6 +180,10 @@ class RunIf: conditions.append(ge_3_9 and old_torch) reasons.append("Impacted by https://github.com/pytorch/pytorch/issues/49370") + if omegaconf: + conditions.append(not _OMEGACONF_AVAILABLE) + reasons.append("omegaconf") + reasons = [rs for cond, rs in zip(conditions, reasons) if cond] return pytest.mark.skipif( *args, condition=any(conditions), reason=f"Requires: [{' + '.join(reasons)}]", **kwargs diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py index 0a99c058ef..d0119b3e86 100644 --- a/tests/loggers/test_tensorboard.py +++ b/tests/loggers/test_tensorboard.py @@ -21,13 +21,16 @@ import numpy as np import pytest import torch import yaml -from omegaconf import OmegaConf from pytorch_lightning import Trainer from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.loggers.base import LoggerCollection -from pytorch_lightning.utilities.imports import _compare_version +from pytorch_lightning.utilities.imports import _compare_version, _OMEGACONF_AVAILABLE from tests.helpers import BoringModel +from tests.helpers.runif import RunIf + +if _OMEGACONF_AVAILABLE: + from omegaconf import OmegaConf @pytest.mark.skipif( @@ -205,6 +208,7 @@ def test_tensorboard_log_hparams_and_metrics(tmpdir): logger.log_hyperparams(hparams, metrics) +@RunIf(omegaconf=True) def test_tensorboard_log_omegaconf_hparams_and_metrics(tmpdir): logger = TensorBoardLogger(tmpdir, default_hp_metric=False) hparams = { @@ -214,8 +218,6 @@ def test_tensorboard_log_omegaconf_hparams_and_metrics(tmpdir): "bool": True, "dict": {"a": {"b": "c"}}, "list": [1, 2, 3], - # "namespace": Namespace(foo=Namespace(bar="buzz")), - # "layer": torch.nn.BatchNorm1d, } hparams = OmegaConf.create(hparams) diff --git a/tests/models/test_hparams.py b/tests/models/test_hparams.py index bfdf81f64b..71ee44c66d 100644 --- a/tests/models/test_hparams.py +++ b/tests/models/test_hparams.py @@ -24,21 +24,24 @@ import cloudpickle import pytest import torch from fsspec.implementations.local import LocalFileSystem -from omegaconf import Container, OmegaConf -from omegaconf.dictconfig import DictConfig from torch.utils.data import DataLoader from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.saving import load_hparams_from_yaml, save_hparams_to_yaml -from pytorch_lightning.utilities import _HYDRA_EXPERIMENTAL_AVAILABLE, AttributeDict, is_picklable +from pytorch_lightning.utilities import _HYDRA_EXPERIMENTAL_AVAILABLE, _OMEGACONF_AVAILABLE, AttributeDict, is_picklable from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel, RandomDataset +from tests.helpers.runif import RunIf if _HYDRA_EXPERIMENTAL_AVAILABLE: from hydra.experimental import compose, initialize +if _OMEGACONF_AVAILABLE: + from omegaconf import Container, OmegaConf + from omegaconf.dictconfig import DictConfig + class SaveHparamsModel(BoringModel): """Tests that a model can take an object.""" @@ -117,6 +120,7 @@ def test_dict_hparams(tmpdir, cls): _run_standard_hparams_test(tmpdir, model, cls) +@RunIf(omegaconf=True) @pytest.mark.parametrize("cls", [SaveHparamsModel, SaveHparamsDecoratedModel]) def test_omega_conf_hparams(tmpdir, cls): # init model @@ -275,10 +279,18 @@ class UnconventionalArgsBoringModel(CustomBoringModel): obj.save_hyperparameters() -class DictConfSubClassBoringModel(SubClassBoringModel): - def __init__(self, *args, dict_conf=OmegaConf.create(dict(my_param="something")), **kwargs): - super().__init__(*args, **kwargs) - self.save_hyperparameters() +if _OMEGACONF_AVAILABLE: + + class DictConfSubClassBoringModel(SubClassBoringModel): + def __init__(self, *args, dict_conf=OmegaConf.create(dict(my_param="something")), **kwargs): + super().__init__(*args, **kwargs) + self.save_hyperparameters() + + +else: + + class DictConfSubClassBoringModel: + ... @pytest.mark.parametrize( @@ -290,7 +302,7 @@ class DictConfSubClassBoringModel(SubClassBoringModel): SubSubClassBoringModel, AggSubClassBoringModel, UnconventionalArgsBoringModel, - DictConfSubClassBoringModel, + pytest.param(DictConfSubClassBoringModel, marks=RunIf(omegaconf=True)), ], ) def test_collect_init_arguments(tmpdir, cls): @@ -383,31 +395,6 @@ def test_collect_init_arguments_with_local_vars(cls): assert model.hparams["arg2"] == 2 -# @pytest.mark.parametrize("cls,config", [ -# (SaveHparamsModel, Namespace(my_arg=42)), -# (SaveHparamsModel, dict(my_arg=42)), -# (SaveHparamsModel, OmegaConf.create(dict(my_arg=42))), -# (AssignHparamsModel, Namespace(my_arg=42)), -# (AssignHparamsModel, dict(my_arg=42)), -# (AssignHparamsModel, OmegaConf.create(dict(my_arg=42))), -# ]) -# def test_single_config_models(tmpdir, cls, config): -# """ Test that the model automatically saves the arguments passed into the constructor """ -# model = cls(config) -# -# # no matter how you do it, it should be assigned -# assert model.hparams.my_arg == 42 -# -# # verify that the checkpoint saved the correct values -# trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5) -# trainer.fit(model) -# -# # verify that model loads correctly -# raw_checkpoint_path = _raw_checkpoint_path(trainer) -# model = cls.load_from_checkpoint(raw_checkpoint_path) -# assert model.hparams.my_arg == 42 - - class AnotherArgModel(BoringModel): def __init__(self, arg1): super().__init__() @@ -511,8 +498,9 @@ def test_hparams_save_yaml(tmpdir): save_hparams_to_yaml(path_yaml, AttributeDict(hparams)) _compare_params(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams) - save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams)) - _compare_params(load_hparams_from_yaml(path_yaml), hparams) + if _OMEGACONF_AVAILABLE: + save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams)) + _compare_params(load_hparams_from_yaml(path_yaml), hparams) class NoArgsSubClassBoringModel(CustomBoringModel): diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 6416ef88fb..e440f5f703 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -26,7 +26,6 @@ from unittest.mock import ANY, call, patch import cloudpickle import pytest import torch -from omegaconf import OmegaConf from torch.nn.parallel.distributed import DistributedDataParallel from torch.optim import SGD from torch.utils.data import DataLoader, IterableDataset @@ -51,6 +50,7 @@ from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities import _AcceleratorType, _StrategyType from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException +from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE from pytorch_lightning.utilities.seed import seed_everything from tests.base import EvalModelTemplate from tests.helpers import BoringModel, RandomDataset @@ -59,6 +59,9 @@ from tests.helpers.datamodules import ClassifDataModule from tests.helpers.runif import RunIf from tests.helpers.simple_models import ClassificationModel +if _OMEGACONF_AVAILABLE: + from omegaconf import OmegaConf + @pytest.mark.parametrize("url_ckpt", [True, False]) def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt): @@ -1271,12 +1274,12 @@ def test_trainer_subclassing(): TrainerSubclass(abcdefg="unknown_arg") -@pytest.mark.parametrize( - "trainer_params", [OmegaConf.create(dict(max_epochs=1, gpus=1)), OmegaConf.create(dict(max_epochs=1, gpus=[0]))] -) -@RunIf(min_gpus=1) -def test_trainer_omegaconf(trainer_params): - Trainer(**trainer_params) +@RunIf(omegaconf=True) +@pytest.mark.parametrize("trainer_params", [{"max_epochs": 1, "gpus": 1}, {"max_epochs": 1, "gpus": [0]}]) +@mock.patch("torch.cuda.device_count", return_value=1) +def test_trainer_omegaconf(_, trainer_params): + config = OmegaConf.create(trainer_params) + Trainer(**config) def test_trainer_pickle(tmpdir):