lightning/tests/parity_pytorch/measure.py

36 lines
1.1 KiB
Python

import gc
import time
from typing import Callable
import torch
from tqdm import tqdm
def measure_loops(cls_model, kind: str, loop: Callable, num_runs: int = 10, num_epochs: int = 10):
"""Returns an array with the last loss from each epoch for each run."""
hist_losses = []
hist_durations = []
hist_memory = []
device_type = "cuda" if torch.cuda.is_available() else "cpu"
torch.backends.cudnn.deterministic = True
for i in tqdm(range(num_runs), desc=f"{kind} with {cls_model.__name__}"):
gc.collect()
if device_type == "cuda":
torch.cuda.empty_cache()
torch.cuda.reset_accumulated_memory_stats()
torch.cuda.reset_peak_memory_stats()
time.sleep(1)
time_start = time.perf_counter()
final_loss, used_memory = loop(cls_model, idx=i, device_type=device_type, num_epochs=num_epochs)
time_end = time.perf_counter()
hist_losses.append(final_loss)
hist_durations.append(time_end - time_start)
hist_memory.append(used_memory)
return {"losses": hist_losses, "durations": hist_durations, "memory": hist_memory}