From 812ffdec84189bdee560d4b629df01dd2a6bfb53 Mon Sep 17 00:00:00 2001 From: Mario Vasilev <66969704+mariovas3@users.noreply.github.com> Date: Thu, 6 Jun 2024 01:24:45 +0100 Subject: [PATCH] Fix `save_last` type annotation for ModelCheckpoint (#19808) --- src/lightning/pytorch/CHANGELOG.md | 3 ++- .../pytorch/callbacks/model_checkpoint.py | 4 ++-- .../checkpointing/test_model_checkpoint.py | 23 +++++++++++++++++++ 3 files changed, 27 insertions(+), 3 deletions(-) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index b47c015928..54ce68c696 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -57,7 +57,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `WandbLogger.log_hyperparameters()` raising an error if hyperparameters are not JSON serializable ([#19769](https://github.com/Lightning-AI/pytorch-lightning/pull/19769)) -- + +- Fixed an issue with the LightningCLI not being able to set the `ModelCheckpoint(save_last=...)` argument ([#19808](https://github.com/Lightning-AI/pytorch-lightning/pull/19808)) ## [2.2.2] - 2024-04-11 diff --git a/src/lightning/pytorch/callbacks/model_checkpoint.py b/src/lightning/pytorch/callbacks/model_checkpoint.py index 6c5dd01df1..ba3014274b 100644 --- a/src/lightning/pytorch/callbacks/model_checkpoint.py +++ b/src/lightning/pytorch/callbacks/model_checkpoint.py @@ -27,7 +27,7 @@ import warnings from copy import deepcopy from datetime import timedelta from pathlib import Path -from typing import Any, Dict, Literal, Optional, Set +from typing import Any, Dict, Literal, Optional, Set, Union from weakref import proxy import torch @@ -216,7 +216,7 @@ class ModelCheckpoint(Checkpoint): filename: Optional[str] = None, monitor: Optional[str] = None, verbose: bool = False, - save_last: Optional[Literal[True, False, "link"]] = None, + save_last: Optional[Union[bool, Literal["link"]]] = None, save_top_k: int = 1, save_weights_only: bool = False, mode: str = "min", diff --git a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py index c911885117..006736e086 100644 --- a/tests/tests_pytorch/checkpointing/test_model_checkpoint.py +++ b/tests/tests_pytorch/checkpointing/test_model_checkpoint.py @@ -18,6 +18,7 @@ import re import time from argparse import Namespace from datetime import timedelta +from inspect import signature from pathlib import Path from typing import Union from unittest import mock @@ -28,6 +29,7 @@ import lightning.pytorch as pl import pytest import torch import yaml +from jsonargparse import ArgumentParser from lightning.fabric.utilities.cloud_io import _load as pl_load from lightning.pytorch import Trainer, seed_everything from lightning.pytorch.callbacks import ModelCheckpoint @@ -1601,3 +1603,24 @@ def test_expand_home(): # it is possible to have a folder with the name `~` checkpoint = ModelCheckpoint(dirpath="./~/checkpoints") assert checkpoint.dirpath == str(Path.cwd() / "~" / "checkpoints") + + +@pytest.mark.parametrize( + ("val", "expected"), + [ + ("yes", True), + ("True", True), + ("true", True), + ("no", False), + ("false", False), + ("False", False), + ("link", "link"), + ], +) +def test_save_last_cli(val, expected): + """Test that the CLI can parse the `save_last` argument correctly (composed type).""" + annot = signature(ModelCheckpoint).parameters["save_last"].annotation + parser = ArgumentParser() + parser.add_argument("--a", type=annot) + args = parser.parse_args(["--a", val]) + assert args.a == expected