diff --git a/.gitignore b/.gitignore index eb56709276..47b9bfff92 100644 --- a/.gitignore +++ b/.gitignore @@ -136,7 +136,7 @@ ENV/ Datasets/ mnist/ MNIST/ -legacy/checkpoints/ +tests/legacy/checkpoints/ *.gz *ubyte diff --git a/docs/source-pytorch/common/checkpointing_expert.rst b/docs/source-pytorch/common/checkpointing_expert.rst index c1859d60ec..c4a948a34c 100644 --- a/docs/source-pytorch/common/checkpointing_expert.rst +++ b/docs/source-pytorch/common/checkpointing_expert.rst @@ -6,7 +6,12 @@ Checkpointing (expert) ###################### -TODO: I don't understand this... +********************************* +Writing your own Checkpoint class +********************************* + +We provide ``Checkpoint`` class, for easier subclassing. Users may want to subclass this class in case of writing custom ``ModelCheckpoint`` callback, so that the ``Trainer`` recognizes the custom class as a checkpointing callback. + *********************** Customize Checkpointing @@ -23,6 +28,8 @@ and :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_load_checkpoint` met what's saved in the checkpoint. +TODO: I don't understand this... + ****************************** Built-in Checkpoint IO Plugins ****************************** diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 28695785c3..38da5a36a4 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -73,6 +73,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added breaking of lazy graph across training, validation, test and predict steps when training with habana accelerators to ensure better performance ([#12938](https://github.com/PyTorchLightning/pytorch-lightning/pull/12938)) +- Added `Checkpoint` class to inherit from ([#13024](https://github.com/PyTorchLightning/pytorch-lightning/pull/13024)) + + - Added CPU metric tracking to `DeviceStatsMonitor` ([#11795](https://github.com/PyTorchLightning/pytorch-lightning/pull/11795)) diff --git a/src/pytorch_lightning/callbacks/__init__.py b/src/pytorch_lightning/callbacks/__init__.py index 6e37b84ce2..b3d2035f33 100644 --- a/src/pytorch_lightning/callbacks/__init__.py +++ b/src/pytorch_lightning/callbacks/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from pytorch_lightning.callbacks.callback import Callback +from pytorch_lightning.callbacks.checkpoint import Checkpoint from pytorch_lightning.callbacks.device_stats_monitor import DeviceStatsMonitor from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks.finetuning import BackboneFinetuning, BaseFinetuning @@ -32,6 +33,7 @@ __all__ = [ "BackboneFinetuning", "BaseFinetuning", "Callback", + "Checkpoint", "DeviceStatsMonitor", "EarlyStopping", "GradientAccumulationScheduler", diff --git a/src/pytorch_lightning/callbacks/checkpoint.py b/src/pytorch_lightning/callbacks/checkpoint.py new file mode 100644 index 0000000000..405f29876c --- /dev/null +++ b/src/pytorch_lightning/callbacks/checkpoint.py @@ -0,0 +1,9 @@ +from pytorch_lightning.callbacks.callback import Callback + + +class Checkpoint(Callback): + r""" + This is the base class for model checkpointing. Expert users may want to subclass it in case of writing + custom :class:`~pytorch_lightning.callbacksCheckpoint` callback, so that + the trainer recognizes the custom class as a checkpointing callback. + """ diff --git a/src/pytorch_lightning/callbacks/fault_tolerance.py b/src/pytorch_lightning/callbacks/fault_tolerance.py index 59b8d31f46..9d04fc86b6 100644 --- a/src/pytorch_lightning/callbacks/fault_tolerance.py +++ b/src/pytorch_lightning/callbacks/fault_tolerance.py @@ -21,11 +21,11 @@ import os from typing import Any import pytorch_lightning as pl -from pytorch_lightning import Callback +from pytorch_lightning.callbacks import Checkpoint from pytorch_lightning.utilities.types import _PATH -class _FaultToleranceCheckpoint(Callback): +class _FaultToleranceCheckpoint(Checkpoint): """Used to save a fault-tolerance checkpoint on exception.""" FILE_EXTENSION = ".ckpt" diff --git a/src/pytorch_lightning/callbacks/model_checkpoint.py b/src/pytorch_lightning/callbacks/model_checkpoint.py index 8522bb49b7..bb6d0a9a9b 100644 --- a/src/pytorch_lightning/callbacks/model_checkpoint.py +++ b/src/pytorch_lightning/callbacks/model_checkpoint.py @@ -34,7 +34,7 @@ import yaml from torch import Tensor import pytorch_lightning as pl -from pytorch_lightning.callbacks.callback import Callback +from pytorch_lightning.callbacks import Checkpoint from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.logger import _name, _version @@ -46,7 +46,7 @@ log = logging.getLogger(__name__) warning_cache = WarningCache() -class ModelCheckpoint(Callback): +class ModelCheckpoint(Checkpoint): r""" Save the model periodically by monitoring a quantity. Every metric logged with :meth:`~pytorch_lightning.core.module.log` or :meth:`~pytorch_lightning.core.module.log_dict` in diff --git a/src/pytorch_lightning/loggers/logger.py b/src/pytorch_lightning/loggers/logger.py index c1eecb93fc..d532aae413 100644 --- a/src/pytorch_lightning/loggers/logger.py +++ b/src/pytorch_lightning/loggers/logger.py @@ -25,7 +25,7 @@ from weakref import ReferenceType import numpy as np import pytorch_lightning as pl -from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.callbacks import Checkpoint from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_only @@ -86,7 +86,7 @@ class Logger(ABC): else: self._agg_default_func = np.mean - def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None: + def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[Checkpoint]") -> None: """Called after model checkpoint callback saves a new checkpoint. Args: @@ -221,7 +221,7 @@ class LoggerCollection(Logger): def __getitem__(self, index: int) -> Logger: return list(self._logger_iterable)[index] - def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None: + def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[Checkpoint]") -> None: for logger in self._logger_iterable: logger.after_save_checkpoint(checkpoint_callback) diff --git a/src/pytorch_lightning/loggers/neptune.py b/src/pytorch_lightning/loggers/neptune.py index 4d2f6897a2..44ae3f0f5b 100644 --- a/src/pytorch_lightning/loggers/neptune.py +++ b/src/pytorch_lightning/loggers/neptune.py @@ -31,7 +31,7 @@ import torch from torch import Tensor from pytorch_lightning import __version__ -from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.callbacks import Checkpoint from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment from pytorch_lightning.utilities.imports import _NEPTUNE_AVAILABLE, _NEPTUNE_GREATER_EQUAL_0_9 from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params @@ -534,7 +534,7 @@ class NeptuneLogger(Logger): ) @rank_zero_only - def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None: + def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[Checkpoint]") -> None: """Automatically log checkpointed model. Called after model checkpoint callback saves a new checkpoint. Args: @@ -547,19 +547,20 @@ class NeptuneLogger(Logger): checkpoints_namespace = self._construct_path_with_prefix("model/checkpoints") # save last model - if checkpoint_callback.last_model_path: + if hasattr(checkpoint_callback, "last_model_path") and checkpoint_callback.last_model_path: model_last_name = self._get_full_model_name(checkpoint_callback.last_model_path, checkpoint_callback) file_names.add(model_last_name) self.run[f"{checkpoints_namespace}/{model_last_name}"].upload(checkpoint_callback.last_model_path) # save best k models - for key in checkpoint_callback.best_k_models.keys(): - model_name = self._get_full_model_name(key, checkpoint_callback) - file_names.add(model_name) - self.run[f"{checkpoints_namespace}/{model_name}"].upload(key) + if hasattr(checkpoint_callback, "best_k_models"): + for key in checkpoint_callback.best_k_models.keys(): + model_name = self._get_full_model_name(key, checkpoint_callback) + file_names.add(model_name) + self.run[f"{checkpoints_namespace}/{model_name}"].upload(key) # log best model path and checkpoint - if checkpoint_callback.best_model_path: + if hasattr(checkpoint_callback, "best_model_path") and checkpoint_callback.best_model_path: self.run[self._construct_path_with_prefix("model/best_model_path")] = checkpoint_callback.best_model_path model_name = self._get_full_model_name(checkpoint_callback.best_model_path, checkpoint_callback) @@ -575,19 +576,22 @@ class NeptuneLogger(Logger): del self.run[f"{checkpoints_namespace}/{file_to_drop}"] # log best model score - if checkpoint_callback.best_model_score: + if hasattr(checkpoint_callback, "best_model_score") and checkpoint_callback.best_model_score: self.run[self._construct_path_with_prefix("model/best_model_score")] = ( checkpoint_callback.best_model_score.cpu().detach().numpy() ) @staticmethod - def _get_full_model_name(model_path: str, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> str: + def _get_full_model_name(model_path: str, checkpoint_callback: "ReferenceType[Checkpoint]") -> str: """Returns model name which is string `model_path` appended to `checkpoint_callback.dirpath`.""" - expected_model_path = f"{checkpoint_callback.dirpath}{os.path.sep}" - if not model_path.startswith(expected_model_path): - raise ValueError(f"{model_path} was expected to start with {expected_model_path}.") - # Remove extension from filepath - filepath, _ = os.path.splitext(model_path[len(expected_model_path) :]) + if hasattr(checkpoint_callback, "dirpath"): + expected_model_path = f"{checkpoint_callback.dirpath}{os.path.sep}" + if not model_path.startswith(expected_model_path): + raise ValueError(f"{model_path} was expected to start with {expected_model_path}.") + # Remove extension from filepath + filepath, _ = os.path.splitext(model_path[len(expected_model_path) :]) + else: + filepath = model_path return filepath diff --git a/src/pytorch_lightning/loggers/wandb.py b/src/pytorch_lightning/loggers/wandb.py index 53103dfdfd..88439cd943 100644 --- a/src/pytorch_lightning/loggers/wandb.py +++ b/src/pytorch_lightning/loggers/wandb.py @@ -23,7 +23,7 @@ from weakref import ReferenceType import torch.nn as nn -from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.callbacks import Checkpoint from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _WANDB_GREATER_EQUAL_0_10_22, _WANDB_GREATER_EQUAL_0_12_10 @@ -461,9 +461,14 @@ class WandbLogger(Logger): # don't create an experiment if we don't have one return self._experiment.id if self._experiment else self._id - def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None: + def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[Checkpoint]") -> None: # log checkpoints as artifacts - if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1: + if ( + self._log_model == "all" + or self._log_model is True + and hasattr(checkpoint_callback, "save_top_k") + and checkpoint_callback.save_top_k == -1 + ): self._scan_and_log_checkpoints(checkpoint_callback) elif self._log_model is True: self._checkpoint_callback = checkpoint_callback @@ -474,25 +479,33 @@ class WandbLogger(Logger): if self._checkpoint_callback: self._scan_and_log_checkpoints(self._checkpoint_callback) - def _scan_and_log_checkpoints(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None: + def _scan_and_log_checkpoints(self, checkpoint_callback: "ReferenceType[Checkpoint]") -> None: # get checkpoints to be saved with associated score - checkpoints = { - checkpoint_callback.last_model_path: checkpoint_callback.current_score, - checkpoint_callback.best_model_path: checkpoint_callback.best_model_score, - **checkpoint_callback.best_k_models, - } - checkpoints = sorted((Path(p).stat().st_mtime, p, s) for p, s in checkpoints.items() if Path(p).is_file()) + checkpoints = dict() + if hasattr(checkpoint_callback, "last_model_path") and hasattr(checkpoint_callback, "current_score"): + checkpoints[checkpoint_callback.last_model_path] = (checkpoint_callback.current_score, "latest") + + if hasattr(checkpoint_callback, "best_model_path") and hasattr(checkpoint_callback, "best_model_score"): + checkpoints[checkpoint_callback.best_model_path] = (checkpoint_callback.best_model_score, "best") + + if hasattr(checkpoint_callback, "best_k_models"): + for key, value in checkpoint_callback.best_k_models.items(): + checkpoints[key] = (value, "best_k") + + checkpoints = sorted( + (Path(p).stat().st_mtime, p, s, tag) for p, (s, tag) in checkpoints.items() if Path(p).is_file() + ) checkpoints = [ c for c in checkpoints if c[1] not in self._logged_model_time.keys() or self._logged_model_time[c[1]] < c[0] ] # log iteratively all new checkpoints - for t, p, s in checkpoints: + for t, p, s, tag in checkpoints: metadata = ( { "score": s, "original_filename": Path(p).name, - "ModelCheckpoint": { + checkpoint_callback.__class__.__name__: { k: getattr(checkpoint_callback, k) for k in [ "monitor", @@ -511,7 +524,6 @@ class WandbLogger(Logger): ) artifact = wandb.Artifact(name=f"model-{self.experiment.id}", type="model", metadata=metadata) artifact.add_file(p, name="model.ckpt") - aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"] - self.experiment.log_artifact(artifact, aliases=aliases) + self.experiment.log_artifact(artifact, aliases=[tag]) # remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name) self._logged_model_time[p] = t diff --git a/src/pytorch_lightning/strategies/launchers/spawn.py b/src/pytorch_lightning/strategies/launchers/spawn.py index 6af2688e47..d94909b778 100644 --- a/src/pytorch_lightning/strategies/launchers/spawn.py +++ b/src/pytorch_lightning/strategies/launchers/spawn.py @@ -109,7 +109,7 @@ class _SpawnLauncher(_Launcher): def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer: "pl.Trainer") -> None: # transfer back the best path to the trainer - if trainer.checkpoint_callback: + if trainer.checkpoint_callback and hasattr(trainer.checkpoint_callback, "best_model_path"): trainer.checkpoint_callback.best_model_path = str(spawn_output.best_model_path) # TODO: pass also best score @@ -131,7 +131,11 @@ class _SpawnLauncher(_Launcher): def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: rank_zero_debug("Finalizing the DDP spawn environment.") checkpoint_callback = trainer.checkpoint_callback - best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None + best_model_path = ( + checkpoint_callback.best_model_path + if checkpoint_callback and hasattr(checkpoint_callback, "best_model_path") + else None + ) # requires to compute the state_dict on all processes in case Metrics are present state_dict = trainer.lightning_module.state_dict() diff --git a/src/pytorch_lightning/strategies/launchers/xla_spawn.py b/src/pytorch_lightning/strategies/launchers/xla_spawn.py index b3e1bf3465..13c948577c 100644 --- a/src/pytorch_lightning/strategies/launchers/xla_spawn.py +++ b/src/pytorch_lightning/strategies/launchers/xla_spawn.py @@ -115,7 +115,11 @@ class _XLASpawnLauncher(_SpawnLauncher): def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]: rank_zero_debug("Finalizing the TPU spawn environment.") checkpoint_callback = trainer.checkpoint_callback - best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None + best_model_path = ( + checkpoint_callback.best_model_path + if checkpoint_callback and hasattr(checkpoint_callback, "best_model_path") + else None + ) # requires to compute the state_dict on all processes in case Metrics are present state_dict = trainer.lightning_module.state_dict() diff --git a/src/pytorch_lightning/trainer/connectors/callback_connector.py b/src/pytorch_lightning/trainer/connectors/callback_connector.py index eddc2e2a84..83881905be 100644 --- a/src/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/src/pytorch_lightning/trainer/connectors/callback_connector.py @@ -19,6 +19,7 @@ from typing import Dict, List, Optional, Sequence, Union from pytorch_lightning.callbacks import ( Callback, + Checkpoint, GradientAccumulationScheduler, ModelCheckpoint, ModelSummary, @@ -232,18 +233,18 @@ class CallbackConnector: @staticmethod def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]: - """Moves all ModelCheckpoint callbacks to the end of the list. The sequential order within the group of + """Moves all Checkpoint callbacks to the end of the list. The sequential order within the group of checkpoint callbacks is preserved, as well as the order of all other callbacks. Args: callbacks: A list of callbacks. Return: - A new list in which the last elements are ModelCheckpoints if there were any present in the + A new list in which the last elements are Checkpoint if there were any present in the input. """ - checkpoints = [c for c in callbacks if isinstance(c, ModelCheckpoint)] - not_checkpoints = [c for c in callbacks if not isinstance(c, ModelCheckpoint)] + checkpoints = [c for c in callbacks if isinstance(c, Checkpoint)] + not_checkpoints = [c for c in callbacks if not isinstance(c, Checkpoint)] return not_checkpoints + checkpoints diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index e823ff7e08..7201ef5350 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -44,7 +44,7 @@ from pytorch_lightning.accelerators import ( MPSAccelerator, TPUAccelerator, ) -from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase +from pytorch_lightning.callbacks import Callback, Checkpoint, EarlyStopping, ProgressBarBase from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter from pytorch_lightning.core.datamodule import LightningDataModule from pytorch_lightning.core.optimizer import LightningOptimizer @@ -1406,7 +1406,7 @@ class Trainer( f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured.' ) - if not self.checkpoint_callback.best_model_path: + if hasattr(self.checkpoint_callback, "best_model_path") and not self.checkpoint_callback.best_model_path: if self.fast_dev_run: raise MisconfigurationException( f'You cannot execute `.{fn}(ckpt_path="best")` with `fast_dev_run=True`.' @@ -1416,11 +1416,11 @@ class Trainer( f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.' ) # load best weights - ckpt_path = self.checkpoint_callback.best_model_path + ckpt_path = getattr(self.checkpoint_callback, "best_model_path", None) if ckpt_path == "last": - candidates = [ft.ckpt_path for ft in ft_checkpoints] + [ - cb.last_model_path for cb in self.checkpoint_callbacks + candidates = [getattr(ft, "ckpt_path", None) for ft in ft_checkpoints] + [ + getattr(cb, "last_model_path", None) for cb in self.checkpoint_callbacks ] candidates_fs = {path: get_filesystem(path) for path in candidates if path} candidates_ts = {path: fs.modified(path) for path, fs in candidates_fs.items() if fs.exists(path)} @@ -2308,17 +2308,17 @@ class Trainer( return [cb for cb in self.callbacks if isinstance(cb, BasePredictionWriter)] @property - def checkpoint_callback(self) -> Optional[ModelCheckpoint]: + def checkpoint_callback(self) -> Optional[Checkpoint]: """The first :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callback in the Trainer.callbacks list, or ``None`` if it doesn't exist.""" callbacks = self.checkpoint_callbacks return callbacks[0] if len(callbacks) > 0 else None @property - def checkpoint_callbacks(self) -> List[ModelCheckpoint]: + def checkpoint_callbacks(self) -> List[Checkpoint]: """A list of all instances of :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` found in the Trainer.callbacks list.""" - return [c for c in self.callbacks if isinstance(c, ModelCheckpoint)] + return [c for c in self.callbacks if isinstance(c, Checkpoint)] @property def progress_bar_callback(self) -> Optional[ProgressBarBase]: