ref: checkpoint connector methods 4/n (#3474)

* ref: checkpoint connector methods 4/n

* ref: checkpoint connector methods 4/n

* ref: checkpoint connector methods 4/n

* ref: checkpoint connector methods 4/n

* ref: checkpoint connector methods 4/n

* ref: checkpoint connector methods 4/n

* ref: checkpoint connector methods 4/n

* ref: checkpoint connector methods 4/n

* ref: checkpoint connector methods 4/n
This commit is contained in:
William Falcon 2020-09-12 08:42:27 -04:00 committed by GitHub
parent 4724cdf5e0
commit cd16aa9854
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 13 additions and 16 deletions

View File

@ -32,6 +32,7 @@
"pytorch_lightning/trainer/distrib_data_parallel.py",
"pytorch_lightning/trainer/lr_scheduler_connector.py",
"pytorch_lightning/trainer/training_loop_temp.py",
"pytorch_lightning/trainer/connectors/checkpoint_connector.py",
"pytorch_lightning/tuner"
],

View File

@ -116,7 +116,7 @@ class CheckpointConnector:
amp.load_state_dict(checkpoint['amp_scaling_state'])
# load training state (affects trainer only)
self.trainer.restore_training_state(checkpoint)
self.restore_training_state(checkpoint)
def restore_training_state(self, checkpoint):
"""
@ -187,7 +187,7 @@ class CheckpointConnector:
# if hpc weights exist restore model
if len(hpc_weight_paths) > 0:
self.trainer.hpc_load(folderpath, self.trainer.on_gpu)
self.hpc_load(folderpath, self.trainer.on_gpu)
did_restore = True
return did_restore
@ -321,7 +321,7 @@ class CheckpointConnector:
if self.trainer.amp_backend == AMPType.NATIVE and not self.trainer.use_tpu and self.trainer.scaler is not None:
checkpoint['native_amp_scaling_state'] = self.trainer.scaler.state_dict()
elif self.trainer.amp_backend == AMPType.APEX:
checkpoint['amp_scaling_state'] = self.trainer.state_dict()
checkpoint['amp_scaling_state'] = amp.state_dict()
# add the module_arguments and state_dict from the model
model = self.trainer.get_model()

View File

@ -42,10 +42,6 @@ class TrainerTrainingTricksMixin(ABC):
def get_model(self) -> LightningModule:
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def fit(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
def print_nan_gradients(self) -> None:
model = self.get_model()
for param in model.parameters():

View File

@ -108,7 +108,7 @@ def scale_batch_size(trainer,
log.info(f'Finished batch size finder, will continue with full run using batch size {new_size}')
# Restore initial state of model
trainer.restore(str(save_path), on_gpu=trainer.on_gpu)
trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer.on_gpu)
os.remove(save_path)
# Finish by resetting variables so trainer is ready to fit model

View File

@ -179,7 +179,7 @@ def lr_find(
lr_finder._total_batch_idx = trainer.total_batch_idx # for debug purpose
# Reset model state
trainer.restore(str(save_path), on_gpu=trainer.on_gpu)
trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer.on_gpu)
os.remove(save_path)
# Finish by resetting variables so trainer is ready to fit model

View File

@ -71,8 +71,8 @@ def run_model_test(trainer_options, model, on_gpu: bool = True, version=None, wi
trainer.init_optimizers(pretrained_model)
# test HPC loading / saving
trainer.hpc_save(save_dir, logger)
trainer.hpc_load(save_dir, on_gpu=on_gpu)
trainer.checkpoint_connector.hpc_save(save_dir, logger)
trainer.checkpoint_connector.hpc_load(save_dir, on_gpu=on_gpu)
def run_prediction(dataloader, trained_model, dp=False, min_acc=0.50):

View File

@ -78,8 +78,8 @@ def run_test_from_config(trainer_options):
run_prediction(dataloader, pretrained_model)
# test HPC loading / saving
trainer.hpc_save(ckpt_path, trainer.logger)
trainer.hpc_load(ckpt_path, on_gpu=args.on_gpu)
trainer.checkpoint_connector.hpc_save(ckpt_path, trainer.logger)
trainer.checkpoint_connector.hpc_load(ckpt_path, on_gpu=args.on_gpu)
if args.on_gpu:
trainer = Trainer(gpus=1, distributed_backend='horovod', max_epochs=1)

View File

@ -56,7 +56,7 @@ def test_cpu_slurm_save_load(tmpdir):
# test HPC saving
# simulate snapshot on slurm
saved_filepath = trainer.hpc_save(trainer.weights_save_path, logger)
saved_filepath = trainer.checkpoint_connector.hpc_save(trainer.weights_save_path, logger)
assert os.path.exists(saved_filepath)
# new logger file to get meta

View File

@ -227,7 +227,7 @@ def test_dp_resume(tmpdir):
# HPC LOAD/SAVE
# ---------------------------
# save
trainer.hpc_save(tmpdir, logger)
trainer.checkpoint_connector.hpc_save(tmpdir, logger)
# init new trainer
new_logger = tutils.get_default_logger(tmpdir, version=logger.version)

View File

@ -457,7 +457,7 @@ def test_model_checkpoint_only_weights(tmpdir):
# assert restoring train state fails
with pytest.raises(KeyError, match='checkpoint contains only the model'):
trainer.restore_training_state(checkpoint)
trainer.checkpoint_connector.restore_training_state(checkpoint)
def test_model_freeze_unfreeze():