diff --git a/benchmarks/generate_comparison.py b/benchmarks/generate_comparison.py index 69eb47cb7e..6b5a0680a6 100644 --- a/benchmarks/generate_comparison.py +++ b/benchmarks/generate_comparison.py @@ -16,7 +16,7 @@ import os import matplotlib.pylab as plt import pandas as pd -from benchmarks.test_basic_parity import lightning_loop, vanilla_loop +from benchmarks.test_basic_parity import measure_loops from tests.base.models import ParityModuleMNIST, ParityModuleRNN NUM_EPOCHS = 20 @@ -34,8 +34,9 @@ def _main(): if os.path.isfile(path_csv): df_time = pd.read_csv(path_csv, index_col=0) else: - vanilla = vanilla_loop(cls_model, num_epochs=NUM_EPOCHS, num_runs=NUM_RUNS) - lightning = lightning_loop(cls_model, num_epochs=NUM_EPOCHS, num_runs=NUM_RUNS) + # todo: kind="Vanilla PT" -> use_lightning=False + vanilla = measure_loops(cls_model, kind="Vanilla PT", num_epochs=NUM_EPOCHS, num_runs=NUM_RUNS) + lightning = measure_loops(cls_model, kind="PT Lightning", num_epochs=NUM_EPOCHS, num_runs=NUM_RUNS) df_time = pd.DataFrame({'vanilla PT': vanilla['durations'][1:], 'PT Lightning': lightning['durations'][1:]}) df_time /= NUM_RUNS diff --git a/benchmarks/test_basic_parity.py b/benchmarks/test_basic_parity.py index c85984b092..ce3d831f09 100644 --- a/benchmarks/test_basic_parity.py +++ b/benchmarks/test_basic_parity.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import gc import time import numpy as np @@ -19,118 +19,156 @@ import pytest import torch from tqdm import tqdm -from pytorch_lightning import seed_everything, Trainer -import tests.base.develop_utils as tutils +from pytorch_lightning import LightningModule, seed_everything, Trainer from tests.base.models import ParityModuleMNIST, ParityModuleRNN +def assert_parity_relative(pl_values, pt_values, norm_by: float = 1, max_diff: float = 0.1): + # assert speeds + diffs = np.asarray(pl_values) - np.mean(pt_values) + # norm by vanilla time + diffs = diffs / norm_by + # relative to mean reference value + diffs = diffs / np.mean(pt_values) + assert np.mean(diffs) < max_diff, f"Lightning diff {diffs} was worse than vanilla PT (threshold {max_diff})" + + +def assert_parity_absolute(pl_values, pt_values, norm_by: float = 1, max_diff: float = 0.55): + # assert speeds + diffs = np.asarray(pl_values) - np.mean(pt_values) + # norm by event count + diffs = diffs / norm_by + assert np.mean(diffs) < max_diff, f"Lightning {diffs} was worse than vanilla PT (threshold {max_diff})" + + # ParityModuleMNIST runs with num_workers=1 -@pytest.mark.parametrize('cls_model,max_diff', [ - (ParityModuleRNN, 0.05), - (ParityModuleMNIST, 0.25), # todo: lower this thr +@pytest.mark.parametrize('cls_model,max_diff_speed,max_diff_memory', [ + (ParityModuleRNN, 0.05, 0.0), + (ParityModuleMNIST, 0.25, 0.0), # todo: lower this thr ]) @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") -def test_pytorch_parity(tmpdir, cls_model, max_diff: float, num_epochs: int = 4, num_runs: int = 3): +def test_pytorch_parity( + tmpdir, + cls_model: LightningModule, + max_diff_speed: float, + max_diff_memory: float, + num_epochs: int = 4, + num_runs: int = 3, +): """ Verify that the same pytorch and lightning models achieve the same results """ - lightning = lightning_loop(cls_model, num_runs, num_epochs) - vanilla = vanilla_loop(cls_model, num_runs, num_epochs) + lightning = measure_loops(cls_model, kind="PT Lightning", num_epochs=num_epochs, num_runs=num_runs) + vanilla = measure_loops(cls_model, kind="Vanilla PT", num_epochs=num_epochs, num_runs=num_runs) # make sure the losses match exactly to 5 decimal places + print(f"Losses are for... \n vanilla: {vanilla['losses']} \n lightning: {lightning['losses']}") for pl_out, pt_out in zip(lightning['losses'], vanilla['losses']): np.testing.assert_almost_equal(pl_out, pt_out, 5) - # the fist run initialize dataset (download & filter) - tutils.assert_speed_parity_absolute( - lightning['durations'][1:], vanilla['durations'][1:], nb_epochs=num_epochs, max_diff=max_diff + # drop the first run for initialize dataset (download & filter) + assert_parity_absolute( + lightning['durations'][1:], vanilla['durations'][1:], norm_by=num_epochs, max_diff=max_diff_speed ) + assert_parity_relative(lightning['memory'], vanilla['memory'], max_diff=max_diff_memory) -def vanilla_loop(cls_model, num_runs=10, num_epochs=10): + +def _hook_memory(): + if torch.cuda.is_available(): + torch.cuda.synchronize() + used_memory = torch.cuda.max_memory_allocated() + else: + used_memory = np.nan + return used_memory + + +def measure_loops(cls_model, kind, num_runs=10, num_epochs=10): """ Returns an array with the last loss from each epoch for each run """ hist_losses = [] hist_durations = [] + hist_memory = [] - device = torch.device('cuda' if torch.cuda.is_available() else "cpu") + device_type = "cuda" if torch.cuda.is_available() else "cpu" torch.backends.cudnn.deterministic = True - for i in tqdm(range(num_runs), desc=f'Vanilla PT with {cls_model.__name__}'): + for i in tqdm(range(num_runs), desc=f'{kind} with {cls_model.__name__}'): + gc.collect() + if device_type == 'cuda': + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_cached() + torch.cuda.reset_max_memory_allocated() + torch.cuda.reset_accumulated_memory_stats() + torch.cuda.reset_peak_memory_stats() + time.sleep(1) + time_start = time.perf_counter() - # set seed - seed = i - seed_everything(seed) - - # init model parts - model = cls_model() - dl = model.train_dataloader() - optimizer = model.configure_optimizers() - - # model to GPU - model = model.to(device) - - epoch_losses = [] - # as the first run is skipped, no need to run it long - for epoch in range(num_epochs if i > 0 else 1): - - # run through full training set - for j, batch in enumerate(dl): - batch = [x.to(device) for x in batch] - loss_dict = model.training_step(batch, j) - loss = loss_dict['loss'] - loss.backward() - optimizer.step() - optimizer.zero_grad() - - # track last epoch loss - epoch_losses.append(loss.item()) + _loop = lightning_loop if kind == "PT Lightning" else vanilla_loop + final_loss, used_memory = _loop(cls_model, idx=i, device_type=device_type, num_epochs=num_epochs) time_end = time.perf_counter() - hist_durations.append(time_end - time_start) - hist_losses.append(epoch_losses[-1]) - - return { - 'losses': hist_losses, - 'durations': hist_durations, - } - - -def lightning_loop(cls_model, num_runs=10, num_epochs=10): - hist_losses = [] - hist_durations = [] - - for i in tqdm(range(num_runs), desc=f'PT Lightning with {cls_model.__name__}'): - time_start = time.perf_counter() - - # set seed - seed = i - seed_everything(seed) - - model = cls_model() - # init model parts - trainer = Trainer( - # as the first run is skipped, no need to run it long - max_epochs=num_epochs if i > 0 else 1, - progress_bar_refresh_rate=0, - weights_summary=None, - gpus=1, - checkpoint_callback=False, - deterministic=True, - logger=False, - replace_sampler_ddp=False, - ) - trainer.fit(model) - - final_loss = trainer.train_loop.running_loss.last().item() hist_losses.append(final_loss) - - time_end = time.perf_counter() hist_durations.append(time_end - time_start) + hist_memory.append(used_memory) return { 'losses': hist_losses, 'durations': hist_durations, + 'memory': hist_memory, } + + +def vanilla_loop(cls_model, idx, device_type: str = 'cuda', num_epochs=10): + device = torch.device(device_type) + # set seed + seed_everything(idx) + + # init model parts + model = cls_model() + dl = model.train_dataloader() + optimizer = model.configure_optimizers() + + # model to GPU + model = model.to(device) + + epoch_losses = [] + # as the first run is skipped, no need to run it long + for epoch in range(num_epochs if idx > 0 else 1): + + # run through full training set + for j, batch in enumerate(dl): + batch = [x.to(device) for x in batch] + loss_dict = model.training_step(batch, j) + loss = loss_dict['loss'] + loss.backward() + optimizer.step() + optimizer.zero_grad() + + # track last epoch loss + epoch_losses.append(loss.item()) + + return epoch_losses[-1], _hook_memory() + + +def lightning_loop(cls_model, idx, device_type: str = 'cuda', num_epochs=10): + seed_everything(idx) + + model = cls_model() + # init model parts + trainer = Trainer( + # as the first run is skipped, no need to run it long + max_epochs=num_epochs if idx > 0 else 1, + progress_bar_refresh_rate=0, + weights_summary=None, + gpus=1 if device_type == 'cuda' else 0, + checkpoint_callback=False, + deterministic=True, + logger=False, + replace_sampler_ddp=False, + ) + trainer.fit(model) + + return trainer.train_loop.running_loss.last().item(), _hook_memory() diff --git a/benchmarks/test_sharded_parity.py b/benchmarks/test_sharded_parity.py index 7bb29ab31b..fae343d921 100644 --- a/benchmarks/test_sharded_parity.py +++ b/benchmarks/test_sharded_parity.py @@ -28,35 +28,32 @@ from tests.backends import DDPLauncher from tests.base.boring_model import BoringModel, RandomDataset -@pytest.mark.skipif(platform.system() == "Windows", - reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_one_device(): plugin_parity_test( accelerator='ddp_cpu', - max_percent_speed_diff=0.15, # slower speed due to one CPU doing additional sequential memory saving calls plugin=DDPShardedPlugin(), - model_cls=SeedTrainLoaderModel + model_cls=SeedTrainLoaderModel, + max_percent_speed_diff=0.15, # todo: slower speed due to one CPU doing additional sequential memory saving calls ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") -@pytest.mark.skipif(platform.system() == "Windows", - reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_one_gpu(): plugin_parity_test( gpus=1, accelerator='ddp_spawn', plugin=DDPShardedPlugin(), - model_cls=SeedTrainLoaderModel + model_cls=SeedTrainLoaderModel, ) @pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Requires native AMP") @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") -@pytest.mark.skipif(platform.system() == "Windows", - reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_amp_one_gpu(): plugin_parity_test( @@ -64,14 +61,13 @@ def test_ddp_sharded_plugin_correctness_amp_one_gpu(): precision=16, accelerator='ddp_spawn', plugin=DDPShardedPlugin(), - model_cls=SeedTrainLoaderModel + model_cls=SeedTrainLoaderModel, ) @pytest.mark.skip(reason="Not a critical test, skip till drone CI performance improves.") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -@pytest.mark.skipif(platform.system() == "Windows", - reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_multi_gpu(): plugin_parity_test( @@ -79,13 +75,12 @@ def test_ddp_sharded_plugin_correctness_multi_gpu(): accelerator='ddp_spawn', plugin=DDPShardedPlugin(), model_cls=SeedTrainLoaderModel, - max_percent_speed_diff=0.25 + max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers ) @pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Requires native AMP") -@pytest.mark.skipif(platform.system() == "Windows", - reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_amp_multi_gpu(): @@ -95,13 +90,12 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu(): accelerator='ddp_spawn', plugin=DDPShardedPlugin(), model_cls=SeedTrainLoaderModel, - max_percent_speed_diff=0.25 + max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers ) @pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Requires native AMP") -@pytest.mark.skipif(platform.system() == "Windows", - reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_string_sharded_plugin_correctness_amp_multi_gpu(): @@ -111,7 +105,7 @@ def test_ddp_string_sharded_plugin_correctness_amp_multi_gpu(): accelerator='ddp_spawn', plugin='ddp_sharded', model_cls=SeedTrainLoaderModel, - max_percent_speed_diff=0.25 + max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers ) @@ -147,8 +141,7 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None): @pytest.mark.skip(reason="Current issue with multiple optimizers and FairScale.") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -@pytest.mark.skipif(platform.system() == "Windows", - reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim(): """ @@ -159,14 +152,13 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim(): gpus=2, accelerator='ddp_spawn', model_cls=SeedTrainLoaderMultipleOptimizersModel, - max_percent_speed_diff=0.25 # Increase speed diff since only 2 GPUs sharding 2 optimizers + max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers ) @pytest.mark.skip(reason="Current issue with multiple optimizers and FairScale.") @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -@pytest.mark.skipif(platform.system() == "Windows", - reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows") @pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir): """ @@ -177,7 +169,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir): gpus=2, accelerator='ddp_spawn', model_cls=SeedTrainLoaderManualModel, - max_percent_speed_diff=0.25 # Increase speed diff since only 2 GPUs sharding 2 optimizers + max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers ) diff --git a/tests/base/develop_utils.py b/tests/base/develop_utils.py index 9c88ba1b7e..7b40ba4f39 100644 --- a/tests/base/develop_utils.py +++ b/tests/base/develop_utils.py @@ -14,8 +14,6 @@ import functools import os -import numpy as np - from pytorch_lightning import seed_everything from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger, TestTubeLogger @@ -23,24 +21,6 @@ from tests import TEMP_PATH, RANDOM_PORTS from tests.base.model_template import EvalModelTemplate -def assert_speed_parity_relative(pl_times, pt_times, max_diff: float = 0.1): - # assert speeds - diffs = np.asarray(pl_times) - np.asarray(pt_times) - # norm by vanila time - diffs = diffs / np.asarray(pt_times) - assert np.alltrue(diffs < max_diff), \ - f"lightning {diffs} was slower than PT (threshold {max_diff})" - - -def assert_speed_parity_absolute(pl_times, pt_times, nb_epochs, max_diff: float = 0.55): - # assert speeds - diffs = np.asarray(pl_times) - np.asarray(pt_times) - # norm by vanila time - diffs = diffs / nb_epochs - assert np.alltrue(diffs < max_diff), \ - f"lightning {diffs} was slower than PT (threshold {max_diff})" - - def get_default_logger(save_dir, version=None): # set up logger object without actually saving logs logger = TensorBoardLogger(save_dir, name='lightning_logs', version=version)