From cd16aa9854aac6afcf8da929c8312950582b23a8 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Sat, 12 Sep 2020 08:42:27 -0400 Subject: [PATCH] ref: checkpoint connector methods 4/n (#3474) * ref: checkpoint connector methods 4/n * ref: checkpoint connector methods 4/n * ref: checkpoint connector methods 4/n * ref: checkpoint connector methods 4/n * ref: checkpoint connector methods 4/n * ref: checkpoint connector methods 4/n * ref: checkpoint connector methods 4/n * ref: checkpoint connector methods 4/n * ref: checkpoint connector methods 4/n --- .pyrightconfig.json | 1 + .../trainer/connectors/checkpoint_connector.py | 6 +++--- pytorch_lightning/trainer/training_tricks.py | 4 ---- pytorch_lightning/tuner/batch_size_scaling.py | 2 +- pytorch_lightning/tuner/lr_finder.py | 2 +- tests/base/develop_pipelines.py | 4 ++-- tests/models/data/horovod/train_default_model.py | 4 ++-- tests/models/test_cpu.py | 2 +- tests/models/test_restore.py | 2 +- tests/trainer/test_trainer.py | 2 +- 10 files changed, 13 insertions(+), 16 deletions(-) diff --git a/.pyrightconfig.json b/.pyrightconfig.json index 893987a2bc..1dfa714532 100644 --- a/.pyrightconfig.json +++ b/.pyrightconfig.json @@ -32,6 +32,7 @@ "pytorch_lightning/trainer/distrib_data_parallel.py", "pytorch_lightning/trainer/lr_scheduler_connector.py", "pytorch_lightning/trainer/training_loop_temp.py", + "pytorch_lightning/trainer/connectors/checkpoint_connector.py", "pytorch_lightning/tuner" ], diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 9f931f1e63..f883f8b080 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -116,7 +116,7 @@ class CheckpointConnector: amp.load_state_dict(checkpoint['amp_scaling_state']) # load training state (affects trainer only) - self.trainer.restore_training_state(checkpoint) + self.restore_training_state(checkpoint) def restore_training_state(self, checkpoint): """ @@ -187,7 +187,7 @@ class CheckpointConnector: # if hpc weights exist restore model if len(hpc_weight_paths) > 0: - self.trainer.hpc_load(folderpath, self.trainer.on_gpu) + self.hpc_load(folderpath, self.trainer.on_gpu) did_restore = True return did_restore @@ -321,7 +321,7 @@ class CheckpointConnector: if self.trainer.amp_backend == AMPType.NATIVE and not self.trainer.use_tpu and self.trainer.scaler is not None: checkpoint['native_amp_scaling_state'] = self.trainer.scaler.state_dict() elif self.trainer.amp_backend == AMPType.APEX: - checkpoint['amp_scaling_state'] = self.trainer.state_dict() + checkpoint['amp_scaling_state'] = amp.state_dict() # add the module_arguments and state_dict from the model model = self.trainer.get_model() diff --git a/pytorch_lightning/trainer/training_tricks.py b/pytorch_lightning/trainer/training_tricks.py index 00091c169d..3dbab4a78e 100644 --- a/pytorch_lightning/trainer/training_tricks.py +++ b/pytorch_lightning/trainer/training_tricks.py @@ -42,10 +42,6 @@ class TrainerTrainingTricksMixin(ABC): def get_model(self) -> LightningModule: """Warning: this is just empty shell for code implemented in other class.""" - @abstractmethod - def fit(self, *args): - """Warning: this is just empty shell for code implemented in other class.""" - def print_nan_gradients(self) -> None: model = self.get_model() for param in model.parameters(): diff --git a/pytorch_lightning/tuner/batch_size_scaling.py b/pytorch_lightning/tuner/batch_size_scaling.py index 06ccf77bb4..8b2e05c66b 100644 --- a/pytorch_lightning/tuner/batch_size_scaling.py +++ b/pytorch_lightning/tuner/batch_size_scaling.py @@ -108,7 +108,7 @@ def scale_batch_size(trainer, log.info(f'Finished batch size finder, will continue with full run using batch size {new_size}') # Restore initial state of model - trainer.restore(str(save_path), on_gpu=trainer.on_gpu) + trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer.on_gpu) os.remove(save_path) # Finish by resetting variables so trainer is ready to fit model diff --git a/pytorch_lightning/tuner/lr_finder.py b/pytorch_lightning/tuner/lr_finder.py index 3f843464d9..37a7d2f6b7 100644 --- a/pytorch_lightning/tuner/lr_finder.py +++ b/pytorch_lightning/tuner/lr_finder.py @@ -179,7 +179,7 @@ def lr_find( lr_finder._total_batch_idx = trainer.total_batch_idx # for debug purpose # Reset model state - trainer.restore(str(save_path), on_gpu=trainer.on_gpu) + trainer.checkpoint_connector.restore(str(save_path), on_gpu=trainer.on_gpu) os.remove(save_path) # Finish by resetting variables so trainer is ready to fit model diff --git a/tests/base/develop_pipelines.py b/tests/base/develop_pipelines.py index ba698e82c8..3d9c341ec6 100644 --- a/tests/base/develop_pipelines.py +++ b/tests/base/develop_pipelines.py @@ -71,8 +71,8 @@ def run_model_test(trainer_options, model, on_gpu: bool = True, version=None, wi trainer.init_optimizers(pretrained_model) # test HPC loading / saving - trainer.hpc_save(save_dir, logger) - trainer.hpc_load(save_dir, on_gpu=on_gpu) + trainer.checkpoint_connector.hpc_save(save_dir, logger) + trainer.checkpoint_connector.hpc_load(save_dir, on_gpu=on_gpu) def run_prediction(dataloader, trained_model, dp=False, min_acc=0.50): diff --git a/tests/models/data/horovod/train_default_model.py b/tests/models/data/horovod/train_default_model.py index 7138021e8e..ce12008a6f 100644 --- a/tests/models/data/horovod/train_default_model.py +++ b/tests/models/data/horovod/train_default_model.py @@ -78,8 +78,8 @@ def run_test_from_config(trainer_options): run_prediction(dataloader, pretrained_model) # test HPC loading / saving - trainer.hpc_save(ckpt_path, trainer.logger) - trainer.hpc_load(ckpt_path, on_gpu=args.on_gpu) + trainer.checkpoint_connector.hpc_save(ckpt_path, trainer.logger) + trainer.checkpoint_connector.hpc_load(ckpt_path, on_gpu=args.on_gpu) if args.on_gpu: trainer = Trainer(gpus=1, distributed_backend='horovod', max_epochs=1) diff --git a/tests/models/test_cpu.py b/tests/models/test_cpu.py index efd59160f2..81905c963c 100644 --- a/tests/models/test_cpu.py +++ b/tests/models/test_cpu.py @@ -56,7 +56,7 @@ def test_cpu_slurm_save_load(tmpdir): # test HPC saving # simulate snapshot on slurm - saved_filepath = trainer.hpc_save(trainer.weights_save_path, logger) + saved_filepath = trainer.checkpoint_connector.hpc_save(trainer.weights_save_path, logger) assert os.path.exists(saved_filepath) # new logger file to get meta diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 391070a3ba..3ec47c778c 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -227,7 +227,7 @@ def test_dp_resume(tmpdir): # HPC LOAD/SAVE # --------------------------- # save - trainer.hpc_save(tmpdir, logger) + trainer.checkpoint_connector.hpc_save(tmpdir, logger) # init new trainer new_logger = tutils.get_default_logger(tmpdir, version=logger.version) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index a8f0634179..4fbd8dbd53 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -457,7 +457,7 @@ def test_model_checkpoint_only_weights(tmpdir): # assert restoring train state fails with pytest.raises(KeyError, match='checkpoint contains only the model'): - trainer.restore_training_state(checkpoint) + trainer.checkpoint_connector.restore_training_state(checkpoint) def test_model_freeze_unfreeze():