Ignore parameters causing ValueError when dumping to YAML (#19804)

This commit is contained in:
Björn Barz 2024-06-07 00:36:28 +02:00 committed by GitHub
parent 4f96c83ba0
commit 5fa32d95e3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 12 additions and 2 deletions

View File

@ -60,6 +60,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue with the LightningCLI not being able to set the `ModelCheckpoint(save_last=...)` argument ([#19808](https://github.com/Lightning-AI/pytorch-lightning/pull/19808))
- Fixed an issue causing ValueError for certain object such as TorchMetrics when dumping hyperparameters to YAML ([#19804](https://github.com/Lightning-AI/pytorch-lightning/pull/19804))
## [2.2.2] - 2024-04-11

View File

@ -359,7 +359,7 @@ def save_hparams_to_yaml(config_yaml: _PATH, hparams: Union[dict, Namespace], us
try:
v = v.name if isinstance(v, Enum) else v
yaml.dump(v)
except TypeError:
except (TypeError, ValueError):
warn(f"Skipping '{k}' parameter because it is not possible to safely dump to YAML.")
hparams[k] = type(v).__name__
else:

View File

@ -552,7 +552,7 @@ def test_hparams_pickle_warning(tmp_path):
trainer.fit(model)
def test_hparams_save_yaml(tmp_path):
def test_save_hparams_to_yaml(tmp_path):
class Options(str, Enum):
option1name = "option1val"
option2name = "option2val"
@ -590,6 +590,14 @@ def test_hparams_save_yaml(tmp_path):
_compare_params(load_hparams_from_yaml(path_yaml), hparams)
def test_save_hparams_to_yaml_warning(tmp_path):
"""Test that we warn about unserializable parameters that need to be dropped."""
path_yaml = tmp_path / "hparams.yaml"
hparams = {"torch_type": torch.float32}
with pytest.warns(UserWarning, match="Skipping 'torch_type' parameter"):
save_hparams_to_yaml(path_yaml, hparams)
class NoArgsSubClassBoringModel(CustomBoringModel):
def __init__(self):
super().__init__()