testing new speed (#1587)
* fixed new amp bugs * fixed new amp bugs * fixed new amp bugs * try exit * larger dataset * full mnist * full mnist * trainer * assert * .05 * .10, #4 * #5 * #5 * #5 * refactor * abs diff * speed * speed * speed * speed Co-authored-by: J. Borovec <jirka.borovec@seznam.cz> Co-authored-by: Jirka <jirka@pytorchlightning.ai>
This commit is contained in:
parent
4234992302
commit
d96df75d6a
|
@ -0,0 +1,77 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
from pytorch_lightning import LightningModule
|
||||
from tests.base.datasets import MNIST
|
||||
|
||||
|
||||
class AverageDataset(Dataset):
|
||||
def __init__(self, dataset_len=300, sequence_len=100):
|
||||
self.dataset_len = dataset_len
|
||||
self.sequence_len = sequence_len
|
||||
self.input_seq = torch.randn(dataset_len, sequence_len, 10)
|
||||
top, bottom = self.input_seq.chunk(2, -1)
|
||||
self.output_seq = top + bottom.roll(shifts=1, dims=-1)
|
||||
|
||||
def __len__(self):
|
||||
return self.dataset_len
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.input_seq[item], self.output_seq[item]
|
||||
|
||||
|
||||
class ParityModuleRNN(LightningModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.rnn = nn.LSTM(10, 20, batch_first=True)
|
||||
self.linear_out = nn.Linear(in_features=20, out_features=5)
|
||||
|
||||
def forward(self, x):
|
||||
seq, last = self.rnn(x)
|
||||
return self.linear_out(seq)
|
||||
|
||||
def training_step(self, batch, batch_nb):
|
||||
x, y = batch
|
||||
y_hat = self(x)
|
||||
loss = F.mse_loss(y_hat, y)
|
||||
return {'loss': loss}
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters(), lr=0.02)
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(AverageDataset(), batch_size=30)
|
||||
|
||||
|
||||
class ParityModuleMNIST(LightningModule):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.c_d1 = nn.Linear(in_features=28 * 28, out_features=128)
|
||||
self.c_d1_bn = nn.BatchNorm1d(128)
|
||||
self.c_d1_drop = nn.Dropout(0.3)
|
||||
self.c_d2 = nn.Linear(in_features=128, out_features=10)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.c_d1(x)
|
||||
x = torch.tanh(x)
|
||||
x = self.c_d1_bn(x)
|
||||
x = self.c_d1_drop(x)
|
||||
x = self.c_d2(x)
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_nb):
|
||||
x, y = batch
|
||||
y_hat = self(x)
|
||||
loss = F.cross_entropy(y_hat, y)
|
||||
return {'loss': loss}
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters(), lr=0.02)
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(MNIST(train=True, download=True,),
|
||||
batch_size=128)
|
|
@ -1,75 +1,38 @@
|
|||
import os
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
|
||||
import tests.base.utils as tutils
|
||||
|
||||
from pytorch_lightning import Trainer, LightningModule, seed_everything
|
||||
from tests.base.datasets import TrialMNIST
|
||||
|
||||
|
||||
class ParityMNIST(LightningModule):
|
||||
|
||||
def __init__(self):
|
||||
super(ParityMNIST, self).__init__()
|
||||
self.c_d1 = nn.Linear(in_features=28 * 28, out_features=128)
|
||||
self.c_d1_bn = nn.BatchNorm1d(128)
|
||||
self.c_d1_drop = nn.Dropout(0.3)
|
||||
self.c_d2 = nn.Linear(in_features=128, out_features=10)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.c_d1(x)
|
||||
x = torch.tanh(x)
|
||||
x = self.c_d1_bn(x)
|
||||
x = self.c_d1_drop(x)
|
||||
x = self.c_d2(x)
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_nb):
|
||||
x, y = batch
|
||||
y_hat = self(x)
|
||||
loss = F.cross_entropy(y_hat, y)
|
||||
return {'loss': loss}
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters(), lr=0.02)
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(TrialMNIST(train=True,
|
||||
download=True,
|
||||
num_samples=500,
|
||||
digits=list(range(5))),
|
||||
batch_size=128)
|
||||
from benchmarks.parity_modules import ParityModuleRNN, ParityModuleMNIST
|
||||
from pytorch_lightning import Trainer, seed_everything
|
||||
|
||||
|
||||
@pytest.mark.parametrize('cls_model,max_diff', [
|
||||
(ParityModuleRNN, 0.05),
|
||||
(ParityModuleMNIST, 0.5)
|
||||
])
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
|
||||
def test_pytorch_parity(tmpdir):
|
||||
def test_pytorch_parity(tmpdir, cls_model, max_diff):
|
||||
"""
|
||||
Verify that the same pytorch and lightning models achieve the same results
|
||||
:param tmpdir:
|
||||
:return:
|
||||
"""
|
||||
num_epochs = 2
|
||||
num_epochs = 4
|
||||
num_rums = 3
|
||||
lightning_outs, pl_times = lightning_loop(ParityMNIST, num_rums, num_epochs)
|
||||
manual_outs, pt_times = vanilla_loop(ParityMNIST, num_rums, num_epochs)
|
||||
lightning_outs, pl_times = lightning_loop(cls_model, num_rums, num_epochs)
|
||||
manual_outs, pt_times = vanilla_loop(cls_model, num_rums, num_epochs)
|
||||
|
||||
# make sure the losses match exactly to 5 decimal places
|
||||
for pl_out, pt_out in zip(lightning_outs, manual_outs):
|
||||
np.testing.assert_almost_equal(pl_out, pt_out, 5)
|
||||
|
||||
# the fist run initialize dataset (download & filter)
|
||||
tutils.assert_speed_parity(pl_times[1:], pt_times[1:], num_epochs)
|
||||
tutils.assert_speed_parity_absolute(pl_times[1:], pt_times[1:],
|
||||
nb_epochs=num_epochs, max_diff=max_diff)
|
||||
|
||||
|
||||
def vanilla_loop(MODEL, num_runs=10, num_epochs=10):
|
||||
def vanilla_loop(cls_model, num_runs=10, num_epochs=10):
|
||||
"""
|
||||
Returns an array with the last loss from each epoch for each run
|
||||
"""
|
||||
|
@ -86,7 +49,7 @@ def vanilla_loop(MODEL, num_runs=10, num_epochs=10):
|
|||
seed_everything(seed)
|
||||
|
||||
# init model parts
|
||||
model = MODEL()
|
||||
model = cls_model()
|
||||
dl = model.train_dataloader()
|
||||
optimizer = model.configure_optimizers()
|
||||
|
||||
|
@ -94,15 +57,12 @@ def vanilla_loop(MODEL, num_runs=10, num_epochs=10):
|
|||
model = model.to(device)
|
||||
|
||||
epoch_losses = []
|
||||
for epoch in range(num_epochs):
|
||||
# as the first run is skipped, no need to run it long
|
||||
for epoch in range(num_epochs if i > 0 else 1):
|
||||
|
||||
# run through full training set
|
||||
for j, batch in enumerate(dl):
|
||||
x, y = batch
|
||||
x = x.cuda(0)
|
||||
y = y.cuda(0)
|
||||
batch = (x, y)
|
||||
|
||||
batch = [x.to(device) for x in batch]
|
||||
loss_dict = model.training_step(batch, j)
|
||||
loss = loss_dict['loss']
|
||||
loss.backward()
|
||||
|
@ -120,7 +80,7 @@ def vanilla_loop(MODEL, num_runs=10, num_epochs=10):
|
|||
return errors, times
|
||||
|
||||
|
||||
def lightning_loop(MODEL, num_runs=10, num_epochs=10):
|
||||
def lightning_loop(cls_model, num_runs=10, num_epochs=10):
|
||||
errors = []
|
||||
times = []
|
||||
|
||||
|
@ -131,16 +91,19 @@ def lightning_loop(MODEL, num_runs=10, num_epochs=10):
|
|||
seed = i
|
||||
seed_everything(seed)
|
||||
|
||||
model = MODEL()
|
||||
model = cls_model()
|
||||
# init model parts
|
||||
trainer = Trainer(
|
||||
max_epochs=num_epochs,
|
||||
# as the first run is skipped, no need to run it long
|
||||
max_epochs=num_epochs if i > 0 else 1,
|
||||
progress_bar_refresh_rate=0,
|
||||
weights_summary=None,
|
||||
gpus=1,
|
||||
early_stop_callback=False,
|
||||
checkpoint_callback=False,
|
||||
deterministic=True,
|
||||
logger=False,
|
||||
replace_sampler_ddp=False,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
|
@ -1,153 +0,0 @@
|
|||
import time
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
import tests.base.utils as tutils
|
||||
|
||||
from pytorch_lightning import Trainer, LightningModule, seed_everything
|
||||
|
||||
|
||||
class AverageDataset(Dataset):
|
||||
def __init__(self, dataset_len=300, sequence_len=100):
|
||||
self.dataset_len = dataset_len
|
||||
self.sequence_len = sequence_len
|
||||
self.input_seq = torch.randn(dataset_len, sequence_len, 10)
|
||||
top, bottom = self.input_seq.chunk(2, -1)
|
||||
self.output_seq = top + bottom.roll(shifts=1, dims=-1)
|
||||
|
||||
def __len__(self):
|
||||
return self.dataset_len
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.input_seq[item], self.output_seq[item]
|
||||
|
||||
|
||||
class ParityRNN(LightningModule):
|
||||
def __init__(self):
|
||||
super(ParityRNN, self).__init__()
|
||||
self.rnn = nn.LSTM(10, 20, batch_first=True)
|
||||
self.linear_out = nn.Linear(in_features=20, out_features=5)
|
||||
|
||||
def forward(self, x):
|
||||
seq, last = self.rnn(x)
|
||||
return self.linear_out(seq)
|
||||
|
||||
def training_step(self, batch, batch_nb):
|
||||
x, y = batch
|
||||
y_hat = self(x)
|
||||
loss = F.mse_loss(y_hat, y)
|
||||
return {'loss': loss}
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters(), lr=0.02)
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(AverageDataset(), batch_size=30)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
|
||||
def test_pytorch_parity(tmpdir):
|
||||
"""
|
||||
Verify that the same pytorch and lightning models achieve the same results
|
||||
:param tmpdir:
|
||||
:return:
|
||||
"""
|
||||
num_epochs = 2
|
||||
num_rums = 3
|
||||
|
||||
lightning_outs, pl_times = lightning_loop(ParityRNN, num_rums, num_epochs)
|
||||
manual_outs, pt_times = vanilla_loop(ParityRNN, num_rums, num_epochs)
|
||||
# make sure the losses match exactly to 5 decimal places
|
||||
for pl_out, pt_out in zip(lightning_outs, manual_outs):
|
||||
np.testing.assert_almost_equal(pl_out, pt_out, 8)
|
||||
|
||||
tutils.assert_speed_parity(pl_times, pt_times, num_epochs)
|
||||
|
||||
|
||||
def vanilla_loop(MODEL, num_runs=10, num_epochs=10):
|
||||
"""
|
||||
Returns an array with the last loss from each epoch for each run
|
||||
"""
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
|
||||
errors = []
|
||||
times = []
|
||||
|
||||
torch.backends.cudnn.deterministic = True
|
||||
for i in range(num_runs):
|
||||
time_start = time.perf_counter()
|
||||
|
||||
# set seed
|
||||
seed = i
|
||||
seed_everything(seed)
|
||||
|
||||
# init model parts
|
||||
model = MODEL()
|
||||
dl = model.train_dataloader()
|
||||
optimizer = model.configure_optimizers()
|
||||
|
||||
# model to GPU
|
||||
model = model.to(device)
|
||||
|
||||
epoch_losses = []
|
||||
for epoch in range(num_epochs):
|
||||
|
||||
# run through full training set
|
||||
for j, batch in enumerate(dl):
|
||||
x, y = batch
|
||||
x = x.cuda(0)
|
||||
y = y.cuda(0)
|
||||
batch = (x, y)
|
||||
|
||||
loss_dict = model.training_step(batch, j)
|
||||
loss = loss_dict['loss']
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# track last epoch loss
|
||||
epoch_losses.append(loss.item())
|
||||
|
||||
time_end = time.perf_counter()
|
||||
times.append(time_end - time_start)
|
||||
|
||||
errors.append(epoch_losses[-1])
|
||||
|
||||
return errors, times
|
||||
|
||||
|
||||
def lightning_loop(MODEL, num_runs=10, num_epochs=10):
|
||||
errors = []
|
||||
times = []
|
||||
|
||||
for i in range(num_runs):
|
||||
time_start = time.perf_counter()
|
||||
|
||||
# set seed
|
||||
seed = i
|
||||
seed_everything(seed)
|
||||
model = MODEL()
|
||||
|
||||
# init model parts
|
||||
trainer = Trainer(
|
||||
max_epochs=num_epochs,
|
||||
progress_bar_refresh_rate=0,
|
||||
weights_summary=None,
|
||||
gpus=1,
|
||||
early_stop_callback=False,
|
||||
checkpoint_callback=False,
|
||||
distributed_backend='dp',
|
||||
deterministic=True,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
final_loss = trainer.running_loss.last().item()
|
||||
errors.append(final_loss)
|
||||
|
||||
time_end = time.perf_counter()
|
||||
times.append(time_end - time_start)
|
||||
|
||||
return errors, times
|
|
@ -12,20 +12,25 @@ from tests import TEMP_PATH, RANDOM_PORTS, RANDOM_SEEDS
|
|||
from tests.base.model_template import EvalModelTemplate
|
||||
|
||||
|
||||
def assert_speed_parity(pl_times, pt_times, num_epochs):
|
||||
|
||||
def assert_speed_parity_relative(pl_times, pt_times, max_diff: float = 0.1):
|
||||
# assert speeds
|
||||
max_diff_per_epoch = 0.65
|
||||
pl_times = np.asarray(pl_times)
|
||||
pt_times = np.asarray(pt_times)
|
||||
diffs = pl_times - pt_times
|
||||
diffs = diffs / num_epochs
|
||||
|
||||
assert np.alltrue(diffs < max_diff_per_epoch), \
|
||||
f"lightning was slower than PT (threshold {max_diff_per_epoch})"
|
||||
diffs = np.asarray(pl_times) - np.asarray(pt_times)
|
||||
# norm by vanila time
|
||||
diffs = diffs / np.asarray(pt_times)
|
||||
assert np.alltrue(diffs < max_diff), \
|
||||
f"lightning {diffs} was slower than PT (threshold {max_diff})"
|
||||
|
||||
|
||||
def run_model_test_without_loggers(trainer_options, model, min_acc=0.50):
|
||||
def assert_speed_parity_absolute(pl_times, pt_times, nb_epochs, max_diff: float = 0.6):
|
||||
# assert speeds
|
||||
diffs = np.asarray(pl_times) - np.asarray(pt_times)
|
||||
# norm by vanila time
|
||||
diffs = diffs / nb_epochs
|
||||
assert np.alltrue(diffs < max_diff), \
|
||||
f"lightning {diffs} was slower than PT (threshold {max_diff})"
|
||||
|
||||
|
||||
def run_model_test_without_loggers(trainer_options, model, min_acc: float = 0.50):
|
||||
reset_seed()
|
||||
|
||||
# fit model
|
||||
|
@ -54,7 +59,7 @@ def run_model_test_without_loggers(trainer_options, model, min_acc=0.50):
|
|||
trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers()
|
||||
|
||||
|
||||
def run_model_test(trainer_options, model, on_gpu=True, version=None, with_hpc=True):
|
||||
def run_model_test(trainer_options, model, on_gpu: bool = True, version=None, with_hpc: bool = True):
|
||||
reset_seed()
|
||||
save_dir = trainer_options['default_root_dir']
|
||||
|
||||
|
|
Loading…
Reference in New Issue