diff --git a/CHANGELOG.md b/CHANGELOG.md index d7b6ce0dbf..758c573a1a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -117,6 +117,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed recursive call for `apply_to_collection(include_none=False)` ([#8719](https://github.com/PyTorchLightning/pytorch-lightning/pull/8719)) +- Fixed an issue with logger outputs not being finalized correctly after prediction runs ([#8333](https://github.com/PyTorchLightning/pytorch-lightning/issues/8333)) + ## [1.4.0] - 2021-07-27 diff --git a/pytorch_lightning/loops/dataloader/evaluation_loop.py b/pytorch_lightning/loops/dataloader/evaluation_loop.py index c8c4af2624..5128a584b0 100644 --- a/pytorch_lightning/loops/dataloader/evaluation_loop.py +++ b/pytorch_lightning/loops/dataloader/evaluation_loop.py @@ -20,7 +20,6 @@ from torch.utils.data.dataloader import DataLoader from pytorch_lightning.loops.dataloader import DataLoaderLoop from pytorch_lightning.loops.epoch import EvaluationEpochLoop from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection -from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import EPOCH_OUTPUT @@ -206,10 +205,6 @@ class EvaluationLoop(DataLoaderLoop): else: self.trainer.call_hook("on_validation_end", *args, **kwargs) - if self.trainer.state.fn != TrainerFn.FITTING: - # summarize profile results - self.trainer.profiler.describe() - # reset any `torchmetrics.Metric` and the logger connector state self.trainer.logger_connector.reset(metrics=True) diff --git a/pytorch_lightning/loops/dataloader/prediction_loop.py b/pytorch_lightning/loops/dataloader/prediction_loop.py index c5011fd993..8010948b7e 100644 --- a/pytorch_lightning/loops/dataloader/prediction_loop.py +++ b/pytorch_lightning/loops/dataloader/prediction_loop.py @@ -119,8 +119,6 @@ class PredictionLoop(DataLoaderLoop): Returns: the results for all dataloaders """ - self.trainer.profiler.describe() - results = self.predictions self.trainer.call_hook("on_predict_epoch_end", results) diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 94bf413c5e..ed26d50d6c 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -225,15 +225,6 @@ class FitLoop(Loop): # hook self.trainer.call_hook("on_train_end") - # todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers. - # It might be related to xla tensors blocked when moving the cpu - # kill loggers - if self.trainer.logger is not None: - self.trainer.logger.finalize("success") - - # summarize profile results - self.trainer.profiler.describe() - # give accelerators a chance to finish self.trainer.accelerator.on_train_end() diff --git a/pytorch_lightning/plugins/training_type/ddp_spawn.py b/pytorch_lightning/plugins/training_type/ddp_spawn.py index b69fea03b5..c39d35e8d1 100644 --- a/pytorch_lightning/plugins/training_type/ddp_spawn.py +++ b/pytorch_lightning/plugins/training_type/ddp_spawn.py @@ -203,6 +203,9 @@ class DDPSpawnPlugin(ParallelPlugin): # persist info in ddp_spawn self.transfer_distrib_spawn_state_on_fit_end(results) + # ensure that spawned processes go through teardown before joining + trainer._call_teardown_hook() + def post_dispatch(self): # restore main state with best weights best_path = self.mp_queue.get() diff --git a/pytorch_lightning/plugins/training_type/tpu_spawn.py b/pytorch_lightning/plugins/training_type/tpu_spawn.py index d8b9457ffe..faec805773 100644 --- a/pytorch_lightning/plugins/training_type/tpu_spawn.py +++ b/pytorch_lightning/plugins/training_type/tpu_spawn.py @@ -172,6 +172,9 @@ class TPUSpawnPlugin(DDPSpawnPlugin): if self.local_rank == 0: time.sleep(2) + # ensure that spawned processes go through teardown before joining + trainer._call_teardown_hook() + @parameter_validation def model_to_device(self) -> None: self.model = self.wrapped_model.to(self.root_device) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 90f64b9df5..fb1f93fdd9 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -76,6 +76,7 @@ from pytorch_lightning.utilities import ( ) from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.distributed import distributed_available +from pytorch_lightning.utilities.enums import DistributedType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_enabled from pytorch_lightning.utilities.model_helpers import is_overridden @@ -944,8 +945,10 @@ class Trainer( if self.state.fn == TrainerFn.FITTING: self.call_hook("on_fit_end") - # teardown - self._call_teardown_hook() + # teardown if necessary (similar calls for spawn plugins are excluded as they have + # been included at the end of `new_process` functions) + if self._distrib_type not in DistributedType.interactive_compatible_types(): + self._call_teardown_hook() if self.state.status != TrainerStatus.INTERRUPTED: self.state.status = TrainerStatus.FINISHED @@ -1211,7 +1214,7 @@ class Trainer( if self.datamodule is not None: self.datamodule.teardown(stage=fn) - self.profiler.teardown(stage=fn) + self.teardown(stage=fn) self.lightning_module.teardown(stage=fn) @@ -1220,6 +1223,14 @@ class Trainer( # these could have become stale if metrics are defined in `setup` self.lightning_module._metric_attributes = None + # todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers. + # It might be related to xla tensors blocked when moving the cpu kill loggers. + if self.logger is not None: + self.logger.finalize("success") + + # summarize profile results + self.profiler.describe() + def call_hook(self, hook_name: str, *args, **kwargs) -> Any: if self.lightning_module: prev_fx_name = self.lightning_module._current_fx_name diff --git a/tests/trainer/logging_/test_distributed_logging.py b/tests/trainer/logging_/test_distributed_logging.py index 03ebfa8f92..b67f924a6f 100644 --- a/tests/trainer/logging_/test_distributed_logging.py +++ b/tests/trainer/logging_/test_distributed_logging.py @@ -12,10 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from typing import Any, Dict, Optional, Union from unittest import mock from unittest.mock import Mock +import pytorch_lightning as pl from pytorch_lightning import Callback, Trainer +from pytorch_lightning.loggers.base import LightningLoggerBase from tests.helpers import BoringModel from tests.helpers.runif import RunIf @@ -101,3 +104,68 @@ def test_first_logger_call_in_subprocess(tmpdir): callbacks=[LoggerCallsObserver()], ) trainer.fit(model) + + +def test_logger_after_fit_predict_test_calls(tmpdir): + """ + Make sure logger outputs are finalized after fit, prediction, and test calls. + """ + + class BufferLogger(LightningLoggerBase): + def __init__(self): + super().__init__() + self.buffer = {} + self.logs = {} + + def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: + self.buffer.update(metrics) + + def finalize(self, status: str) -> None: + self.logs.update(self.buffer) + self.buffer = {} + + @property + def experiment(self) -> Any: + return None + + @property + def version(self) -> Union[int, str]: + return 1 + + @property + def name(self) -> str: + return "BufferLogger" + + def log_hyperparams(self, *args, **kwargs) -> None: + return None + + class LoggerCallsObserver(Callback): + def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + trainer.logger.log_metrics({"fit": 1}) + + def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + trainer.logger.log_metrics({"validate": 1}) + + def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + trainer.logger.log_metrics({"predict": 1}) + + def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + trainer.logger.log_metrics({"test": 1}) + + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=1, + max_epochs=1, + logger=BufferLogger(), + callbacks=[LoggerCallsObserver()], + ) + + assert not trainer.logger.logs + trainer.fit(model) + assert trainer.logger.logs == {"fit": 1, "validate": 1} + trainer.test(model) + assert trainer.logger.logs == {"fit": 1, "validate": 1, "test": 1} + trainer.predict(model) + assert trainer.logger.logs == {"fit": 1, "validate": 1, "test": 1, "predict": 1}