Fix mypy typing errors in pytorch_lightning/callbacks/model_checkpoint.py (#13617)
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
ca1917ec80
commit
b40766c333
|
@ -47,7 +47,6 @@ warn_no_return = "False"
|
|||
# the list can be generated with:
|
||||
# mypy --no-error-summary 2>&1 | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g; s|src/||g; s|\/|\.|g' | xargs -I {} echo '"{}",'
|
||||
module = [
|
||||
"pytorch_lightning.callbacks.model_checkpoint",
|
||||
"pytorch_lightning.callbacks.progress.rich_progress",
|
||||
"pytorch_lightning.callbacks.quantization",
|
||||
"pytorch_lightning.callbacks.stochastic_weight_avg",
|
||||
|
|
|
@ -135,7 +135,7 @@ class EarlyStopping(Callback):
|
|||
# validation, then we run after validation instead of on train epoch end
|
||||
self._check_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1
|
||||
|
||||
def _validate_condition_metric(self, logs: Dict[str, float]) -> bool:
|
||||
def _validate_condition_metric(self, logs: Dict[str, Tensor]) -> bool:
|
||||
monitor_val = logs.get(self.monitor)
|
||||
|
||||
error_msg = (
|
||||
|
|
|
@ -39,7 +39,7 @@ from pytorch_lightning.utilities.cloud_io import get_filesystem
|
|||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.logger import _name, _version
|
||||
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
|
||||
from pytorch_lightning.utilities.types import _METRIC, _PATH, STEP_OUTPUT
|
||||
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT
|
||||
from pytorch_lightning.utilities.warnings import WarningCache
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -231,13 +231,14 @@ class ModelCheckpoint(Checkpoint):
|
|||
self._save_on_train_epoch_end = save_on_train_epoch_end
|
||||
self._last_global_step_saved = 0 # no need to save when no steps were taken
|
||||
self._last_time_checked: Optional[float] = None
|
||||
self.current_score = None
|
||||
self.best_k_models = {}
|
||||
self.current_score: Optional[Tensor] = None
|
||||
self.best_k_models: Dict[str, Tensor] = {}
|
||||
self.kth_best_model_path = ""
|
||||
self.best_model_score = None
|
||||
self.best_model_score: Optional[Tensor] = None
|
||||
self.best_model_path = ""
|
||||
self.last_model_path = ""
|
||||
|
||||
self.kth_value: Tensor
|
||||
self.__init_monitor_mode(mode)
|
||||
self.__init_ckpt_dir(dirpath, filename)
|
||||
self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval)
|
||||
|
@ -256,6 +257,7 @@ class ModelCheckpoint(Checkpoint):
|
|||
|
||||
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
|
||||
self.__resolve_ckpt_dir(trainer)
|
||||
assert self.dirpath is not None
|
||||
if trainer.is_global_zero and stage == "fit":
|
||||
self.__warn_if_dir_not_empty(self.dirpath)
|
||||
|
||||
|
@ -362,7 +364,7 @@ class ModelCheckpoint(Checkpoint):
|
|||
self._save_topk_checkpoint(trainer, monitor_candidates)
|
||||
self._save_last_checkpoint(trainer, monitor_candidates)
|
||||
|
||||
def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
|
||||
def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None:
|
||||
if self.save_top_k == 0:
|
||||
return
|
||||
|
||||
|
@ -395,7 +397,7 @@ class ModelCheckpoint(Checkpoint):
|
|||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
|
||||
return (
|
||||
trainer.fast_dev_run # disable checkpointing with fast_dev_run
|
||||
bool(trainer.fast_dev_run) # disable checkpointing with fast_dev_run
|
||||
or trainer.state.fn != TrainerFn.FITTING # don't save anything during non-fit
|
||||
or trainer.sanity_checking # don't save anything during sanity check
|
||||
or self._last_global_step_saved == trainer.global_step # already saved at the last step
|
||||
|
@ -493,7 +495,7 @@ class ModelCheckpoint(Checkpoint):
|
|||
should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path])
|
||||
|
||||
# If using multiple devices, make sure all processes are unanimous on the decision.
|
||||
should_update_best_and_save = trainer.strategy.reduce_boolean_decision(should_update_best_and_save)
|
||||
should_update_best_and_save = trainer.strategy.reduce_boolean_decision(bool(should_update_best_and_save))
|
||||
|
||||
return should_update_best_and_save
|
||||
|
||||
|
@ -501,7 +503,7 @@ class ModelCheckpoint(Checkpoint):
|
|||
def _format_checkpoint_name(
|
||||
cls,
|
||||
filename: Optional[str],
|
||||
metrics: Dict[str, _METRIC],
|
||||
metrics: Dict[str, Tensor],
|
||||
prefix: str = "",
|
||||
auto_insert_metric_name: bool = True,
|
||||
) -> str:
|
||||
|
@ -522,7 +524,7 @@ class ModelCheckpoint(Checkpoint):
|
|||
filename = filename.replace(group, f"{{0[{name}]")
|
||||
|
||||
if name not in metrics:
|
||||
metrics[name] = 0
|
||||
metrics[name] = torch.tensor(0)
|
||||
filename = filename.format(metrics)
|
||||
|
||||
if prefix:
|
||||
|
@ -531,7 +533,7 @@ class ModelCheckpoint(Checkpoint):
|
|||
return filename
|
||||
|
||||
def format_checkpoint_name(
|
||||
self, metrics: Dict[str, _METRIC], filename: Optional[str] = None, ver: Optional[int] = None
|
||||
self, metrics: Dict[str, Tensor], filename: Optional[str] = None, ver: Optional[int] = None
|
||||
) -> str:
|
||||
"""Generate a filename according to the defined template.
|
||||
|
||||
|
@ -591,6 +593,7 @@ class ModelCheckpoint(Checkpoint):
|
|||
ckpt_path = os.path.join(trainer._weights_save_path_internal, "checkpoints")
|
||||
elif trainer.loggers:
|
||||
if len(trainer.loggers) == 1:
|
||||
assert trainer.logger is not None
|
||||
save_dir = trainer.logger.save_dir or trainer.default_root_dir
|
||||
else:
|
||||
save_dir = trainer.default_root_dir
|
||||
|
@ -613,7 +616,7 @@ class ModelCheckpoint(Checkpoint):
|
|||
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
|
||||
|
||||
def _get_metric_interpolated_filepath_name(
|
||||
self, monitor_candidates: Dict[str, _METRIC], trainer: "pl.Trainer", del_filepath: Optional[str] = None
|
||||
self, monitor_candidates: Dict[str, Tensor], trainer: "pl.Trainer", del_filepath: Optional[str] = None
|
||||
) -> str:
|
||||
filepath = self.format_checkpoint_name(monitor_candidates)
|
||||
|
||||
|
@ -624,7 +627,7 @@ class ModelCheckpoint(Checkpoint):
|
|||
|
||||
return filepath
|
||||
|
||||
def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, _METRIC]:
|
||||
def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, Tensor]:
|
||||
monitor_candidates = deepcopy(trainer.callback_metrics)
|
||||
# cast to int if necessary because `self.log("epoch", 123)` will convert it to float. if it's not a tensor
|
||||
# or does not exist we overwrite it as it's likely an error
|
||||
|
@ -634,7 +637,7 @@ class ModelCheckpoint(Checkpoint):
|
|||
monitor_candidates["step"] = step.int() if isinstance(step, Tensor) else torch.tensor(trainer.global_step)
|
||||
return monitor_candidates
|
||||
|
||||
def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
|
||||
def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None:
|
||||
if not self.save_last:
|
||||
return
|
||||
|
||||
|
@ -651,16 +654,18 @@ class ModelCheckpoint(Checkpoint):
|
|||
if previous and previous != filepath:
|
||||
trainer.strategy.remove_checkpoint(previous)
|
||||
|
||||
def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
|
||||
def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None:
|
||||
assert self.monitor
|
||||
current = monitor_candidates.get(self.monitor)
|
||||
if self.check_monitor_top_k(trainer, current):
|
||||
assert current is not None
|
||||
self._update_best_and_save(current, trainer, monitor_candidates)
|
||||
elif self.verbose:
|
||||
epoch = monitor_candidates["epoch"]
|
||||
step = monitor_candidates["step"]
|
||||
rank_zero_info(f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} was not in top {self.save_top_k}")
|
||||
|
||||
def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
|
||||
def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None:
|
||||
filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer)
|
||||
# set the best model path before saving because it will be part of the state.
|
||||
previous, self.best_model_path = self.best_model_path, filepath
|
||||
|
@ -669,7 +674,7 @@ class ModelCheckpoint(Checkpoint):
|
|||
trainer.strategy.remove_checkpoint(previous)
|
||||
|
||||
def _update_best_and_save(
|
||||
self, current: Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]
|
||||
self, current: Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]
|
||||
) -> None:
|
||||
k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k
|
||||
|
||||
|
@ -691,11 +696,11 @@ class ModelCheckpoint(Checkpoint):
|
|||
if len(self.best_k_models) == k:
|
||||
# monitor dict has reached k elements
|
||||
_op = max if self.mode == "min" else min
|
||||
self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get)
|
||||
self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type]
|
||||
self.kth_value = self.best_k_models[self.kth_best_model_path]
|
||||
|
||||
_op = min if self.mode == "min" else max
|
||||
self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get)
|
||||
self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type]
|
||||
self.best_model_score = self.best_k_models[self.best_model_path]
|
||||
|
||||
if self.verbose:
|
||||
|
@ -715,6 +720,7 @@ class ModelCheckpoint(Checkpoint):
|
|||
file."""
|
||||
best_k = {k: v.item() for k, v in self.best_k_models.items()}
|
||||
if filepath is None:
|
||||
assert self.dirpath
|
||||
filepath = os.path.join(self.dirpath, "best_k_models.yaml")
|
||||
with self._fs.open(filepath, "w") as fp:
|
||||
yaml.dump(best_k, fp)
|
||||
|
|
|
@ -532,7 +532,7 @@ class LightningModule(
|
|||
return torch.tensor(value, device=self.device)
|
||||
|
||||
@staticmethod
|
||||
def __check_numel_1(value: torch.Tensor, name: str) -> None:
|
||||
def __check_numel_1(value: Tensor, name: str) -> None:
|
||||
if not torch.numel(value) == 1:
|
||||
raise ValueError(
|
||||
f"`self.log({name}, {value})` was called, but the tensor must have a single element."
|
||||
|
|
|
@ -285,7 +285,7 @@ class Strategy(ABC):
|
|||
"""
|
||||
|
||||
def reduce_boolean_decision(self, decision: bool) -> bool:
|
||||
"""Reduce the early stopping decision across all processes."""
|
||||
"""Reduce a boolean decision across all processes."""
|
||||
return decision
|
||||
|
||||
def pre_backward(self, closure_loss: Tensor) -> None:
|
||||
|
|
|
@ -169,19 +169,13 @@ class TPUSpawnStrategy(DDPSpawnStrategy):
|
|||
obj = torch.load(buffer)
|
||||
return obj
|
||||
|
||||
def reduce_boolean_decision(self, decision: bool) -> bool:
|
||||
decision = torch.tensor(int(decision), device=self.root_device)
|
||||
decision = self.reduce(decision, reduce_op="sum")
|
||||
decision = bool(decision == self.world_size)
|
||||
return decision
|
||||
|
||||
def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
|
||||
if not isinstance(output, Tensor):
|
||||
output = torch.tensor(output, device=self.root_device)
|
||||
|
||||
_invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM
|
||||
_invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg")
|
||||
if _invalid_reduce_op or _invalid_reduce_op_str:
|
||||
invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM
|
||||
invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg")
|
||||
if invalid_reduce_op or invalid_reduce_op_str:
|
||||
raise MisconfigurationException(
|
||||
"Currently, TPUSpawn Strategy only support `sum`, `mean`, `avg` reduce operation."
|
||||
)
|
||||
|
|
|
@ -529,7 +529,7 @@ class _ResultCollection(dict):
|
|||
result_metric.meta.sync.should = should
|
||||
cache = result_metric._computed
|
||||
if cache is not None:
|
||||
if not isinstance(cache, torch.Tensor):
|
||||
if not isinstance(cache, Tensor):
|
||||
raise ValueError(
|
||||
f"The `.compute()` return of the metric logged as {result_metric.meta.name!r} must be a tensor."
|
||||
f" Found {cache}"
|
||||
|
|
|
@ -2705,7 +2705,9 @@ class Trainer(
|
|||
self._loggers = loggers if loggers else []
|
||||
|
||||
@property
|
||||
def callback_metrics(self) -> dict:
|
||||
def callback_metrics(self) -> Dict[str, Tensor]:
|
||||
# TODO: the true typing return can include dictionaries as defined in
|
||||
# `pytorch_lightning.trainer.connectors.logger_connector.result._OUT_DICT`
|
||||
return self._logger_connector.callback_metrics
|
||||
|
||||
@property
|
||||
|
|
|
@ -99,7 +99,7 @@ def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tens
|
|||
return gathered_result
|
||||
|
||||
|
||||
def _simple_gather_all_tensors(result: torch.Tensor, group: Any, world_size: int) -> List[torch.Tensor]:
|
||||
def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]:
|
||||
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
|
||||
torch.distributed.all_gather(gathered_result, result, group)
|
||||
return gathered_result
|
||||
|
|
Loading…
Reference in New Issue