170 lines
6.0 KiB
Python
170 lines
6.0 KiB
Python
# Copyright The PyTorch Lightning team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
import gc
|
|
import os
|
|
import time
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
from tqdm import tqdm
|
|
|
|
from pytorch_lightning import LightningModule, seed_everything, Trainer
|
|
from tests.helpers.advanced_models import ParityModuleCIFAR, ParityModuleMNIST, ParityModuleRNN
|
|
|
|
_EXTEND_BENCHMARKS = os.getenv("PL_RUNNING_BENCHMARKS", "0") == "1"
|
|
_SHORT_BENCHMARKS = not _EXTEND_BENCHMARKS
|
|
_MARK_SHORT_BM = pytest.mark.skipif(_SHORT_BENCHMARKS, reason="Only run during Benchmarking")
|
|
|
|
|
|
def assert_parity_relative(pl_values, pt_values, norm_by: float = 1, max_diff: float = 0.1):
|
|
# assert speeds
|
|
diffs = np.asarray(pl_values) - np.mean(pt_values)
|
|
# norm by vanilla time
|
|
diffs = diffs / norm_by
|
|
# relative to mean reference value
|
|
diffs = diffs / np.mean(pt_values)
|
|
assert np.mean(diffs) < max_diff, f"Lightning diff {diffs} was worse than vanilla PT (threshold {max_diff})"
|
|
|
|
|
|
def assert_parity_absolute(pl_values, pt_values, norm_by: float = 1, max_diff: float = 0.55):
|
|
# assert speeds
|
|
diffs = np.asarray(pl_values) - np.mean(pt_values)
|
|
# norm by event count
|
|
diffs = diffs / norm_by
|
|
assert np.mean(diffs) < max_diff, f"Lightning {diffs} was worse than vanilla PT (threshold {max_diff})"
|
|
|
|
|
|
# ParityModuleMNIST runs with num_workers=1
|
|
@pytest.mark.parametrize(
|
|
"cls_model,max_diff_speed,max_diff_memory,num_epochs,num_runs",
|
|
[
|
|
(ParityModuleRNN, 0.05, 0.001, 4, 3),
|
|
(ParityModuleMNIST, 0.3, 0.001, 4, 3), # todo: lower this thr
|
|
pytest.param(ParityModuleCIFAR, 4.0, 0.0002, 2, 2, marks=_MARK_SHORT_BM),
|
|
],
|
|
)
|
|
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
|
|
def test_pytorch_parity(
|
|
tmpdir, cls_model: LightningModule, max_diff_speed: float, max_diff_memory: float, num_epochs: int, num_runs: int
|
|
):
|
|
"""Verify that the same pytorch and lightning models achieve the same results."""
|
|
lightning = measure_loops(cls_model, kind="PT Lightning", num_epochs=num_epochs, num_runs=num_runs)
|
|
vanilla = measure_loops(cls_model, kind="Vanilla PT", num_epochs=num_epochs, num_runs=num_runs)
|
|
|
|
# make sure the losses match exactly to 5 decimal places
|
|
print(f"Losses are for... \n vanilla: {vanilla['losses']} \n lightning: {lightning['losses']}")
|
|
for pl_out, pt_out in zip(lightning["losses"], vanilla["losses"]):
|
|
np.testing.assert_almost_equal(pl_out, pt_out, 5)
|
|
|
|
# drop the first run for initialize dataset (download & filter)
|
|
assert_parity_absolute(
|
|
lightning["durations"][1:], vanilla["durations"][1:], norm_by=num_epochs, max_diff=max_diff_speed
|
|
)
|
|
|
|
assert_parity_relative(lightning["memory"], vanilla["memory"], max_diff=max_diff_memory)
|
|
|
|
|
|
def _hook_memory():
|
|
if torch.cuda.is_available():
|
|
torch.cuda.synchronize()
|
|
used_memory = torch.cuda.max_memory_allocated()
|
|
else:
|
|
used_memory = np.nan
|
|
return used_memory
|
|
|
|
|
|
def measure_loops(cls_model, kind, num_runs=10, num_epochs=10):
|
|
"""Returns an array with the last loss from each epoch for each run."""
|
|
hist_losses = []
|
|
hist_durations = []
|
|
hist_memory = []
|
|
|
|
device_type = "cuda" if torch.cuda.is_available() else "cpu"
|
|
torch.backends.cudnn.deterministic = True
|
|
for i in tqdm(range(num_runs), desc=f"{kind} with {cls_model.__name__}"):
|
|
gc.collect()
|
|
if device_type == "cuda":
|
|
torch.cuda.empty_cache()
|
|
torch.cuda.reset_accumulated_memory_stats()
|
|
torch.cuda.reset_peak_memory_stats()
|
|
time.sleep(1)
|
|
|
|
time_start = time.perf_counter()
|
|
|
|
_loop = lightning_loop if kind == "PT Lightning" else vanilla_loop
|
|
final_loss, used_memory = _loop(cls_model, idx=i, device_type=device_type, num_epochs=num_epochs)
|
|
|
|
time_end = time.perf_counter()
|
|
|
|
hist_losses.append(final_loss)
|
|
hist_durations.append(time_end - time_start)
|
|
hist_memory.append(used_memory)
|
|
|
|
return {"losses": hist_losses, "durations": hist_durations, "memory": hist_memory}
|
|
|
|
|
|
def vanilla_loop(cls_model, idx, device_type: str = "cuda", num_epochs=10):
|
|
device = torch.device(device_type)
|
|
# set seed
|
|
seed_everything(idx)
|
|
|
|
# init model parts
|
|
model = cls_model()
|
|
dl = model.train_dataloader()
|
|
optimizer = model.configure_optimizers()
|
|
|
|
# model to GPU
|
|
model = model.to(device)
|
|
|
|
epoch_losses = []
|
|
# as the first run is skipped, no need to run it long
|
|
for epoch in range(num_epochs if idx > 0 else 1):
|
|
|
|
# run through full training set
|
|
for j, batch in enumerate(dl):
|
|
batch = [x.to(device) for x in batch]
|
|
loss_dict = model.training_step(batch, j)
|
|
loss = loss_dict["loss"]
|
|
loss.backward()
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|
|
# track last epoch loss
|
|
epoch_losses.append(loss.item())
|
|
|
|
return epoch_losses[-1], _hook_memory()
|
|
|
|
|
|
def lightning_loop(cls_model, idx, device_type: str = "cuda", num_epochs=10):
|
|
seed_everything(idx)
|
|
torch.backends.cudnn.deterministic = True
|
|
|
|
model = cls_model()
|
|
# init model parts
|
|
trainer = Trainer(
|
|
# as the first run is skipped, no need to run it long
|
|
max_epochs=num_epochs if idx > 0 else 1,
|
|
enable_progress_bar=False,
|
|
enable_model_summary=False,
|
|
enable_checkpointing=False,
|
|
accelerator="gpu" if device_type == "cuda" else "cpu",
|
|
devices=1,
|
|
logger=False,
|
|
replace_sampler_ddp=False,
|
|
)
|
|
trainer.fit(model)
|
|
|
|
return trainer.fit_loop.running_loss.last().item(), _hook_memory()
|