From d96df75d6a1f3cc6c26168405f4780e30c10de6e Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 4 Jun 2020 11:20:12 -0400 Subject: [PATCH] testing new speed (#1587) * fixed new amp bugs * fixed new amp bugs * fixed new amp bugs * try exit * larger dataset * full mnist * full mnist * trainer * assert * .05 * .10, #4 * #5 * #5 * #5 * refactor * abs diff * speed * speed * speed * speed Co-authored-by: J. Borovec Co-authored-by: Jirka --- benchmarks/parity_modules.py | 77 +++++++++ ...{test_trainer_parity.py => test_parity.py} | 85 +++------- benchmarks/test_rnn_parity.py | 153 ------------------ tests/base/utils.py | 29 ++-- 4 files changed, 118 insertions(+), 226 deletions(-) create mode 100644 benchmarks/parity_modules.py rename benchmarks/{test_trainer_parity.py => test_parity.py} (53%) delete mode 100644 benchmarks/test_rnn_parity.py diff --git a/benchmarks/parity_modules.py b/benchmarks/parity_modules.py new file mode 100644 index 0000000000..344debe46a --- /dev/null +++ b/benchmarks/parity_modules.py @@ -0,0 +1,77 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import Dataset, DataLoader + +from pytorch_lightning import LightningModule +from tests.base.datasets import MNIST + + +class AverageDataset(Dataset): + def __init__(self, dataset_len=300, sequence_len=100): + self.dataset_len = dataset_len + self.sequence_len = sequence_len + self.input_seq = torch.randn(dataset_len, sequence_len, 10) + top, bottom = self.input_seq.chunk(2, -1) + self.output_seq = top + bottom.roll(shifts=1, dims=-1) + + def __len__(self): + return self.dataset_len + + def __getitem__(self, item): + return self.input_seq[item], self.output_seq[item] + + +class ParityModuleRNN(LightningModule): + def __init__(self): + super().__init__() + self.rnn = nn.LSTM(10, 20, batch_first=True) + self.linear_out = nn.Linear(in_features=20, out_features=5) + + def forward(self, x): + seq, last = self.rnn(x) + return self.linear_out(seq) + + def training_step(self, batch, batch_nb): + x, y = batch + y_hat = self(x) + loss = F.mse_loss(y_hat, y) + return {'loss': loss} + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0.02) + + def train_dataloader(self): + return DataLoader(AverageDataset(), batch_size=30) + + +class ParityModuleMNIST(LightningModule): + + def __init__(self): + super().__init__() + self.c_d1 = nn.Linear(in_features=28 * 28, out_features=128) + self.c_d1_bn = nn.BatchNorm1d(128) + self.c_d1_drop = nn.Dropout(0.3) + self.c_d2 = nn.Linear(in_features=128, out_features=10) + + def forward(self, x): + x = x.view(x.size(0), -1) + x = self.c_d1(x) + x = torch.tanh(x) + x = self.c_d1_bn(x) + x = self.c_d1_drop(x) + x = self.c_d2(x) + return x + + def training_step(self, batch, batch_nb): + x, y = batch + y_hat = self(x) + loss = F.cross_entropy(y_hat, y) + return {'loss': loss} + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0.02) + + def train_dataloader(self): + return DataLoader(MNIST(train=True, download=True,), + batch_size=128) diff --git a/benchmarks/test_trainer_parity.py b/benchmarks/test_parity.py similarity index 53% rename from benchmarks/test_trainer_parity.py rename to benchmarks/test_parity.py index d97eb07c5c..186dc57dac 100644 --- a/benchmarks/test_trainer_parity.py +++ b/benchmarks/test_parity.py @@ -1,75 +1,38 @@ -import os import time import numpy as np import pytest import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.utils.data import DataLoader -from torchvision import transforms + import tests.base.utils as tutils - -from pytorch_lightning import Trainer, LightningModule, seed_everything -from tests.base.datasets import TrialMNIST - - -class ParityMNIST(LightningModule): - - def __init__(self): - super(ParityMNIST, self).__init__() - self.c_d1 = nn.Linear(in_features=28 * 28, out_features=128) - self.c_d1_bn = nn.BatchNorm1d(128) - self.c_d1_drop = nn.Dropout(0.3) - self.c_d2 = nn.Linear(in_features=128, out_features=10) - - def forward(self, x): - x = x.view(x.size(0), -1) - x = self.c_d1(x) - x = torch.tanh(x) - x = self.c_d1_bn(x) - x = self.c_d1_drop(x) - x = self.c_d2(x) - return x - - def training_step(self, batch, batch_nb): - x, y = batch - y_hat = self(x) - loss = F.cross_entropy(y_hat, y) - return {'loss': loss} - - def configure_optimizers(self): - return torch.optim.Adam(self.parameters(), lr=0.02) - - def train_dataloader(self): - return DataLoader(TrialMNIST(train=True, - download=True, - num_samples=500, - digits=list(range(5))), - batch_size=128) +from benchmarks.parity_modules import ParityModuleRNN, ParityModuleMNIST +from pytorch_lightning import Trainer, seed_everything +@pytest.mark.parametrize('cls_model,max_diff', [ + (ParityModuleRNN, 0.05), + (ParityModuleMNIST, 0.5) +]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") -def test_pytorch_parity(tmpdir): +def test_pytorch_parity(tmpdir, cls_model, max_diff): """ Verify that the same pytorch and lightning models achieve the same results - :param tmpdir: - :return: """ - num_epochs = 2 + num_epochs = 4 num_rums = 3 - lightning_outs, pl_times = lightning_loop(ParityMNIST, num_rums, num_epochs) - manual_outs, pt_times = vanilla_loop(ParityMNIST, num_rums, num_epochs) + lightning_outs, pl_times = lightning_loop(cls_model, num_rums, num_epochs) + manual_outs, pt_times = vanilla_loop(cls_model, num_rums, num_epochs) # make sure the losses match exactly to 5 decimal places for pl_out, pt_out in zip(lightning_outs, manual_outs): np.testing.assert_almost_equal(pl_out, pt_out, 5) # the fist run initialize dataset (download & filter) - tutils.assert_speed_parity(pl_times[1:], pt_times[1:], num_epochs) + tutils.assert_speed_parity_absolute(pl_times[1:], pt_times[1:], + nb_epochs=num_epochs, max_diff=max_diff) -def vanilla_loop(MODEL, num_runs=10, num_epochs=10): +def vanilla_loop(cls_model, num_runs=10, num_epochs=10): """ Returns an array with the last loss from each epoch for each run """ @@ -86,7 +49,7 @@ def vanilla_loop(MODEL, num_runs=10, num_epochs=10): seed_everything(seed) # init model parts - model = MODEL() + model = cls_model() dl = model.train_dataloader() optimizer = model.configure_optimizers() @@ -94,15 +57,12 @@ def vanilla_loop(MODEL, num_runs=10, num_epochs=10): model = model.to(device) epoch_losses = [] - for epoch in range(num_epochs): + # as the first run is skipped, no need to run it long + for epoch in range(num_epochs if i > 0 else 1): # run through full training set for j, batch in enumerate(dl): - x, y = batch - x = x.cuda(0) - y = y.cuda(0) - batch = (x, y) - + batch = [x.to(device) for x in batch] loss_dict = model.training_step(batch, j) loss = loss_dict['loss'] loss.backward() @@ -120,7 +80,7 @@ def vanilla_loop(MODEL, num_runs=10, num_epochs=10): return errors, times -def lightning_loop(MODEL, num_runs=10, num_epochs=10): +def lightning_loop(cls_model, num_runs=10, num_epochs=10): errors = [] times = [] @@ -131,16 +91,19 @@ def lightning_loop(MODEL, num_runs=10, num_epochs=10): seed = i seed_everything(seed) - model = MODEL() + model = cls_model() # init model parts trainer = Trainer( - max_epochs=num_epochs, + # as the first run is skipped, no need to run it long + max_epochs=num_epochs if i > 0 else 1, progress_bar_refresh_rate=0, weights_summary=None, gpus=1, early_stop_callback=False, checkpoint_callback=False, deterministic=True, + logger=False, + replace_sampler_ddp=False, ) trainer.fit(model) diff --git a/benchmarks/test_rnn_parity.py b/benchmarks/test_rnn_parity.py deleted file mode 100644 index b535a9ddb5..0000000000 --- a/benchmarks/test_rnn_parity.py +++ /dev/null @@ -1,153 +0,0 @@ -import time - -import numpy as np -import pytest -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.utils.data import Dataset, DataLoader -import tests.base.utils as tutils - -from pytorch_lightning import Trainer, LightningModule, seed_everything - - -class AverageDataset(Dataset): - def __init__(self, dataset_len=300, sequence_len=100): - self.dataset_len = dataset_len - self.sequence_len = sequence_len - self.input_seq = torch.randn(dataset_len, sequence_len, 10) - top, bottom = self.input_seq.chunk(2, -1) - self.output_seq = top + bottom.roll(shifts=1, dims=-1) - - def __len__(self): - return self.dataset_len - - def __getitem__(self, item): - return self.input_seq[item], self.output_seq[item] - - -class ParityRNN(LightningModule): - def __init__(self): - super(ParityRNN, self).__init__() - self.rnn = nn.LSTM(10, 20, batch_first=True) - self.linear_out = nn.Linear(in_features=20, out_features=5) - - def forward(self, x): - seq, last = self.rnn(x) - return self.linear_out(seq) - - def training_step(self, batch, batch_nb): - x, y = batch - y_hat = self(x) - loss = F.mse_loss(y_hat, y) - return {'loss': loss} - - def configure_optimizers(self): - return torch.optim.Adam(self.parameters(), lr=0.02) - - def train_dataloader(self): - return DataLoader(AverageDataset(), batch_size=30) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") -def test_pytorch_parity(tmpdir): - """ - Verify that the same pytorch and lightning models achieve the same results - :param tmpdir: - :return: - """ - num_epochs = 2 - num_rums = 3 - - lightning_outs, pl_times = lightning_loop(ParityRNN, num_rums, num_epochs) - manual_outs, pt_times = vanilla_loop(ParityRNN, num_rums, num_epochs) - # make sure the losses match exactly to 5 decimal places - for pl_out, pt_out in zip(lightning_outs, manual_outs): - np.testing.assert_almost_equal(pl_out, pt_out, 8) - - tutils.assert_speed_parity(pl_times, pt_times, num_epochs) - - -def vanilla_loop(MODEL, num_runs=10, num_epochs=10): - """ - Returns an array with the last loss from each epoch for each run - """ - device = torch.device('cuda' if torch.cuda.is_available() else "cpu") - errors = [] - times = [] - - torch.backends.cudnn.deterministic = True - for i in range(num_runs): - time_start = time.perf_counter() - - # set seed - seed = i - seed_everything(seed) - - # init model parts - model = MODEL() - dl = model.train_dataloader() - optimizer = model.configure_optimizers() - - # model to GPU - model = model.to(device) - - epoch_losses = [] - for epoch in range(num_epochs): - - # run through full training set - for j, batch in enumerate(dl): - x, y = batch - x = x.cuda(0) - y = y.cuda(0) - batch = (x, y) - - loss_dict = model.training_step(batch, j) - loss = loss_dict['loss'] - loss.backward() - optimizer.step() - optimizer.zero_grad() - - # track last epoch loss - epoch_losses.append(loss.item()) - - time_end = time.perf_counter() - times.append(time_end - time_start) - - errors.append(epoch_losses[-1]) - - return errors, times - - -def lightning_loop(MODEL, num_runs=10, num_epochs=10): - errors = [] - times = [] - - for i in range(num_runs): - time_start = time.perf_counter() - - # set seed - seed = i - seed_everything(seed) - model = MODEL() - - # init model parts - trainer = Trainer( - max_epochs=num_epochs, - progress_bar_refresh_rate=0, - weights_summary=None, - gpus=1, - early_stop_callback=False, - checkpoint_callback=False, - distributed_backend='dp', - deterministic=True, - ) - trainer.fit(model) - - final_loss = trainer.running_loss.last().item() - errors.append(final_loss) - - time_end = time.perf_counter() - times.append(time_end - time_start) - - return errors, times diff --git a/tests/base/utils.py b/tests/base/utils.py index 4cb7e4cd67..467b5cb4fc 100644 --- a/tests/base/utils.py +++ b/tests/base/utils.py @@ -12,20 +12,25 @@ from tests import TEMP_PATH, RANDOM_PORTS, RANDOM_SEEDS from tests.base.model_template import EvalModelTemplate -def assert_speed_parity(pl_times, pt_times, num_epochs): - +def assert_speed_parity_relative(pl_times, pt_times, max_diff: float = 0.1): # assert speeds - max_diff_per_epoch = 0.65 - pl_times = np.asarray(pl_times) - pt_times = np.asarray(pt_times) - diffs = pl_times - pt_times - diffs = diffs / num_epochs - - assert np.alltrue(diffs < max_diff_per_epoch), \ - f"lightning was slower than PT (threshold {max_diff_per_epoch})" + diffs = np.asarray(pl_times) - np.asarray(pt_times) + # norm by vanila time + diffs = diffs / np.asarray(pt_times) + assert np.alltrue(diffs < max_diff), \ + f"lightning {diffs} was slower than PT (threshold {max_diff})" -def run_model_test_without_loggers(trainer_options, model, min_acc=0.50): +def assert_speed_parity_absolute(pl_times, pt_times, nb_epochs, max_diff: float = 0.6): + # assert speeds + diffs = np.asarray(pl_times) - np.asarray(pt_times) + # norm by vanila time + diffs = diffs / nb_epochs + assert np.alltrue(diffs < max_diff), \ + f"lightning {diffs} was slower than PT (threshold {max_diff})" + + +def run_model_test_without_loggers(trainer_options, model, min_acc: float = 0.50): reset_seed() # fit model @@ -54,7 +59,7 @@ def run_model_test_without_loggers(trainer_options, model, min_acc=0.50): trainer.optimizers, trainer.lr_schedulers = pretrained_model.configure_optimizers() -def run_model_test(trainer_options, model, on_gpu=True, version=None, with_hpc=True): +def run_model_test(trainer_options, model, on_gpu: bool = True, version=None, with_hpc: bool = True): reset_seed() save_dir = trainer_options['default_root_dir']