36 lines
1.1 KiB
Python
36 lines
1.1 KiB
Python
import gc
|
|
import time
|
|
from typing import Callable
|
|
|
|
import torch
|
|
from tqdm import tqdm
|
|
|
|
|
|
def measure_loops(cls_model, kind: str, loop: Callable, num_runs: int = 10, num_epochs: int = 10):
|
|
"""Returns an array with the last loss from each epoch for each run."""
|
|
hist_losses = []
|
|
hist_durations = []
|
|
hist_memory = []
|
|
|
|
device_type = "cuda" if torch.cuda.is_available() else "cpu"
|
|
torch.backends.cudnn.deterministic = True
|
|
for i in tqdm(range(num_runs), desc=f"{kind} with {cls_model.__name__}"):
|
|
gc.collect()
|
|
if device_type == "cuda":
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.reset_accumulated_memory_stats()
|
|
torch.cuda.reset_peak_memory_stats()
|
|
time.sleep(1)
|
|
|
|
time_start = time.perf_counter()
|
|
|
|
final_loss, used_memory = loop(cls_model, idx=i, device_type=device_type, num_epochs=num_epochs)
|
|
|
|
time_end = time.perf_counter()
|
|
|
|
hist_losses.append(final_loss)
|
|
hist_durations.append(time_end - time_start)
|
|
hist_memory.append(used_memory)
|
|
|
|
return {"losses": hist_losses, "durations": hist_durations, "memory": hist_memory}
|