better simple profiler
This commit is contained in:
parent
dee968f20b
commit
bf573607f8
|
@ -26,6 +26,7 @@ from typing import Optional, Union
|
|||
|
||||
import fsspec
|
||||
import numpy as np
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
|
||||
|
@ -121,7 +122,7 @@ class SimpleProfiler(BaseProfiler):
|
|||
the mean duration of each action and the total time spent over the entire training run.
|
||||
"""
|
||||
|
||||
def __init__(self, output_filename: Optional[str] = None):
|
||||
def __init__(self, output_filename: Optional[str] = None, extended=True):
|
||||
"""
|
||||
Args:
|
||||
output_filename: optionally save profile results to file instead of printing
|
||||
|
@ -129,6 +130,7 @@ class SimpleProfiler(BaseProfiler):
|
|||
"""
|
||||
self.current_actions = {}
|
||||
self.recorded_durations = defaultdict(list)
|
||||
self.extended = extended
|
||||
|
||||
self.output_fname = output_filename
|
||||
self.output_file = None
|
||||
|
@ -137,6 +139,7 @@ class SimpleProfiler(BaseProfiler):
|
|||
self.output_file = fs.open(self.output_fname, "w")
|
||||
|
||||
streaming_out = [self.output_file.write] if self.output_file else [log.info]
|
||||
self.start_time = time.monotonic()
|
||||
super().__init__(output_streams=streaming_out)
|
||||
|
||||
def start(self, action_name: str) -> None:
|
||||
|
@ -156,18 +159,43 @@ class SimpleProfiler(BaseProfiler):
|
|||
duration = end_time - start_time
|
||||
self.recorded_durations[action_name].append(duration)
|
||||
|
||||
def make_report(self):
|
||||
total_duration = time.monotonic() - self.start_time
|
||||
report = [[a, d, 100. * np.sum(d) / total_duration] for a, d in self.recorded_durations.items()]
|
||||
report.sort(key=lambda x: x[2], reverse=True)
|
||||
return report, total_duration
|
||||
|
||||
def summary(self) -> str:
|
||||
output_string = "\n\nProfiler Report\n"
|
||||
|
||||
def log_row(action, mean, total):
|
||||
return f"{os.linesep}{action:<20s}\t| {mean:<15}\t| {total:<15}"
|
||||
if self.extended:
|
||||
|
||||
output_string += log_row("Action", "Mean duration (s)", "Total time (s)")
|
||||
output_string += f"{os.linesep}{'-' * 65}"
|
||||
for action, durations in self.recorded_durations.items():
|
||||
output_string += log_row(
|
||||
action, f"{np.mean(durations):.5}", f"{np.sum(durations):.5}"
|
||||
)
|
||||
max_key = np.max([len(k) for k in self.recorded_durations.keys()])
|
||||
|
||||
def log_row(action, mean, num_calls, total, per):
|
||||
return f"{os.linesep}{action:<{max_key}s}\t| {mean:<15}\t| {num_calls:<15}\t| {total:<15}\t| {per:<15}\t|"
|
||||
|
||||
output_string += log_row("Action", "Mean duration (s)", "Num calls", "Total time (s)", "Percentage %")
|
||||
output_string_len = len(output_string)
|
||||
output_string += f"{os.linesep}{'-' * output_string_len}"
|
||||
report, total_duration = self.make_report()
|
||||
output_string += log_row("Total", "-", "_", f"{total_duration:.5}", "100 %")
|
||||
output_string += f"{os.linesep}{'-' * output_string_len}"
|
||||
for action, durations, duration_per in report:
|
||||
output_string += log_row(
|
||||
action, f"{np.mean(durations):.5}", f"{len(durations):}", f"{np.sum(durations):.5}", f"{duration_per:.5}"
|
||||
)
|
||||
else:
|
||||
def log_row(action, mean, total):
|
||||
return f"{os.linesep}{action:<20s}\t| {mean:<15}\t| {total:<15}"
|
||||
|
||||
output_string += log_row("Action", "Mean duration (s)", "Total time (s)")
|
||||
output_string += f"{os.linesep}{'-' * 65}"
|
||||
|
||||
for action, durations in self.recorded_durations.items():
|
||||
output_string += log_row(
|
||||
action, f"{np.mean(durations):.5}", f"{np.sum(durations):.5}"
|
||||
)
|
||||
output_string += os.linesep
|
||||
return output_string
|
||||
|
||||
|
|
Loading…
Reference in New Issue