Fix `save_last` type annotation for ModelCheckpoint (#19808)
This commit is contained in:
parent
7668a6bf59
commit
812ffdec84
|
@ -57,7 +57,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Fixed `WandbLogger.log_hyperparameters()` raising an error if hyperparameters are not JSON serializable ([#19769](https://github.com/Lightning-AI/pytorch-lightning/pull/19769))
|
||||
|
||||
-
|
||||
|
||||
- Fixed an issue with the LightningCLI not being able to set the `ModelCheckpoint(save_last=...)` argument ([#19808](https://github.com/Lightning-AI/pytorch-lightning/pull/19808))
|
||||
|
||||
|
||||
## [2.2.2] - 2024-04-11
|
||||
|
|
|
@ -27,7 +27,7 @@ import warnings
|
|||
from copy import deepcopy
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Literal, Optional, Set
|
||||
from typing import Any, Dict, Literal, Optional, Set, Union
|
||||
from weakref import proxy
|
||||
|
||||
import torch
|
||||
|
@ -216,7 +216,7 @@ class ModelCheckpoint(Checkpoint):
|
|||
filename: Optional[str] = None,
|
||||
monitor: Optional[str] = None,
|
||||
verbose: bool = False,
|
||||
save_last: Optional[Literal[True, False, "link"]] = None,
|
||||
save_last: Optional[Union[bool, Literal["link"]]] = None,
|
||||
save_top_k: int = 1,
|
||||
save_weights_only: bool = False,
|
||||
mode: str = "min",
|
||||
|
|
|
@ -18,6 +18,7 @@ import re
|
|||
import time
|
||||
from argparse import Namespace
|
||||
from datetime import timedelta
|
||||
from inspect import signature
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from unittest import mock
|
||||
|
@ -28,6 +29,7 @@ import lightning.pytorch as pl
|
|||
import pytest
|
||||
import torch
|
||||
import yaml
|
||||
from jsonargparse import ArgumentParser
|
||||
from lightning.fabric.utilities.cloud_io import _load as pl_load
|
||||
from lightning.pytorch import Trainer, seed_everything
|
||||
from lightning.pytorch.callbacks import ModelCheckpoint
|
||||
|
@ -1601,3 +1603,24 @@ def test_expand_home():
|
|||
# it is possible to have a folder with the name `~`
|
||||
checkpoint = ModelCheckpoint(dirpath="./~/checkpoints")
|
||||
assert checkpoint.dirpath == str(Path.cwd() / "~" / "checkpoints")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("val", "expected"),
|
||||
[
|
||||
("yes", True),
|
||||
("True", True),
|
||||
("true", True),
|
||||
("no", False),
|
||||
("false", False),
|
||||
("False", False),
|
||||
("link", "link"),
|
||||
],
|
||||
)
|
||||
def test_save_last_cli(val, expected):
|
||||
"""Test that the CLI can parse the `save_last` argument correctly (composed type)."""
|
||||
annot = signature(ModelCheckpoint).parameters["save_last"].annotation
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--a", type=annot)
|
||||
args = parser.parse_args(["--a", val])
|
||||
assert args.a == expected
|
||||
|
|
Loading…
Reference in New Issue