Add BaseModelCheckpoint class to inherit from (#13024)
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> Co-authored-by: Jirka <jirka.borovec@seznam.cz> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
a743d96350
commit
663d4c9c28
|
@ -136,7 +136,7 @@ ENV/
|
|||
Datasets/
|
||||
mnist/
|
||||
MNIST/
|
||||
legacy/checkpoints/
|
||||
tests/legacy/checkpoints/
|
||||
*.gz
|
||||
*ubyte
|
||||
|
||||
|
|
|
@ -6,7 +6,12 @@
|
|||
Checkpointing (expert)
|
||||
######################
|
||||
|
||||
TODO: I don't understand this...
|
||||
*********************************
|
||||
Writing your own Checkpoint class
|
||||
*********************************
|
||||
|
||||
We provide ``Checkpoint`` class, for easier subclassing. Users may want to subclass this class in case of writing custom ``ModelCheckpoint`` callback, so that the ``Trainer`` recognizes the custom class as a checkpointing callback.
|
||||
|
||||
|
||||
***********************
|
||||
Customize Checkpointing
|
||||
|
@ -23,6 +28,8 @@ and :meth:`~pytorch_lightning.core.hooks.CheckpointHooks.on_load_checkpoint` met
|
|||
what's saved in the checkpoint.
|
||||
|
||||
|
||||
TODO: I don't understand this...
|
||||
|
||||
******************************
|
||||
Built-in Checkpoint IO Plugins
|
||||
******************************
|
||||
|
|
|
@ -73,6 +73,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added breaking of lazy graph across training, validation, test and predict steps when training with habana accelerators to ensure better performance ([#12938](https://github.com/PyTorchLightning/pytorch-lightning/pull/12938))
|
||||
|
||||
|
||||
- Added `Checkpoint` class to inherit from ([#13024](https://github.com/PyTorchLightning/pytorch-lightning/pull/13024))
|
||||
|
||||
|
||||
- Added CPU metric tracking to `DeviceStatsMonitor` ([#11795](https://github.com/PyTorchLightning/pytorch-lightning/pull/11795))
|
||||
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pytorch_lightning.callbacks.callback import Callback
|
||||
from pytorch_lightning.callbacks.checkpoint import Checkpoint
|
||||
from pytorch_lightning.callbacks.device_stats_monitor import DeviceStatsMonitor
|
||||
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
||||
from pytorch_lightning.callbacks.finetuning import BackboneFinetuning, BaseFinetuning
|
||||
|
@ -32,6 +33,7 @@ __all__ = [
|
|||
"BackboneFinetuning",
|
||||
"BaseFinetuning",
|
||||
"Callback",
|
||||
"Checkpoint",
|
||||
"DeviceStatsMonitor",
|
||||
"EarlyStopping",
|
||||
"GradientAccumulationScheduler",
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
from pytorch_lightning.callbacks.callback import Callback
|
||||
|
||||
|
||||
class Checkpoint(Callback):
|
||||
r"""
|
||||
This is the base class for model checkpointing. Expert users may want to subclass it in case of writing
|
||||
custom :class:`~pytorch_lightning.callbacksCheckpoint` callback, so that
|
||||
the trainer recognizes the custom class as a checkpointing callback.
|
||||
"""
|
|
@ -21,11 +21,11 @@ import os
|
|||
from typing import Any
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning import Callback
|
||||
from pytorch_lightning.callbacks import Checkpoint
|
||||
from pytorch_lightning.utilities.types import _PATH
|
||||
|
||||
|
||||
class _FaultToleranceCheckpoint(Callback):
|
||||
class _FaultToleranceCheckpoint(Checkpoint):
|
||||
"""Used to save a fault-tolerance checkpoint on exception."""
|
||||
|
||||
FILE_EXTENSION = ".ckpt"
|
||||
|
|
|
@ -34,7 +34,7 @@ import yaml
|
|||
from torch import Tensor
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.callback import Callback
|
||||
from pytorch_lightning.callbacks import Checkpoint
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.logger import _name, _version
|
||||
|
@ -46,7 +46,7 @@ log = logging.getLogger(__name__)
|
|||
warning_cache = WarningCache()
|
||||
|
||||
|
||||
class ModelCheckpoint(Callback):
|
||||
class ModelCheckpoint(Checkpoint):
|
||||
r"""
|
||||
Save the model periodically by monitoring a quantity. Every metric logged with
|
||||
:meth:`~pytorch_lightning.core.module.log` or :meth:`~pytorch_lightning.core.module.log_dict` in
|
||||
|
|
|
@ -25,7 +25,7 @@ from weakref import ReferenceType
|
|||
import numpy as np
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
||||
from pytorch_lightning.callbacks import Checkpoint
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_only
|
||||
|
||||
|
||||
|
@ -86,7 +86,7 @@ class Logger(ABC):
|
|||
else:
|
||||
self._agg_default_func = np.mean
|
||||
|
||||
def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None:
|
||||
def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[Checkpoint]") -> None:
|
||||
"""Called after model checkpoint callback saves a new checkpoint.
|
||||
|
||||
Args:
|
||||
|
@ -221,7 +221,7 @@ class LoggerCollection(Logger):
|
|||
def __getitem__(self, index: int) -> Logger:
|
||||
return list(self._logger_iterable)[index]
|
||||
|
||||
def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None:
|
||||
def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[Checkpoint]") -> None:
|
||||
for logger in self._logger_iterable:
|
||||
logger.after_save_checkpoint(checkpoint_callback)
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ import torch
|
|||
from torch import Tensor
|
||||
|
||||
from pytorch_lightning import __version__
|
||||
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
||||
from pytorch_lightning.callbacks import Checkpoint
|
||||
from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment
|
||||
from pytorch_lightning.utilities.imports import _NEPTUNE_AVAILABLE, _NEPTUNE_GREATER_EQUAL_0_9
|
||||
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _sanitize_callable_params
|
||||
|
@ -534,7 +534,7 @@ class NeptuneLogger(Logger):
|
|||
)
|
||||
|
||||
@rank_zero_only
|
||||
def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None:
|
||||
def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[Checkpoint]") -> None:
|
||||
"""Automatically log checkpointed model. Called after model checkpoint callback saves a new checkpoint.
|
||||
|
||||
Args:
|
||||
|
@ -547,19 +547,20 @@ class NeptuneLogger(Logger):
|
|||
checkpoints_namespace = self._construct_path_with_prefix("model/checkpoints")
|
||||
|
||||
# save last model
|
||||
if checkpoint_callback.last_model_path:
|
||||
if hasattr(checkpoint_callback, "last_model_path") and checkpoint_callback.last_model_path:
|
||||
model_last_name = self._get_full_model_name(checkpoint_callback.last_model_path, checkpoint_callback)
|
||||
file_names.add(model_last_name)
|
||||
self.run[f"{checkpoints_namespace}/{model_last_name}"].upload(checkpoint_callback.last_model_path)
|
||||
|
||||
# save best k models
|
||||
for key in checkpoint_callback.best_k_models.keys():
|
||||
model_name = self._get_full_model_name(key, checkpoint_callback)
|
||||
file_names.add(model_name)
|
||||
self.run[f"{checkpoints_namespace}/{model_name}"].upload(key)
|
||||
if hasattr(checkpoint_callback, "best_k_models"):
|
||||
for key in checkpoint_callback.best_k_models.keys():
|
||||
model_name = self._get_full_model_name(key, checkpoint_callback)
|
||||
file_names.add(model_name)
|
||||
self.run[f"{checkpoints_namespace}/{model_name}"].upload(key)
|
||||
|
||||
# log best model path and checkpoint
|
||||
if checkpoint_callback.best_model_path:
|
||||
if hasattr(checkpoint_callback, "best_model_path") and checkpoint_callback.best_model_path:
|
||||
self.run[self._construct_path_with_prefix("model/best_model_path")] = checkpoint_callback.best_model_path
|
||||
|
||||
model_name = self._get_full_model_name(checkpoint_callback.best_model_path, checkpoint_callback)
|
||||
|
@ -575,19 +576,22 @@ class NeptuneLogger(Logger):
|
|||
del self.run[f"{checkpoints_namespace}/{file_to_drop}"]
|
||||
|
||||
# log best model score
|
||||
if checkpoint_callback.best_model_score:
|
||||
if hasattr(checkpoint_callback, "best_model_score") and checkpoint_callback.best_model_score:
|
||||
self.run[self._construct_path_with_prefix("model/best_model_score")] = (
|
||||
checkpoint_callback.best_model_score.cpu().detach().numpy()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_full_model_name(model_path: str, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> str:
|
||||
def _get_full_model_name(model_path: str, checkpoint_callback: "ReferenceType[Checkpoint]") -> str:
|
||||
"""Returns model name which is string `model_path` appended to `checkpoint_callback.dirpath`."""
|
||||
expected_model_path = f"{checkpoint_callback.dirpath}{os.path.sep}"
|
||||
if not model_path.startswith(expected_model_path):
|
||||
raise ValueError(f"{model_path} was expected to start with {expected_model_path}.")
|
||||
# Remove extension from filepath
|
||||
filepath, _ = os.path.splitext(model_path[len(expected_model_path) :])
|
||||
if hasattr(checkpoint_callback, "dirpath"):
|
||||
expected_model_path = f"{checkpoint_callback.dirpath}{os.path.sep}"
|
||||
if not model_path.startswith(expected_model_path):
|
||||
raise ValueError(f"{model_path} was expected to start with {expected_model_path}.")
|
||||
# Remove extension from filepath
|
||||
filepath, _ = os.path.splitext(model_path[len(expected_model_path) :])
|
||||
else:
|
||||
filepath = model_path
|
||||
|
||||
return filepath
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ from weakref import ReferenceType
|
|||
|
||||
import torch.nn as nn
|
||||
|
||||
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
||||
from pytorch_lightning.callbacks import Checkpoint
|
||||
from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _WANDB_GREATER_EQUAL_0_10_22, _WANDB_GREATER_EQUAL_0_12_10
|
||||
|
@ -461,9 +461,14 @@ class WandbLogger(Logger):
|
|||
# don't create an experiment if we don't have one
|
||||
return self._experiment.id if self._experiment else self._id
|
||||
|
||||
def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None:
|
||||
def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[Checkpoint]") -> None:
|
||||
# log checkpoints as artifacts
|
||||
if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1:
|
||||
if (
|
||||
self._log_model == "all"
|
||||
or self._log_model is True
|
||||
and hasattr(checkpoint_callback, "save_top_k")
|
||||
and checkpoint_callback.save_top_k == -1
|
||||
):
|
||||
self._scan_and_log_checkpoints(checkpoint_callback)
|
||||
elif self._log_model is True:
|
||||
self._checkpoint_callback = checkpoint_callback
|
||||
|
@ -474,25 +479,33 @@ class WandbLogger(Logger):
|
|||
if self._checkpoint_callback:
|
||||
self._scan_and_log_checkpoints(self._checkpoint_callback)
|
||||
|
||||
def _scan_and_log_checkpoints(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None:
|
||||
def _scan_and_log_checkpoints(self, checkpoint_callback: "ReferenceType[Checkpoint]") -> None:
|
||||
# get checkpoints to be saved with associated score
|
||||
checkpoints = {
|
||||
checkpoint_callback.last_model_path: checkpoint_callback.current_score,
|
||||
checkpoint_callback.best_model_path: checkpoint_callback.best_model_score,
|
||||
**checkpoint_callback.best_k_models,
|
||||
}
|
||||
checkpoints = sorted((Path(p).stat().st_mtime, p, s) for p, s in checkpoints.items() if Path(p).is_file())
|
||||
checkpoints = dict()
|
||||
if hasattr(checkpoint_callback, "last_model_path") and hasattr(checkpoint_callback, "current_score"):
|
||||
checkpoints[checkpoint_callback.last_model_path] = (checkpoint_callback.current_score, "latest")
|
||||
|
||||
if hasattr(checkpoint_callback, "best_model_path") and hasattr(checkpoint_callback, "best_model_score"):
|
||||
checkpoints[checkpoint_callback.best_model_path] = (checkpoint_callback.best_model_score, "best")
|
||||
|
||||
if hasattr(checkpoint_callback, "best_k_models"):
|
||||
for key, value in checkpoint_callback.best_k_models.items():
|
||||
checkpoints[key] = (value, "best_k")
|
||||
|
||||
checkpoints = sorted(
|
||||
(Path(p).stat().st_mtime, p, s, tag) for p, (s, tag) in checkpoints.items() if Path(p).is_file()
|
||||
)
|
||||
checkpoints = [
|
||||
c for c in checkpoints if c[1] not in self._logged_model_time.keys() or self._logged_model_time[c[1]] < c[0]
|
||||
]
|
||||
|
||||
# log iteratively all new checkpoints
|
||||
for t, p, s in checkpoints:
|
||||
for t, p, s, tag in checkpoints:
|
||||
metadata = (
|
||||
{
|
||||
"score": s,
|
||||
"original_filename": Path(p).name,
|
||||
"ModelCheckpoint": {
|
||||
checkpoint_callback.__class__.__name__: {
|
||||
k: getattr(checkpoint_callback, k)
|
||||
for k in [
|
||||
"monitor",
|
||||
|
@ -511,7 +524,6 @@ class WandbLogger(Logger):
|
|||
)
|
||||
artifact = wandb.Artifact(name=f"model-{self.experiment.id}", type="model", metadata=metadata)
|
||||
artifact.add_file(p, name="model.ckpt")
|
||||
aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"]
|
||||
self.experiment.log_artifact(artifact, aliases=aliases)
|
||||
self.experiment.log_artifact(artifact, aliases=[tag])
|
||||
# remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)
|
||||
self._logged_model_time[p] = t
|
||||
|
|
|
@ -109,7 +109,7 @@ class _SpawnLauncher(_Launcher):
|
|||
|
||||
def _recover_results_in_main_process(self, spawn_output: "_SpawnOutput", trainer: "pl.Trainer") -> None:
|
||||
# transfer back the best path to the trainer
|
||||
if trainer.checkpoint_callback:
|
||||
if trainer.checkpoint_callback and hasattr(trainer.checkpoint_callback, "best_model_path"):
|
||||
trainer.checkpoint_callback.best_model_path = str(spawn_output.best_model_path)
|
||||
|
||||
# TODO: pass also best score
|
||||
|
@ -131,7 +131,11 @@ class _SpawnLauncher(_Launcher):
|
|||
def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]:
|
||||
rank_zero_debug("Finalizing the DDP spawn environment.")
|
||||
checkpoint_callback = trainer.checkpoint_callback
|
||||
best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None
|
||||
best_model_path = (
|
||||
checkpoint_callback.best_model_path
|
||||
if checkpoint_callback and hasattr(checkpoint_callback, "best_model_path")
|
||||
else None
|
||||
)
|
||||
|
||||
# requires to compute the state_dict on all processes in case Metrics are present
|
||||
state_dict = trainer.lightning_module.state_dict()
|
||||
|
|
|
@ -115,7 +115,11 @@ class _XLASpawnLauncher(_SpawnLauncher):
|
|||
def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_SpawnOutput"]:
|
||||
rank_zero_debug("Finalizing the TPU spawn environment.")
|
||||
checkpoint_callback = trainer.checkpoint_callback
|
||||
best_model_path = checkpoint_callback.best_model_path if checkpoint_callback else None
|
||||
best_model_path = (
|
||||
checkpoint_callback.best_model_path
|
||||
if checkpoint_callback and hasattr(checkpoint_callback, "best_model_path")
|
||||
else None
|
||||
)
|
||||
|
||||
# requires to compute the state_dict on all processes in case Metrics are present
|
||||
state_dict = trainer.lightning_module.state_dict()
|
||||
|
|
|
@ -19,6 +19,7 @@ from typing import Dict, List, Optional, Sequence, Union
|
|||
|
||||
from pytorch_lightning.callbacks import (
|
||||
Callback,
|
||||
Checkpoint,
|
||||
GradientAccumulationScheduler,
|
||||
ModelCheckpoint,
|
||||
ModelSummary,
|
||||
|
@ -232,18 +233,18 @@ class CallbackConnector:
|
|||
|
||||
@staticmethod
|
||||
def _reorder_callbacks(callbacks: List[Callback]) -> List[Callback]:
|
||||
"""Moves all ModelCheckpoint callbacks to the end of the list. The sequential order within the group of
|
||||
"""Moves all Checkpoint callbacks to the end of the list. The sequential order within the group of
|
||||
checkpoint callbacks is preserved, as well as the order of all other callbacks.
|
||||
|
||||
Args:
|
||||
callbacks: A list of callbacks.
|
||||
|
||||
Return:
|
||||
A new list in which the last elements are ModelCheckpoints if there were any present in the
|
||||
A new list in which the last elements are Checkpoint if there were any present in the
|
||||
input.
|
||||
"""
|
||||
checkpoints = [c for c in callbacks if isinstance(c, ModelCheckpoint)]
|
||||
not_checkpoints = [c for c in callbacks if not isinstance(c, ModelCheckpoint)]
|
||||
checkpoints = [c for c in callbacks if isinstance(c, Checkpoint)]
|
||||
not_checkpoints = [c for c in callbacks if not isinstance(c, Checkpoint)]
|
||||
return not_checkpoints + checkpoints
|
||||
|
||||
|
||||
|
|
|
@ -44,7 +44,7 @@ from pytorch_lightning.accelerators import (
|
|||
MPSAccelerator,
|
||||
TPUAccelerator,
|
||||
)
|
||||
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint, ProgressBarBase
|
||||
from pytorch_lightning.callbacks import Callback, Checkpoint, EarlyStopping, ProgressBarBase
|
||||
from pytorch_lightning.callbacks.prediction_writer import BasePredictionWriter
|
||||
from pytorch_lightning.core.datamodule import LightningDataModule
|
||||
from pytorch_lightning.core.optimizer import LightningOptimizer
|
||||
|
@ -1406,7 +1406,7 @@ class Trainer(
|
|||
f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured.'
|
||||
)
|
||||
|
||||
if not self.checkpoint_callback.best_model_path:
|
||||
if hasattr(self.checkpoint_callback, "best_model_path") and not self.checkpoint_callback.best_model_path:
|
||||
if self.fast_dev_run:
|
||||
raise MisconfigurationException(
|
||||
f'You cannot execute `.{fn}(ckpt_path="best")` with `fast_dev_run=True`.'
|
||||
|
@ -1416,11 +1416,11 @@ class Trainer(
|
|||
f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.'
|
||||
)
|
||||
# load best weights
|
||||
ckpt_path = self.checkpoint_callback.best_model_path
|
||||
ckpt_path = getattr(self.checkpoint_callback, "best_model_path", None)
|
||||
|
||||
if ckpt_path == "last":
|
||||
candidates = [ft.ckpt_path for ft in ft_checkpoints] + [
|
||||
cb.last_model_path for cb in self.checkpoint_callbacks
|
||||
candidates = [getattr(ft, "ckpt_path", None) for ft in ft_checkpoints] + [
|
||||
getattr(cb, "last_model_path", None) for cb in self.checkpoint_callbacks
|
||||
]
|
||||
candidates_fs = {path: get_filesystem(path) for path in candidates if path}
|
||||
candidates_ts = {path: fs.modified(path) for path, fs in candidates_fs.items() if fs.exists(path)}
|
||||
|
@ -2308,17 +2308,17 @@ class Trainer(
|
|||
return [cb for cb in self.callbacks if isinstance(cb, BasePredictionWriter)]
|
||||
|
||||
@property
|
||||
def checkpoint_callback(self) -> Optional[ModelCheckpoint]:
|
||||
def checkpoint_callback(self) -> Optional[Checkpoint]:
|
||||
"""The first :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callback in the
|
||||
Trainer.callbacks list, or ``None`` if it doesn't exist."""
|
||||
callbacks = self.checkpoint_callbacks
|
||||
return callbacks[0] if len(callbacks) > 0 else None
|
||||
|
||||
@property
|
||||
def checkpoint_callbacks(self) -> List[ModelCheckpoint]:
|
||||
def checkpoint_callbacks(self) -> List[Checkpoint]:
|
||||
"""A list of all instances of :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` found
|
||||
in the Trainer.callbacks list."""
|
||||
return [c for c in self.callbacks if isinstance(c, ModelCheckpoint)]
|
||||
return [c for c in self.callbacks if isinstance(c, Checkpoint)]
|
||||
|
||||
@property
|
||||
def progress_bar_callback(self) -> Optional[ProgressBarBase]:
|
||||
|
|
Loading…
Reference in New Issue