lightning/tests/loggers/test_base.py

398 lines
13 KiB
Python
Raw Normal View History

2020-10-13 11:18:07 +00:00
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
from argparse import Namespace
from copy import deepcopy
from typing import Any, Dict, Optional
from unittest.mock import MagicMock, patch
Added accumulation of loggers' metrics for the same steps (#1278) * `add_argparse_args` method fixed (argument types added) * autopep8 fixes * --gpus=0 removed from test (for ci tests) * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Joe Davison <joe@huggingface.co> * test_with_accumulate_grad_batches added * agg_and_log_metrics logic added to the base logger class * small format fix * agg metrics strategies removed (not to complicate stuff) * agg metrics: handle zero step * autopep8 * changelog upd * flake fix * metrics aggregators factored out, metrics_agg.py added + tests * metrics agg default value added * Update pytorch_lightning/loggers/metrics_agg.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * metrics aggregators factored out, metrics_agg.py added + tests * metrics agg default value added * Update pytorch_lightning/loggers/metrics_agg.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * remove .item which causes sync issues (#1254) * remove .item which causes sync issues * fixed gradient acc sched * fixed gradient acc sched * test_metrics_agg.py removed (all tested in doctrings), agg metrics refactored * test_metrics_agg.py removed (all tested in doctrings), agg metrics refactored * autopep8 * loggers base.py types fixed * test * test * metrics aggregation for loggers: each key now has a specific function (or default one) * metrics aggregation for loggers: each key now has a specific function (or default one) * docstrings upd * manual typehints removed from docstrings * batch_size decreased for test `test_with_accumulate_grad_batches` * extend running accum * refactor * fix tests * fix tests * allowed_types generator scoped * trainer.py distutils was imported twice, fixed * TensorRunningAccum refactored * TensorRunningAccum added to change log (Changed) * change log pull link added Co-authored-by: Joe Davison <joe@huggingface.co> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: J. Borovec <jirka.borovec@seznam.cz>
2020-04-08 12:35:47 +00:00
import numpy as np
import pytest
import torch
Added accumulation of loggers' metrics for the same steps (#1278) * `add_argparse_args` method fixed (argument types added) * autopep8 fixes * --gpus=0 removed from test (for ci tests) * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Joe Davison <joe@huggingface.co> * test_with_accumulate_grad_batches added * agg_and_log_metrics logic added to the base logger class * small format fix * agg metrics strategies removed (not to complicate stuff) * agg metrics: handle zero step * autopep8 * changelog upd * flake fix * metrics aggregators factored out, metrics_agg.py added + tests * metrics agg default value added * Update pytorch_lightning/loggers/metrics_agg.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * metrics aggregators factored out, metrics_agg.py added + tests * metrics agg default value added * Update pytorch_lightning/loggers/metrics_agg.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * remove .item which causes sync issues (#1254) * remove .item which causes sync issues * fixed gradient acc sched * fixed gradient acc sched * test_metrics_agg.py removed (all tested in doctrings), agg metrics refactored * test_metrics_agg.py removed (all tested in doctrings), agg metrics refactored * autopep8 * loggers base.py types fixed * test * test * metrics aggregation for loggers: each key now has a specific function (or default one) * metrics aggregation for loggers: each key now has a specific function (or default one) * docstrings upd * manual typehints removed from docstrings * batch_size decreased for test `test_with_accumulate_grad_batches` * extend running accum * refactor * fix tests * fix tests * allowed_types generator scoped * trainer.py distutils was imported twice, fixed * TensorRunningAccum refactored * TensorRunningAccum added to change log (Changed) * change log pull link added Co-authored-by: Joe Davison <joe@huggingface.co> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: J. Borovec <jirka.borovec@seznam.cz>
2020-04-08 12:35:47 +00:00
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import LightningLoggerBase, LoggerCollection, TensorBoardLogger
from pytorch_lightning.loggers.base import DummyExperiment, DummyLogger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.logger import _convert_params, _sanitize_params
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from tests.helpers.boring_model import BoringDataModule, BoringModel
def test_logger_collection():
mock1 = MagicMock()
mock2 = MagicMock()
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
logger = LoggerCollection([mock1, mock2])
assert logger[0] == mock1
assert logger[1] == mock2
assert logger.experiment[0] == mock1.experiment
assert logger.experiment[1] == mock2.experiment
assert logger.save_dir is None
logger.update_agg_funcs({"test": np.mean}, np.sum)
mock1.update_agg_funcs.assert_called_once_with({"test": np.mean}, np.sum)
mock2.update_agg_funcs.assert_called_once_with({"test": np.mean}, np.sum)
logger.log_metrics(metrics={"test": 2.0}, step=4)
mock1.log_metrics.assert_called_once_with(metrics={"test": 2.0}, step=4)
mock2.log_metrics.assert_called_once_with(metrics={"test": 2.0}, step=4)
Deprecate `LightningLoggerBase.close` (#9422) * deprecate loggerbase.close * deprecate warning * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add to changelog * fix import * fix import alphabetize * spacing? * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * copy-paste avoid pre-commit.ci? * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * literally match the other comment * unindent * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * suggest finalize instead of save * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update tests/loggers/test_base.py * format but to be formatted * Update pytorch_lightning/loggers/base.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update pytorch_lightning/loggers/base.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update pytorch_lightning/loggers/base.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: thomas chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2021-09-20 22:00:09 +00:00
logger.finalize("success")
mock1.finalize.assert_called_once()
mock2.finalize.assert_called_once()
def test_logger_collection_unique_names():
unique_name = "name1"
logger1 = CustomLogger(name=unique_name)
logger2 = CustomLogger(name=unique_name)
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
logger = LoggerCollection([logger1, logger2])
assert logger.name == unique_name
def test_logger_collection_names_order():
loggers = [CustomLogger(name=n) for n in ("name1", "name2", "name1", "name3")]
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
logger = LoggerCollection(loggers)
assert logger.name == f"{loggers[0].name}_{loggers[1].name}_{loggers[3].name}"
def test_logger_collection_unique_versions():
unique_version = "1"
logger1 = CustomLogger(version=unique_version)
logger2 = CustomLogger(version=unique_version)
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
logger = LoggerCollection([logger1, logger2])
assert logger.version == unique_version
def test_logger_collection_versions_order():
loggers = [CustomLogger(version=v) for v in ("1", "2", "1", "3")]
with pytest.deprecated_call(match="`LoggerCollection` is deprecated in v1.6"):
logger = LoggerCollection(loggers)
assert logger.version == f"{loggers[0].version}_{loggers[1].version}_{loggers[3].version}"
class CustomLogger(LightningLoggerBase):
def __init__(self, experiment: str = "test", name: str = "name", version: str = "1"):
super().__init__()
self._experiment = experiment
self._name = name
self._version = version
self.hparams_logged = None
self.metrics_logged = {}
self.finalized = False
self.after_save_checkpoint_called = False
@property
def experiment(self):
return self._experiment
@rank_zero_only
def log_hyperparams(self, params):
self.hparams_logged = params
@rank_zero_only
def log_metrics(self, metrics, step):
self.metrics_logged = metrics
@rank_zero_only
def finalize(self, status):
self.finalized_status = status
@property
def save_dir(self) -> Optional[str]:
"""Return the root directory where experiment logs get saved, or `None` if the logger does not save data
locally."""
return None
@property
def name(self):
return self._name
@property
def version(self):
return self._version
def after_save_checkpoint(self, checkpoint_callback):
self.after_save_checkpoint_called = True
def test_custom_logger(tmpdir):
class CustomModel(BoringModel):
def training_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
self.log("train_loss", loss)
return {"loss": loss}
logger = CustomLogger()
model = CustomModel()
trainer = Trainer(max_steps=2, log_every_n_steps=1, logger=logger, default_root_dir=tmpdir)
trainer.fit(model)
assert trainer.state.finished, f"Training failed with {trainer.state}"
assert logger.metrics_logged != {}
assert logger.after_save_checkpoint_called
assert logger.finalized_status == "success"
def test_multiple_loggers(tmpdir):
class CustomModel(BoringModel):
def training_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
self.log("train_loss", loss)
return {"loss": loss}
model = CustomModel()
logger1 = CustomLogger()
logger2 = CustomLogger()
trainer = Trainer(max_steps=2, log_every_n_steps=1, logger=[logger1, logger2], default_root_dir=tmpdir)
trainer.fit(model)
assert trainer.state.finished, f"Training failed with {trainer.state}"
assert logger1.hparams_logged is None
assert logger1.metrics_logged != {}
assert logger1.finalized_status == "success"
assert logger2.hparams_logged is None
assert logger2.metrics_logged != {}
assert logger2.finalized_status == "success"
def test_multiple_loggers_pickle(tmpdir):
"""Verify that pickling trainer with multiple loggers works."""
logger1 = CustomLogger()
logger2 = CustomLogger()
trainer = Trainer(logger=[logger1, logger2])
pkl_bytes = pickle.dumps(trainer)
trainer2 = pickle.loads(pkl_bytes)
for logger in trainer2.loggers:
logger.log_metrics({"acc": 1.0}, 0)
for logger in trainer2.loggers:
assert logger.metrics_logged == {"acc": 1.0}
def test_adding_step_key(tmpdir):
class CustomTensorBoardLogger(TensorBoardLogger):
2021-04-27 20:23:55 +00:00
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.logged_step = 0
def log_metrics(self, metrics, step):
if "val_acc" in metrics:
assert step == self.logged_step
super().log_metrics(metrics, step)
class CustomModel(BoringModel):
def training_epoch_end(self, outputs):
self.logger.logged_step += 1
self.log_dict({"step": self.logger.logged_step, "train_acc": self.logger.logged_step / 10})
def validation_epoch_end(self, outputs):
self.logger.logged_step += 1
self.log_dict({"step": self.logger.logged_step, "val_acc": self.logger.logged_step / 10})
model = CustomModel()
trainer = Trainer(
Replaces ddp .spawn with subprocess (#2029) * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * replace ddp spawn with subprocess * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix * hot fix
2020-06-01 15:00:32 +00:00
max_epochs=3,
logger=CustomTensorBoardLogger(save_dir=tmpdir),
default_root_dir=tmpdir,
limit_train_batches=0.1,
[WIP] Rename overfit_pct to overfit_batches (and fix) and val_percent_check and test_percent_check (and fix) (#2213) * fixed percent check for val/test * fixed percent check for val/test * fixed percent check for val/test * fixed percent check for val/test * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * add on fit_start on fit_end hooks * add on fit_start on fit_end hooks * add on fit_start on fit_end hooks Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-06-17 12:03:28 +00:00
limit_val_batches=0.1,
num_sanity_val_steps=0,
)
trainer.fit(model)
Added accumulation of loggers' metrics for the same steps (#1278) * `add_argparse_args` method fixed (argument types added) * autopep8 fixes * --gpus=0 removed from test (for ci tests) * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Joe Davison <joe@huggingface.co> * test_with_accumulate_grad_batches added * agg_and_log_metrics logic added to the base logger class * small format fix * agg metrics strategies removed (not to complicate stuff) * agg metrics: handle zero step * autopep8 * changelog upd * flake fix * metrics aggregators factored out, metrics_agg.py added + tests * metrics agg default value added * Update pytorch_lightning/loggers/metrics_agg.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * metrics aggregators factored out, metrics_agg.py added + tests * metrics agg default value added * Update pytorch_lightning/loggers/metrics_agg.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * remove .item which causes sync issues (#1254) * remove .item which causes sync issues * fixed gradient acc sched * fixed gradient acc sched * test_metrics_agg.py removed (all tested in doctrings), agg metrics refactored * test_metrics_agg.py removed (all tested in doctrings), agg metrics refactored * autopep8 * loggers base.py types fixed * test * test * metrics aggregation for loggers: each key now has a specific function (or default one) * metrics aggregation for loggers: each key now has a specific function (or default one) * docstrings upd * manual typehints removed from docstrings * batch_size decreased for test `test_with_accumulate_grad_batches` * extend running accum * refactor * fix tests * fix tests * allowed_types generator scoped * trainer.py distutils was imported twice, fixed * TensorRunningAccum refactored * TensorRunningAccum added to change log (Changed) * change log pull link added Co-authored-by: Joe Davison <joe@huggingface.co> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu> Co-authored-by: J. Borovec <jirka.borovec@seznam.cz>
2020-04-08 12:35:47 +00:00
def test_dummyexperiment_support_indexing():
"""Test that the DummyExperiment can imitate indexing the experiment in a LoggerCollection."""
experiment = DummyExperiment()
assert experiment[0] == experiment
def test_dummylogger_support_indexing():
"""Test that the DummyLogger can imitate indexing of a LoggerCollection."""
logger = DummyLogger()
assert logger[0] == logger
def test_dummylogger_empty_iterable():
"""Test that DummyLogger represents an empty iterable."""
logger = DummyLogger()
for _ in logger:
assert False
def test_dummylogger_noop_method_calls():
"""Test that the DummyLogger methods can be called with arbitrary arguments."""
logger = DummyLogger()
logger.log_hyperparams("1", 2, three="three")
logger.log_metrics("1", 2, three="three")
def test_dummyexperiment_support_item_assignment():
"""Test that the DummyExperiment supports item assignment."""
experiment = DummyExperiment()
experiment["variable"] = "value"
assert experiment["variable"] != "value" # this is only a stateless mock experiment
def test_np_sanitization():
class CustomParamsLogger(CustomLogger):
def __init__(self):
super().__init__()
self.logged_params = None
@rank_zero_only
def log_hyperparams(self, params):
params = _convert_params(params)
params = _sanitize_params(params)
self.logged_params = params
logger = CustomParamsLogger()
np_params = {
"np.bool_": np.bool_(1),
"np.byte": np.byte(2),
"np.intc": np.intc(3),
"np.int_": np.int_(4),
"np.longlong": np.longlong(5),
"np.single": np.single(6.0),
"np.double": np.double(8.9),
"np.csingle": np.csingle(7 + 2j),
"np.cdouble": np.cdouble(9 + 4j),
}
sanitized_params = {
"np.bool_": True,
"np.byte": 2,
"np.intc": 3,
"np.int_": 4,
"np.longlong": 5,
"np.single": 6.0,
"np.double": 8.9,
"np.csingle": "(7+2j)",
"np.cdouble": "(9+4j)",
}
logger.log_hyperparams(Namespace(**np_params))
assert logger.logged_params == sanitized_params
@pytest.mark.parametrize("logger", [True, False])
@patch("pytorch_lightning.loggers.tensorboard.TensorBoardLogger.log_hyperparams")
def test_log_hyperparams_being_called(log_hyperparams_mock, tmpdir, logger):
class TestModel(BoringModel):
def __init__(self, param_one, param_two):
super().__init__()
self.save_hyperparameters(logger=logger)
model = TestModel("pytorch", "lightning")
trainer = Trainer(
default_root_dir=tmpdir, max_epochs=1, limit_train_batches=0.1, limit_val_batches=0.1, num_sanity_val_steps=0
)
trainer.fit(model)
if logger:
log_hyperparams_mock.assert_called()
else:
log_hyperparams_mock.assert_not_called()
@patch("pytorch_lightning.loggers.tensorboard.TensorBoardLogger.log_hyperparams")
def test_log_hyperparams_key_collision(log_hyperparams_mock, tmpdir):
class TestModel(BoringModel):
def __init__(self, hparams: Dict[str, Any]) -> None:
super().__init__()
self.save_hyperparameters(hparams)
class TestDataModule(BoringDataModule):
def __init__(self, hparams: Dict[str, Any]) -> None:
super().__init__()
self.save_hyperparameters(hparams)
class _Test:
...
same_params = {1: 1, "2": 2, "three": 3.0, "test": _Test(), "4": torch.tensor(4)}
model = TestModel(same_params)
dm = TestDataModule(same_params)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=0.1,
limit_val_batches=0.1,
num_sanity_val_steps=0,
enable_checkpointing=False,
enable_progress_bar=False,
enable_model_summary=False,
)
# there should be no exceptions raised for the same key/value pair in the hparams of both
# the lightning module and data module
trainer.fit(model)
obj_params = deepcopy(same_params)
obj_params["test"] = _Test()
model = TestModel(same_params)
dm = TestDataModule(obj_params)
trainer.fit(model)
diff_params = deepcopy(same_params)
diff_params.update({1: 0, "test": _Test()})
model = TestModel(same_params)
dm = TestDataModule(diff_params)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=0.1,
limit_val_batches=0.1,
num_sanity_val_steps=0,
enable_checkpointing=False,
enable_progress_bar=False,
enable_model_summary=False,
)
with pytest.raises(MisconfigurationException, match="Error while merging hparams"):
trainer.fit(model, dm)
tensor_params = deepcopy(same_params)
tensor_params.update({"4": torch.tensor(3)})
model = TestModel(same_params)
dm = TestDataModule(tensor_params)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=0.1,
limit_val_batches=0.1,
num_sanity_val_steps=0,
enable_checkpointing=False,
enable_progress_bar=False,
enable_model_summary=False,
)
with pytest.raises(MisconfigurationException, match="Error while merging hparams"):
trainer.fit(model, dm)