lightning/pytorch_lightning/trainer/connectors/checkpoint_connector.py

413 lines
17 KiB
Python
Raw Normal View History

2020-09-12 11:05:21 +00:00
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
from pathlib import Path
from typing import Optional, Union
2020-09-12 11:05:21 +00:00
import torch
import pytorch_lightning
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import (
_APEX_AVAILABLE,
_OMEGACONF_AVAILABLE,
AMPType,
DeviceType,
rank_zero_info,
rank_zero_warn,
)
2020-09-12 11:05:21 +00:00
from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.upgrade_checkpoint import KEYS_MAPPING as DEPRECATED_CHECKPOINT_KEYS
2020-09-12 11:05:21 +00:00
if _APEX_AVAILABLE:
2020-09-12 11:05:21 +00:00
from apex import amp
if _OMEGACONF_AVAILABLE:
2020-09-12 11:05:21 +00:00
from omegaconf import Container
class CheckpointConnector:
def __init__(self, trainer):
self.trainer = trainer
# used to validate checkpointing logic
self.has_trained = False
def restore_weights(self) -> None:
2020-09-12 11:05:21 +00:00
"""
Attempt to restore a checkpoint (e.g. weights) in this priority:
1. from HPC weights
2. from `resume_from_checkpoint` file
3. don't restore
2020-09-12 11:05:21 +00:00
"""
# clear cache before restore
if self.trainer._device_type == DeviceType.GPU:
2020-09-12 11:05:21 +00:00
torch.cuda.empty_cache()
# 1. Attempt to restore states from HPC checkpoint
dir_path_hpc = str(self.trainer.weights_save_path)
max_suffix = self.max_ckpt_in_folder(dir_path_hpc, "hpc_ckpt_")
if max_suffix is not None:
checkpoint_path = f'{dir_path_hpc}/hpc_ckpt_{max_suffix}.ckpt'
self.hpc_load(checkpoint_path, self.trainer._device_type == DeviceType.GPU)
rank_zero_info(f'restored hpc model from: {checkpoint_path}')
2020-09-12 11:05:21 +00:00
# 2. Attempt to restore states from `resume_from_checkpoint` file
elif self.trainer.resume_from_checkpoint is not None:
self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer._device_type == DeviceType.GPU)
2020-09-12 11:05:21 +00:00
# wait for all to catch up
PoC: Accelerator refactor (#5743) * restoring the result from subprocess * fix queue.get() order for results * add missing "block_backward_sync" context manager * add missing "block_backward_sync" context manager * fix sync_batchnorm * fix supported gpu-ids for tuple * fix clip gradients and inf recursion * accelerator selection: added cluster_environment plugin * fix torchelastic test * fix reduce early stopping decision for DDP * fix tests: callbacks, conversion to lightning optimizer * fix lightning optimizer does not pickle * fix setting benchmark and deterministic option * fix slurm amp test * fix prepare_data test and determine node_rank * fix retrieving last path when testing * remove obsolete plugin argument * fix test: test_trainer_config * fix torchscript tests * fix trainer.model access * move properties * fix test_transfer_batch_hook * fix auto_select_gpus * fix omegaconf test * fix test that needs to simulate slurm ddp * add horovod plugin * fix test with named arguments * clean up whitespace * fix datamodules test * remove old accelerators * fix naming * move old plugins * move to plugins * create precision subpackage * create training_type subpackage * fix all new import errors * fix wrong arguments order passed to test * fix LR finder * Added sharded training type and amp plugin * Move clip grad to precision plugin * Added sharded spawn, select accelerators based on distributed_backend + enable custom fp16 plugin automatically * Fix import issue, attempting to fix tests * Fix initial test * Reflect hook logic from master, should wrap model after move to device * Optional state consolidation, since master has optimizers not wrapped * change attribute for instance test * reset optimizers optimizers are not used in main process, so state would be wrong. * legacy * imports in accel * legacy2 * trainer imports * fix import errors after rebase * move hook to new setup location * provide unwrapping logic * fix trainer callback system * added ddp2 implementation * fix imports .legacy * move plugins * restore legacy * drop test.py from root * add tpu accelerator and plugins * fixes * fix lightning optimizer merge * reset bugreportmodel * unwrapping * step routing forward * model access * unwrap * opt * integrate distrib_type * sync changes * sync * fixes * add forgotten generators * add missing logic * update * import * missed imports * import fixes * isort * mv f * changelog * format * move helper to parallel plugin * d * add world size * clean up * duplicate * activate ddp_sharded and tpu * set nvidia flags * remove unused colab var * use_tpu <-> on_tpu attrs * make some ddp_cpu and clusterplugin tests pass * Ref/accelerator connector (#5742) * final cleanup Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * connector cleanup Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * trainer cleanup Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * accelerator cleanup + missing logic in accelerator connector Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * add missing changes to callbacks Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * reflect accelerator changes to lightning module Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * clean cluster envs Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * cleanup plugins Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * add broadcasting Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * yapf * remove plugin connector Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * plugins * manual optimization * update optimizer routing * add rank to torchelastic * fix memory mixed precision * setstate on trainer for pickling in ddp spawn * add predict method * add back commented accelerator code * adapt test for sync_batch_norm to new plugin * fix deprecated tests * fix ddp cpu choice when no num_processes are given * yapf format * skip a memory test that cannot pass anymore * fix pickle error in spawn plugin * x * avoid * x * fix cyclic import in docs build * add support for sharded * update typing * add sharded and sharded_spawn to distributed types * make unwrap model default * refactor LightningShardedDataParallel similar to LightningDistributedDataParallel * update sharded spawn to reflect changes * update sharded to reflect changes * Merge 1.1.5 changes * fix merge * fix merge * yapf isort * fix merge * yapf isort * fix indentation in test * copy over reinit scheduler implementation from dev1.2 * fix apex tracking calls with dev_debugger * reduce diff to dev1.2, clean up * fix trainer config test when gpus>0 and num_processes >0 and ddp_cpu * sort plugin tests legacy/new * fix error handling for amp on cpu * fix merge fix merge fix merge * [Feat] Resolve manual_backward (#5837) * resolve manual_backward * resolve flake8 * update * resolve for ddp_spawn * resolve flake8 * resolve flake8 * resolve flake8 Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal> * fix tests/accelerator tests on cpu * [BugFix] Resolve manual optimization (#5852) * resolve manual_optimization * update * update Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal> * Remove copy trainer parameters to happen earlier within the loop and add safe guard to get ref model (#5856) * resovle a bug * Accelerator refactor sharded rpc (#5854) * rpc branch * merge * update handling of rpc * make devices etc. Optional in RPC * set devices etc. later if necessary * remove devices from sequential * make devices optional in rpc * fix import * uncomment everything * fix cluster selection Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal> * resolve bug * fix assert in rpc test * resolve a test * fix docs compilation * accelerator refactor - fix for sharded parity test (#5866) * fix memory issue with ddp_spawn * x x x x x x x x x * x * Remove DDP2 as this does not apply * Add missing pre optimizer hook to ensure lambda closure is called * fix apex docstring * [accelerator][BugFix] Resolve some test for 1 gpu (#5863) * update * revert init * resolve a bug * update * resolve flake8 * update * update * update * revert init * resolve a bug * update * resolve flake8 * update * update * update * update * update * revert init * resolve a bug * update * resolve flake8 * update * update * update * revert init * update * resolve flake8 * update * update * update * update * update * all_gather * update * make plugins work, add misconfig for RPC * update * update * remove breaking test * resolve some tests * resolve flake8 * revert to ddp_spawn Co-authored-by: root <root@ip-172-31-88-60.ec2.internal> Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal> Co-authored-by: Justus Schock <justus.schock@rwth-aachen.de> * yapf isort * resolve flake8 * fix apex doctests * fix apex doctests 2 * resolve docs * update drone * clean env * update * update * update * update * merge * Fix RPC related tests, clean out old API, update for new accelerator API [skip ci] (#5881) * Fix RPC related tests, clean out old API, update for new accelerator API * Move tests out of legacy folder, update paths and names * Update test_remove_1-4.py * Expose properties for tpu cores/gpus/num_gpus * Add root GPU property * Move properties to properties.py * move tests that were previously in drone * Fix root GPU property (#5908) * Move root GPU to property, remove horovod set as this is handled in horovod plugin, ensure we mock correctly to set GPU accelerator * Add missing tests back * fix best model path transfer when no checkpoint callback available * Fix setup hook order [wip] (#5858) * Call trainer setup hook before accelerator setup * Add test case * add new test * typo * fix callback order in test Co-authored-by: tchaton <thomas@grid.ai> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * rename ddp sequential -> rpc sequential for special test * revert * fix stupid merge problem * Use property in connector for sampler (#5913) * merge the import conflicts * fix spawning of processes in slurm * [wip] Fix some bugs for TPU [skip ci] (#5878) * fixed for single tpu * fixed spawn * fixed spawn * update * update * wip * resolve bugs * resolve bug * update on comment * removed decorator * resolve comments * set to 4 * update * update * need cleaning * update * update * update * resolve flake8 * resolve bugs * exclude broadcast * resolve bugs * change test * update * update * skip if meet fails * properly raise trace * update * add catch * wrap test * resolve typo * update * typo Co-authored-by: Lezwon Castelino <lezwon@gmail.com> Co-authored-by: Your Name <you@example.com> * resolve some tests * update * fix imports * update * resolve flake8 * update azure pipeline * skip a sharded test on cpu that requires a gpu * resolve tpus * resolve bug * resolve flake8 * update * updat utils * revert permission change on files * suggestions from carlos Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * remove unrelated formatting changes * remove incomplete comment * Update pytorch_lightning/accelerators/__init__.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * remove unrelated formatting change * add types * warn 1.7 ddp manual backward only if ddp kwarg unset * yapf + isort * pep8 unused imports * fix cyclic import in docs * Apply suggestions from code review * typer in accelerator.py * typo * Apply suggestions from code review * formatting * update on comments * update typo * Update pytorch_lightning/trainer/properties.py Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * update * suggestion from code review * suggestion from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: SeanNaren <sean@grid.ai> Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal> Co-authored-by: Sean Naren <sean.narenthiran@gmail.com> Co-authored-by: root <root@ip-172-31-88-60.ec2.internal> Co-authored-by: Lezwon Castelino <lezwon@gmail.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
2021-02-12 20:48:56 +00:00
self.trainer.training_type_plugin.barrier('TrainerIOMixin.restore_weights')
2020-09-12 11:05:21 +00:00
# clear cache after restore
if self.trainer._device_type == DeviceType.GPU:
2020-09-12 11:05:21 +00:00
torch.cuda.empty_cache()
Add non-existing resume_from_checkpoint acceptance for auto-resubmit (#4402) * Add empty resume_from_checkpoint acceptance #4366 * Fix general error catch with focused file check * Add fsspec HTTP extras Add fsspec's HTTPFileSystem support through http extras. pl has supported remote http file (e.g. #2925), so this commit do not add new functionality. * Fix potential too much logging in DDP * Add PR changelog * Add well-written argument explanation Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Fix DDP-compatible restore logging Notify from where the states are restored. This feature temporally deleted as a result of PR review. With succeeding review, added with DDP compatibility. * Fix utility import pathes * Refactor load step commentaries * Refactor hpc ckpt suffix acquisition * Refactor restore/hpc_load match * Refactor hpc load trial * Refactor checkpoint dir check * Refactor unneeded function nest * Refactor nested If * Refactor duplicated cache clear * Refactor attempt flow with if/elif * Fix pip8 * Refactor hook commentary Co-authored-by: chaton <thomas@grid.ai> * Fix pep8 * Refactor hpc load checkpoint path acquisition * Fix pip8 * Fix typo Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Fix typo Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Fix doc Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Refactor None Union type with Optional * Fix build-doc CI failure debuged in #5329 * Fix fsspec import during build-doc #5329 * Fix test epoch Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Fix test with latest test models * . Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Sean Naren <sean.narenthiran@gmail.com> Co-authored-by: Roger Shieh <sh.rog@protonmail.ch> (cherry picked from commit b0051e8c036fa3312ad4d37aa7141bea64ac6148)
2021-01-05 00:52:35 +00:00
def restore(self, checkpoint_path: str, on_gpu: bool) -> bool:
2020-09-12 11:05:21 +00:00
"""
Load model/training states from a 'PyTorch-Lightning checkpoint' file through file-read and state-restore.
All restored states are listed in return value description of `dump_checkpoint`.
2020-09-12 11:05:21 +00:00
"""
Add non-existing resume_from_checkpoint acceptance for auto-resubmit (#4402) * Add empty resume_from_checkpoint acceptance #4366 * Fix general error catch with focused file check * Add fsspec HTTP extras Add fsspec's HTTPFileSystem support through http extras. pl has supported remote http file (e.g. #2925), so this commit do not add new functionality. * Fix potential too much logging in DDP * Add PR changelog * Add well-written argument explanation Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Fix DDP-compatible restore logging Notify from where the states are restored. This feature temporally deleted as a result of PR review. With succeeding review, added with DDP compatibility. * Fix utility import pathes * Refactor load step commentaries * Refactor hpc ckpt suffix acquisition * Refactor restore/hpc_load match * Refactor hpc load trial * Refactor checkpoint dir check * Refactor unneeded function nest * Refactor nested If * Refactor duplicated cache clear * Refactor attempt flow with if/elif * Fix pip8 * Refactor hook commentary Co-authored-by: chaton <thomas@grid.ai> * Fix pep8 * Refactor hpc load checkpoint path acquisition * Fix pip8 * Fix typo Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Fix typo Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Fix doc Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Refactor None Union type with Optional * Fix build-doc CI failure debuged in #5329 * Fix fsspec import during build-doc #5329 * Fix test epoch Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Fix test with latest test models * . Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Sean Naren <sean.narenthiran@gmail.com> Co-authored-by: Roger Shieh <sh.rog@protonmail.ch> (cherry picked from commit b0051e8c036fa3312ad4d37aa7141bea64ac6148)
2021-01-05 00:52:35 +00:00
# Try to read the checkpoint file at `checkpoint_path`. If not exist, do not restore checkpoint.
fs = get_filesystem(checkpoint_path)
if not fs.exists(checkpoint_path):
rank_zero_warn("No checkpoint file exists at `resume_from_checkpoint`. Start from scratch")
return False
2020-09-12 11:05:21 +00:00
# read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path`
2020-09-12 11:05:21 +00:00
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
# acquire the model
model = self.trainer.lightning_module
2020-09-12 11:05:21 +00:00
# restore model and datamodule state
self.restore_model_state(model, checkpoint)
if on_gpu:
model.cuda(self.trainer.root_gpu)
# restore training state
self.restore_training_state(checkpoint)
Add non-existing resume_from_checkpoint acceptance for auto-resubmit (#4402) * Add empty resume_from_checkpoint acceptance #4366 * Fix general error catch with focused file check * Add fsspec HTTP extras Add fsspec's HTTPFileSystem support through http extras. pl has supported remote http file (e.g. #2925), so this commit do not add new functionality. * Fix potential too much logging in DDP * Add PR changelog * Add well-written argument explanation Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Fix DDP-compatible restore logging Notify from where the states are restored. This feature temporally deleted as a result of PR review. With succeeding review, added with DDP compatibility. * Fix utility import pathes * Refactor load step commentaries * Refactor hpc ckpt suffix acquisition * Refactor restore/hpc_load match * Refactor hpc load trial * Refactor checkpoint dir check * Refactor unneeded function nest * Refactor nested If * Refactor duplicated cache clear * Refactor attempt flow with if/elif * Fix pip8 * Refactor hook commentary Co-authored-by: chaton <thomas@grid.ai> * Fix pep8 * Refactor hpc load checkpoint path acquisition * Fix pip8 * Fix typo Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Fix typo Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Fix doc Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Refactor None Union type with Optional * Fix build-doc CI failure debuged in #5329 * Fix fsspec import during build-doc #5329 * Fix test epoch Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Fix test with latest test models * . Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Sean Naren <sean.narenthiran@gmail.com> Co-authored-by: Roger Shieh <sh.rog@protonmail.ch> (cherry picked from commit b0051e8c036fa3312ad4d37aa7141bea64ac6148)
2021-01-05 00:52:35 +00:00
rank_zero_info(f"Restored states from the checkpoint file at {checkpoint_path}")
return True
def restore_model_state(self, model: LightningModule, checkpoint) -> None:
"""
Restore model states from a 'PyTorch-Lightning checkpoint' dictionary object
"""
# restore datamodule states
if self.trainer.datamodule is not None:
self.trainer.datamodule.on_load_checkpoint(checkpoint)
# hook: give user access to checkpoint if needed.
2020-09-12 11:05:21 +00:00
model.on_load_checkpoint(checkpoint)
# restore model state_dict
model.load_state_dict(checkpoint['state_dict'])
2020-09-12 11:05:21 +00:00
def restore_training_state(self, checkpoint):
"""
Restore trainer state.
Model will get its change to update
:param checkpoint:
:return:
"""
# validation
2020-09-12 11:05:21 +00:00
if 'optimizer_states' not in checkpoint or 'lr_schedulers' not in checkpoint:
raise KeyError(
'Trying to restore training state but checkpoint contains only the model.'
' This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`.'
)
if any([key in checkpoint for key in DEPRECATED_CHECKPOINT_KEYS]):
raise ValueError(
"The checkpoint you're attempting to load follows an"
" outdated schema. You can upgrade to the current schema by running"
" `python -m pytorch_lightning.utilities.upgrade_checkpoint --file model.ckpt`"
" where `model.ckpt` is your checkpoint file."
)
# restore amp scaling
if self.trainer.amp_backend == AMPType.NATIVE and 'native_amp_scaling_state' in checkpoint:
self.trainer.scaler.load_state_dict(checkpoint['native_amp_scaling_state'])
elif self.trainer.amp_backend == AMPType.APEX and 'amp_scaling_state' in checkpoint:
amp.load_state_dict(checkpoint['amp_scaling_state'])
# restore callback states
2020-09-12 11:05:21 +00:00
self.trainer.on_load_checkpoint(checkpoint)
self.trainer.global_step = checkpoint['global_step']
self.trainer.current_epoch = checkpoint['epoch']
# crash if max_epochs is lower then the current epoch from the checkpoint
if self.trainer.max_epochs is not None and self.trainer.current_epoch > self.trainer.max_epochs:
m = f"""
you restored a checkpoint with current_epoch={self.trainer.current_epoch}
but the Trainer(max_epochs={self.trainer.max_epochs})
"""
raise MisconfigurationException(m)
2020-09-12 11:05:21 +00:00
# Division deals with global step stepping once per accumulated batch
# Inequality deals with different global step for odd vs even num_training_batches
n_accum = 1 if self.trainer.accumulate_grad_batches is None else self.trainer.accumulate_grad_batches
expected_steps = self.trainer.num_training_batches / n_accum
if self.trainer.num_training_batches != 0 and self.trainer.global_step % expected_steps > 1:
rank_zero_warn(
"You're resuming from a checkpoint that ended mid-epoch."
" Training will start from the beginning of the next epoch."
" This can cause unreliable results if further training is done,"
" consider using an end of epoch checkpoint."
2020-09-12 11:05:21 +00:00
)
# restore the optimizers
optimizer_states = checkpoint['optimizer_states']
for optimizer, opt_state in zip(self.trainer.optimizers, optimizer_states):
optimizer.load_state_dict(opt_state)
# move optimizer to GPU 1 weight at a time
# avoids OOM
if self.trainer.root_gpu is not None:
for state in optimizer.state.values():
for k, v in state.items():
if isinstance(v, torch.Tensor):
state[k] = v.cuda(self.trainer.root_gpu)
# restore the lr schedulers
lr_schedulers = checkpoint['lr_schedulers']
for scheduler, lrs_state in zip(self.trainer.lr_schedulers, lr_schedulers):
scheduler['scheduler'].load_state_dict(lrs_state)
# ----------------------------------
# PRIVATE OPS
# ----------------------------------
def hpc_save(self, folderpath: str, logger):
# make sure the checkpoint folder exists
folderpath = str(folderpath) # because the tests pass a path object
fs = get_filesystem(folderpath)
fs.makedirs(folderpath, exist_ok=True)
# save logger to make sure we get all the metrics
logger.save()
max_suffix = self.max_ckpt_in_folder(folderpath)
ckpt_number = (max_suffix if max_suffix is not None else 0) + 1
2020-09-12 11:05:21 +00:00
fs.makedirs(folderpath, exist_ok=True)
filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt')
# give model a chance to do something on hpc_save
model = self.trainer.lightning_module
2020-09-12 11:05:21 +00:00
checkpoint = self.dump_checkpoint()
model.on_hpc_save(checkpoint)
checkpoint = self.trainer.accelerator.on_save(checkpoint)
2020-09-12 11:05:21 +00:00
# do the actual save
# TODO: fix for anything with multiprocess DP, DDP, DDP2
try:
atomic_save(checkpoint, filepath)
except AttributeError as err:
if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
rank_zero_warn(
'warning, `hyper_parameters` dropped from checkpoint.'
f' An attribute is not picklable {err}'
2020-09-12 11:05:21 +00:00
)
atomic_save(checkpoint, filepath)
return filepath
def dump_checkpoint(self, weights_only: bool = False) -> dict:
"""Creating a model checkpoint dictionary object from various component states.
2020-09-12 11:05:21 +00:00
Args:
weights_only: saving model weights only
Return:
structured dictionary: {
'epoch': training epoch
'global_step': training global step
'pytorch-lightning_version': PyTorch Lightning's version
'callbacks': "callback specific state"[] # if not weights_only
'optimizer_states': "PT optim's state_dict"[] # if not weights_only
'lr_schedulers': "PT sched's state_dict"[] # if not weights_only
'native_amp_scaling_state': PT amp's state_dict # if not weights_only and use native amp
'amp_scaling_state': Apex's state_dict # if not weights_only and use apex amp
'state_dict': Model's state_dict (e.g. network weights)
CHECKPOINT_HYPER_PARAMS_NAME:
CHECKPOINT_HYPER_PARAMS_KEY:
CHECKPOINT_HYPER_PARAMS_TYPE:
something_cool_i_want_to_save: anything you define through model.on_save_checkpoint
LightningDataModule.__class__.__name__: pl DataModule's state
}
2020-09-12 11:05:21 +00:00
"""
# dump epoch/global_step/pytorch-lightning_version
current_epoch = self.trainer.current_epoch
global_step = self.trainer.global_step
has_reached_max_steps = self.trainer.max_steps and self.trainer.max_steps <= global_step
global_step += 1
if not has_reached_max_steps:
current_epoch += 1
model = self.trainer.lightning_module
2020-09-12 11:05:21 +00:00
checkpoint = {
'epoch': current_epoch,
'global_step': global_step,
2020-09-12 11:05:21 +00:00
'pytorch-lightning_version': pytorch_lightning.__version__,
'state_dict': model.state_dict(),
2020-09-12 11:05:21 +00:00
}
if not weights_only:
# dump callbacks
checkpoint['callbacks'] = self.trainer.on_save_checkpoint(checkpoint)
2020-09-12 11:05:21 +00:00
optimizer_states = []
for i, optimizer in enumerate(self.trainer.optimizers):
# Rely on accelerator to dump optimizer state
optimizer_state = self.trainer.accelerator.optimizer_state(optimizer)
optimizer_states.append(optimizer_state)
2020-09-12 11:05:21 +00:00
checkpoint['optimizer_states'] = optimizer_states
# dump lr schedulers
2020-09-12 11:05:21 +00:00
lr_schedulers = []
for scheduler in self.trainer.lr_schedulers:
lr_schedulers.append(scheduler['scheduler'].state_dict())
checkpoint['lr_schedulers'] = lr_schedulers
# dump amp scaling
if (
self.trainer.amp_backend == AMPType.NATIVE and self.trainer._device_type != DeviceType.TPU
and self.trainer.scaler is not None
):
2020-09-12 11:05:21 +00:00
checkpoint['native_amp_scaling_state'] = self.trainer.scaler.state_dict()
elif self.trainer.amp_backend == AMPType.APEX:
checkpoint['amp_scaling_state'] = amp.state_dict()
2020-09-12 11:05:21 +00:00
# dump hyper-parameters
2020-09-12 11:05:21 +00:00
if model.hparams:
if hasattr(model, '_hparams_name'):
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name
# dump arguments
if _OMEGACONF_AVAILABLE and isinstance(model.hparams, Container):
2020-09-12 11:05:21 +00:00
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams)
2020-09-12 11:05:21 +00:00
else:
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams)
# give the model a chance to dump a few things
2020-09-12 11:05:21 +00:00
model.on_save_checkpoint(checkpoint)
if self.trainer.datamodule is not None:
self.trainer.datamodule.on_save_checkpoint(checkpoint)
2020-09-12 11:05:21 +00:00
return checkpoint
def hpc_load(self, checkpoint_path: str, on_gpu: bool):
"""
Load model/training states from a 'PyTorch-Lightning checkpoint' file for hpc.
All restored states are listed in return value description of `dump_checkpoint`.
"""
2020-09-12 11:05:21 +00:00
# read a checkpoint dictionary object from the 'PyTorch-Lightning checkpoint' file at `checkpoint_path`
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
2020-09-12 11:05:21 +00:00
# acquire the model
model = self.trainer.lightning_module
2020-09-12 11:05:21 +00:00
# restore model and datamodule state
self.restore_model_state(model, checkpoint)
2020-09-12 11:05:21 +00:00
if self.trainer.root_gpu is not None:
model.cuda(self.trainer.root_gpu)
# restore training state
2020-09-12 11:05:21 +00:00
self.restore_training_state(checkpoint)
# call hpc specific hook
2020-09-12 11:05:21 +00:00
model.on_hpc_load(checkpoint)
def max_ckpt_in_folder(self, dir_path: Union[str, Path], name_key: str = 'ckpt_') -> Optional[int]:
"""List up files in `dir_path` with `name_key`, then yield maximum suffix number.
Args:
dir_path: path of directory which may contain files whose name include `name_key`
name_key: file name prefix
Returns:
None if no-corresponding-file else maximum suffix number
"""
# check directory existence
fs = get_filesystem(dir_path)
if not fs.exists(dir_path):
return None
2020-09-12 11:05:21 +00:00
# check corresponding file existence
files = [os.path.basename(f["name"]) for f in fs.listdir(dir_path)]
2020-09-12 11:05:21 +00:00
files = [x for x in files if name_key in x]
if len(files) == 0:
return None
2020-09-12 11:05:21 +00:00
# extract suffix number
2020-09-12 11:05:21 +00:00
ckpt_vs = []
for name in files:
name = name.split(name_key)[-1]
name = re.sub('[^0-9]', '', name)
ckpt_vs.append(int(name))
return max(ckpt_vs)
def get_max_ckpt_path_from_folder(self, folder_path: Union[str, Path]) -> str:
"""Get path of maximum-epoch checkpoint in the folder."""
max_suffix = self.max_ckpt_in_folder(folder_path)
ckpt_number = max_suffix if max_suffix is not None else 0
return f'{folder_path}/hpc_ckpt_{ckpt_number}.ckpt'
2020-09-12 11:05:21 +00:00
def save_checkpoint(self, filepath, weights_only: bool = False):
"""Save model/training states as a checkpoint file through state-dump and file-write.
Args:
filepath: write-target file's path
weights_only: saving model weights only
"""
# dump states as a checkpoint dictionary object
2020-09-12 11:05:21 +00:00
checkpoint = self.dump_checkpoint(weights_only)
if self.trainer.is_global_zero:
# write the checkpoint dictionary on the file
PoC: Accelerator refactor (#5743) * restoring the result from subprocess * fix queue.get() order for results * add missing "block_backward_sync" context manager * add missing "block_backward_sync" context manager * fix sync_batchnorm * fix supported gpu-ids for tuple * fix clip gradients and inf recursion * accelerator selection: added cluster_environment plugin * fix torchelastic test * fix reduce early stopping decision for DDP * fix tests: callbacks, conversion to lightning optimizer * fix lightning optimizer does not pickle * fix setting benchmark and deterministic option * fix slurm amp test * fix prepare_data test and determine node_rank * fix retrieving last path when testing * remove obsolete plugin argument * fix test: test_trainer_config * fix torchscript tests * fix trainer.model access * move properties * fix test_transfer_batch_hook * fix auto_select_gpus * fix omegaconf test * fix test that needs to simulate slurm ddp * add horovod plugin * fix test with named arguments * clean up whitespace * fix datamodules test * remove old accelerators * fix naming * move old plugins * move to plugins * create precision subpackage * create training_type subpackage * fix all new import errors * fix wrong arguments order passed to test * fix LR finder * Added sharded training type and amp plugin * Move clip grad to precision plugin * Added sharded spawn, select accelerators based on distributed_backend + enable custom fp16 plugin automatically * Fix import issue, attempting to fix tests * Fix initial test * Reflect hook logic from master, should wrap model after move to device * Optional state consolidation, since master has optimizers not wrapped * change attribute for instance test * reset optimizers optimizers are not used in main process, so state would be wrong. * legacy * imports in accel * legacy2 * trainer imports * fix import errors after rebase * move hook to new setup location * provide unwrapping logic * fix trainer callback system * added ddp2 implementation * fix imports .legacy * move plugins * restore legacy * drop test.py from root * add tpu accelerator and plugins * fixes * fix lightning optimizer merge * reset bugreportmodel * unwrapping * step routing forward * model access * unwrap * opt * integrate distrib_type * sync changes * sync * fixes * add forgotten generators * add missing logic * update * import * missed imports * import fixes * isort * mv f * changelog * format * move helper to parallel plugin * d * add world size * clean up * duplicate * activate ddp_sharded and tpu * set nvidia flags * remove unused colab var * use_tpu <-> on_tpu attrs * make some ddp_cpu and clusterplugin tests pass * Ref/accelerator connector (#5742) * final cleanup Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * connector cleanup Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * trainer cleanup Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * accelerator cleanup + missing logic in accelerator connector Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * add missing changes to callbacks Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * reflect accelerator changes to lightning module Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * clean cluster envs Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * cleanup plugins Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * add broadcasting Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * yapf * remove plugin connector Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * plugins * manual optimization * update optimizer routing * add rank to torchelastic * fix memory mixed precision * setstate on trainer for pickling in ddp spawn * add predict method * add back commented accelerator code * adapt test for sync_batch_norm to new plugin * fix deprecated tests * fix ddp cpu choice when no num_processes are given * yapf format * skip a memory test that cannot pass anymore * fix pickle error in spawn plugin * x * avoid * x * fix cyclic import in docs build * add support for sharded * update typing * add sharded and sharded_spawn to distributed types * make unwrap model default * refactor LightningShardedDataParallel similar to LightningDistributedDataParallel * update sharded spawn to reflect changes * update sharded to reflect changes * Merge 1.1.5 changes * fix merge * fix merge * yapf isort * fix merge * yapf isort * fix indentation in test * copy over reinit scheduler implementation from dev1.2 * fix apex tracking calls with dev_debugger * reduce diff to dev1.2, clean up * fix trainer config test when gpus>0 and num_processes >0 and ddp_cpu * sort plugin tests legacy/new * fix error handling for amp on cpu * fix merge fix merge fix merge * [Feat] Resolve manual_backward (#5837) * resolve manual_backward * resolve flake8 * update * resolve for ddp_spawn * resolve flake8 * resolve flake8 * resolve flake8 Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal> * fix tests/accelerator tests on cpu * [BugFix] Resolve manual optimization (#5852) * resolve manual_optimization * update * update Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal> * Remove copy trainer parameters to happen earlier within the loop and add safe guard to get ref model (#5856) * resovle a bug * Accelerator refactor sharded rpc (#5854) * rpc branch * merge * update handling of rpc * make devices etc. Optional in RPC * set devices etc. later if necessary * remove devices from sequential * make devices optional in rpc * fix import * uncomment everything * fix cluster selection Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal> * resolve bug * fix assert in rpc test * resolve a test * fix docs compilation * accelerator refactor - fix for sharded parity test (#5866) * fix memory issue with ddp_spawn * x x x x x x x x x * x * Remove DDP2 as this does not apply * Add missing pre optimizer hook to ensure lambda closure is called * fix apex docstring * [accelerator][BugFix] Resolve some test for 1 gpu (#5863) * update * revert init * resolve a bug * update * resolve flake8 * update * update * update * revert init * resolve a bug * update * resolve flake8 * update * update * update * update * update * revert init * resolve a bug * update * resolve flake8 * update * update * update * revert init * update * resolve flake8 * update * update * update * update * update * all_gather * update * make plugins work, add misconfig for RPC * update * update * remove breaking test * resolve some tests * resolve flake8 * revert to ddp_spawn Co-authored-by: root <root@ip-172-31-88-60.ec2.internal> Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal> Co-authored-by: Justus Schock <justus.schock@rwth-aachen.de> * yapf isort * resolve flake8 * fix apex doctests * fix apex doctests 2 * resolve docs * update drone * clean env * update * update * update * update * merge * Fix RPC related tests, clean out old API, update for new accelerator API [skip ci] (#5881) * Fix RPC related tests, clean out old API, update for new accelerator API * Move tests out of legacy folder, update paths and names * Update test_remove_1-4.py * Expose properties for tpu cores/gpus/num_gpus * Add root GPU property * Move properties to properties.py * move tests that were previously in drone * Fix root GPU property (#5908) * Move root GPU to property, remove horovod set as this is handled in horovod plugin, ensure we mock correctly to set GPU accelerator * Add missing tests back * fix best model path transfer when no checkpoint callback available * Fix setup hook order [wip] (#5858) * Call trainer setup hook before accelerator setup * Add test case * add new test * typo * fix callback order in test Co-authored-by: tchaton <thomas@grid.ai> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * rename ddp sequential -> rpc sequential for special test * revert * fix stupid merge problem * Use property in connector for sampler (#5913) * merge the import conflicts * fix spawning of processes in slurm * [wip] Fix some bugs for TPU [skip ci] (#5878) * fixed for single tpu * fixed spawn * fixed spawn * update * update * wip * resolve bugs * resolve bug * update on comment * removed decorator * resolve comments * set to 4 * update * update * need cleaning * update * update * update * resolve flake8 * resolve bugs * exclude broadcast * resolve bugs * change test * update * update * skip if meet fails * properly raise trace * update * add catch * wrap test * resolve typo * update * typo Co-authored-by: Lezwon Castelino <lezwon@gmail.com> Co-authored-by: Your Name <you@example.com> * resolve some tests * update * fix imports * update * resolve flake8 * update azure pipeline * skip a sharded test on cpu that requires a gpu * resolve tpus * resolve bug * resolve flake8 * update * updat utils * revert permission change on files * suggestions from carlos Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * remove unrelated formatting changes * remove incomplete comment * Update pytorch_lightning/accelerators/__init__.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> * remove unrelated formatting change * add types * warn 1.7 ddp manual backward only if ddp kwarg unset * yapf + isort * pep8 unused imports * fix cyclic import in docs * Apply suggestions from code review * typer in accelerator.py * typo * Apply suggestions from code review * formatting * update on comments * update typo * Update pytorch_lightning/trainer/properties.py Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * update * suggestion from code review * suggestion from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: SeanNaren <sean@grid.ai> Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz> Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: Ubuntu <ubuntu@ip-172-31-88-60.ec2.internal> Co-authored-by: Sean Naren <sean.narenthiran@gmail.com> Co-authored-by: root <root@ip-172-31-88-60.ec2.internal> Co-authored-by: Lezwon Castelino <lezwon@gmail.com> Co-authored-by: Your Name <you@example.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
2021-02-12 20:48:56 +00:00
if self.trainer.training_type_plugin:
checkpoint = self.trainer.training_type_plugin.on_save(checkpoint)
2020-09-12 11:05:21 +00:00
try:
atomic_save(checkpoint, filepath)
except AttributeError as err:
if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
rank_zero_warn(
'Warning, `hyper_parameters` dropped from checkpoint.'
f' An attribute is not picklable {err}'
2020-09-12 11:05:21 +00:00
)
atomic_save(checkpoint, filepath)