lightning/tests/plugins/test_sharded_plugin.py

258 lines
8.5 KiB
Python
Raw Normal View History

2020-11-24 21:12:18 +00:00
import glob
import os
import platform
import time
from distutils.version import LooseVersion
from unittest import mock
import pytest
import torch
from torch.utils.data.distributed import DistributedSampler
from pytorch_lightning import Trainer, seed_everything
2020-11-24 21:12:18 +00:00
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin, FAIRSCALE_AVAILABLE
from tests.base.boring_model import BoringModel, RandomDataset
@mock.patch.dict(
os.environ,
{
"CUDA_VISIBLE_DEVICES": "0,1",
"SLURM_NTASKS": "2",
"SLURM_JOB_NAME": "SOME_NAME",
"SLURM_NODEID": "0",
"LOCAL_RANK": "0",
"SLURM_LOCALID": "0",
},
)
@mock.patch("torch.cuda.device_count", return_value=2)
@pytest.mark.parametrize(
["ddp_backend", "gpus", "num_processes"],
[("ddp_cpu", None, None), ("ddp", 2, 0), ("ddp2", 2, 0), ("ddp_spawn", 2, 0)],
)
2020-11-24 21:12:18 +00:00
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_choice_sharded_cpu(tmpdir, ddp_backend, gpus, num_processes):
class CB(Callback):
def on_fit_start(self, trainer, pl_module):
assert isinstance(trainer.accelerator_backend.ddp_plugin, DDPShardedPlugin)
raise SystemExit()
model = BoringModel()
trainer = Trainer(
fast_dev_run=True,
gpus=gpus,
num_processes=num_processes,
distributed_backend=ddp_backend,
plugins=[DDPShardedPlugin()],
callbacks=[CB()],
)
with pytest.raises(SystemExit):
trainer.fit(model)
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
2020-11-24 21:12:18 +00:00
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_checkpoint_cpu(tmpdir):
model = BoringModel()
trainer = Trainer(
callbacks=[ModelCheckpoint(dirpath=tmpdir, save_last=True)],
accelerator='ddp_cpu',
plugins=[DDPShardedPlugin()],
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1
)
trainer.fit(model)
checkpoint_path = glob.glob(os.path.join(tmpdir, "*.ckpt"))[0]
saved_model = BoringModel.load_from_checkpoint(checkpoint_path)
# Assert model parameters are identical after loading
for ddp_param, shard_param in zip(model.parameters(), saved_model.parameters()):
assert torch.equal(ddp_param, shard_param)
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_checkpoint_multi_gpu(tmpdir):
model = BoringModel()
trainer = Trainer(
callbacks=[ModelCheckpoint(dirpath=tmpdir, save_last=True)],
gpus=2,
accelerator='ddp_spawn',
plugins=[DDPShardedPlugin()],
limit_train_batches=2,
limit_val_batches=2,
max_epochs=1
)
trainer.fit(model)
checkpoint_path = glob.glob(os.path.join(tmpdir, "*.ckpt"))[0]
saved_model = BoringModel.load_from_checkpoint(checkpoint_path)
# Assert model parameters are identical after loading
for ddp_param, shard_param in zip(model.parameters(), saved_model.parameters()):
assert torch.equal(ddp_param, shard_param)
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_one_device():
run_sharded_correctness(accelerator='ddp_cpu')
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
2020-11-24 21:12:18 +00:00
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_one_gpu():
run_sharded_correctness(gpus=1, accelerator='ddp_spawn')
@pytest.mark.skipif(
LooseVersion(torch.__version__) < LooseVersion("1.6.0"),
reason="Minimal PT version is set to 1.6")
@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
2020-11-24 21:12:18 +00:00
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_amp_one_gpu():
run_sharded_correctness(gpus=1, precision=16, accelerator='ddp_spawn')
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
2020-11-24 21:12:18 +00:00
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE,
reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_multi_gpu():
run_sharded_correctness(gpus=2, accelerator='ddp_spawn')
@pytest.mark.skipif(
LooseVersion(torch.__version__) < LooseVersion("1.6.0"),
reason="Minimal PT version is set to 1.6")
@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
2020-11-24 21:12:18 +00:00
@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available")
def test_ddp_sharded_plugin_correctness_amp_multi_gpu():
run_sharded_correctness(gpus=2, precision=16, accelerator='ddp_spawn')
class TestModel(BoringModel):
"""
Overrides training loader to ensure we enforce the same seed for all DDP processes.
"""
def train_dataloader(self):
seed_everything(42)
return torch.utils.data.DataLoader(RandomDataset(32, 64))
def record_ddp_fit_model_stats(trainer, model, gpus):
"""
Helper to calculate wall clock time for fit + max allocated memory.
Args:
trainer: The trainer object.
model: The LightningModule.
gpus: Number of GPUs in test.
Returns:
Max Memory if using GPUs, and total wall clock time.
"""
max_memory = None
if gpus > 0:
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
time_start = time.perf_counter()
trainer.fit(model)
if gpus > 0:
torch.cuda.synchronize()
max_memory = torch.cuda.max_memory_allocated() / 2 ** 20
total_time = time.perf_counter() - time_start
return max_memory, total_time
def run_sharded_correctness(
accelerator='ddp_spawn',
gpus=0,
precision=32,
max_percent_speed_regression=0.1):
"""
Ensures that the trained model is identical to the standard DDP implementation.
Also checks for speed/memory regressions, we should expect always less memory but performance to fluctuate.
Args:
accelerator: Accelerator type for test.
gpus: Number of GPUS to enable.
precision: Whether to use AMP or normal FP32 training.
max_percent_speed_regression: The maximum speed regression compared to normal DDP training
"""
# Train normal DDP
seed_everything(42)
ddp_model = TestModel()
trainer = Trainer(
limit_val_batches=0.0,
fast_dev_run=False,
max_epochs=1,
gpus=gpus,
precision=precision,
accelerator=accelerator,
)
max_ddp_memory, ddp_time = record_ddp_fit_model_stats(
trainer=trainer,
model=ddp_model,
gpus=gpus
)
# Reset and train sharded DDP
seed_everything(42)
sharded_model = TestModel()
trainer = Trainer(
limit_val_batches=0.0,
fast_dev_run=False,
max_epochs=1,
gpus=gpus,
precision=precision,
accelerator=accelerator,
plugins=[DDPShardedPlugin()],
)
max_sharded_memory, sharded_time = record_ddp_fit_model_stats(
trainer=trainer,
model=sharded_model,
gpus=gpus
)
# Assert model parameters are identical after fit
for ddp_param, shard_param in zip(ddp_model.parameters(), sharded_model.parameters()):
assert torch.equal(ddp_param, shard_param)
# Assert speed parity
upper_bound_speed = ddp_time * (1 + max_percent_speed_regression)
assert sharded_time <= upper_bound_speed
if gpus > 0:
# Assert CUDA memory parity
assert max_sharded_memory <= max_ddp_memory