lightning/tests/parity_pytorch/measure.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

36 lines
1.1 KiB
Python
Raw Normal View History

import gc
import time
from typing import Callable
import torch
from tqdm import tqdm
def measure_loops(cls_model, kind: str, loop: Callable, num_runs: int = 10, num_epochs: int = 10):
"""Returns an array with the last loss from each epoch for each run."""
hist_losses = []
hist_durations = []
hist_memory = []
device_type = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cudnn.deterministic = True
for i in tqdm(range(num_runs), desc=f"{kind} with {cls_model.__name__}"):
gc.collect()
if device_type == "cuda":
torch.cuda.empty_cache()
torch.cuda.reset_accumulated_memory_stats()
torch.cuda.reset_peak_memory_stats()
time.sleep(1)
time_start = time.perf_counter()
final_loss, used_memory = loop(cls_model, idx=i, device_type=device_type, num_epochs=num_epochs)
time_end = time.perf_counter()
hist_losses.append(final_loss)
hist_durations.append(time_end - time_start)
hist_memory.append(used_memory)
return {"losses": hist_losses, "durations": hist_durations, "memory": hist_memory}