better simple profiler

This commit is contained in:
tchaton 2020-11-27 16:34:51 +00:00
parent dee968f20b
commit bf573607f8
1 changed files with 37 additions and 9 deletions

View File

@ -26,6 +26,7 @@ from typing import Optional, Union
import fsspec
import numpy as np
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities.cloud_io import get_filesystem
@ -121,7 +122,7 @@ class SimpleProfiler(BaseProfiler):
the mean duration of each action and the total time spent over the entire training run.
"""
def __init__(self, output_filename: Optional[str] = None):
def __init__(self, output_filename: Optional[str] = None, extended=True):
"""
Args:
output_filename: optionally save profile results to file instead of printing
@ -129,6 +130,7 @@ class SimpleProfiler(BaseProfiler):
"""
self.current_actions = {}
self.recorded_durations = defaultdict(list)
self.extended = extended
self.output_fname = output_filename
self.output_file = None
@ -137,6 +139,7 @@ class SimpleProfiler(BaseProfiler):
self.output_file = fs.open(self.output_fname, "w")
streaming_out = [self.output_file.write] if self.output_file else [log.info]
self.start_time = time.monotonic()
super().__init__(output_streams=streaming_out)
def start(self, action_name: str) -> None:
@ -156,18 +159,43 @@ class SimpleProfiler(BaseProfiler):
duration = end_time - start_time
self.recorded_durations[action_name].append(duration)
def make_report(self):
total_duration = time.monotonic() - self.start_time
report = [[a, d, 100. * np.sum(d) / total_duration] for a, d in self.recorded_durations.items()]
report.sort(key=lambda x: x[2], reverse=True)
return report, total_duration
def summary(self) -> str:
output_string = "\n\nProfiler Report\n"
def log_row(action, mean, total):
return f"{os.linesep}{action:<20s}\t| {mean:<15}\t| {total:<15}"
if self.extended:
output_string += log_row("Action", "Mean duration (s)", "Total time (s)")
output_string += f"{os.linesep}{'-' * 65}"
for action, durations in self.recorded_durations.items():
output_string += log_row(
action, f"{np.mean(durations):.5}", f"{np.sum(durations):.5}"
)
max_key = np.max([len(k) for k in self.recorded_durations.keys()])
def log_row(action, mean, num_calls, total, per):
return f"{os.linesep}{action:<{max_key}s}\t| {mean:<15}\t| {num_calls:<15}\t| {total:<15}\t| {per:<15}\t|"
output_string += log_row("Action", "Mean duration (s)", "Num calls", "Total time (s)", "Percentage %")
output_string_len = len(output_string)
output_string += f"{os.linesep}{'-' * output_string_len}"
report, total_duration = self.make_report()
output_string += log_row("Total", "-", "_", f"{total_duration:.5}", "100 %")
output_string += f"{os.linesep}{'-' * output_string_len}"
for action, durations, duration_per in report:
output_string += log_row(
action, f"{np.mean(durations):.5}", f"{len(durations):}", f"{np.sum(durations):.5}", f"{duration_per:.5}"
)
else:
def log_row(action, mean, total):
return f"{os.linesep}{action:<20s}\t| {mean:<15}\t| {total:<15}"
output_string += log_row("Action", "Mean duration (s)", "Total time (s)")
output_string += f"{os.linesep}{'-' * 65}"
for action, durations in self.recorded_durations.items():
output_string += log_row(
action, f"{np.mean(durations):.5}", f"{np.sum(durations):.5}"
)
output_string += os.linesep
return output_string