From 5fa32d95e3c28354c360d23dbee85ddd8507e5f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Bj=C3=B6rn=20Barz?= Date: Fri, 7 Jun 2024 00:36:28 +0200 Subject: [PATCH] Ignore parameters causing ValueError when dumping to YAML (#19804) --- src/lightning/pytorch/CHANGELOG.md | 2 ++ src/lightning/pytorch/core/saving.py | 2 +- tests/tests_pytorch/models/test_hparams.py | 10 +++++++++- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 54ce68c696..2b76b36902 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -60,6 +60,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue with the LightningCLI not being able to set the `ModelCheckpoint(save_last=...)` argument ([#19808](https://github.com/Lightning-AI/pytorch-lightning/pull/19808)) +- Fixed an issue causing ValueError for certain object such as TorchMetrics when dumping hyperparameters to YAML ([#19804](https://github.com/Lightning-AI/pytorch-lightning/pull/19804)) + ## [2.2.2] - 2024-04-11 diff --git a/src/lightning/pytorch/core/saving.py b/src/lightning/pytorch/core/saving.py index f8e9c83003..521192f500 100644 --- a/src/lightning/pytorch/core/saving.py +++ b/src/lightning/pytorch/core/saving.py @@ -359,7 +359,7 @@ def save_hparams_to_yaml(config_yaml: _PATH, hparams: Union[dict, Namespace], us try: v = v.name if isinstance(v, Enum) else v yaml.dump(v) - except TypeError: + except (TypeError, ValueError): warn(f"Skipping '{k}' parameter because it is not possible to safely dump to YAML.") hparams[k] = type(v).__name__ else: diff --git a/tests/tests_pytorch/models/test_hparams.py b/tests/tests_pytorch/models/test_hparams.py index 0d7fced3b8..e8a3cf6801 100644 --- a/tests/tests_pytorch/models/test_hparams.py +++ b/tests/tests_pytorch/models/test_hparams.py @@ -552,7 +552,7 @@ def test_hparams_pickle_warning(tmp_path): trainer.fit(model) -def test_hparams_save_yaml(tmp_path): +def test_save_hparams_to_yaml(tmp_path): class Options(str, Enum): option1name = "option1val" option2name = "option2val" @@ -590,6 +590,14 @@ def test_hparams_save_yaml(tmp_path): _compare_params(load_hparams_from_yaml(path_yaml), hparams) +def test_save_hparams_to_yaml_warning(tmp_path): + """Test that we warn about unserializable parameters that need to be dropped.""" + path_yaml = tmp_path / "hparams.yaml" + hparams = {"torch_type": torch.float32} + with pytest.warns(UserWarning, match="Skipping 'torch_type' parameter"): + save_hparams_to_yaml(path_yaml, hparams) + + class NoArgsSubClassBoringModel(CustomBoringModel): def __init__(self): super().__init__()