Move logger and profiler finalization to trainer's teardown (#8685)

Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
Binh Tang 2021-08-05 01:09:43 -07:00 committed by GitHub
parent 963c267646
commit efec3d461c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 90 additions and 19 deletions

View File

@ -117,6 +117,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed recursive call for `apply_to_collection(include_none=False)` ([#8719](https://github.com/PyTorchLightning/pytorch-lightning/pull/8719))
- Fixed an issue with logger outputs not being finalized correctly after prediction runs ([#8333](https://github.com/PyTorchLightning/pytorch-lightning/issues/8333))
## [1.4.0] - 2021-07-27

View File

@ -20,7 +20,6 @@ from torch.utils.data.dataloader import DataLoader
from pytorch_lightning.loops.dataloader import DataLoaderLoop
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
@ -206,10 +205,6 @@ class EvaluationLoop(DataLoaderLoop):
else:
self.trainer.call_hook("on_validation_end", *args, **kwargs)
if self.trainer.state.fn != TrainerFn.FITTING:
# summarize profile results
self.trainer.profiler.describe()
# reset any `torchmetrics.Metric` and the logger connector state
self.trainer.logger_connector.reset(metrics=True)

View File

@ -119,8 +119,6 @@ class PredictionLoop(DataLoaderLoop):
Returns:
the results for all dataloaders
"""
self.trainer.profiler.describe()
results = self.predictions
self.trainer.call_hook("on_predict_epoch_end", results)

View File

@ -225,15 +225,6 @@ class FitLoop(Loop):
# hook
self.trainer.call_hook("on_train_end")
# todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers.
# It might be related to xla tensors blocked when moving the cpu
# kill loggers
if self.trainer.logger is not None:
self.trainer.logger.finalize("success")
# summarize profile results
self.trainer.profiler.describe()
# give accelerators a chance to finish
self.trainer.accelerator.on_train_end()

View File

@ -203,6 +203,9 @@ class DDPSpawnPlugin(ParallelPlugin):
# persist info in ddp_spawn
self.transfer_distrib_spawn_state_on_fit_end(results)
# ensure that spawned processes go through teardown before joining
trainer._call_teardown_hook()
def post_dispatch(self):
# restore main state with best weights
best_path = self.mp_queue.get()

View File

@ -172,6 +172,9 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
if self.local_rank == 0:
time.sleep(2)
# ensure that spawned processes go through teardown before joining
trainer._call_teardown_hook()
@parameter_validation
def model_to_device(self) -> None:
self.model = self.wrapped_model.to(self.root_device)

View File

@ -76,6 +76,7 @@ from pytorch_lightning.utilities import (
)
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.enums import DistributedType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled
from pytorch_lightning.utilities.model_helpers import is_overridden
@ -944,8 +945,10 @@ class Trainer(
if self.state.fn == TrainerFn.FITTING:
self.call_hook("on_fit_end")
# teardown
self._call_teardown_hook()
# teardown if necessary (similar calls for spawn plugins are excluded as they have
# been included at the end of `new_process` functions)
if self._distrib_type not in DistributedType.interactive_compatible_types():
self._call_teardown_hook()
if self.state.status != TrainerStatus.INTERRUPTED:
self.state.status = TrainerStatus.FINISHED
@ -1211,7 +1214,7 @@ class Trainer(
if self.datamodule is not None:
self.datamodule.teardown(stage=fn)
self.profiler.teardown(stage=fn)
self.teardown(stage=fn)
self.lightning_module.teardown(stage=fn)
@ -1220,6 +1223,14 @@ class Trainer(
# these could have become stale if metrics are defined in `setup`
self.lightning_module._metric_attributes = None
# todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers.
# It might be related to xla tensors blocked when moving the cpu kill loggers.
if self.logger is not None:
self.logger.finalize("success")
# summarize profile results
self.profiler.describe()
def call_hook(self, hook_name: str, *args, **kwargs) -> Any:
if self.lightning_module:
prev_fx_name = self.lightning_module._current_fx_name

View File

@ -12,10 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Any, Dict, Optional, Union
from unittest import mock
from unittest.mock import Mock
import pytorch_lightning as pl
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.loggers.base import LightningLoggerBase
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf
@ -101,3 +104,68 @@ def test_first_logger_call_in_subprocess(tmpdir):
callbacks=[LoggerCallsObserver()],
)
trainer.fit(model)
def test_logger_after_fit_predict_test_calls(tmpdir):
"""
Make sure logger outputs are finalized after fit, prediction, and test calls.
"""
class BufferLogger(LightningLoggerBase):
def __init__(self):
super().__init__()
self.buffer = {}
self.logs = {}
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
self.buffer.update(metrics)
def finalize(self, status: str) -> None:
self.logs.update(self.buffer)
self.buffer = {}
@property
def experiment(self) -> Any:
return None
@property
def version(self) -> Union[int, str]:
return 1
@property
def name(self) -> str:
return "BufferLogger"
def log_hyperparams(self, *args, **kwargs) -> None:
return None
class LoggerCallsObserver(Callback):
def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
trainer.logger.log_metrics({"fit": 1})
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
trainer.logger.log_metrics({"validate": 1})
def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
trainer.logger.log_metrics({"predict": 1})
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
trainer.logger.log_metrics({"test": 1})
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=1,
limit_val_batches=1,
max_epochs=1,
logger=BufferLogger(),
callbacks=[LoggerCallsObserver()],
)
assert not trainer.logger.logs
trainer.fit(model)
assert trainer.logger.logs == {"fit": 1, "validate": 1}
trainer.test(model)
assert trainer.logger.logs == {"fit": 1, "validate": 1, "test": 1}
trainer.predict(model)
assert trainer.logger.logs == {"fit": 1, "validate": 1, "test": 1, "predict": 1}