From 56437e98a6e673f5b978eebdf2babb28c09e41ad Mon Sep 17 00:00:00 2001 From: chaton Date: Tue, 5 Jan 2021 11:01:59 +0100 Subject: [PATCH] [bug-fix] Trainer.test points to latest best_model_path (#5161) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * resolve bug * update code * add set -e * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Adrian Wälchli * update test * Update tests/checkpointing/test_trainer_checkpoint.py Co-authored-by: Sean Naren * Update tests/checkpointing/test_trainer_checkpoint.py Co-authored-by: Carlos Mocholí * update on comments * resolve test * convert to set * update * add error triggering * update * update on comments * update * resolve import * update * update * Update pytorch_lightning/plugins/rpc_plugin.py Co-authored-by: Jirka Borovec * update Co-authored-by: Adrian Wälchli Co-authored-by: Sean Naren Co-authored-by: Carlos Mocholí Co-authored-by: Ubuntu Co-authored-by: Jirka Borovec (cherry picked from commit d5b367871fa3924090ec74bf903bd172bd3e2343) --- .drone.yml | 1 + CHANGELOG.md | 3 +- .../callbacks/model_checkpoint.py | 1 + pytorch_lightning/plugins/rpc_plugin.py | 9 +- .../connectors/checkpoint_connector.py | 3 +- pytorch_lightning/trainer/training_loop.py | 2 +- tests/checkpointing/test_model_checkpoint.py | 13 +-- .../checkpointing/test_trainer_checkpoint.py | 87 +++++++++++++++++++ tests/plugins/test_ddp_sequential_plugin.py | 3 +- tests/special_tests.sh | 1 + 10 files changed, 111 insertions(+), 12 deletions(-) create mode 100644 tests/checkpointing/test_trainer_checkpoint.py diff --git a/.drone.yml b/.drone.yml index b0b6c3df1b..472861852c 100644 --- a/.drone.yml +++ b/.drone.yml @@ -30,6 +30,7 @@ steps: MKL_THREADING_LAYER: GNU commands: + - set -e - python --version - pip --version - nvidia-smi diff --git a/CHANGELOG.md b/CHANGELOG.md index f7a11aa7fa..fe57ec5574 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -76,12 +76,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Skip restore from `resume_from_checkpoint` in while `testing` ([#5161](https://github.com/PyTorchLightning/pytorch-lightning/pull/5161)) + - Allowed `log_momentum` for adaptive optimizers in `LearningRateMonitor` ([#5333](https://github.com/PyTorchLightning/pytorch-lightning/pull/5333)) - Disabled checkpointing, earlystopping and logger with `fast_dev_run` ([#5277](https://github.com/PyTorchLightning/pytorch-lightning/pull/5277)) - ## [1.1.2] - 2020-12-23 ### Added diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 8472ff88f4..7fd7a571a4 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -199,6 +199,7 @@ class ModelCheckpoint(Callback): "best_model_score": self.best_model_score, "best_model_path": self.best_model_path, "current_score": self.current_score, + "dirpath": self.dirpath } def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]): diff --git a/pytorch_lightning/plugins/rpc_plugin.py b/pytorch_lightning/plugins/rpc_plugin.py index d9fb6df3d1..8a694557e1 100644 --- a/pytorch_lightning/plugins/rpc_plugin.py +++ b/pytorch_lightning/plugins/rpc_plugin.py @@ -18,10 +18,13 @@ import torch from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.ddp_plugin import DDPPlugin -from pytorch_lightning.utilities import _RPC_AVAILABLE +from pytorch_lightning.utilities import _RPC_AVAILABLE, _module_available +DEFAULT_RPC_TIMEOUT_SEC = 60. if _RPC_AVAILABLE: from torch.distributed import rpc + if _module_available("torch.distributed.rpc.constants") and hasattr(torch.distributed.rpc.constants, "DEFAULT_RPC_TIMEOUT_SEC"): + from torch.distributed.rpc.constants import DEFAULT_RPC_TIMEOUT_SEC class RPCPlugin(DDPPlugin): @@ -33,7 +36,8 @@ class RPCPlugin(DDPPlugin): that need to be addressed when using RPC communication when building custom RPC Plugins. """ - def __init__(self, **kwargs): + def __init__(self, rpc_timeout_sec: float = DEFAULT_RPC_TIMEOUT_SEC, **kwargs): + self.rpc_timeout_sec = rpc_timeout_sec self.rpc_initialized = False super().__init__(**kwargs) @@ -42,6 +46,7 @@ class RPCPlugin(DDPPlugin): world_size: int) -> None: os.environ['MASTER_PORT'] = os.getenv('RPC_MASTER_PORT', '15000') rpc.init_rpc(f"worker{global_rank}", rank=global_rank, world_size=world_size) + rpc._set_rpc_timeout(self.rpc_timeout_sec) self.rpc_initialized = True def rpc_save_model(self, diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 6c7d3994f8..58c596181f 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -20,6 +20,7 @@ from typing import Optional, Union import torch import pytorch_lightning +from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import _APEX_AVAILABLE, AMPType, _OMEGACONF_AVAILABLE, rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.cloud_io import atomic_save, get_filesystem @@ -62,7 +63,7 @@ class CheckpointConnector: rank_zero_info(f'restored hpc model from: {checkpoint_path}') # 2. Attempt to restore states from `resume_from_checkpoint` file - elif self.trainer.resume_from_checkpoint is not None: + elif self.trainer.resume_from_checkpoint is not None and not self.trainer.testing: self.restore(self.trainer.resume_from_checkpoint, on_gpu=self.trainer.on_gpu) # wait for all to catch up diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 99eb4c52b1..64eb224a42 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -184,7 +184,7 @@ class TrainLoop: # if cluster resets state, the model will update with the saved weights self.trainer.model = model - # restore training and model before hpc is called + # restore training state and model weights before hpc is called self.trainer.checkpoint_connector.restore_weights(model) # on pretrain routine end diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 7fb62637e1..8926f98b43 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -11,20 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from argparse import Namespace import os +from pathlib import Path import pickle import platform import re -from argparse import Namespace -from pathlib import Path from unittest import mock from unittest.mock import Mock import cloudpickle +from omegaconf import Container, OmegaConf import pytest import torch import yaml -from omegaconf import Container, OmegaConf import pytorch_lightning as pl import tests.base.develop_utils as tutils @@ -34,6 +34,7 @@ from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import BoringModel +import tests.base.develop_utils as tutils class LogInTwoMethods(BoringModel): @@ -745,9 +746,9 @@ def test_checkpoint_repeated_strategy_extended(enable_pl_optimizer, tmpdir): model = ExtendedBoringModel() trainer.test(model) assert not trainer.checkpoint_connector.has_trained - assert trainer.global_step == epochs * limit_train_batches - assert trainer.current_epoch == epochs - + # resume_from_checkpoint is resumed when calling `.fit` + assert trainer.global_step == 0 + assert trainer.current_epoch == 0 trainer.fit(model) assert not trainer.checkpoint_connector.has_trained assert trainer.global_step == epochs * limit_train_batches diff --git a/tests/checkpointing/test_trainer_checkpoint.py b/tests/checkpointing/test_trainer_checkpoint.py new file mode 100644 index 0000000000..9e93a8c297 --- /dev/null +++ b/tests/checkpointing/test_trainer_checkpoint.py @@ -0,0 +1,87 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from copy import deepcopy +import os + +import torch + +import pytorch_lightning as pl +from pytorch_lightning import seed_everything, Trainer +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.utilities.cloud_io import load as pl_load +from tests.base import BoringModel + + +def test_finetuning_with_resume_from_checkpoint(tmpdir): + """ + This test validates that generated ModelCheckpoint is pointing to the right best_model_path during test + """ + + seed_everything(3) + + checkpoint_callback = ModelCheckpoint(monitor='val_loss', dirpath=tmpdir, filename="{epoch:02d}", save_top_k=-1) + + class ExtendedBoringModel(BoringModel): + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.001) + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) + return [optimizer], [lr_scheduler] + + def validation_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.log("val_loss", loss, on_epoch=True, prog_bar=True) + + model = ExtendedBoringModel() + model.validation_epoch_end = None + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=1, + limit_train_batches=12, + limit_val_batches=6, + limit_test_batches=12, + callbacks=[checkpoint_callback], + logger=False, + ) + trainer.fit(model) + assert os.listdir(tmpdir) == ['epoch=00.ckpt'] + + best_model_paths = [checkpoint_callback.best_model_path] + results = [] + + for idx in range(3, 6): + # load from checkpoint + trainer = pl.Trainer( + default_root_dir=tmpdir, + max_epochs=idx, + limit_train_batches=12, + limit_val_batches=12, + limit_test_batches=12, + resume_from_checkpoint=best_model_paths[-1], + progress_bar_refresh_rate=0, + ) + trainer.fit(model) + trainer.test() + results.append(deepcopy(trainer.callback_metrics)) + best_model_paths.append(trainer.checkpoint_callback.best_model_path) + + for idx in range(len(results) - 1): + assert results[idx]["val_loss"] > results[idx + 1]["val_loss"] + + for idx, best_model_path in enumerate(best_model_paths): + if idx == 0: + assert best_model_path.endswith(f"epoch=0{idx}.ckpt") + else: + assert f"epoch={idx + 1}" in best_model_path diff --git a/tests/plugins/test_ddp_sequential_plugin.py b/tests/plugins/test_ddp_sequential_plugin.py index 966b58c63e..460d195f67 100644 --- a/tests/plugins/test_ddp_sequential_plugin.py +++ b/tests/plugins/test_ddp_sequential_plugin.py @@ -47,7 +47,8 @@ def test_ddp_sequential_plugin_ddp_rpc_manual(tmpdir, args=None): limit_test_batches=2, gpus=2, distributed_backend="ddp", - plugins=[DDPSequentialPlugin(balance=[2, 1])], + plugins=[DDPSequentialPlugin(balance=[2, 1], rpc_timeout_sec=5 * 60)], + enable_pl_optimizer=True, ) trainer.fit(model) diff --git a/tests/special_tests.sh b/tests/special_tests.sh index 950e3776bb..8d67cce28b 100644 --- a/tests/special_tests.sh +++ b/tests/special_tests.sh @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # Running special tests +set -e export PL_RUNNING_SPECIAL_TESTS=1 DEFAULTS="-m coverage run --source pytorch_lightning -a -m pytest --verbose --capture=no" python ${DEFAULTS} tests/trainer/optimization/test_manual_optimization.py::test_step_with_optimizer_closure_with_different_frequencies_ddp