From 6ddb03922a461f5672de4b844379764e85780549 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 31 Mar 2020 14:57:48 +0200 Subject: [PATCH] Profiler summary (#1259) * refactor and add types * add Prorfiler summary * fix imports * Revert "refactor and add types" This reverts commit b4c552fa * changelog * revert rename * fix test * mute verbose --- CHANGELOG.md | 3 +- pytorch_lightning/core/lightning.py | 2 +- pytorch_lightning/loggers/comet.py | 2 +- pytorch_lightning/profiler/__init__.py | 9 +- .../profiler/{profiler.py => profilers.py} | 104 +++++++++++++----- pytorch_lightning/trainer/data_loading.py | 2 +- .../trainer/distrib_data_parallel.py | 2 +- pytorch_lightning/trainer/distrib_parts.py | 2 +- pytorch_lightning/trainer/evaluation_loop.py | 2 +- pytorch_lightning/trainer/trainer.py | 11 +- pytorch_lightning/trainer/training_loop.py | 2 +- .../utilities/{debugging.py => exceptions.py} | 0 tests/loggers/test_comet.py | 2 +- tests/models/test_amp.py | 2 +- tests/models/test_gpu.py | 2 +- tests/models/test_restore.py | 2 +- tests/test_deprecated.py | 2 + tests/test_profiler.py | 17 ++- tests/trainer/test_dataloaders.py | 2 +- tests/trainer/test_trainer.py | 2 +- 20 files changed, 113 insertions(+), 59 deletions(-) rename pytorch_lightning/profiler/{profiler.py => profilers.py} (65%) rename pytorch_lightning/utilities/{debugging.py => exceptions.py} (100%) diff --git a/CHANGELOG.md b/CHANGELOG.md index 590095fa7f..02110f6e68 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a check that stops the training when loss or weights contain `NaN` or `inf` values. ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097)) - Updated references to self.forward() to instead use the `__call__` interface. ([#1211](https://github.com/PyTorchLightning/pytorch-lightning/pull/1211)) - Added support for `IterableDataset` when `val_check_interval=1.0` (default), this will trigger validation at the end of each epoch. ([#1283](https://github.com/PyTorchLightning/pytorch-lightning/pull/1283)) +- Added `summary` method to Profilers. ([#1259](https://github.com/PyTorchLightning/pytorch-lightning/pull/1259)) - Added informative errors if user defined dataloader has zero length ([#1280](https://github.com/PyTorchLightning/pytorch-lightning/pull/1280)) ### Changed @@ -72,7 +73,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Support for user defined callbacks ([#889](https://github.com/PyTorchLightning/pytorch-lightning/pull/889) and [#950](https://github.com/PyTorchLightning/pytorch-lightning/pull/950)) - Added support for multiple loggers to be passed to `Trainer` as an iterable (e.g. list, tuple, etc.) ([#903](https://github.com/PyTorchLightning/pytorch-lightning/pull/903)) - Added support for step-based learning rate scheduling ([#941](https://github.com/PyTorchLightning/pytorch-lightning/pull/941)) -- Added support for logging hparams as dict ([#1029](https://github.com/PyTorchLightning/pytorch-lightning/pull/1029)) +- Added support for logging `hparams` as dict ([#1029](https://github.com/PyTorchLightning/pytorch-lightning/pull/1029)) - Checkpoint and early stopping now work without val. step ([#1041](https://github.com/PyTorchLightning/pytorch-lightning/pull/1041)) - Support graceful training cleanup after Keyboard Interrupt ([#856](https://github.com/PyTorchLightning/pytorch-lightning/pull/856), [#1019](https://github.com/PyTorchLightning/pytorch-lightning/pull/1019)) - Added type hints for function arguments ([#912](https://github.com/PyTorchLightning/pytorch-lightning/pull/912), ) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 92dc78782b..8cec41da2b 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -20,7 +20,7 @@ from pytorch_lightning.core.hooks import ModelHooks from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel -from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.utilities.exceptions import MisconfigurationException try: import torch_xla.core.xla_model as xm diff --git a/pytorch_lightning/loggers/comet.py b/pytorch_lightning/loggers/comet.py index 0109f9dbdd..ee9d65a73c 100644 --- a/pytorch_lightning/loggers/comet.py +++ b/pytorch_lightning/loggers/comet.py @@ -28,7 +28,7 @@ from torch import is_tensor from pytorch_lightning import _logger as log from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_only -from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.utilities.exceptions import MisconfigurationException class CometLogger(LightningLoggerBase): diff --git a/pytorch_lightning/profiler/__init__.py b/pytorch_lightning/profiler/__init__.py index 20eece9ad6..683baccafa 100644 --- a/pytorch_lightning/profiler/__init__.py +++ b/pytorch_lightning/profiler/__init__.py @@ -3,7 +3,7 @@ Profiling your training run can help you understand if there are any bottlenecks Built-in checks ----------------- +--------------- PyTorch Lightning supports profiling standard actions in the training loop out of the box, including: @@ -20,7 +20,7 @@ PyTorch Lightning supports profiling standard actions in the training loop out o - on_training_end Enable simple profiling -------------------------- +----------------------- If you only wish to profile the standard actions, you can set `profiler=True` when constructing your `Trainer` object. @@ -113,10 +113,11 @@ to track and the profiler will record performance for code executed within this """ -from pytorch_lightning.profiler.profiler import Profiler, AdvancedProfiler, PassThroughProfiler +from pytorch_lightning.profiler.profilers import SimpleProfiler, AdvancedProfiler, PassThroughProfiler, BaseProfiler __all__ = [ - 'Profiler', + 'BaseProfiler', + 'SimpleProfiler', 'AdvancedProfiler', 'PassThroughProfiler', ] diff --git a/pytorch_lightning/profiler/profiler.py b/pytorch_lightning/profiler/profilers.py similarity index 65% rename from pytorch_lightning/profiler/profiler.py rename to pytorch_lightning/profiler/profilers.py index e2a6e5b200..6f6aa959ac 100644 --- a/pytorch_lightning/profiler/profiler.py +++ b/pytorch_lightning/profiler/profilers.py @@ -1,5 +1,6 @@ import cProfile import io +import os import pstats import time from abc import ABC, abstractmethod @@ -16,6 +17,18 @@ class BaseProfiler(ABC): If you wish to write a custom profiler, you should inhereit from this class. """ + def __init__(self, output_streams: list = None): + """ + Params: + stream_out: callable + """ + if output_streams: + if not isinstance(output_streams, (list, tuple)): + output_streams = [output_streams] + else: + output_streams = [] + self.write_streams = output_streams + @abstractmethod def start(self, action_name: str) -> None: """Defines how to start recording an action.""" @@ -57,7 +70,12 @@ class BaseProfiler(ABC): def describe(self) -> None: """Logs a profile report after the conclusion of the training run.""" - pass + for write in self.write_streams: + write(self.summary()) + + @abstractmethod + def summary(self) -> str: + """Create profiler summary in text format.""" class PassThroughProfiler(BaseProfiler): @@ -67,7 +85,7 @@ class PassThroughProfiler(BaseProfiler): """ def __init__(self): - pass + super().__init__(output_streams=None) def start(self, action_name: str) -> None: pass @@ -75,17 +93,31 @@ class PassThroughProfiler(BaseProfiler): def stop(self, action_name: str) -> None: pass + def summary(self) -> str: + return "" -class Profiler(BaseProfiler): + +class SimpleProfiler(BaseProfiler): """ This profiler simply records the duration of actions (in seconds) and reports the mean duration of each action and the total time spent over the entire training run. """ - def __init__(self): + def __init__(self, output_filename: str = None): + """ + Params: + output_filename (str): optionally save profile results to file instead of printing + to std out when training is finished. + """ self.current_actions = {} self.recorded_durations = defaultdict(list) + self.output_fname = output_filename + self.output_file = open(self.output_fname, 'w') if self.output_fname else None + + streaming_out = [self.output_file.write] if self.output_file else [log.info] + super().__init__(output_streams=streaming_out) + def start(self, action_name: str) -> None: if action_name in self.current_actions: raise ValueError( @@ -103,20 +135,31 @@ class Profiler(BaseProfiler): duration = end_time - start_time self.recorded_durations[action_name].append(duration) - def describe(self) -> None: + def summary(self) -> str: output_string = "\n\nProfiler Report\n" def log_row(action, mean, total): - return f"\n{action:<20s}\t| {mean:<15}\t| {total:<15}" + return f"{os.linesep}{action:<20s}\t| {mean:<15}\t| {total:<15}" output_string += log_row("Action", "Mean duration (s)", "Total time (s)") - output_string += f"\n{'-' * 65}" + output_string += f"{os.linesep}{'-' * 65}" for action, durations in self.recorded_durations.items(): output_string += log_row( action, f"{np.mean(durations):.5}", f"{np.sum(durations):.5}", ) - output_string += "\n" - log.info(output_string) + output_string += os.linesep + return output_string + + def describe(self): + """Logs a profile report after the conclusion of the training run.""" + super().describe() + if self.output_file: + self.output_file.flush() + + def __del__(self): + """Close profiler's stream.""" + if self.output_file: + self.output_file.close() class AdvancedProfiler(BaseProfiler): @@ -136,9 +179,14 @@ class AdvancedProfiler(BaseProfiler): or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines) """ self.profiled_actions = {} - self.output_filename = output_filename self.line_count_restriction = line_count_restriction + self.output_fname = output_filename + self.output_file = open(self.output_fname, 'w') if self.output_fname else None + + streaming_out = [self.output_file.write] if self.output_file else [log.info] + super().__init__(output_streams=streaming_out) + def start(self, action_name: str) -> None: if action_name not in self.profiled_actions: self.profiled_actions[action_name] = cProfile.Profile() @@ -152,22 +200,28 @@ class AdvancedProfiler(BaseProfiler): ) pr.disable() - def describe(self) -> None: - self.recorded_stats = {} + def summary(self) -> str: + recorded_stats = {} for action_name, pr in self.profiled_actions.items(): s = io.StringIO() ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats('cumulative') ps.print_stats(self.line_count_restriction) - self.recorded_stats[action_name] = s.getvalue() - if self.output_filename is not None: - # save to file - with open(self.output_filename, "w") as f: - for action, stats in self.recorded_stats.items(): - f.write(f"Profile stats for: {action}") - f.write(stats) - else: - # log to standard out - output_string = "\nProfiler Report\n" - for action, stats in self.recorded_stats.items(): - output_string += f"\nProfile stats for: {action}\n{stats}" - log.info(output_string) + recorded_stats[action_name] = s.getvalue() + + # log to standard out + output_string = f"{os.linesep}Profiler Report{os.linesep}" + for action, stats in recorded_stats.items(): + output_string += f"{os.linesep}Profile stats for: {action}{os.linesep}{stats}" + + return output_string + + def describe(self): + """Logs a profile report after the conclusion of the training run.""" + super().describe() + if self.output_file: + self.output_file.flush() + + def __del__(self): + """Close profiler's stream.""" + if self.output_file: + self.output_file.close() diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 039c7ee588..83b59d21a7 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -6,7 +6,7 @@ from torch.utils.data import SequentialSampler, DataLoader from torch.utils.data.distributed import DistributedSampler from pytorch_lightning.core import LightningModule -from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.utilities.exceptions import MisconfigurationException try: from apex import amp diff --git a/pytorch_lightning/trainer/distrib_data_parallel.py b/pytorch_lightning/trainer/distrib_data_parallel.py index 182b5404b3..95b9a61974 100644 --- a/pytorch_lightning/trainer/distrib_data_parallel.py +++ b/pytorch_lightning/trainer/distrib_data_parallel.py @@ -122,7 +122,7 @@ from typing import Union import torch from pytorch_lightning import _logger as log from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.utilities.exceptions import MisconfigurationException try: from apex import amp diff --git a/pytorch_lightning/trainer/distrib_parts.py b/pytorch_lightning/trainer/distrib_parts.py index 4718407a5f..fc6007b75d 100644 --- a/pytorch_lightning/trainer/distrib_parts.py +++ b/pytorch_lightning/trainer/distrib_parts.py @@ -344,7 +344,7 @@ from pytorch_lightning.overrides.data_parallel import ( LightningDistributedDataParallel, LightningDataParallel, ) -from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.utilities.exceptions import MisconfigurationException try: from apex import amp diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index d720f81e6e..4fef97f3c8 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -135,7 +135,7 @@ from tqdm.auto import tqdm from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel -from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.utilities.exceptions import MisconfigurationException try: import torch_xla.distributed.parallel_loader as xla_pl diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 7fbae97f3d..95df77a778 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -18,8 +18,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.profiler import Profiler, PassThroughProfiler -from pytorch_lightning.profiler.profiler import BaseProfiler +from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler, BaseProfiler from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin @@ -33,7 +32,7 @@ from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin from pytorch_lightning.trainer.training_io import TrainerIOMixin from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin -from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.trainer.supporters import TensorRunningMean try: @@ -364,7 +363,7 @@ class Trainer( # configure profiler if profiler is True: - profiler = Profiler() + profiler = SimpleProfiler() self.profiler = profiler or PassThroughProfiler() # configure early stop callback @@ -490,10 +489,10 @@ class Trainer( ('print_nan_grads', (,), False), ('process_position', (,), 0), ('profiler', - (, + (, ), None), - ... + ... """ trainer_default_params = inspect.signature(cls).parameters name_type_default = [] diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index f01e2294c4..02abae9b47 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -145,7 +145,7 @@ from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.loggers import LightningLoggerBase -from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.trainer.supporters import TensorRunningMean try: diff --git a/pytorch_lightning/utilities/debugging.py b/pytorch_lightning/utilities/exceptions.py similarity index 100% rename from pytorch_lightning/utilities/debugging.py rename to pytorch_lightning/utilities/exceptions.py diff --git a/tests/loggers/test_comet.py b/tests/loggers/test_comet.py index 1aaf4cb7fd..771ca3b6e7 100644 --- a/tests/loggers/test_comet.py +++ b/tests/loggers/test_comet.py @@ -8,7 +8,7 @@ import torch import tests.base.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.loggers import CometLogger -from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import LightningTestModel diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index eea3bf9865..a51ea938bd 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -5,7 +5,7 @@ import torch import tests.base.utils as tutils from pytorch_lightning import Trainer -from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import ( LightningTestModel, ) diff --git a/tests/models/test_gpu.py b/tests/models/test_gpu.py index 0ee77351b1..5de29f647d 100644 --- a/tests/models/test_gpu.py +++ b/tests/models/test_gpu.py @@ -11,7 +11,7 @@ from pytorch_lightning.trainer.distrib_parts import ( parse_gpu_ids, determine_root_gpu_device, ) -from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import LightningTestModel PRETEND_N_OF_GPUS = 16 diff --git a/tests/models/test_restore.py b/tests/models/test_restore.py index 07434274c2..d1f5e6ecac 100644 --- a/tests/models/test_restore.py +++ b/tests/models/test_restore.py @@ -8,7 +8,7 @@ import torch import tests.base.utils as tutils from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import ( LightningTestModel, LightningTestModelWithoutHyperparametersArg, diff --git a/tests/test_deprecated.py b/tests/test_deprecated.py index a3b087c718..378c0f915b 100644 --- a/tests/test_deprecated.py +++ b/tests/test_deprecated.py @@ -57,6 +57,8 @@ def test_tbd_remove_in_v0_9_0_module_imports(): from pytorch_lightning.logging.test_tube import TestTubeLogger # noqa: F402 from pytorch_lightning.logging.wandb import WandbLogger # noqa: F402 + from pytorch_lightning.profiler import SimpleProfiler, AdvancedProfiler # noqa: F402 + class ModelVer0_6(LightTrainDataloader, LightEmptyTestStep, TestModelBase): diff --git a/tests/test_profiler.py b/tests/test_profiler.py index e60476bc59..ae5dc3eb36 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -1,10 +1,10 @@ -import tempfile +import os import time from pathlib import Path import numpy as np import pytest -from pytorch_lightning.profiler import AdvancedProfiler, Profiler +from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0001 @@ -25,13 +25,13 @@ def _sleep_generator(durations): @pytest.fixture def simple_profiler(): - profiler = Profiler() + profiler = SimpleProfiler() return profiler @pytest.fixture -def advanced_profiler(): - profiler = AdvancedProfiler() +def advanced_profiler(tmpdir): + profiler = AdvancedProfiler(output_filename=os.path.join(tmpdir, "profiler.txt")) return profiler @@ -168,12 +168,9 @@ def test_advanced_profiler_describe(tmpdir, advanced_profiler): # record at least one event with advanced_profiler.profile("test"): pass - # log to stdout + # log to stdout and print to file advanced_profiler.describe() - # print to file - advanced_profiler.output_filename = Path(tmpdir, "profiler.txt") - advanced_profiler.describe() - data = Path(advanced_profiler.output_filename).read_text() + data = Path(advanced_profiler.output_fname).read_text() assert len(data) > 0 diff --git a/tests/trainer/test_dataloaders.py b/tests/trainer/test_dataloaders.py index 6d2332cdf5..fd6f05cc92 100644 --- a/tests/trainer/test_dataloaders.py +++ b/tests/trainer/test_dataloaders.py @@ -2,7 +2,7 @@ import pytest import tests.base.utils as tutils from pytorch_lightning import Trainer -from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import ( TestModelBase, LightningTestModel, diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 295ea3bdce..307365223d 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -14,7 +14,7 @@ from pytorch_lightning.callbacks import ( ) from pytorch_lightning.core.lightning import load_hparams_from_tags_csv from pytorch_lightning.trainer.logging import TrainerLoggingMixin -from pytorch_lightning.utilities.debugging import MisconfigurationException +from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base import ( TestModelBase, DictHparamsModel,