Move logger and profiler finalization to trainer's teardown (#8685)
Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
This commit is contained in:
parent
963c267646
commit
efec3d461c
|
@ -117,6 +117,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Fixed recursive call for `apply_to_collection(include_none=False)` ([#8719](https://github.com/PyTorchLightning/pytorch-lightning/pull/8719))
|
||||
|
||||
|
||||
- Fixed an issue with logger outputs not being finalized correctly after prediction runs ([#8333](https://github.com/PyTorchLightning/pytorch-lightning/issues/8333))
|
||||
|
||||
|
||||
## [1.4.0] - 2021-07-27
|
||||
|
||||
|
|
|
@ -20,7 +20,6 @@ from torch.utils.data.dataloader import DataLoader
|
|||
from pytorch_lightning.loops.dataloader import DataLoaderLoop
|
||||
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
|
||||
from pytorch_lightning.trainer.states import TrainerFn
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
|
||||
|
||||
|
@ -206,10 +205,6 @@ class EvaluationLoop(DataLoaderLoop):
|
|||
else:
|
||||
self.trainer.call_hook("on_validation_end", *args, **kwargs)
|
||||
|
||||
if self.trainer.state.fn != TrainerFn.FITTING:
|
||||
# summarize profile results
|
||||
self.trainer.profiler.describe()
|
||||
|
||||
# reset any `torchmetrics.Metric` and the logger connector state
|
||||
self.trainer.logger_connector.reset(metrics=True)
|
||||
|
||||
|
|
|
@ -119,8 +119,6 @@ class PredictionLoop(DataLoaderLoop):
|
|||
Returns:
|
||||
the results for all dataloaders
|
||||
"""
|
||||
self.trainer.profiler.describe()
|
||||
|
||||
results = self.predictions
|
||||
|
||||
self.trainer.call_hook("on_predict_epoch_end", results)
|
||||
|
|
|
@ -225,15 +225,6 @@ class FitLoop(Loop):
|
|||
# hook
|
||||
self.trainer.call_hook("on_train_end")
|
||||
|
||||
# todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers.
|
||||
# It might be related to xla tensors blocked when moving the cpu
|
||||
# kill loggers
|
||||
if self.trainer.logger is not None:
|
||||
self.trainer.logger.finalize("success")
|
||||
|
||||
# summarize profile results
|
||||
self.trainer.profiler.describe()
|
||||
|
||||
# give accelerators a chance to finish
|
||||
self.trainer.accelerator.on_train_end()
|
||||
|
||||
|
|
|
@ -203,6 +203,9 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
# persist info in ddp_spawn
|
||||
self.transfer_distrib_spawn_state_on_fit_end(results)
|
||||
|
||||
# ensure that spawned processes go through teardown before joining
|
||||
trainer._call_teardown_hook()
|
||||
|
||||
def post_dispatch(self):
|
||||
# restore main state with best weights
|
||||
best_path = self.mp_queue.get()
|
||||
|
|
|
@ -172,6 +172,9 @@ class TPUSpawnPlugin(DDPSpawnPlugin):
|
|||
if self.local_rank == 0:
|
||||
time.sleep(2)
|
||||
|
||||
# ensure that spawned processes go through teardown before joining
|
||||
trainer._call_teardown_hook()
|
||||
|
||||
@parameter_validation
|
||||
def model_to_device(self) -> None:
|
||||
self.model = self.wrapped_model.to(self.root_device)
|
||||
|
|
|
@ -76,6 +76,7 @@ from pytorch_lightning.utilities import (
|
|||
)
|
||||
from pytorch_lightning.utilities.debugging import InternalDebugger
|
||||
from pytorch_lightning.utilities.distributed import distributed_available
|
||||
from pytorch_lightning.utilities.enums import DistributedType
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled
|
||||
from pytorch_lightning.utilities.model_helpers import is_overridden
|
||||
|
@ -944,8 +945,10 @@ class Trainer(
|
|||
if self.state.fn == TrainerFn.FITTING:
|
||||
self.call_hook("on_fit_end")
|
||||
|
||||
# teardown
|
||||
self._call_teardown_hook()
|
||||
# teardown if necessary (similar calls for spawn plugins are excluded as they have
|
||||
# been included at the end of `new_process` functions)
|
||||
if self._distrib_type not in DistributedType.interactive_compatible_types():
|
||||
self._call_teardown_hook()
|
||||
|
||||
if self.state.status != TrainerStatus.INTERRUPTED:
|
||||
self.state.status = TrainerStatus.FINISHED
|
||||
|
@ -1211,7 +1214,7 @@ class Trainer(
|
|||
|
||||
if self.datamodule is not None:
|
||||
self.datamodule.teardown(stage=fn)
|
||||
self.profiler.teardown(stage=fn)
|
||||
|
||||
self.teardown(stage=fn)
|
||||
self.lightning_module.teardown(stage=fn)
|
||||
|
||||
|
@ -1220,6 +1223,14 @@ class Trainer(
|
|||
# these could have become stale if metrics are defined in `setup`
|
||||
self.lightning_module._metric_attributes = None
|
||||
|
||||
# todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers.
|
||||
# It might be related to xla tensors blocked when moving the cpu kill loggers.
|
||||
if self.logger is not None:
|
||||
self.logger.finalize("success")
|
||||
|
||||
# summarize profile results
|
||||
self.profiler.describe()
|
||||
|
||||
def call_hook(self, hook_name: str, *args, **kwargs) -> Any:
|
||||
if self.lightning_module:
|
||||
prev_fx_name = self.lightning_module._current_fx_name
|
||||
|
|
|
@ -12,10 +12,13 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from unittest import mock
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning import Callback, Trainer
|
||||
from pytorch_lightning.loggers.base import LightningLoggerBase
|
||||
from tests.helpers import BoringModel
|
||||
from tests.helpers.runif import RunIf
|
||||
|
||||
|
@ -101,3 +104,68 @@ def test_first_logger_call_in_subprocess(tmpdir):
|
|||
callbacks=[LoggerCallsObserver()],
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
def test_logger_after_fit_predict_test_calls(tmpdir):
|
||||
"""
|
||||
Make sure logger outputs are finalized after fit, prediction, and test calls.
|
||||
"""
|
||||
|
||||
class BufferLogger(LightningLoggerBase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.buffer = {}
|
||||
self.logs = {}
|
||||
|
||||
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
|
||||
self.buffer.update(metrics)
|
||||
|
||||
def finalize(self, status: str) -> None:
|
||||
self.logs.update(self.buffer)
|
||||
self.buffer = {}
|
||||
|
||||
@property
|
||||
def experiment(self) -> Any:
|
||||
return None
|
||||
|
||||
@property
|
||||
def version(self) -> Union[int, str]:
|
||||
return 1
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "BufferLogger"
|
||||
|
||||
def log_hyperparams(self, *args, **kwargs) -> None:
|
||||
return None
|
||||
|
||||
class LoggerCallsObserver(Callback):
|
||||
def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
trainer.logger.log_metrics({"fit": 1})
|
||||
|
||||
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
trainer.logger.log_metrics({"validate": 1})
|
||||
|
||||
def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
trainer.logger.log_metrics({"predict": 1})
|
||||
|
||||
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
||||
trainer.logger.log_metrics({"test": 1})
|
||||
|
||||
model = BoringModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=1,
|
||||
limit_val_batches=1,
|
||||
max_epochs=1,
|
||||
logger=BufferLogger(),
|
||||
callbacks=[LoggerCallsObserver()],
|
||||
)
|
||||
|
||||
assert not trainer.logger.logs
|
||||
trainer.fit(model)
|
||||
assert trainer.logger.logs == {"fit": 1, "validate": 1}
|
||||
trainer.test(model)
|
||||
assert trainer.logger.logs == {"fit": 1, "validate": 1, "test": 1}
|
||||
trainer.predict(model)
|
||||
assert trainer.logger.logs == {"fit": 1, "validate": 1, "test": 1, "predict": 1}
|
||||
|
|
Loading…
Reference in New Issue