Do not require omegaconf to run tests (#10832)
This commit is contained in:
parent
a81accb2ad
commit
38ed26ec5a
|
@ -29,7 +29,6 @@ import cloudpickle
|
|||
import pytest
|
||||
import torch
|
||||
import yaml
|
||||
from omegaconf import Container, OmegaConf
|
||||
from torch import optim
|
||||
|
||||
import pytorch_lightning as pl
|
||||
|
@ -39,9 +38,13 @@ from pytorch_lightning.callbacks import ModelCheckpoint
|
|||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE
|
||||
from tests.helpers import BoringModel
|
||||
from tests.helpers.runif import RunIf
|
||||
|
||||
if _OMEGACONF_AVAILABLE:
|
||||
from omegaconf import Container, OmegaConf
|
||||
|
||||
|
||||
def test_model_checkpoint_state_key():
|
||||
early_stopping = ModelCheckpoint(monitor="val_loss")
|
||||
|
@ -1094,8 +1097,8 @@ def test_current_score_when_nan(tmpdir, mode: str):
|
|||
assert model_checkpoint.current_score == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("hparams_type", [dict, Container])
|
||||
def test_hparams_type(tmpdir, hparams_type):
|
||||
@pytest.mark.parametrize("use_omegaconf", [False, pytest.param(True, marks=RunIf(omegaconf=True))])
|
||||
def test_hparams_type(tmpdir, use_omegaconf):
|
||||
class TestModel(BoringModel):
|
||||
def __init__(self, hparams):
|
||||
super().__init__()
|
||||
|
@ -1113,15 +1116,15 @@ def test_hparams_type(tmpdir, hparams_type):
|
|||
enable_model_summary=False,
|
||||
)
|
||||
hp = {"test_hp_0": 1, "test_hp_1": 2}
|
||||
hp = OmegaConf.create(hp) if hparams_type == Container else Namespace(**hp)
|
||||
hp = OmegaConf.create(hp) if use_omegaconf else Namespace(**hp)
|
||||
model = TestModel(hp)
|
||||
trainer.fit(model)
|
||||
ckpt = trainer.checkpoint_connector.dump_checkpoint()
|
||||
if hparams_type == Container:
|
||||
assert isinstance(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY], hparams_type)
|
||||
if use_omegaconf:
|
||||
assert isinstance(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY], Container)
|
||||
else:
|
||||
# make sure it's not AttributeDict
|
||||
assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) is hparams_type
|
||||
assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) is dict
|
||||
|
||||
|
||||
def test_ckpt_version_after_rerun_new_trainer(tmpdir):
|
||||
|
|
|
@ -20,12 +20,11 @@ from unittest.mock import call, PropertyMock
|
|||
|
||||
import pytest
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from pytorch_lightning import LightningDataModule, Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities import AttributeDict
|
||||
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, AttributeDict
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from tests.helpers import BoringDataModule, BoringModel
|
||||
|
@ -34,6 +33,9 @@ from tests.helpers.runif import RunIf
|
|||
from tests.helpers.simple_models import ClassificationModel
|
||||
from tests.helpers.utils import reset_seed
|
||||
|
||||
if _OMEGACONF_AVAILABLE:
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
@mock.patch("pytorch_lightning.trainer.trainer.Trainer.node_rank", new_callable=PropertyMock)
|
||||
@mock.patch("pytorch_lightning.trainer.trainer.Trainer.local_rank", new_callable=PropertyMock)
|
||||
|
@ -440,8 +442,9 @@ def test_hyperparameters_saving():
|
|||
data = DataModuleWithHparams_1({"hello": "world"}, "foo", kwarg0="bar")
|
||||
assert data.hparams == AttributeDict({"hello": "world"})
|
||||
|
||||
data = DataModuleWithHparams_1(OmegaConf.create({"hello": "world"}), "foo", kwarg0="bar")
|
||||
assert data.hparams == OmegaConf.create({"hello": "world"})
|
||||
if _OMEGACONF_AVAILABLE:
|
||||
data = DataModuleWithHparams_1(OmegaConf.create({"hello": "world"}), "foo", kwarg0="bar")
|
||||
assert data.hparams == OmegaConf.create({"hello": "world"})
|
||||
|
||||
|
||||
def test_define_as_dataclass():
|
||||
|
|
|
@ -27,6 +27,7 @@ from pytorch_lightning.utilities import (
|
|||
_FAIRSCALE_FULLY_SHARDED_AVAILABLE,
|
||||
_HOROVOD_AVAILABLE,
|
||||
_IPU_AVAILABLE,
|
||||
_OMEGACONF_AVAILABLE,
|
||||
_RICH_AVAILABLE,
|
||||
_TORCH_QUANTIZE_AVAILABLE,
|
||||
_TPU_AVAILABLE,
|
||||
|
@ -70,6 +71,7 @@ class RunIf:
|
|||
deepspeed: bool = False,
|
||||
rich: bool = False,
|
||||
skip_49370: bool = False,
|
||||
omegaconf: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
|
@ -89,9 +91,10 @@ class RunIf:
|
|||
standalone: Mark the test as standalone, our CI will run it in a separate process.
|
||||
fairscale: Require that facebookresearch/fairscale is installed.
|
||||
fairscale_fully_sharded: Require that `fairscale` fully sharded support is available.
|
||||
deepspeed: Require that Microsoft/DeepSpeed is installed.
|
||||
deepspeed: Require that microsoft/DeepSpeed is installed.
|
||||
rich: Require that willmcgugan/rich is installed.
|
||||
skip_49370: Skip the test as it's impacted by https://github.com/pytorch/pytorch/issues/49370.
|
||||
omegaconf: Require that omry/omegaconf is installed.
|
||||
**kwargs: Any :class:`pytest.mark.skipif` keyword arguments.
|
||||
"""
|
||||
conditions = []
|
||||
|
@ -177,6 +180,10 @@ class RunIf:
|
|||
conditions.append(ge_3_9 and old_torch)
|
||||
reasons.append("Impacted by https://github.com/pytorch/pytorch/issues/49370")
|
||||
|
||||
if omegaconf:
|
||||
conditions.append(not _OMEGACONF_AVAILABLE)
|
||||
reasons.append("omegaconf")
|
||||
|
||||
reasons = [rs for cond, rs in zip(conditions, reasons) if cond]
|
||||
return pytest.mark.skipif(
|
||||
*args, condition=any(conditions), reason=f"Requires: [{' + '.join(reasons)}]", **kwargs
|
||||
|
|
|
@ -21,13 +21,16 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
import yaml
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from pytorch_lightning.loggers.base import LoggerCollection
|
||||
from pytorch_lightning.utilities.imports import _compare_version
|
||||
from pytorch_lightning.utilities.imports import _compare_version, _OMEGACONF_AVAILABLE
|
||||
from tests.helpers import BoringModel
|
||||
from tests.helpers.runif import RunIf
|
||||
|
||||
if _OMEGACONF_AVAILABLE:
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
|
@ -205,6 +208,7 @@ def test_tensorboard_log_hparams_and_metrics(tmpdir):
|
|||
logger.log_hyperparams(hparams, metrics)
|
||||
|
||||
|
||||
@RunIf(omegaconf=True)
|
||||
def test_tensorboard_log_omegaconf_hparams_and_metrics(tmpdir):
|
||||
logger = TensorBoardLogger(tmpdir, default_hp_metric=False)
|
||||
hparams = {
|
||||
|
@ -214,8 +218,6 @@ def test_tensorboard_log_omegaconf_hparams_and_metrics(tmpdir):
|
|||
"bool": True,
|
||||
"dict": {"a": {"b": "c"}},
|
||||
"list": [1, 2, 3],
|
||||
# "namespace": Namespace(foo=Namespace(bar="buzz")),
|
||||
# "layer": torch.nn.BatchNorm1d,
|
||||
}
|
||||
hparams = OmegaConf.create(hparams)
|
||||
|
||||
|
|
|
@ -24,21 +24,24 @@ import cloudpickle
|
|||
import pytest
|
||||
import torch
|
||||
from fsspec.implementations.local import LocalFileSystem
|
||||
from omegaconf import Container, OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_lightning import LightningModule, Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.core.datamodule import LightningDataModule
|
||||
from pytorch_lightning.core.saving import load_hparams_from_yaml, save_hparams_to_yaml
|
||||
from pytorch_lightning.utilities import _HYDRA_EXPERIMENTAL_AVAILABLE, AttributeDict, is_picklable
|
||||
from pytorch_lightning.utilities import _HYDRA_EXPERIMENTAL_AVAILABLE, _OMEGACONF_AVAILABLE, AttributeDict, is_picklable
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.helpers import BoringModel, RandomDataset
|
||||
from tests.helpers.runif import RunIf
|
||||
|
||||
if _HYDRA_EXPERIMENTAL_AVAILABLE:
|
||||
from hydra.experimental import compose, initialize
|
||||
|
||||
if _OMEGACONF_AVAILABLE:
|
||||
from omegaconf import Container, OmegaConf
|
||||
from omegaconf.dictconfig import DictConfig
|
||||
|
||||
|
||||
class SaveHparamsModel(BoringModel):
|
||||
"""Tests that a model can take an object."""
|
||||
|
@ -117,6 +120,7 @@ def test_dict_hparams(tmpdir, cls):
|
|||
_run_standard_hparams_test(tmpdir, model, cls)
|
||||
|
||||
|
||||
@RunIf(omegaconf=True)
|
||||
@pytest.mark.parametrize("cls", [SaveHparamsModel, SaveHparamsDecoratedModel])
|
||||
def test_omega_conf_hparams(tmpdir, cls):
|
||||
# init model
|
||||
|
@ -275,10 +279,18 @@ class UnconventionalArgsBoringModel(CustomBoringModel):
|
|||
obj.save_hyperparameters()
|
||||
|
||||
|
||||
class DictConfSubClassBoringModel(SubClassBoringModel):
|
||||
def __init__(self, *args, dict_conf=OmegaConf.create(dict(my_param="something")), **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.save_hyperparameters()
|
||||
if _OMEGACONF_AVAILABLE:
|
||||
|
||||
class DictConfSubClassBoringModel(SubClassBoringModel):
|
||||
def __init__(self, *args, dict_conf=OmegaConf.create(dict(my_param="something")), **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.save_hyperparameters()
|
||||
|
||||
|
||||
else:
|
||||
|
||||
class DictConfSubClassBoringModel:
|
||||
...
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -290,7 +302,7 @@ class DictConfSubClassBoringModel(SubClassBoringModel):
|
|||
SubSubClassBoringModel,
|
||||
AggSubClassBoringModel,
|
||||
UnconventionalArgsBoringModel,
|
||||
DictConfSubClassBoringModel,
|
||||
pytest.param(DictConfSubClassBoringModel, marks=RunIf(omegaconf=True)),
|
||||
],
|
||||
)
|
||||
def test_collect_init_arguments(tmpdir, cls):
|
||||
|
@ -383,31 +395,6 @@ def test_collect_init_arguments_with_local_vars(cls):
|
|||
assert model.hparams["arg2"] == 2
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("cls,config", [
|
||||
# (SaveHparamsModel, Namespace(my_arg=42)),
|
||||
# (SaveHparamsModel, dict(my_arg=42)),
|
||||
# (SaveHparamsModel, OmegaConf.create(dict(my_arg=42))),
|
||||
# (AssignHparamsModel, Namespace(my_arg=42)),
|
||||
# (AssignHparamsModel, dict(my_arg=42)),
|
||||
# (AssignHparamsModel, OmegaConf.create(dict(my_arg=42))),
|
||||
# ])
|
||||
# def test_single_config_models(tmpdir, cls, config):
|
||||
# """ Test that the model automatically saves the arguments passed into the constructor """
|
||||
# model = cls(config)
|
||||
#
|
||||
# # no matter how you do it, it should be assigned
|
||||
# assert model.hparams.my_arg == 42
|
||||
#
|
||||
# # verify that the checkpoint saved the correct values
|
||||
# trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5)
|
||||
# trainer.fit(model)
|
||||
#
|
||||
# # verify that model loads correctly
|
||||
# raw_checkpoint_path = _raw_checkpoint_path(trainer)
|
||||
# model = cls.load_from_checkpoint(raw_checkpoint_path)
|
||||
# assert model.hparams.my_arg == 42
|
||||
|
||||
|
||||
class AnotherArgModel(BoringModel):
|
||||
def __init__(self, arg1):
|
||||
super().__init__()
|
||||
|
@ -511,8 +498,9 @@ def test_hparams_save_yaml(tmpdir):
|
|||
save_hparams_to_yaml(path_yaml, AttributeDict(hparams))
|
||||
_compare_params(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams)
|
||||
|
||||
save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams))
|
||||
_compare_params(load_hparams_from_yaml(path_yaml), hparams)
|
||||
if _OMEGACONF_AVAILABLE:
|
||||
save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams))
|
||||
_compare_params(load_hparams_from_yaml(path_yaml), hparams)
|
||||
|
||||
|
||||
class NoArgsSubClassBoringModel(CustomBoringModel):
|
||||
|
|
|
@ -26,7 +26,6 @@ from unittest.mock import ANY, call, patch
|
|||
import cloudpickle
|
||||
import pytest
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
from torch.optim import SGD
|
||||
from torch.utils.data import DataLoader, IterableDataset
|
||||
|
@ -51,6 +50,7 @@ from pytorch_lightning.trainer.states import TrainerFn
|
|||
from pytorch_lightning.utilities import _AcceleratorType, _StrategyType
|
||||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE
|
||||
from pytorch_lightning.utilities.seed import seed_everything
|
||||
from tests.base import EvalModelTemplate
|
||||
from tests.helpers import BoringModel, RandomDataset
|
||||
|
@ -59,6 +59,9 @@ from tests.helpers.datamodules import ClassifDataModule
|
|||
from tests.helpers.runif import RunIf
|
||||
from tests.helpers.simple_models import ClassificationModel
|
||||
|
||||
if _OMEGACONF_AVAILABLE:
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
|
||||
@pytest.mark.parametrize("url_ckpt", [True, False])
|
||||
def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
|
||||
|
@ -1271,12 +1274,12 @@ def test_trainer_subclassing():
|
|||
TrainerSubclass(abcdefg="unknown_arg")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"trainer_params", [OmegaConf.create(dict(max_epochs=1, gpus=1)), OmegaConf.create(dict(max_epochs=1, gpus=[0]))]
|
||||
)
|
||||
@RunIf(min_gpus=1)
|
||||
def test_trainer_omegaconf(trainer_params):
|
||||
Trainer(**trainer_params)
|
||||
@RunIf(omegaconf=True)
|
||||
@pytest.mark.parametrize("trainer_params", [{"max_epochs": 1, "gpus": 1}, {"max_epochs": 1, "gpus": [0]}])
|
||||
@mock.patch("torch.cuda.device_count", return_value=1)
|
||||
def test_trainer_omegaconf(_, trainer_params):
|
||||
config = OmegaConf.create(trainer_params)
|
||||
Trainer(**config)
|
||||
|
||||
|
||||
def test_trainer_pickle(tmpdir):
|
||||
|
|
Loading…
Reference in New Issue