Profiler summary (#1259)

* refactor and add types

* add Prorfiler summary

* fix imports

* Revert "refactor and add types"

This reverts commit b4c552fa

* changelog

* revert rename

* fix test

* mute verbose
This commit is contained in:
Jirka Borovec 2020-03-31 14:57:48 +02:00 committed by GitHub
parent 4dcb9d3e30
commit 6ddb03922a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 113 additions and 59 deletions

View File

@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a check that stops the training when loss or weights contain `NaN` or `inf` values. ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097))
- Updated references to self.forward() to instead use the `__call__` interface. ([#1211](https://github.com/PyTorchLightning/pytorch-lightning/pull/1211))
- Added support for `IterableDataset` when `val_check_interval=1.0` (default), this will trigger validation at the end of each epoch. ([#1283](https://github.com/PyTorchLightning/pytorch-lightning/pull/1283))
- Added `summary` method to Profilers. ([#1259](https://github.com/PyTorchLightning/pytorch-lightning/pull/1259))
- Added informative errors if user defined dataloader has zero length ([#1280](https://github.com/PyTorchLightning/pytorch-lightning/pull/1280))
### Changed
@ -72,7 +73,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Support for user defined callbacks ([#889](https://github.com/PyTorchLightning/pytorch-lightning/pull/889) and [#950](https://github.com/PyTorchLightning/pytorch-lightning/pull/950))
- Added support for multiple loggers to be passed to `Trainer` as an iterable (e.g. list, tuple, etc.) ([#903](https://github.com/PyTorchLightning/pytorch-lightning/pull/903))
- Added support for step-based learning rate scheduling ([#941](https://github.com/PyTorchLightning/pytorch-lightning/pull/941))
- Added support for logging hparams as dict ([#1029](https://github.com/PyTorchLightning/pytorch-lightning/pull/1029))
- Added support for logging `hparams` as dict ([#1029](https://github.com/PyTorchLightning/pytorch-lightning/pull/1029))
- Checkpoint and early stopping now work without val. step ([#1041](https://github.com/PyTorchLightning/pytorch-lightning/pull/1041))
- Support graceful training cleanup after Keyboard Interrupt ([#856](https://github.com/PyTorchLightning/pytorch-lightning/pull/856), [#1019](https://github.com/PyTorchLightning/pytorch-lightning/pull/1019))
- Added type hints for function arguments ([#912](https://github.com/PyTorchLightning/pytorch-lightning/pull/912), )

View File

@ -20,7 +20,7 @@ from pytorch_lightning.core.hooks import ModelHooks
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException
try:
import torch_xla.core.xla_model as xm

View File

@ -28,7 +28,7 @@ from torch import is_tensor
from pytorch_lightning import _logger as log
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_only
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException
class CometLogger(LightningLoggerBase):

View File

@ -3,7 +3,7 @@ Profiling your training run can help you understand if there are any bottlenecks
Built-in checks
----------------
---------------
PyTorch Lightning supports profiling standard actions in the training loop out of the box, including:
@ -20,7 +20,7 @@ PyTorch Lightning supports profiling standard actions in the training loop out o
- on_training_end
Enable simple profiling
-------------------------
-----------------------
If you only wish to profile the standard actions, you can set `profiler=True` when constructing
your `Trainer` object.
@ -113,10 +113,11 @@ to track and the profiler will record performance for code executed within this
"""
from pytorch_lightning.profiler.profiler import Profiler, AdvancedProfiler, PassThroughProfiler
from pytorch_lightning.profiler.profilers import SimpleProfiler, AdvancedProfiler, PassThroughProfiler, BaseProfiler
__all__ = [
'Profiler',
'BaseProfiler',
'SimpleProfiler',
'AdvancedProfiler',
'PassThroughProfiler',
]

View File

@ -1,5 +1,6 @@
import cProfile
import io
import os
import pstats
import time
from abc import ABC, abstractmethod
@ -16,6 +17,18 @@ class BaseProfiler(ABC):
If you wish to write a custom profiler, you should inhereit from this class.
"""
def __init__(self, output_streams: list = None):
"""
Params:
stream_out: callable
"""
if output_streams:
if not isinstance(output_streams, (list, tuple)):
output_streams = [output_streams]
else:
output_streams = []
self.write_streams = output_streams
@abstractmethod
def start(self, action_name: str) -> None:
"""Defines how to start recording an action."""
@ -57,7 +70,12 @@ class BaseProfiler(ABC):
def describe(self) -> None:
"""Logs a profile report after the conclusion of the training run."""
pass
for write in self.write_streams:
write(self.summary())
@abstractmethod
def summary(self) -> str:
"""Create profiler summary in text format."""
class PassThroughProfiler(BaseProfiler):
@ -67,7 +85,7 @@ class PassThroughProfiler(BaseProfiler):
"""
def __init__(self):
pass
super().__init__(output_streams=None)
def start(self, action_name: str) -> None:
pass
@ -75,17 +93,31 @@ class PassThroughProfiler(BaseProfiler):
def stop(self, action_name: str) -> None:
pass
def summary(self) -> str:
return ""
class Profiler(BaseProfiler):
class SimpleProfiler(BaseProfiler):
"""
This profiler simply records the duration of actions (in seconds) and reports
the mean duration of each action and the total time spent over the entire training run.
"""
def __init__(self):
def __init__(self, output_filename: str = None):
"""
Params:
output_filename (str): optionally save profile results to file instead of printing
to std out when training is finished.
"""
self.current_actions = {}
self.recorded_durations = defaultdict(list)
self.output_fname = output_filename
self.output_file = open(self.output_fname, 'w') if self.output_fname else None
streaming_out = [self.output_file.write] if self.output_file else [log.info]
super().__init__(output_streams=streaming_out)
def start(self, action_name: str) -> None:
if action_name in self.current_actions:
raise ValueError(
@ -103,20 +135,31 @@ class Profiler(BaseProfiler):
duration = end_time - start_time
self.recorded_durations[action_name].append(duration)
def describe(self) -> None:
def summary(self) -> str:
output_string = "\n\nProfiler Report\n"
def log_row(action, mean, total):
return f"\n{action:<20s}\t| {mean:<15}\t| {total:<15}"
return f"{os.linesep}{action:<20s}\t| {mean:<15}\t| {total:<15}"
output_string += log_row("Action", "Mean duration (s)", "Total time (s)")
output_string += f"\n{'-' * 65}"
output_string += f"{os.linesep}{'-' * 65}"
for action, durations in self.recorded_durations.items():
output_string += log_row(
action, f"{np.mean(durations):.5}", f"{np.sum(durations):.5}",
)
output_string += "\n"
log.info(output_string)
output_string += os.linesep
return output_string
def describe(self):
"""Logs a profile report after the conclusion of the training run."""
super().describe()
if self.output_file:
self.output_file.flush()
def __del__(self):
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()
class AdvancedProfiler(BaseProfiler):
@ -136,9 +179,14 @@ class AdvancedProfiler(BaseProfiler):
or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines)
"""
self.profiled_actions = {}
self.output_filename = output_filename
self.line_count_restriction = line_count_restriction
self.output_fname = output_filename
self.output_file = open(self.output_fname, 'w') if self.output_fname else None
streaming_out = [self.output_file.write] if self.output_file else [log.info]
super().__init__(output_streams=streaming_out)
def start(self, action_name: str) -> None:
if action_name not in self.profiled_actions:
self.profiled_actions[action_name] = cProfile.Profile()
@ -152,22 +200,28 @@ class AdvancedProfiler(BaseProfiler):
)
pr.disable()
def describe(self) -> None:
self.recorded_stats = {}
def summary(self) -> str:
recorded_stats = {}
for action_name, pr in self.profiled_actions.items():
s = io.StringIO()
ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats('cumulative')
ps.print_stats(self.line_count_restriction)
self.recorded_stats[action_name] = s.getvalue()
if self.output_filename is not None:
# save to file
with open(self.output_filename, "w") as f:
for action, stats in self.recorded_stats.items():
f.write(f"Profile stats for: {action}")
f.write(stats)
else:
# log to standard out
output_string = "\nProfiler Report\n"
for action, stats in self.recorded_stats.items():
output_string += f"\nProfile stats for: {action}\n{stats}"
log.info(output_string)
recorded_stats[action_name] = s.getvalue()
# log to standard out
output_string = f"{os.linesep}Profiler Report{os.linesep}"
for action, stats in recorded_stats.items():
output_string += f"{os.linesep}Profile stats for: {action}{os.linesep}{stats}"
return output_string
def describe(self):
"""Logs a profile report after the conclusion of the training run."""
super().describe()
if self.output_file:
self.output_file.flush()
def __del__(self):
"""Close profiler's stream."""
if self.output_file:
self.output_file.close()

View File

@ -6,7 +6,7 @@ from torch.utils.data import SequentialSampler, DataLoader
from torch.utils.data.distributed import DistributedSampler
from pytorch_lightning.core import LightningModule
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException
try:
from apex import amp

View File

@ -122,7 +122,7 @@ from typing import Union
import torch
from pytorch_lightning import _logger as log
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException
try:
from apex import amp

View File

@ -344,7 +344,7 @@ from pytorch_lightning.overrides.data_parallel import (
LightningDistributedDataParallel,
LightningDataParallel,
)
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException
try:
from apex import amp

View File

@ -135,7 +135,7 @@ from tqdm.auto import tqdm
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException
try:
import torch_xla.distributed.parallel_loader as xla_pl

View File

@ -18,8 +18,7 @@ from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.profiler import Profiler, PassThroughProfiler
from pytorch_lightning.profiler.profiler import BaseProfiler
from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler, BaseProfiler
from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin
from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
@ -33,7 +32,7 @@ from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
from pytorch_lightning.trainer.training_io import TrainerIOMixin
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.trainer.supporters import TensorRunningMean
try:
@ -364,7 +363,7 @@ class Trainer(
# configure profiler
if profiler is True:
profiler = Profiler()
profiler = SimpleProfiler()
self.profiler = profiler or PassThroughProfiler()
# configure early stop callback
@ -490,10 +489,10 @@ class Trainer(
('print_nan_grads', (<class 'bool'>,), False),
('process_position', (<class 'int'>,), 0),
('profiler',
(<class 'pytorch_lightning.profiler.profiler.BaseProfiler'>,
(<class 'pytorch_lightning.profiler.profilers.BaseProfiler'>,
<class 'NoneType'>),
None),
...
...
"""
trainer_default_params = inspect.signature(cls).parameters
name_type_default = []

View File

@ -145,7 +145,7 @@ from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.trainer.supporters import TensorRunningMean
try:

View File

@ -8,7 +8,7 @@ import torch
import tests.base.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import CometLogger
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import LightningTestModel

View File

@ -5,7 +5,7 @@ import torch
import tests.base.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import (
LightningTestModel,
)

View File

@ -11,7 +11,7 @@ from pytorch_lightning.trainer.distrib_parts import (
parse_gpu_ids,
determine_root_gpu_device,
)
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import LightningTestModel
PRETEND_N_OF_GPUS = 16

View File

@ -8,7 +8,7 @@ import torch
import tests.base.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import (
LightningTestModel,
LightningTestModelWithoutHyperparametersArg,

View File

@ -57,6 +57,8 @@ def test_tbd_remove_in_v0_9_0_module_imports():
from pytorch_lightning.logging.test_tube import TestTubeLogger # noqa: F402
from pytorch_lightning.logging.wandb import WandbLogger # noqa: F402
from pytorch_lightning.profiler import SimpleProfiler, AdvancedProfiler # noqa: F402
class ModelVer0_6(LightTrainDataloader, LightEmptyTestStep, TestModelBase):

View File

@ -1,10 +1,10 @@
import tempfile
import os
import time
from pathlib import Path
import numpy as np
import pytest
from pytorch_lightning.profiler import AdvancedProfiler, Profiler
from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler
PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0001
@ -25,13 +25,13 @@ def _sleep_generator(durations):
@pytest.fixture
def simple_profiler():
profiler = Profiler()
profiler = SimpleProfiler()
return profiler
@pytest.fixture
def advanced_profiler():
profiler = AdvancedProfiler()
def advanced_profiler(tmpdir):
profiler = AdvancedProfiler(output_filename=os.path.join(tmpdir, "profiler.txt"))
return profiler
@ -168,12 +168,9 @@ def test_advanced_profiler_describe(tmpdir, advanced_profiler):
# record at least one event
with advanced_profiler.profile("test"):
pass
# log to stdout
# log to stdout and print to file
advanced_profiler.describe()
# print to file
advanced_profiler.output_filename = Path(tmpdir, "profiler.txt")
advanced_profiler.describe()
data = Path(advanced_profiler.output_filename).read_text()
data = Path(advanced_profiler.output_fname).read_text()
assert len(data) > 0

View File

@ -2,7 +2,7 @@ import pytest
import tests.base.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import (
TestModelBase,
LightningTestModel,

View File

@ -14,7 +14,7 @@ from pytorch_lightning.callbacks import (
)
from pytorch_lightning.core.lightning import load_hparams_from_tags_csv
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.utilities.debugging import MisconfigurationException
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import (
TestModelBase,
DictHparamsModel,