add memory parity for PL vs Vanilla (#5170)
* refactor * memory * show * clean * clean * try * device * reset * fix * fix * mean * hook * format * add todo Co-authored-by: chaton <thomas@grid.ai> Co-authored-by: chaton <thomas@grid.ai>
This commit is contained in:
parent
176735097a
commit
6adc1b32bd
|
@ -16,7 +16,7 @@ import os
|
|||
import matplotlib.pylab as plt
|
||||
import pandas as pd
|
||||
|
||||
from benchmarks.test_basic_parity import lightning_loop, vanilla_loop
|
||||
from benchmarks.test_basic_parity import measure_loops
|
||||
from tests.base.models import ParityModuleMNIST, ParityModuleRNN
|
||||
|
||||
NUM_EPOCHS = 20
|
||||
|
@ -34,8 +34,9 @@ def _main():
|
|||
if os.path.isfile(path_csv):
|
||||
df_time = pd.read_csv(path_csv, index_col=0)
|
||||
else:
|
||||
vanilla = vanilla_loop(cls_model, num_epochs=NUM_EPOCHS, num_runs=NUM_RUNS)
|
||||
lightning = lightning_loop(cls_model, num_epochs=NUM_EPOCHS, num_runs=NUM_RUNS)
|
||||
# todo: kind="Vanilla PT" -> use_lightning=False
|
||||
vanilla = measure_loops(cls_model, kind="Vanilla PT", num_epochs=NUM_EPOCHS, num_runs=NUM_RUNS)
|
||||
lightning = measure_loops(cls_model, kind="PT Lightning", num_epochs=NUM_EPOCHS, num_runs=NUM_RUNS)
|
||||
|
||||
df_time = pd.DataFrame({'vanilla PT': vanilla['durations'][1:], 'PT Lightning': lightning['durations'][1:]})
|
||||
df_time /= NUM_RUNS
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
|
@ -19,118 +19,156 @@ import pytest
|
|||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from pytorch_lightning import seed_everything, Trainer
|
||||
import tests.base.develop_utils as tutils
|
||||
from pytorch_lightning import LightningModule, seed_everything, Trainer
|
||||
from tests.base.models import ParityModuleMNIST, ParityModuleRNN
|
||||
|
||||
|
||||
def assert_parity_relative(pl_values, pt_values, norm_by: float = 1, max_diff: float = 0.1):
|
||||
# assert speeds
|
||||
diffs = np.asarray(pl_values) - np.mean(pt_values)
|
||||
# norm by vanilla time
|
||||
diffs = diffs / norm_by
|
||||
# relative to mean reference value
|
||||
diffs = diffs / np.mean(pt_values)
|
||||
assert np.mean(diffs) < max_diff, f"Lightning diff {diffs} was worse than vanilla PT (threshold {max_diff})"
|
||||
|
||||
|
||||
def assert_parity_absolute(pl_values, pt_values, norm_by: float = 1, max_diff: float = 0.55):
|
||||
# assert speeds
|
||||
diffs = np.asarray(pl_values) - np.mean(pt_values)
|
||||
# norm by event count
|
||||
diffs = diffs / norm_by
|
||||
assert np.mean(diffs) < max_diff, f"Lightning {diffs} was worse than vanilla PT (threshold {max_diff})"
|
||||
|
||||
|
||||
# ParityModuleMNIST runs with num_workers=1
|
||||
@pytest.mark.parametrize('cls_model,max_diff', [
|
||||
(ParityModuleRNN, 0.05),
|
||||
(ParityModuleMNIST, 0.25), # todo: lower this thr
|
||||
@pytest.mark.parametrize('cls_model,max_diff_speed,max_diff_memory', [
|
||||
(ParityModuleRNN, 0.05, 0.0),
|
||||
(ParityModuleMNIST, 0.25, 0.0), # todo: lower this thr
|
||||
])
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
|
||||
def test_pytorch_parity(tmpdir, cls_model, max_diff: float, num_epochs: int = 4, num_runs: int = 3):
|
||||
def test_pytorch_parity(
|
||||
tmpdir,
|
||||
cls_model: LightningModule,
|
||||
max_diff_speed: float,
|
||||
max_diff_memory: float,
|
||||
num_epochs: int = 4,
|
||||
num_runs: int = 3,
|
||||
):
|
||||
"""
|
||||
Verify that the same pytorch and lightning models achieve the same results
|
||||
"""
|
||||
lightning = lightning_loop(cls_model, num_runs, num_epochs)
|
||||
vanilla = vanilla_loop(cls_model, num_runs, num_epochs)
|
||||
lightning = measure_loops(cls_model, kind="PT Lightning", num_epochs=num_epochs, num_runs=num_runs)
|
||||
vanilla = measure_loops(cls_model, kind="Vanilla PT", num_epochs=num_epochs, num_runs=num_runs)
|
||||
|
||||
# make sure the losses match exactly to 5 decimal places
|
||||
print(f"Losses are for... \n vanilla: {vanilla['losses']} \n lightning: {lightning['losses']}")
|
||||
for pl_out, pt_out in zip(lightning['losses'], vanilla['losses']):
|
||||
np.testing.assert_almost_equal(pl_out, pt_out, 5)
|
||||
|
||||
# the fist run initialize dataset (download & filter)
|
||||
tutils.assert_speed_parity_absolute(
|
||||
lightning['durations'][1:], vanilla['durations'][1:], nb_epochs=num_epochs, max_diff=max_diff
|
||||
# drop the first run for initialize dataset (download & filter)
|
||||
assert_parity_absolute(
|
||||
lightning['durations'][1:], vanilla['durations'][1:], norm_by=num_epochs, max_diff=max_diff_speed
|
||||
)
|
||||
|
||||
assert_parity_relative(lightning['memory'], vanilla['memory'], max_diff=max_diff_memory)
|
||||
|
||||
def vanilla_loop(cls_model, num_runs=10, num_epochs=10):
|
||||
|
||||
def _hook_memory():
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.synchronize()
|
||||
used_memory = torch.cuda.max_memory_allocated()
|
||||
else:
|
||||
used_memory = np.nan
|
||||
return used_memory
|
||||
|
||||
|
||||
def measure_loops(cls_model, kind, num_runs=10, num_epochs=10):
|
||||
"""
|
||||
Returns an array with the last loss from each epoch for each run
|
||||
"""
|
||||
hist_losses = []
|
||||
hist_durations = []
|
||||
hist_memory = []
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
|
||||
device_type = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
torch.backends.cudnn.deterministic = True
|
||||
for i in tqdm(range(num_runs), desc=f'Vanilla PT with {cls_model.__name__}'):
|
||||
for i in tqdm(range(num_runs), desc=f'{kind} with {cls_model.__name__}'):
|
||||
gc.collect()
|
||||
if device_type == 'cuda':
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_cached()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_accumulated_memory_stats()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
time.sleep(1)
|
||||
|
||||
time_start = time.perf_counter()
|
||||
|
||||
# set seed
|
||||
seed = i
|
||||
seed_everything(seed)
|
||||
|
||||
# init model parts
|
||||
model = cls_model()
|
||||
dl = model.train_dataloader()
|
||||
optimizer = model.configure_optimizers()
|
||||
|
||||
# model to GPU
|
||||
model = model.to(device)
|
||||
|
||||
epoch_losses = []
|
||||
# as the first run is skipped, no need to run it long
|
||||
for epoch in range(num_epochs if i > 0 else 1):
|
||||
|
||||
# run through full training set
|
||||
for j, batch in enumerate(dl):
|
||||
batch = [x.to(device) for x in batch]
|
||||
loss_dict = model.training_step(batch, j)
|
||||
loss = loss_dict['loss']
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# track last epoch loss
|
||||
epoch_losses.append(loss.item())
|
||||
_loop = lightning_loop if kind == "PT Lightning" else vanilla_loop
|
||||
final_loss, used_memory = _loop(cls_model, idx=i, device_type=device_type, num_epochs=num_epochs)
|
||||
|
||||
time_end = time.perf_counter()
|
||||
hist_durations.append(time_end - time_start)
|
||||
|
||||
hist_losses.append(epoch_losses[-1])
|
||||
|
||||
return {
|
||||
'losses': hist_losses,
|
||||
'durations': hist_durations,
|
||||
}
|
||||
|
||||
|
||||
def lightning_loop(cls_model, num_runs=10, num_epochs=10):
|
||||
hist_losses = []
|
||||
hist_durations = []
|
||||
|
||||
for i in tqdm(range(num_runs), desc=f'PT Lightning with {cls_model.__name__}'):
|
||||
time_start = time.perf_counter()
|
||||
|
||||
# set seed
|
||||
seed = i
|
||||
seed_everything(seed)
|
||||
|
||||
model = cls_model()
|
||||
# init model parts
|
||||
trainer = Trainer(
|
||||
# as the first run is skipped, no need to run it long
|
||||
max_epochs=num_epochs if i > 0 else 1,
|
||||
progress_bar_refresh_rate=0,
|
||||
weights_summary=None,
|
||||
gpus=1,
|
||||
checkpoint_callback=False,
|
||||
deterministic=True,
|
||||
logger=False,
|
||||
replace_sampler_ddp=False,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
final_loss = trainer.train_loop.running_loss.last().item()
|
||||
hist_losses.append(final_loss)
|
||||
|
||||
time_end = time.perf_counter()
|
||||
hist_durations.append(time_end - time_start)
|
||||
hist_memory.append(used_memory)
|
||||
|
||||
return {
|
||||
'losses': hist_losses,
|
||||
'durations': hist_durations,
|
||||
'memory': hist_memory,
|
||||
}
|
||||
|
||||
|
||||
def vanilla_loop(cls_model, idx, device_type: str = 'cuda', num_epochs=10):
|
||||
device = torch.device(device_type)
|
||||
# set seed
|
||||
seed_everything(idx)
|
||||
|
||||
# init model parts
|
||||
model = cls_model()
|
||||
dl = model.train_dataloader()
|
||||
optimizer = model.configure_optimizers()
|
||||
|
||||
# model to GPU
|
||||
model = model.to(device)
|
||||
|
||||
epoch_losses = []
|
||||
# as the first run is skipped, no need to run it long
|
||||
for epoch in range(num_epochs if idx > 0 else 1):
|
||||
|
||||
# run through full training set
|
||||
for j, batch in enumerate(dl):
|
||||
batch = [x.to(device) for x in batch]
|
||||
loss_dict = model.training_step(batch, j)
|
||||
loss = loss_dict['loss']
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# track last epoch loss
|
||||
epoch_losses.append(loss.item())
|
||||
|
||||
return epoch_losses[-1], _hook_memory()
|
||||
|
||||
|
||||
def lightning_loop(cls_model, idx, device_type: str = 'cuda', num_epochs=10):
|
||||
seed_everything(idx)
|
||||
|
||||
model = cls_model()
|
||||
# init model parts
|
||||
trainer = Trainer(
|
||||
# as the first run is skipped, no need to run it long
|
||||
max_epochs=num_epochs if idx > 0 else 1,
|
||||
progress_bar_refresh_rate=0,
|
||||
weights_summary=None,
|
||||
gpus=1 if device_type == 'cuda' else 0,
|
||||
checkpoint_callback=False,
|
||||
deterministic=True,
|
||||
logger=False,
|
||||
replace_sampler_ddp=False,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
return trainer.train_loop.running_loss.last().item(), _hook_memory()
|
||||
|
|
|
@ -28,35 +28,32 @@ from tests.backends import DDPLauncher
|
|||
from tests.base.boring_model import BoringModel, RandomDataset
|
||||
|
||||
|
||||
@pytest.mark.skipif(platform.system() == "Windows",
|
||||
reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
def test_ddp_sharded_plugin_correctness_one_device():
|
||||
plugin_parity_test(
|
||||
accelerator='ddp_cpu',
|
||||
max_percent_speed_diff=0.15, # slower speed due to one CPU doing additional sequential memory saving calls
|
||||
plugin=DDPShardedPlugin(),
|
||||
model_cls=SeedTrainLoaderModel
|
||||
model_cls=SeedTrainLoaderModel,
|
||||
max_percent_speed_diff=0.15, # todo: slower speed due to one CPU doing additional sequential memory saving calls
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
|
||||
@pytest.mark.skipif(platform.system() == "Windows",
|
||||
reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
def test_ddp_sharded_plugin_correctness_one_gpu():
|
||||
plugin_parity_test(
|
||||
gpus=1,
|
||||
accelerator='ddp_spawn',
|
||||
plugin=DDPShardedPlugin(),
|
||||
model_cls=SeedTrainLoaderModel
|
||||
model_cls=SeedTrainLoaderModel,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
|
||||
@pytest.mark.skipif(platform.system() == "Windows",
|
||||
reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
def test_ddp_sharded_plugin_correctness_amp_one_gpu():
|
||||
plugin_parity_test(
|
||||
|
@ -64,14 +61,13 @@ def test_ddp_sharded_plugin_correctness_amp_one_gpu():
|
|||
precision=16,
|
||||
accelerator='ddp_spawn',
|
||||
plugin=DDPShardedPlugin(),
|
||||
model_cls=SeedTrainLoaderModel
|
||||
model_cls=SeedTrainLoaderModel,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Not a critical test, skip till drone CI performance improves.")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(platform.system() == "Windows",
|
||||
reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
def test_ddp_sharded_plugin_correctness_multi_gpu():
|
||||
plugin_parity_test(
|
||||
|
@ -79,13 +75,12 @@ def test_ddp_sharded_plugin_correctness_multi_gpu():
|
|||
accelerator='ddp_spawn',
|
||||
plugin=DDPShardedPlugin(),
|
||||
model_cls=SeedTrainLoaderModel,
|
||||
max_percent_speed_diff=0.25
|
||||
max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
|
||||
@pytest.mark.skipif(platform.system() == "Windows",
|
||||
reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
def test_ddp_sharded_plugin_correctness_amp_multi_gpu():
|
||||
|
@ -95,13 +90,12 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu():
|
|||
accelerator='ddp_spawn',
|
||||
plugin=DDPShardedPlugin(),
|
||||
model_cls=SeedTrainLoaderModel,
|
||||
max_percent_speed_diff=0.25
|
||||
max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not NATIVE_AMP_AVAILABLE, reason="Requires native AMP")
|
||||
@pytest.mark.skipif(platform.system() == "Windows",
|
||||
reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
def test_ddp_string_sharded_plugin_correctness_amp_multi_gpu():
|
||||
|
@ -111,7 +105,7 @@ def test_ddp_string_sharded_plugin_correctness_amp_multi_gpu():
|
|||
accelerator='ddp_spawn',
|
||||
plugin='ddp_sharded',
|
||||
model_cls=SeedTrainLoaderModel,
|
||||
max_percent_speed_diff=0.25
|
||||
max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers
|
||||
)
|
||||
|
||||
|
||||
|
@ -147,8 +141,7 @@ def test_ddp_sharded_plugin_correctness_amp_multi_gpu_ddp(tmpdir, args=None):
|
|||
|
||||
@pytest.mark.skip(reason="Current issue with multiple optimizers and FairScale.")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(platform.system() == "Windows",
|
||||
reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim():
|
||||
"""
|
||||
|
@ -159,14 +152,13 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim():
|
|||
gpus=2,
|
||||
accelerator='ddp_spawn',
|
||||
model_cls=SeedTrainLoaderMultipleOptimizersModel,
|
||||
max_percent_speed_diff=0.25 # Increase speed diff since only 2 GPUs sharding 2 optimizers
|
||||
max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Current issue with multiple optimizers and FairScale.")
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
|
||||
@pytest.mark.skipif(platform.system() == "Windows",
|
||||
reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(platform.system() == "Windows", reason="Distributed training is not supported on Windows")
|
||||
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
|
||||
def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir):
|
||||
"""
|
||||
|
@ -177,7 +169,7 @@ def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir):
|
|||
gpus=2,
|
||||
accelerator='ddp_spawn',
|
||||
model_cls=SeedTrainLoaderManualModel,
|
||||
max_percent_speed_diff=0.25 # Increase speed diff since only 2 GPUs sharding 2 optimizers
|
||||
max_percent_speed_diff=0.25, # todo: Increase speed diff since only 2 GPUs sharding 2 optimizers
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -14,8 +14,6 @@
|
|||
import functools
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from pytorch_lightning import seed_everything
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.loggers import TensorBoardLogger, TestTubeLogger
|
||||
|
@ -23,24 +21,6 @@ from tests import TEMP_PATH, RANDOM_PORTS
|
|||
from tests.base.model_template import EvalModelTemplate
|
||||
|
||||
|
||||
def assert_speed_parity_relative(pl_times, pt_times, max_diff: float = 0.1):
|
||||
# assert speeds
|
||||
diffs = np.asarray(pl_times) - np.asarray(pt_times)
|
||||
# norm by vanila time
|
||||
diffs = diffs / np.asarray(pt_times)
|
||||
assert np.alltrue(diffs < max_diff), \
|
||||
f"lightning {diffs} was slower than PT (threshold {max_diff})"
|
||||
|
||||
|
||||
def assert_speed_parity_absolute(pl_times, pt_times, nb_epochs, max_diff: float = 0.55):
|
||||
# assert speeds
|
||||
diffs = np.asarray(pl_times) - np.asarray(pt_times)
|
||||
# norm by vanila time
|
||||
diffs = diffs / nb_epochs
|
||||
assert np.alltrue(diffs < max_diff), \
|
||||
f"lightning {diffs} was slower than PT (threshold {max_diff})"
|
||||
|
||||
|
||||
def get_default_logger(save_dir, version=None):
|
||||
# set up logger object without actually saving logs
|
||||
logger = TensorBoardLogger(save_dir, name='lightning_logs', version=version)
|
||||
|
|
Loading…
Reference in New Issue