Fix mypy typing errors in pytorch_lightning/callbacks/model_checkpoint.py ()

Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
Lee Jungwon 2022-07-21 02:07:38 +09:00 committed by GitHub
parent ca1917ec80
commit b40766c333
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 35 additions and 34 deletions
pyproject.toml
src/pytorch_lightning

View File

@ -47,7 +47,6 @@ warn_no_return = "False"
# the list can be generated with:
# mypy --no-error-summary 2>&1 | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g; s|src/||g; s|\/|\.|g' | xargs -I {} echo '"{}",'
module = [
"pytorch_lightning.callbacks.model_checkpoint",
"pytorch_lightning.callbacks.progress.rich_progress",
"pytorch_lightning.callbacks.quantization",
"pytorch_lightning.callbacks.stochastic_weight_avg",

View File

@ -135,7 +135,7 @@ class EarlyStopping(Callback):
# validation, then we run after validation instead of on train epoch end
self._check_on_train_epoch_end = trainer.val_check_interval == 1.0 and trainer.check_val_every_n_epoch == 1
def _validate_condition_metric(self, logs: Dict[str, float]) -> bool:
def _validate_condition_metric(self, logs: Dict[str, Tensor]) -> bool:
monitor_val = logs.get(self.monitor)
error_msg = (

View File

@ -39,7 +39,7 @@ from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.logger import _name, _version
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.types import _METRIC, _PATH, STEP_OUTPUT
from pytorch_lightning.utilities.types import _PATH, STEP_OUTPUT
from pytorch_lightning.utilities.warnings import WarningCache
log = logging.getLogger(__name__)
@ -231,13 +231,14 @@ class ModelCheckpoint(Checkpoint):
self._save_on_train_epoch_end = save_on_train_epoch_end
self._last_global_step_saved = 0 # no need to save when no steps were taken
self._last_time_checked: Optional[float] = None
self.current_score = None
self.best_k_models = {}
self.current_score: Optional[Tensor] = None
self.best_k_models: Dict[str, Tensor] = {}
self.kth_best_model_path = ""
self.best_model_score = None
self.best_model_score: Optional[Tensor] = None
self.best_model_path = ""
self.last_model_path = ""
self.kth_value: Tensor
self.__init_monitor_mode(mode)
self.__init_ckpt_dir(dirpath, filename)
self.__init_triggers(every_n_train_steps, every_n_epochs, train_time_interval)
@ -256,6 +257,7 @@ class ModelCheckpoint(Checkpoint):
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
self.__resolve_ckpt_dir(trainer)
assert self.dirpath is not None
if trainer.is_global_zero and stage == "fit":
self.__warn_if_dir_not_empty(self.dirpath)
@ -362,7 +364,7 @@ class ModelCheckpoint(Checkpoint):
self._save_topk_checkpoint(trainer, monitor_candidates)
self._save_last_checkpoint(trainer, monitor_candidates)
def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
def _save_topk_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None:
if self.save_top_k == 0:
return
@ -395,7 +397,7 @@ class ModelCheckpoint(Checkpoint):
from pytorch_lightning.trainer.states import TrainerFn
return (
trainer.fast_dev_run # disable checkpointing with fast_dev_run
bool(trainer.fast_dev_run) # disable checkpointing with fast_dev_run
or trainer.state.fn != TrainerFn.FITTING # don't save anything during non-fit
or trainer.sanity_checking # don't save anything during sanity check
or self._last_global_step_saved == trainer.global_step # already saved at the last step
@ -493,7 +495,7 @@ class ModelCheckpoint(Checkpoint):
should_update_best_and_save = monitor_op(current, self.best_k_models[self.kth_best_model_path])
# If using multiple devices, make sure all processes are unanimous on the decision.
should_update_best_and_save = trainer.strategy.reduce_boolean_decision(should_update_best_and_save)
should_update_best_and_save = trainer.strategy.reduce_boolean_decision(bool(should_update_best_and_save))
return should_update_best_and_save
@ -501,7 +503,7 @@ class ModelCheckpoint(Checkpoint):
def _format_checkpoint_name(
cls,
filename: Optional[str],
metrics: Dict[str, _METRIC],
metrics: Dict[str, Tensor],
prefix: str = "",
auto_insert_metric_name: bool = True,
) -> str:
@ -522,7 +524,7 @@ class ModelCheckpoint(Checkpoint):
filename = filename.replace(group, f"{{0[{name}]")
if name not in metrics:
metrics[name] = 0
metrics[name] = torch.tensor(0)
filename = filename.format(metrics)
if prefix:
@ -531,7 +533,7 @@ class ModelCheckpoint(Checkpoint):
return filename
def format_checkpoint_name(
self, metrics: Dict[str, _METRIC], filename: Optional[str] = None, ver: Optional[int] = None
self, metrics: Dict[str, Tensor], filename: Optional[str] = None, ver: Optional[int] = None
) -> str:
"""Generate a filename according to the defined template.
@ -591,6 +593,7 @@ class ModelCheckpoint(Checkpoint):
ckpt_path = os.path.join(trainer._weights_save_path_internal, "checkpoints")
elif trainer.loggers:
if len(trainer.loggers) == 1:
assert trainer.logger is not None
save_dir = trainer.logger.save_dir or trainer.default_root_dir
else:
save_dir = trainer.default_root_dir
@ -613,7 +616,7 @@ class ModelCheckpoint(Checkpoint):
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
def _get_metric_interpolated_filepath_name(
self, monitor_candidates: Dict[str, _METRIC], trainer: "pl.Trainer", del_filepath: Optional[str] = None
self, monitor_candidates: Dict[str, Tensor], trainer: "pl.Trainer", del_filepath: Optional[str] = None
) -> str:
filepath = self.format_checkpoint_name(monitor_candidates)
@ -624,7 +627,7 @@ class ModelCheckpoint(Checkpoint):
return filepath
def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, _METRIC]:
def _monitor_candidates(self, trainer: "pl.Trainer") -> Dict[str, Tensor]:
monitor_candidates = deepcopy(trainer.callback_metrics)
# cast to int if necessary because `self.log("epoch", 123)` will convert it to float. if it's not a tensor
# or does not exist we overwrite it as it's likely an error
@ -634,7 +637,7 @@ class ModelCheckpoint(Checkpoint):
monitor_candidates["step"] = step.int() if isinstance(step, Tensor) else torch.tensor(trainer.global_step)
return monitor_candidates
def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None:
if not self.save_last:
return
@ -651,16 +654,18 @@ class ModelCheckpoint(Checkpoint):
if previous and previous != filepath:
trainer.strategy.remove_checkpoint(previous)
def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
def _save_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None:
assert self.monitor
current = monitor_candidates.get(self.monitor)
if self.check_monitor_top_k(trainer, current):
assert current is not None
self._update_best_and_save(current, trainer, monitor_candidates)
elif self.verbose:
epoch = monitor_candidates["epoch"]
step = monitor_candidates["step"]
rank_zero_info(f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} was not in top {self.save_top_k}")
def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]) -> None:
filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer)
# set the best model path before saving because it will be part of the state.
previous, self.best_model_path = self.best_model_path, filepath
@ -669,7 +674,7 @@ class ModelCheckpoint(Checkpoint):
trainer.strategy.remove_checkpoint(previous)
def _update_best_and_save(
self, current: Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]
self, current: Tensor, trainer: "pl.Trainer", monitor_candidates: Dict[str, Tensor]
) -> None:
k = len(self.best_k_models) + 1 if self.save_top_k == -1 else self.save_top_k
@ -691,11 +696,11 @@ class ModelCheckpoint(Checkpoint):
if len(self.best_k_models) == k:
# monitor dict has reached k elements
_op = max if self.mode == "min" else min
self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get)
self.kth_best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type]
self.kth_value = self.best_k_models[self.kth_best_model_path]
_op = min if self.mode == "min" else max
self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get)
self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get) # type: ignore[arg-type]
self.best_model_score = self.best_k_models[self.best_model_path]
if self.verbose:
@ -715,6 +720,7 @@ class ModelCheckpoint(Checkpoint):
file."""
best_k = {k: v.item() for k, v in self.best_k_models.items()}
if filepath is None:
assert self.dirpath
filepath = os.path.join(self.dirpath, "best_k_models.yaml")
with self._fs.open(filepath, "w") as fp:
yaml.dump(best_k, fp)

View File

@ -532,7 +532,7 @@ class LightningModule(
return torch.tensor(value, device=self.device)
@staticmethod
def __check_numel_1(value: torch.Tensor, name: str) -> None:
def __check_numel_1(value: Tensor, name: str) -> None:
if not torch.numel(value) == 1:
raise ValueError(
f"`self.log({name}, {value})` was called, but the tensor must have a single element."

View File

@ -285,7 +285,7 @@ class Strategy(ABC):
"""
def reduce_boolean_decision(self, decision: bool) -> bool:
"""Reduce the early stopping decision across all processes."""
"""Reduce a boolean decision across all processes."""
return decision
def pre_backward(self, closure_loss: Tensor) -> None:

View File

@ -169,19 +169,13 @@ class TPUSpawnStrategy(DDPSpawnStrategy):
obj = torch.load(buffer)
return obj
def reduce_boolean_decision(self, decision: bool) -> bool:
decision = torch.tensor(int(decision), device=self.root_device)
decision = self.reduce(decision, reduce_op="sum")
decision = bool(decision == self.world_size)
return decision
def reduce(self, output, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = None):
if not isinstance(output, Tensor):
output = torch.tensor(output, device=self.root_device)
_invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM
_invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg")
if _invalid_reduce_op or _invalid_reduce_op_str:
invalid_reduce_op = isinstance(reduce_op, ReduceOp) and reduce_op != ReduceOp.SUM
invalid_reduce_op_str = isinstance(reduce_op, str) and reduce_op.lower() not in ("sum", "mean", "avg")
if invalid_reduce_op or invalid_reduce_op_str:
raise MisconfigurationException(
"Currently, TPUSpawn Strategy only support `sum`, `mean`, `avg` reduce operation."
)

View File

@ -529,7 +529,7 @@ class _ResultCollection(dict):
result_metric.meta.sync.should = should
cache = result_metric._computed
if cache is not None:
if not isinstance(cache, torch.Tensor):
if not isinstance(cache, Tensor):
raise ValueError(
f"The `.compute()` return of the metric logged as {result_metric.meta.name!r} must be a tensor."
f" Found {cache}"

View File

@ -2705,7 +2705,9 @@ class Trainer(
self._loggers = loggers if loggers else []
@property
def callback_metrics(self) -> dict:
def callback_metrics(self) -> Dict[str, Tensor]:
# TODO: the true typing return can include dictionaries as defined in
# `pytorch_lightning.trainer.connectors.logger_connector.result._OUT_DICT`
return self._logger_connector.callback_metrics
@property

View File

@ -99,7 +99,7 @@ def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tens
return gathered_result
def _simple_gather_all_tensors(result: torch.Tensor, group: Any, world_size: int) -> List[torch.Tensor]:
def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]:
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
torch.distributed.all_gather(gathered_result, result, group)
return gathered_result