Do not require omegaconf to run tests (#10832)

This commit is contained in:
Carlos Mocholí 2021-11-30 15:48:03 +01:00 committed by GitHub
parent a81accb2ad
commit 38ed26ec5a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 64 additions and 58 deletions

View File

@ -29,7 +29,6 @@ import cloudpickle
import pytest
import torch
import yaml
from omegaconf import Container, OmegaConf
from torch import optim
import pytorch_lightning as pl
@ -39,9 +38,13 @@ from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
if _OMEGACONF_AVAILABLE:
from omegaconf import Container, OmegaConf
def test_model_checkpoint_state_key():
early_stopping = ModelCheckpoint(monitor="val_loss")
@ -1094,8 +1097,8 @@ def test_current_score_when_nan(tmpdir, mode: str):
assert model_checkpoint.current_score == expected
@pytest.mark.parametrize("hparams_type", [dict, Container])
def test_hparams_type(tmpdir, hparams_type):
@pytest.mark.parametrize("use_omegaconf", [False, pytest.param(True, marks=RunIf(omegaconf=True))])
def test_hparams_type(tmpdir, use_omegaconf):
class TestModel(BoringModel):
def __init__(self, hparams):
super().__init__()
@ -1113,15 +1116,15 @@ def test_hparams_type(tmpdir, hparams_type):
enable_model_summary=False,
)
hp = {"test_hp_0": 1, "test_hp_1": 2}
hp = OmegaConf.create(hp) if hparams_type == Container else Namespace(**hp)
hp = OmegaConf.create(hp) if use_omegaconf else Namespace(**hp)
model = TestModel(hp)
trainer.fit(model)
ckpt = trainer.checkpoint_connector.dump_checkpoint()
if hparams_type == Container:
assert isinstance(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY], hparams_type)
if use_omegaconf:
assert isinstance(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY], Container)
else:
# make sure it's not AttributeDict
assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) is hparams_type
assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) is dict
def test_ckpt_version_after_rerun_new_trainer(tmpdir):

View File

@ -20,12 +20,11 @@ from unittest.mock import call, PropertyMock
import pytest
import torch
from omegaconf import OmegaConf
from pytorch_lightning import LightningDataModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import AttributeDict
from pytorch_lightning.utilities import _OMEGACONF_AVAILABLE, AttributeDict
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from tests.helpers import BoringDataModule, BoringModel
@ -34,6 +33,9 @@ from tests.helpers.runif import RunIf
from tests.helpers.simple_models import ClassificationModel
from tests.helpers.utils import reset_seed
if _OMEGACONF_AVAILABLE:
from omegaconf import OmegaConf
@mock.patch("pytorch_lightning.trainer.trainer.Trainer.node_rank", new_callable=PropertyMock)
@mock.patch("pytorch_lightning.trainer.trainer.Trainer.local_rank", new_callable=PropertyMock)
@ -440,8 +442,9 @@ def test_hyperparameters_saving():
data = DataModuleWithHparams_1({"hello": "world"}, "foo", kwarg0="bar")
assert data.hparams == AttributeDict({"hello": "world"})
data = DataModuleWithHparams_1(OmegaConf.create({"hello": "world"}), "foo", kwarg0="bar")
assert data.hparams == OmegaConf.create({"hello": "world"})
if _OMEGACONF_AVAILABLE:
data = DataModuleWithHparams_1(OmegaConf.create({"hello": "world"}), "foo", kwarg0="bar")
assert data.hparams == OmegaConf.create({"hello": "world"})
def test_define_as_dataclass():

View File

@ -27,6 +27,7 @@ from pytorch_lightning.utilities import (
_FAIRSCALE_FULLY_SHARDED_AVAILABLE,
_HOROVOD_AVAILABLE,
_IPU_AVAILABLE,
_OMEGACONF_AVAILABLE,
_RICH_AVAILABLE,
_TORCH_QUANTIZE_AVAILABLE,
_TPU_AVAILABLE,
@ -70,6 +71,7 @@ class RunIf:
deepspeed: bool = False,
rich: bool = False,
skip_49370: bool = False,
omegaconf: bool = False,
**kwargs,
):
"""
@ -89,9 +91,10 @@ class RunIf:
standalone: Mark the test as standalone, our CI will run it in a separate process.
fairscale: Require that facebookresearch/fairscale is installed.
fairscale_fully_sharded: Require that `fairscale` fully sharded support is available.
deepspeed: Require that Microsoft/DeepSpeed is installed.
deepspeed: Require that microsoft/DeepSpeed is installed.
rich: Require that willmcgugan/rich is installed.
skip_49370: Skip the test as it's impacted by https://github.com/pytorch/pytorch/issues/49370.
omegaconf: Require that omry/omegaconf is installed.
**kwargs: Any :class:`pytest.mark.skipif` keyword arguments.
"""
conditions = []
@ -177,6 +180,10 @@ class RunIf:
conditions.append(ge_3_9 and old_torch)
reasons.append("Impacted by https://github.com/pytorch/pytorch/issues/49370")
if omegaconf:
conditions.append(not _OMEGACONF_AVAILABLE)
reasons.append("omegaconf")
reasons = [rs for cond, rs in zip(conditions, reasons) if cond]
return pytest.mark.skipif(
*args, condition=any(conditions), reason=f"Requires: [{' + '.join(reasons)}]", **kwargs

View File

@ -21,13 +21,16 @@ import numpy as np
import pytest
import torch
import yaml
from omegaconf import OmegaConf
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.loggers.base import LoggerCollection
from pytorch_lightning.utilities.imports import _compare_version
from pytorch_lightning.utilities.imports import _compare_version, _OMEGACONF_AVAILABLE
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
if _OMEGACONF_AVAILABLE:
from omegaconf import OmegaConf
@pytest.mark.skipif(
@ -205,6 +208,7 @@ def test_tensorboard_log_hparams_and_metrics(tmpdir):
logger.log_hyperparams(hparams, metrics)
@RunIf(omegaconf=True)
def test_tensorboard_log_omegaconf_hparams_and_metrics(tmpdir):
logger = TensorBoardLogger(tmpdir, default_hp_metric=False)
hparams = {
@ -214,8 +218,6 @@ def test_tensorboard_log_omegaconf_hparams_and_metrics(tmpdir):
"bool": True,
"dict": {"a": {"b": "c"}},
"list": [1, 2, 3],
# "namespace": Namespace(foo=Namespace(bar="buzz")),
# "layer": torch.nn.BatchNorm1d,
}
hparams = OmegaConf.create(hparams)

View File

@ -24,21 +24,24 @@ import cloudpickle
import pytest
import torch
from fsspec.implementations.local import LocalFileSystem
from omegaconf import Container, OmegaConf
from omegaconf.dictconfig import DictConfig
from torch.utils.data import DataLoader
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.saving import load_hparams_from_yaml, save_hparams_to_yaml
from pytorch_lightning.utilities import _HYDRA_EXPERIMENTAL_AVAILABLE, AttributeDict, is_picklable
from pytorch_lightning.utilities import _HYDRA_EXPERIMENTAL_AVAILABLE, _OMEGACONF_AVAILABLE, AttributeDict, is_picklable
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel, RandomDataset
from tests.helpers.runif import RunIf
if _HYDRA_EXPERIMENTAL_AVAILABLE:
from hydra.experimental import compose, initialize
if _OMEGACONF_AVAILABLE:
from omegaconf import Container, OmegaConf
from omegaconf.dictconfig import DictConfig
class SaveHparamsModel(BoringModel):
"""Tests that a model can take an object."""
@ -117,6 +120,7 @@ def test_dict_hparams(tmpdir, cls):
_run_standard_hparams_test(tmpdir, model, cls)
@RunIf(omegaconf=True)
@pytest.mark.parametrize("cls", [SaveHparamsModel, SaveHparamsDecoratedModel])
def test_omega_conf_hparams(tmpdir, cls):
# init model
@ -275,10 +279,18 @@ class UnconventionalArgsBoringModel(CustomBoringModel):
obj.save_hyperparameters()
class DictConfSubClassBoringModel(SubClassBoringModel):
def __init__(self, *args, dict_conf=OmegaConf.create(dict(my_param="something")), **kwargs):
super().__init__(*args, **kwargs)
self.save_hyperparameters()
if _OMEGACONF_AVAILABLE:
class DictConfSubClassBoringModel(SubClassBoringModel):
def __init__(self, *args, dict_conf=OmegaConf.create(dict(my_param="something")), **kwargs):
super().__init__(*args, **kwargs)
self.save_hyperparameters()
else:
class DictConfSubClassBoringModel:
...
@pytest.mark.parametrize(
@ -290,7 +302,7 @@ class DictConfSubClassBoringModel(SubClassBoringModel):
SubSubClassBoringModel,
AggSubClassBoringModel,
UnconventionalArgsBoringModel,
DictConfSubClassBoringModel,
pytest.param(DictConfSubClassBoringModel, marks=RunIf(omegaconf=True)),
],
)
def test_collect_init_arguments(tmpdir, cls):
@ -383,31 +395,6 @@ def test_collect_init_arguments_with_local_vars(cls):
assert model.hparams["arg2"] == 2
# @pytest.mark.parametrize("cls,config", [
# (SaveHparamsModel, Namespace(my_arg=42)),
# (SaveHparamsModel, dict(my_arg=42)),
# (SaveHparamsModel, OmegaConf.create(dict(my_arg=42))),
# (AssignHparamsModel, Namespace(my_arg=42)),
# (AssignHparamsModel, dict(my_arg=42)),
# (AssignHparamsModel, OmegaConf.create(dict(my_arg=42))),
# ])
# def test_single_config_models(tmpdir, cls, config):
# """ Test that the model automatically saves the arguments passed into the constructor """
# model = cls(config)
#
# # no matter how you do it, it should be assigned
# assert model.hparams.my_arg == 42
#
# # verify that the checkpoint saved the correct values
# trainer = Trainer(default_root_dir=tmpdir, max_epochs=2, overfit_batches=0.5)
# trainer.fit(model)
#
# # verify that model loads correctly
# raw_checkpoint_path = _raw_checkpoint_path(trainer)
# model = cls.load_from_checkpoint(raw_checkpoint_path)
# assert model.hparams.my_arg == 42
class AnotherArgModel(BoringModel):
def __init__(self, arg1):
super().__init__()
@ -511,8 +498,9 @@ def test_hparams_save_yaml(tmpdir):
save_hparams_to_yaml(path_yaml, AttributeDict(hparams))
_compare_params(load_hparams_from_yaml(path_yaml, use_omegaconf=False), hparams)
save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams))
_compare_params(load_hparams_from_yaml(path_yaml), hparams)
if _OMEGACONF_AVAILABLE:
save_hparams_to_yaml(path_yaml, OmegaConf.create(hparams))
_compare_params(load_hparams_from_yaml(path_yaml), hparams)
class NoArgsSubClassBoringModel(CustomBoringModel):

View File

@ -26,7 +26,6 @@ from unittest.mock import ANY, call, patch
import cloudpickle
import pytest
import torch
from omegaconf import OmegaConf
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.optim import SGD
from torch.utils.data import DataLoader, IterableDataset
@ -51,6 +50,7 @@ from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _AcceleratorType, _StrategyType
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.exceptions import DeadlockDetectedException, MisconfigurationException
from pytorch_lightning.utilities.imports import _OMEGACONF_AVAILABLE
from pytorch_lightning.utilities.seed import seed_everything
from tests.base import EvalModelTemplate
from tests.helpers import BoringModel, RandomDataset
@ -59,6 +59,9 @@ from tests.helpers.datamodules import ClassifDataModule
from tests.helpers.runif import RunIf
from tests.helpers.simple_models import ClassificationModel
if _OMEGACONF_AVAILABLE:
from omegaconf import OmegaConf
@pytest.mark.parametrize("url_ckpt", [True, False])
def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt):
@ -1271,12 +1274,12 @@ def test_trainer_subclassing():
TrainerSubclass(abcdefg="unknown_arg")
@pytest.mark.parametrize(
"trainer_params", [OmegaConf.create(dict(max_epochs=1, gpus=1)), OmegaConf.create(dict(max_epochs=1, gpus=[0]))]
)
@RunIf(min_gpus=1)
def test_trainer_omegaconf(trainer_params):
Trainer(**trainer_params)
@RunIf(omegaconf=True)
@pytest.mark.parametrize("trainer_params", [{"max_epochs": 1, "gpus": 1}, {"max_epochs": 1, "gpus": [0]}])
@mock.patch("torch.cuda.device_count", return_value=1)
def test_trainer_omegaconf(_, trainer_params):
config = OmegaConf.create(trainer_params)
Trainer(**config)
def test_trainer_pickle(tmpdir):