Profiler summary (#1259)
* refactor and add types * add Prorfiler summary * fix imports * Revert "refactor and add types" This reverts commit b4c552fa * changelog * revert rename * fix test * mute verbose
This commit is contained in:
parent
4dcb9d3e30
commit
6ddb03922a
|
@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Added a check that stops the training when loss or weights contain `NaN` or `inf` values. ([#1097](https://github.com/PyTorchLightning/pytorch-lightning/pull/1097))
|
||||
- Updated references to self.forward() to instead use the `__call__` interface. ([#1211](https://github.com/PyTorchLightning/pytorch-lightning/pull/1211))
|
||||
- Added support for `IterableDataset` when `val_check_interval=1.0` (default), this will trigger validation at the end of each epoch. ([#1283](https://github.com/PyTorchLightning/pytorch-lightning/pull/1283))
|
||||
- Added `summary` method to Profilers. ([#1259](https://github.com/PyTorchLightning/pytorch-lightning/pull/1259))
|
||||
- Added informative errors if user defined dataloader has zero length ([#1280](https://github.com/PyTorchLightning/pytorch-lightning/pull/1280))
|
||||
|
||||
### Changed
|
||||
|
@ -72,7 +73,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
- Support for user defined callbacks ([#889](https://github.com/PyTorchLightning/pytorch-lightning/pull/889) and [#950](https://github.com/PyTorchLightning/pytorch-lightning/pull/950))
|
||||
- Added support for multiple loggers to be passed to `Trainer` as an iterable (e.g. list, tuple, etc.) ([#903](https://github.com/PyTorchLightning/pytorch-lightning/pull/903))
|
||||
- Added support for step-based learning rate scheduling ([#941](https://github.com/PyTorchLightning/pytorch-lightning/pull/941))
|
||||
- Added support for logging hparams as dict ([#1029](https://github.com/PyTorchLightning/pytorch-lightning/pull/1029))
|
||||
- Added support for logging `hparams` as dict ([#1029](https://github.com/PyTorchLightning/pytorch-lightning/pull/1029))
|
||||
- Checkpoint and early stopping now work without val. step ([#1041](https://github.com/PyTorchLightning/pytorch-lightning/pull/1041))
|
||||
- Support graceful training cleanup after Keyboard Interrupt ([#856](https://github.com/PyTorchLightning/pytorch-lightning/pull/856), [#1019](https://github.com/PyTorchLightning/pytorch-lightning/pull/1019))
|
||||
- Added type hints for function arguments ([#912](https://github.com/PyTorchLightning/pytorch-lightning/pull/912), )
|
||||
|
|
|
@ -20,7 +20,7 @@ from pytorch_lightning.core.hooks import ModelHooks
|
|||
from pytorch_lightning.core.memory import ModelSummary
|
||||
from pytorch_lightning.core.saving import ModelIO, load_hparams_from_tags_csv
|
||||
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
|
||||
from pytorch_lightning.utilities.debugging import MisconfigurationException
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
try:
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
|
|
@ -28,7 +28,7 @@ from torch import is_tensor
|
|||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_only
|
||||
from pytorch_lightning.utilities.debugging import MisconfigurationException
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
class CometLogger(LightningLoggerBase):
|
||||
|
|
|
@ -3,7 +3,7 @@ Profiling your training run can help you understand if there are any bottlenecks
|
|||
|
||||
|
||||
Built-in checks
|
||||
----------------
|
||||
---------------
|
||||
|
||||
PyTorch Lightning supports profiling standard actions in the training loop out of the box, including:
|
||||
|
||||
|
@ -20,7 +20,7 @@ PyTorch Lightning supports profiling standard actions in the training loop out o
|
|||
- on_training_end
|
||||
|
||||
Enable simple profiling
|
||||
-------------------------
|
||||
-----------------------
|
||||
|
||||
If you only wish to profile the standard actions, you can set `profiler=True` when constructing
|
||||
your `Trainer` object.
|
||||
|
@ -113,10 +113,11 @@ to track and the profiler will record performance for code executed within this
|
|||
|
||||
"""
|
||||
|
||||
from pytorch_lightning.profiler.profiler import Profiler, AdvancedProfiler, PassThroughProfiler
|
||||
from pytorch_lightning.profiler.profilers import SimpleProfiler, AdvancedProfiler, PassThroughProfiler, BaseProfiler
|
||||
|
||||
__all__ = [
|
||||
'Profiler',
|
||||
'BaseProfiler',
|
||||
'SimpleProfiler',
|
||||
'AdvancedProfiler',
|
||||
'PassThroughProfiler',
|
||||
]
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import cProfile
|
||||
import io
|
||||
import os
|
||||
import pstats
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
|
@ -16,6 +17,18 @@ class BaseProfiler(ABC):
|
|||
If you wish to write a custom profiler, you should inhereit from this class.
|
||||
"""
|
||||
|
||||
def __init__(self, output_streams: list = None):
|
||||
"""
|
||||
Params:
|
||||
stream_out: callable
|
||||
"""
|
||||
if output_streams:
|
||||
if not isinstance(output_streams, (list, tuple)):
|
||||
output_streams = [output_streams]
|
||||
else:
|
||||
output_streams = []
|
||||
self.write_streams = output_streams
|
||||
|
||||
@abstractmethod
|
||||
def start(self, action_name: str) -> None:
|
||||
"""Defines how to start recording an action."""
|
||||
|
@ -57,7 +70,12 @@ class BaseProfiler(ABC):
|
|||
|
||||
def describe(self) -> None:
|
||||
"""Logs a profile report after the conclusion of the training run."""
|
||||
pass
|
||||
for write in self.write_streams:
|
||||
write(self.summary())
|
||||
|
||||
@abstractmethod
|
||||
def summary(self) -> str:
|
||||
"""Create profiler summary in text format."""
|
||||
|
||||
|
||||
class PassThroughProfiler(BaseProfiler):
|
||||
|
@ -67,7 +85,7 @@ class PassThroughProfiler(BaseProfiler):
|
|||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
super().__init__(output_streams=None)
|
||||
|
||||
def start(self, action_name: str) -> None:
|
||||
pass
|
||||
|
@ -75,17 +93,31 @@ class PassThroughProfiler(BaseProfiler):
|
|||
def stop(self, action_name: str) -> None:
|
||||
pass
|
||||
|
||||
def summary(self) -> str:
|
||||
return ""
|
||||
|
||||
class Profiler(BaseProfiler):
|
||||
|
||||
class SimpleProfiler(BaseProfiler):
|
||||
"""
|
||||
This profiler simply records the duration of actions (in seconds) and reports
|
||||
the mean duration of each action and the total time spent over the entire training run.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, output_filename: str = None):
|
||||
"""
|
||||
Params:
|
||||
output_filename (str): optionally save profile results to file instead of printing
|
||||
to std out when training is finished.
|
||||
"""
|
||||
self.current_actions = {}
|
||||
self.recorded_durations = defaultdict(list)
|
||||
|
||||
self.output_fname = output_filename
|
||||
self.output_file = open(self.output_fname, 'w') if self.output_fname else None
|
||||
|
||||
streaming_out = [self.output_file.write] if self.output_file else [log.info]
|
||||
super().__init__(output_streams=streaming_out)
|
||||
|
||||
def start(self, action_name: str) -> None:
|
||||
if action_name in self.current_actions:
|
||||
raise ValueError(
|
||||
|
@ -103,20 +135,31 @@ class Profiler(BaseProfiler):
|
|||
duration = end_time - start_time
|
||||
self.recorded_durations[action_name].append(duration)
|
||||
|
||||
def describe(self) -> None:
|
||||
def summary(self) -> str:
|
||||
output_string = "\n\nProfiler Report\n"
|
||||
|
||||
def log_row(action, mean, total):
|
||||
return f"\n{action:<20s}\t| {mean:<15}\t| {total:<15}"
|
||||
return f"{os.linesep}{action:<20s}\t| {mean:<15}\t| {total:<15}"
|
||||
|
||||
output_string += log_row("Action", "Mean duration (s)", "Total time (s)")
|
||||
output_string += f"\n{'-' * 65}"
|
||||
output_string += f"{os.linesep}{'-' * 65}"
|
||||
for action, durations in self.recorded_durations.items():
|
||||
output_string += log_row(
|
||||
action, f"{np.mean(durations):.5}", f"{np.sum(durations):.5}",
|
||||
)
|
||||
output_string += "\n"
|
||||
log.info(output_string)
|
||||
output_string += os.linesep
|
||||
return output_string
|
||||
|
||||
def describe(self):
|
||||
"""Logs a profile report after the conclusion of the training run."""
|
||||
super().describe()
|
||||
if self.output_file:
|
||||
self.output_file.flush()
|
||||
|
||||
def __del__(self):
|
||||
"""Close profiler's stream."""
|
||||
if self.output_file:
|
||||
self.output_file.close()
|
||||
|
||||
|
||||
class AdvancedProfiler(BaseProfiler):
|
||||
|
@ -136,9 +179,14 @@ class AdvancedProfiler(BaseProfiler):
|
|||
or a decimal fraction between 0.0 and 1.0 inclusive (to select a percentage of lines)
|
||||
"""
|
||||
self.profiled_actions = {}
|
||||
self.output_filename = output_filename
|
||||
self.line_count_restriction = line_count_restriction
|
||||
|
||||
self.output_fname = output_filename
|
||||
self.output_file = open(self.output_fname, 'w') if self.output_fname else None
|
||||
|
||||
streaming_out = [self.output_file.write] if self.output_file else [log.info]
|
||||
super().__init__(output_streams=streaming_out)
|
||||
|
||||
def start(self, action_name: str) -> None:
|
||||
if action_name not in self.profiled_actions:
|
||||
self.profiled_actions[action_name] = cProfile.Profile()
|
||||
|
@ -152,22 +200,28 @@ class AdvancedProfiler(BaseProfiler):
|
|||
)
|
||||
pr.disable()
|
||||
|
||||
def describe(self) -> None:
|
||||
self.recorded_stats = {}
|
||||
def summary(self) -> str:
|
||||
recorded_stats = {}
|
||||
for action_name, pr in self.profiled_actions.items():
|
||||
s = io.StringIO()
|
||||
ps = pstats.Stats(pr, stream=s).strip_dirs().sort_stats('cumulative')
|
||||
ps.print_stats(self.line_count_restriction)
|
||||
self.recorded_stats[action_name] = s.getvalue()
|
||||
if self.output_filename is not None:
|
||||
# save to file
|
||||
with open(self.output_filename, "w") as f:
|
||||
for action, stats in self.recorded_stats.items():
|
||||
f.write(f"Profile stats for: {action}")
|
||||
f.write(stats)
|
||||
else:
|
||||
# log to standard out
|
||||
output_string = "\nProfiler Report\n"
|
||||
for action, stats in self.recorded_stats.items():
|
||||
output_string += f"\nProfile stats for: {action}\n{stats}"
|
||||
log.info(output_string)
|
||||
recorded_stats[action_name] = s.getvalue()
|
||||
|
||||
# log to standard out
|
||||
output_string = f"{os.linesep}Profiler Report{os.linesep}"
|
||||
for action, stats in recorded_stats.items():
|
||||
output_string += f"{os.linesep}Profile stats for: {action}{os.linesep}{stats}"
|
||||
|
||||
return output_string
|
||||
|
||||
def describe(self):
|
||||
"""Logs a profile report after the conclusion of the training run."""
|
||||
super().describe()
|
||||
if self.output_file:
|
||||
self.output_file.flush()
|
||||
|
||||
def __del__(self):
|
||||
"""Close profiler's stream."""
|
||||
if self.output_file:
|
||||
self.output_file.close()
|
|
@ -6,7 +6,7 @@ from torch.utils.data import SequentialSampler, DataLoader
|
|||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from pytorch_lightning.core import LightningModule
|
||||
from pytorch_lightning.utilities.debugging import MisconfigurationException
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
|
|
|
@ -122,7 +122,7 @@ from typing import Union
|
|||
import torch
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.loggers import LightningLoggerBase
|
||||
from pytorch_lightning.utilities.debugging import MisconfigurationException
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
|
|
|
@ -344,7 +344,7 @@ from pytorch_lightning.overrides.data_parallel import (
|
|||
LightningDistributedDataParallel,
|
||||
LightningDataParallel,
|
||||
)
|
||||
from pytorch_lightning.utilities.debugging import MisconfigurationException
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
try:
|
||||
from apex import amp
|
||||
|
|
|
@ -135,7 +135,7 @@ from tqdm.auto import tqdm
|
|||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel, LightningDataParallel
|
||||
from pytorch_lightning.utilities.debugging import MisconfigurationException
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
try:
|
||||
import torch_xla.distributed.parallel_loader as xla_pl
|
||||
|
|
|
@ -18,8 +18,7 @@ from pytorch_lightning import _logger as log
|
|||
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.loggers import LightningLoggerBase
|
||||
from pytorch_lightning.profiler import Profiler, PassThroughProfiler
|
||||
from pytorch_lightning.profiler.profiler import BaseProfiler
|
||||
from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler, BaseProfiler
|
||||
from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin
|
||||
from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin
|
||||
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
|
||||
|
@ -33,7 +32,7 @@ from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
|
|||
from pytorch_lightning.trainer.training_io import TrainerIOMixin
|
||||
from pytorch_lightning.trainer.training_loop import TrainerTrainLoopMixin
|
||||
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
|
||||
from pytorch_lightning.utilities.debugging import MisconfigurationException
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.trainer.supporters import TensorRunningMean
|
||||
|
||||
try:
|
||||
|
@ -364,7 +363,7 @@ class Trainer(
|
|||
|
||||
# configure profiler
|
||||
if profiler is True:
|
||||
profiler = Profiler()
|
||||
profiler = SimpleProfiler()
|
||||
self.profiler = profiler or PassThroughProfiler()
|
||||
|
||||
# configure early stop callback
|
||||
|
@ -490,10 +489,10 @@ class Trainer(
|
|||
('print_nan_grads', (<class 'bool'>,), False),
|
||||
('process_position', (<class 'int'>,), 0),
|
||||
('profiler',
|
||||
(<class 'pytorch_lightning.profiler.profiler.BaseProfiler'>,
|
||||
(<class 'pytorch_lightning.profiler.profilers.BaseProfiler'>,
|
||||
<class 'NoneType'>),
|
||||
None),
|
||||
...
|
||||
...
|
||||
"""
|
||||
trainer_default_params = inspect.signature(cls).parameters
|
||||
name_type_default = []
|
||||
|
|
|
@ -145,7 +145,7 @@ from pytorch_lightning import _logger as log
|
|||
from pytorch_lightning.callbacks.base import Callback
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.loggers import LightningLoggerBase
|
||||
from pytorch_lightning.utilities.debugging import MisconfigurationException
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.trainer.supporters import TensorRunningMean
|
||||
|
||||
try:
|
||||
|
|
|
@ -8,7 +8,7 @@ import torch
|
|||
import tests.base.utils as tutils
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.loggers import CometLogger
|
||||
from pytorch_lightning.utilities.debugging import MisconfigurationException
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base import LightningTestModel
|
||||
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
|||
|
||||
import tests.base.utils as tutils
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.utilities.debugging import MisconfigurationException
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base import (
|
||||
LightningTestModel,
|
||||
)
|
||||
|
|
|
@ -11,7 +11,7 @@ from pytorch_lightning.trainer.distrib_parts import (
|
|||
parse_gpu_ids,
|
||||
determine_root_gpu_device,
|
||||
)
|
||||
from pytorch_lightning.utilities.debugging import MisconfigurationException
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base import LightningTestModel
|
||||
|
||||
PRETEND_N_OF_GPUS = 16
|
||||
|
|
|
@ -8,7 +8,7 @@ import torch
|
|||
import tests.base.utils as tutils
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.utilities.debugging import MisconfigurationException
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base import (
|
||||
LightningTestModel,
|
||||
LightningTestModelWithoutHyperparametersArg,
|
||||
|
|
|
@ -57,6 +57,8 @@ def test_tbd_remove_in_v0_9_0_module_imports():
|
|||
from pytorch_lightning.logging.test_tube import TestTubeLogger # noqa: F402
|
||||
from pytorch_lightning.logging.wandb import WandbLogger # noqa: F402
|
||||
|
||||
from pytorch_lightning.profiler import SimpleProfiler, AdvancedProfiler # noqa: F402
|
||||
|
||||
|
||||
class ModelVer0_6(LightTrainDataloader, LightEmptyTestStep, TestModelBase):
|
||||
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
import tempfile
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from pytorch_lightning.profiler import AdvancedProfiler, Profiler
|
||||
from pytorch_lightning.profiler import AdvancedProfiler, SimpleProfiler
|
||||
|
||||
PROFILER_OVERHEAD_MAX_TOLERANCE = 0.0001
|
||||
|
||||
|
@ -25,13 +25,13 @@ def _sleep_generator(durations):
|
|||
|
||||
@pytest.fixture
|
||||
def simple_profiler():
|
||||
profiler = Profiler()
|
||||
profiler = SimpleProfiler()
|
||||
return profiler
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def advanced_profiler():
|
||||
profiler = AdvancedProfiler()
|
||||
def advanced_profiler(tmpdir):
|
||||
profiler = AdvancedProfiler(output_filename=os.path.join(tmpdir, "profiler.txt"))
|
||||
return profiler
|
||||
|
||||
|
||||
|
@ -168,12 +168,9 @@ def test_advanced_profiler_describe(tmpdir, advanced_profiler):
|
|||
# record at least one event
|
||||
with advanced_profiler.profile("test"):
|
||||
pass
|
||||
# log to stdout
|
||||
# log to stdout and print to file
|
||||
advanced_profiler.describe()
|
||||
# print to file
|
||||
advanced_profiler.output_filename = Path(tmpdir, "profiler.txt")
|
||||
advanced_profiler.describe()
|
||||
data = Path(advanced_profiler.output_filename).read_text()
|
||||
data = Path(advanced_profiler.output_fname).read_text()
|
||||
assert len(data) > 0
|
||||
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@ import pytest
|
|||
|
||||
import tests.base.utils as tutils
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning.utilities.debugging import MisconfigurationException
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base import (
|
||||
TestModelBase,
|
||||
LightningTestModel,
|
||||
|
|
|
@ -14,7 +14,7 @@ from pytorch_lightning.callbacks import (
|
|||
)
|
||||
from pytorch_lightning.core.lightning import load_hparams_from_tags_csv
|
||||
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
|
||||
from pytorch_lightning.utilities.debugging import MisconfigurationException
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.base import (
|
||||
TestModelBase,
|
||||
DictHparamsModel,
|
||||
|
|
Loading…
Reference in New Issue