ref: checkpoint connector methods 4/n (#3474)
* ref: checkpoint connector methods 4/n * ref: checkpoint connector methods 4/n * ref: checkpoint connector methods 4/n * ref: checkpoint connector methods 4/n * ref: checkpoint connector methods 4/n * ref: checkpoint connector methods 4/n * ref: checkpoint connector methods 4/n * ref: checkpoint connector methods 4/n * ref: checkpoint connector methods 4/n
This commit is contained in:
parent
4724cdf5e0
commit
cd16aa9854
|
@ -32,6 +32,7 @@
|
|||
"pytorch_lightning/trainer/distrib_data_parallel.py",
|
||||
"pytorch_lightning/trainer/lr_scheduler_connector.py",
|
||||
"pytorch_lightning/trainer/training_loop_temp.py",
|
||||
"pytorch_lightning/trainer/connectors/checkpoint_connector.py",
|
||||
"pytorch_lightning/tuner"
|
||||
],
|
||||
|
||||
|
|
|
@ -116,7 +116,7 @@ class CheckpointConnector:
|
|||
amp.load_state_dict(checkpoint['amp_scaling_state'])
|
||||
|
||||
# load training state (affects trainer only)
|
||||
self.trainer.restore_training_state(checkpoint)
|
||||
self.restore_training_state(checkpoint)
|
||||
|
||||
def restore_training_state(self, checkpoint):
|
||||
"""
|
||||
|
@ -187,7 +187,7 @@ class CheckpointConnector:
|
|||
|
||||
# if hpc weights exist restore model
|
||||
if len(hpc_weight_paths) > 0:
|
||||
self.trainer.hpc_load(folderpath, self.trainer.on_gpu)
|
||||
self.hpc_load(folderpath, self.trainer.on_gpu)
|
||||
did_restore = True
|
||||
return did_restore
|
||||
|
||||
|
@ -321,7 +321,7 @@ class CheckpointConnector:
|
|||
if self.trainer.amp_backend == AMPType.NATIVE and not self.trainer.use_tpu and self.trainer.scaler is not None:
|
||||
checkpoint['native_amp_scaling_state'] = self.trainer.scaler.state_dict()
|
||||
elif self.trainer.amp_backend == AMPType.APEX:
|
||||
checkpoint['amp_scaling_state'] = self.trainer.state_dict()
|
||||
checkpoint['amp_scaling_state'] = amp.state_dict()
|
||||
|
||||
# add the module_arguments and state_dict from the model
|
||||
model = self.trainer.get_model()
|
||||
|
|
|
@ -42,10 +42,6 @@ class TrainerTrainingTricksMixin(ABC):
|
|||
def get_model(self) -> LightningModule:
|
||||
"""Warning: this is just empty shell for code implemented in other class."""
|
||||
|
||||
@abstractmethod
|
||||
def fit(self, *args):
|
||||
"""Warning: this is just empty shell for code implemented in other class."""
|
||||
|
||||
def print_nan_gradients(self) -> None:
|
||||
model = self.get_model()
|
||||
for param in model.parameters():
|
||||
|
|
|
@ -108,7 +108,7 @@ def scale_batch_size(trainer,
|
|||
log.info(f'Finished batch size finder, will continue with full run using batch size {new_size}')
|
||||
|
||||
# Restore initial state of model
|
||||
trainer.restore(str(save_path), on_gpu=trainer.on_gpu)
|
||||
trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer.on_gpu)
|
||||
os.remove(save_path)
|
||||
|
||||
# Finish by resetting variables so trainer is ready to fit model
|
||||
|
|
|
@ -179,7 +179,7 @@ def lr_find(
|
|||
lr_finder._total_batch_idx = trainer.total_batch_idx # for debug purpose
|
||||
|
||||
# Reset model state
|
||||
trainer.restore(str(save_path), on_gpu=trainer.on_gpu)
|
||||
trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer.on_gpu)
|
||||
os.remove(save_path)
|
||||
|
||||
# Finish by resetting variables so trainer is ready to fit model
|
||||
|
|
|
@ -71,8 +71,8 @@ def run_model_test(trainer_options, model, on_gpu: bool = True, version=None, wi
|
|||
trainer.init_optimizers(pretrained_model)
|
||||
|
||||
# test HPC loading / saving
|
||||
trainer.hpc_save(save_dir, logger)
|
||||
trainer.hpc_load(save_dir, on_gpu=on_gpu)
|
||||
trainer.checkpoint_connector.hpc_save(save_dir, logger)
|
||||
trainer.checkpoint_connector.hpc_load(save_dir, on_gpu=on_gpu)
|
||||
|
||||
|
||||
def run_prediction(dataloader, trained_model, dp=False, min_acc=0.50):
|
||||
|
|
|
@ -78,8 +78,8 @@ def run_test_from_config(trainer_options):
|
|||
run_prediction(dataloader, pretrained_model)
|
||||
|
||||
# test HPC loading / saving
|
||||
trainer.hpc_save(ckpt_path, trainer.logger)
|
||||
trainer.hpc_load(ckpt_path, on_gpu=args.on_gpu)
|
||||
trainer.checkpoint_connector.hpc_save(ckpt_path, trainer.logger)
|
||||
trainer.checkpoint_connector.hpc_load(ckpt_path, on_gpu=args.on_gpu)
|
||||
|
||||
if args.on_gpu:
|
||||
trainer = Trainer(gpus=1, distributed_backend='horovod', max_epochs=1)
|
||||
|
|
|
@ -56,7 +56,7 @@ def test_cpu_slurm_save_load(tmpdir):
|
|||
|
||||
# test HPC saving
|
||||
# simulate snapshot on slurm
|
||||
saved_filepath = trainer.hpc_save(trainer.weights_save_path, logger)
|
||||
saved_filepath = trainer.checkpoint_connector.hpc_save(trainer.weights_save_path, logger)
|
||||
assert os.path.exists(saved_filepath)
|
||||
|
||||
# new logger file to get meta
|
||||
|
|
|
@ -227,7 +227,7 @@ def test_dp_resume(tmpdir):
|
|||
# HPC LOAD/SAVE
|
||||
# ---------------------------
|
||||
# save
|
||||
trainer.hpc_save(tmpdir, logger)
|
||||
trainer.checkpoint_connector.hpc_save(tmpdir, logger)
|
||||
|
||||
# init new trainer
|
||||
new_logger = tutils.get_default_logger(tmpdir, version=logger.version)
|
||||
|
|
|
@ -457,7 +457,7 @@ def test_model_checkpoint_only_weights(tmpdir):
|
|||
|
||||
# assert restoring train state fails
|
||||
with pytest.raises(KeyError, match='checkpoint contains only the model'):
|
||||
trainer.restore_training_state(checkpoint)
|
||||
trainer.checkpoint_connector.restore_training_state(checkpoint)
|
||||
|
||||
|
||||
def test_model_freeze_unfreeze():
|
||||
|
|
Loading…
Reference in New Issue