Add BaseModelCheckpoint class to inherit from (#13024)

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
otaj 2022-06-30 12:07:46 +02:00 committed by GitHub
parent a743d96350
commit 663d4c9c28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 99 additions and 53 deletions

2
.gitignore vendored
View File

@ -136,7 +136,7 @@ ENV/
Datasets/
mnist/
MNIST/
legacy/checkpoints/
tests/legacy/checkpoints/
*.gz
*ubyte

View File

@ -6,7 +6,12 @@
Checkpointing (expert)
######################
TODO: I don't understand this...
*********************************
Writing your own Checkpoint class
*********************************
We provide ``Checkpoint`` class, for easier subclassing. Users may want to subclass this class in case of writing custom ``ModelCheckpoint`` callback, so that the ``Trainer`` recognizes the custom class as a checkpointing callback.
***********************
Customize Checkpointing
@ -23,6 +28,8 @@ and :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_load_checkpoint` met
what's saved in the checkpoint.
TODO: I don't understand this...
******************************
Built-in Checkpoint IO Plugins
******************************

View File

@ -73,6 +73,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added breaking of lazy graph across training, validation, test and predict steps when training with habana accelerators to ensure better performance ([#12938](https://github.com/PyTorchLightning/pytorch-lightning/pull/12938))
- Added `Checkpoint` class to inherit from ([#13024](https://github.com/PyTorchLightning/pytorch-lightning/pull/13024))
- Added CPU metric tracking to `DeviceStatsMonitor` ([#11795](https://github.com/PyTorchLightning/pytorch-lightning/pull/11795))

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.callbacks.checkpoint import Checkpoint
from pytorch_lightning.callbacks.device_stats_monitor import DeviceStatsMonitor
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks.finetuning import BackboneFinetuning, BaseFinetuning
@ -32,6 +33,7 @@ __all__ = [
"BackboneFinetuning",
"BaseFinetuning",
"Callback",
"Checkpoint",
"DeviceStatsMonitor",
"EarlyStopping",
"GradientAccumulationScheduler",

View File

@ -0,0 +1,9 @@
from pytorch_lightning.callbacks.callback import Callback
class Checkpoint(Callback):
r"""
This is the base class for model checkpointing. Expert users may want to subclass it in case of writing
custom :class:`~pytorch_lightning.callbacksCheckpoint` callback, so that
the trainer recognizes the custom class as a checkpointing callback.
"""

View File

@ -21,11 +21,11 @@ import os
from typing import Any
import pytorch_lightning as pl
from pytorch_lightning import Callback
from pytorch_lightning.callbacks import Checkpoint
from pytorch_lightning.utilities.types import _PATH
class _FaultToleranceCheckpoint(Callback):
class _FaultToleranceCheckpoint(Checkpoint):
"""Used to save a fault-tolerance checkpoint on exception."""
FILE_EXTENSION = ".ckpt"

View File

@ -34,7 +34,7 @@ import yaml
from torch import Tensor
import pytorch_lightning as pl
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.callbacks import Checkpoint
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.logger import _name, _version
@ -46,7 +46,7 @@ log = logging.getLogger(__name__)
warning_cache = WarningCache()
class ModelCheckpoint(Callback):
class ModelCheckpoint(Checkpoint):
r"""
Save the model periodically by monitoring a quantity. Every metric logged with
:meth:`~pytorch_lightning.core.module.log` or :meth:`~pytorch_lightning.core.module.log_dict` in

View File

@ -25,7 +25,7 @@ from weakref import ReferenceType
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks import Checkpoint
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_only
@ -86,7 +86,7 @@ class Logger(ABC):
else:
self._agg_default_func = np.mean
def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None:
def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[Checkpoint]") -> None:
"""Called after model checkpoint callback saves a new checkpoint.
Args:
@ -221,7 +221,7 @@ class LoggerCollection(Logger):
def __getitem__(self, index: int) -> Logger:
return list(self._logger_iterable)[index]
def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None:
def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[Checkpoint]") -> None:
for logger in self._logger_iterable:
logger.after_save_checkpoint(checkpoint_callback)

View File

@ -31,7 +31,7 @@ import torch
from torch import Tensor
from pytorch_lightning import __version__
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks import Checkpoint
from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment
from pytorch_lightning.utilities.imports import _NEPTUNE_AVAILABLE, _NEPTUNE_GREATER_EQUAL_0_9
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params
@ -534,7 +534,7 @@ class NeptuneLogger(Logger):
)
@rank_zero_only
def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None:
def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[Checkpoint]") -> None:
"""Automatically log checkpointed model. Called after model checkpoint callback saves a new checkpoint.
Args:
@ -547,19 +547,20 @@ class NeptuneLogger(Logger):
checkpoints_namespace = self._construct_path_with_prefix("model/checkpoints")
# save last model
if checkpoint_callback.last_model_path:
if hasattr(checkpoint_callback, "last_model_path") and checkpoint_callback.last_model_path:
model_last_name = self._get_full_model_name(checkpoint_callback.last_model_path, checkpoint_callback)
file_names.add(model_last_name)
self.run[f"{checkpoints_namespace}/{model_last_name}"].upload(checkpoint_callback.last_model_path)
# save best k models
for key in checkpoint_callback.best_k_models.keys():
model_name = self._get_full_model_name(key, checkpoint_callback)
file_names.add(model_name)
self.run[f"{checkpoints_namespace}/{model_name}"].upload(key)
if hasattr(checkpoint_callback, "best_k_models"):
for key in checkpoint_callback.best_k_models.keys():
model_name = self._get_full_model_name(key, checkpoint_callback)
file_names.add(model_name)
self.run[f"{checkpoints_namespace}/{model_name}"].upload(key)
# log best model path and checkpoint
if checkpoint_callback.best_model_path:
if hasattr(checkpoint_callback, "best_model_path") and checkpoint_callback.best_model_path:
self.run[self._construct_path_with_prefix("model/best_model_path")] = checkpoint_callback.best_model_path
model_name = self._get_full_model_name(checkpoint_callback.best_model_path, checkpoint_callback)
@ -575,19 +576,22 @@ class NeptuneLogger(Logger):
del self.run[f"{checkpoints_namespace}/{file_to_drop}"]
# log best model score
if checkpoint_callback.best_model_score:
if hasattr(checkpoint_callback, "best_model_score") and checkpoint_callback.best_model_score:
self.run[self._construct_path_with_prefix("model/best_model_score")] = (
checkpoint_callback.best_model_score.cpu().detach().numpy()
)
@staticmethod
def _get_full_model_name(model_path: str, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> str:
def _get_full_model_name(model_path: str, checkpoint_callback: "ReferenceType[Checkpoint]") -> str:
"""Returns model name which is string `model_path` appended to `checkpoint_callback.dirpath`."""
expected_model_path = f"{checkpoint_callback.dirpath}{os.path.sep}"
if not model_path.startswith(expected_model_path):
raise ValueError(f"{model_path} was expected to start with {expected_model_path}.")
# Remove extension from filepath
filepath, _ = os.path.splitext(model_path[len(expected_model_path) :])
if hasattr(checkpoint_callback, "dirpath"):
expected_model_path = f"{checkpoint_callback.dirpath}{os.path.sep}"
if not model_path.startswith(expected_model_path):
raise ValueError(f"{model_path} was expected to start with {expected_model_path}.")
# Remove extension from filepath
filepath, _ = os.path.splitext(model_path[len(expected_model_path) :])
else:
filepath = model_path
return filepath

View File

@ -23,7 +23,7 @@ from weakref import ReferenceType
import torch.nn as nn
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.callbacks import Checkpoint
from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _WANDB_GREATER_EQUAL_0_10_22, _WANDB_GREATER_EQUAL_0_12_10
@ -461,9 +461,14 @@ class WandbLogger(Logger):
# don't create an experiment if we don't have one
return self._experiment.id if self._experiment else self._id
def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None:
def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[Checkpoint]") -> None:
# log checkpoints as artifacts
if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1:
if (
self._log_model == "all"
or self._log_model is True
and hasattr(checkpoint_callback, "save_top_k")
and checkpoint_callback.save_top_k == -1
):
self._scan_and_log_checkpoints(checkpoint_callback)
elif self._log_model is True:
self._checkpoint_callback = checkpoint_callback
@ -474,25 +479,33 @@ class WandbLogger(Logger):
if self._checkpoint_callback:
self._scan_and_log_checkpoints(self._checkpoint_callback)
def _scan_and_log_checkpoints(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None:
def _scan_and_log_checkpoints(self, checkpoint_callback: "ReferenceType[Checkpoint]") -> None:
# get checkpoints to be saved with associated score
checkpoints = {
checkpoint_callback.last_model_path: checkpoint_callback.current_score,
checkpoint_callback.best_model_path: checkpoint_callback.best_model_score,
**checkpoint_callback.best_k_models,
}
checkpoints = sorted((Path(p).stat().st_mtime, p, s) for p, s in checkpoints.items() if Path(p).is_file())
checkpoints = dict()
if hasattr(checkpoint_callback, "last_model_path") and hasattr(checkpoint_callback, "current_score"):
checkpoints[checkpoint_callback.last_model_path] = (checkpoint_callback.current_score, "latest")
if hasattr(checkpoint_callback, "best_model_path") and hasattr(checkpoint_callback, "best_model_score"):
checkpoints[checkpoint_callback.best_model_path] = (checkpoint_callback.best_model_score, "best")
if hasattr(checkpoint_callback, "best_k_models"):
for key, value in checkpoint_callback.best_k_models.items():
checkpoints[key] = (value, "best_k")
checkpoints = sorted(
(Path(p).stat().st_mtime, p, s, tag) for p, (s, tag) in checkpoints.items() if Path(p).is_file()
)
checkpoints = [
c for c in checkpoints if c[1] not in self._logged_model_time.keys() or self._logged_model_time[c[1]] < c[0]
]
# log iteratively all new checkpoints
for t, p, s in checkpoints:
for t, p, s, tag in checkpoints:
metadata = (
{
"score": s,
"original_filename": Path(p).name,
"ModelCheckpoint": {
checkpoint_callback.__class__.__name__: {
k: getattr(checkpoint_callback, k)
for k in [
"monitor",
@ -511,7 +524,6 @@ class WandbLogger(Logger):
)
artifact = wandb.Artifact(name=f"model-{self.experiment.id}", type="model", metadata=metadata)
artifact.add_file(p, name="model.ckpt")
aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"]
self.experiment.log_artifact(artifact, aliases=aliases)
self.experiment.log_artifact(artifact, aliases=[tag])
# remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)
self._logged_model_time[p] = t

View File

@ -109,7 +109,7 @@ class _SpawnLauncher(_Launcher):
def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer: "pl.Trainer") -> None:
# transfer back the best path to the trainer
if trainer.checkpoint_callback:
if trainer.checkpoint_callback and hasattr(trainer.checkpoint_callback, "best_model_path"):
trainer.checkpoint_callback.best_model_path = str(spawn_output.best_model_path)
# TODO: pass also best score
@ -131,7 +131,11 @@ class _SpawnLauncher(_Launcher):
def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]:
rank_zero_debug("Finalizing the DDP spawn environment.")
checkpoint_callback = trainer.checkpoint_callback
best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None
best_model_path = (
checkpoint_callback.best_model_path
if checkpoint_callback and hasattr(checkpoint_callback, "best_model_path")
else None
)
# requires to compute the state_dict on all processes in case Metrics are present
state_dict = trainer.lightning_module.state_dict()

View File

@ -115,7 +115,11 @@ class _XLASpawnLauncher(_SpawnLauncher):
def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]:
rank_zero_debug("Finalizing the TPU spawn environment.")
checkpoint_callback = trainer.checkpoint_callback
best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None
best_model_path = (
checkpoint_callback.best_model_path
if checkpoint_callback and hasattr(checkpoint_callback, "best_model_path")
else None
)
# requires to compute the state_dict on all processes in case Metrics are present
state_dict = trainer.lightning_module.state_dict()

View File

@ -19,6 +19,7 @@ from typing import Dict, List, Optional, Sequence, Union
from pytorch_lightning.callbacks import (
Callback,
Checkpoint,
GradientAccumulationScheduler,
ModelCheckpoint,
ModelSummary,
@ -232,18 +233,18 @@ class CallbackConnector:
@staticmethod
def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]:
"""Moves all ModelCheckpoint callbacks to the end of the list. The sequential order within the group of
"""Moves all Checkpoint callbacks to the end of the list. The sequential order within the group of
checkpoint callbacks is preserved, as well as the order of all other callbacks.
Args:
callbacks: A list of callbacks.
Return:
A new list in which the last elements are ModelCheckpoints if there were any present in the
A new list in which the last elements are Checkpoint if there were any present in the
input.
"""
checkpoints = [c for c in callbacks if isinstance(c, ModelCheckpoint)]
not_checkpoints = [c for c in callbacks if not isinstance(c, ModelCheckpoint)]
checkpoints = [c for c in callbacks if isinstance(c, Checkpoint)]
not_checkpoints = [c for c in callbacks if not isinstance(c, Checkpoint)]
return not_checkpoints + checkpoints

View File

@ -44,7 +44,7 @@ from pytorch_lightning.accelerators import (
MPSAccelerator,
TPUAccelerator,
)
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase
from pytorch_lightning.callbacks import Callback, Checkpoint, EarlyStopping, ProgressBarBase
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.optimizer import LightningOptimizer
@ -1406,7 +1406,7 @@ class Trainer(
f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured.'
)
if not self.checkpoint_callback.best_model_path:
if hasattr(self.checkpoint_callback, "best_model_path") and not self.checkpoint_callback.best_model_path:
if self.fast_dev_run:
raise MisconfigurationException(
f'You cannot execute `.{fn}(ckpt_path="best")` with `fast_dev_run=True`.'
@ -1416,11 +1416,11 @@ class Trainer(
f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.'
)
# load best weights
ckpt_path = self.checkpoint_callback.best_model_path
ckpt_path = getattr(self.checkpoint_callback, "best_model_path", None)
if ckpt_path == "last":
candidates = [ft.ckpt_path for ft in ft_checkpoints] + [
cb.last_model_path for cb in self.checkpoint_callbacks
candidates = [getattr(ft, "ckpt_path", None) for ft in ft_checkpoints] + [
getattr(cb, "last_model_path", None) for cb in self.checkpoint_callbacks
]
candidates_fs = {path: get_filesystem(path) for path in candidates if path}
candidates_ts = {path: fs.modified(path) for path, fs in candidates_fs.items() if fs.exists(path)}
@ -2308,17 +2308,17 @@ class Trainer(
return [cb for cb in self.callbacks if isinstance(cb, BasePredictionWriter)]
@property
def checkpoint_callback(self) -> Optional[ModelCheckpoint]:
def checkpoint_callback(self) -> Optional[Checkpoint]:
"""The first :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callback in the
Trainer.callbacks list, or ``None`` if it doesn't exist."""
callbacks = self.checkpoint_callbacks
return callbacks[0] if len(callbacks) > 0 else None
@property
def checkpoint_callbacks(self) -> List[ModelCheckpoint]:
def checkpoint_callbacks(self) -> List[Checkpoint]:
"""A list of all instances of :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` found
in the Trainer.callbacks list."""
return [c for c in self.callbacks if isinstance(c, ModelCheckpoint)]
return [c for c in self.callbacks if isinstance(c, Checkpoint)]
@property
def progress_bar_callback(self) -> Optional[ProgressBarBase]: