diff --git a/benchmarks/test_sharded_correctness.py b/benchmarks/test_sharded_correctness.py new file mode 100644 index 0000000000..0b986f4199 --- /dev/null +++ b/benchmarks/test_sharded_correctness.py @@ -0,0 +1,255 @@ +import os +import platform +import time +from unittest import mock + +import pytest +import torch +from torch.utils.data.distributed import DistributedSampler + +from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin, FAIRSCALE_AVAILABLE +from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE +from tests.base.boring_model import BoringModel, RandomDataset + + +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_correctness_one_device(): + # Allow slightly slower speed due to one CPU doing additional sequential memory saving calls + run_sharded_correctness(accelerator='ddp_cpu', max_percent_speed_diff=0.5) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_correctness_one_gpu(): + run_sharded_correctness(gpus=1, accelerator='ddp_spawn') + + +@pytest.mark.skipif(not NATIVE_AMP_AVALAIBLE, reason="Requires native AMP") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_correctness_amp_one_gpu(): + run_sharded_correctness(gpus=1, precision=16, accelerator='ddp_spawn') + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, + reason="Fairscale is not available") +def test_ddp_sharded_plugin_correctness_multi_gpu(): + run_sharded_correctness(gpus=2, accelerator='ddp_spawn') + + +@pytest.mark.skipif(not NATIVE_AMP_AVALAIBLE, reason="Requires native AMP") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +def test_ddp_sharded_plugin_correctness_amp_multi_gpu(): + run_sharded_correctness(gpus=2, precision=16, accelerator='ddp_spawn') + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, + reason="Fairscale is not available") +def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim(): + """ + Ensures same results using multiple optimizers across multiple GPUs + """ + run_sharded_correctness( + gpus=2, + accelerator='ddp_spawn', + model_cls=SeedTrainLoaderMultipleOptimizersModel, + max_percent_speed_diff=0.3 # Increase speed diff since only 2 GPUs sharding 2 optimizers + ) + + +@pytest.mark.skip(reason="Currently DDP manual optimization is broken due to no reduce within training step.") +@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") +@pytest.mark.skipif(platform.system() == "Windows", + reason="Distributed training is not supported on Windows") +@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir): + """ + Ensures using multiple optimizers across multiple GPUs with manual optimization + """ + run_sharded_correctness( + gpus=2, + accelerator='ddp_spawn', + model_cls=SeedTrainLoaderManualModel, + ) + + +class SeedTrainLoaderModel(BoringModel): + """ + Overrides training loader to ensure we enforce the same seed for all DDP processes. + """ + + def train_dataloader(self): + seed_everything(42) + return torch.utils.data.DataLoader(RandomDataset(32, 64)) + + +class SeedTrainLoaderManualModel(SeedTrainLoaderModel): + def training_step(self, batch, batch_idx, optimizer_idx): + # manual + (opt_a, opt_b) = self.optimizers() + loss_1 = self.step(batch) + + self.manual_backward(loss_1, opt_a) + self.manual_optimizer_step(opt_a) + + # fake discriminator + loss_2 = self.step(batch[0]) + + # ensure we forward the correct params to the optimizer + # without retain_graph we can't do multiple backward passes + self.manual_backward(loss_2, opt_b, retain_graph=True) + self.manual_backward(loss_2, opt_a, retain_graph=True) + self.manual_optimizer_step(opt_b) + + assert self.layer.weight.grad is None or torch.all(self.layer.weight.grad == 0) + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) + return optimizer, optimizer_2 + + @property + def automatic_optimization(self) -> bool: + return False + + +class SeedTrainLoaderMultipleOptimizersModel(SeedTrainLoaderModel): + def training_step(self, batch, batch_idx, optimizer_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + return {"loss": loss} + + def training_epoch_end(self, outputs) -> None: + # outputs should be an array with an entry per optimizer + assert len(outputs) == 2 + + def configure_optimizers(self): + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) + return optimizer, optimizer_2 + + +def record_ddp_fit_model_stats(trainer, model, gpus): + """ + Helper to calculate wall clock time for fit + max allocated memory. + + Args: + trainer: The trainer object. + model: The LightningModule. + gpus: Number of GPUs in test. + + Returns: + Max Memory if using GPUs, and total wall clock time. + + """ + max_memory = None + + time_start = time.perf_counter() + if gpus > 0: + torch.cuda.reset_peak_memory_stats() + torch.cuda.synchronize() + + trainer.fit(model) + + if gpus > 0: + torch.cuda.synchronize() + max_memory = torch.cuda.max_memory_allocated() / 2 ** 20 + + total_time = time.perf_counter() - time_start + + return max_memory, total_time + + +def run_sharded_correctness( + accelerator='ddp_spawn', + gpus=0, + precision=32, + max_percent_speed_diff=0.25, + model_cls=SeedTrainLoaderModel): + """ + Ensures that the trained model is identical to the standard DDP implementation. + Also checks for speed/memory regressions, we should expect always less memory but performance to fluctuate. + + Args: + accelerator: Accelerator type for test. + gpus: Number of GPUS to enable. + precision: Whether to use AMP or normal FP32 training. + max_percent_speed_diff: The maximum speed difference compared to normal DDP training. + This is more a safety net for variability in CI which can vary in speed, not for benchmarking. + model_cls: Model class to use for test. + + """ + + # Train normal DDP + seed_everything(42) + ddp_model = model_cls() + + trainer = Trainer( + fast_dev_run=True, + max_epochs=1, + gpus=gpus, + precision=precision, + accelerator=accelerator, + ) + + max_ddp_memory, ddp_time = record_ddp_fit_model_stats( + trainer=trainer, + model=ddp_model, + gpus=gpus + ) + + # Reset and train sharded DDP + seed_everything(42) + sharded_model = model_cls() + + trainer = Trainer( + fast_dev_run=True, + max_epochs=1, + gpus=gpus, + precision=precision, + accelerator=accelerator, + plugins=[DDPShardedPlugin()], + ) + + max_sharded_memory, sharded_time = record_ddp_fit_model_stats( + trainer=trainer, + model=sharded_model, + gpus=gpus + ) + + # Assert model parameters are identical after fit + for ddp_param, shard_param in zip(ddp_model.parameters(), sharded_model.parameters()): + assert torch.equal(ddp_param, shard_param), 'Model parameters are different between DDP and Sharded plugin' + + # Assert speed parity by ensuring percentage difference between sharded/ddp is below threshold + percent_diff = (sharded_time - ddp_time) / sharded_time + + assert percent_diff <= max_percent_speed_diff, \ + f'Sharded plugin was too slow compared to DDP, Sharded Time: {sharded_time}, DDP Time: {ddp_time}' + + if gpus > 0: + # Assert CUDA memory parity + assert max_sharded_memory <= max_ddp_memory, \ + f'Sharded plugin used too much memory compared to DDP,' \ + f'Sharded Mem: {max_sharded_memory}, DDP Mem: {max_ddp_memory}' diff --git a/pytorch_lightning/overrides/fairscale.py b/pytorch_lightning/overrides/fairscale.py index e6bc5b3e25..df49199e2a 100644 --- a/pytorch_lightning/overrides/fairscale.py +++ b/pytorch_lightning/overrides/fairscale.py @@ -11,11 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import platform +from pytorch_lightning.utilities import FAIRSCALE_AVAILABLE -from pytorch_lightning.utilities import _module_available - -if platform.system() != "Windows" and _module_available('fairscale.nn.data_parallel.sharded_ddp'): +if FAIRSCALE_AVAILABLE: from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel class LightningShardedDataParallel(ShardedDataParallel): @@ -32,6 +30,5 @@ if platform.system() != "Windows" and _module_available('fairscale.nn.data_paral outputs = self.module.validation_step(*inputs, **kwargs) return outputs - FAIRSCALE_SHARDED_AVAILABLE = True else: - FAIRSCALE_SHARDED_AVAILABLE = False + LightningShardedDataParallel = None diff --git a/pytorch_lightning/plugins/sharded_native_amp_plugin.py b/pytorch_lightning/plugins/sharded_native_amp_plugin.py index d244f68dcc..67aac2c3a2 100644 --- a/pytorch_lightning/plugins/sharded_native_amp_plugin.py +++ b/pytorch_lightning/plugins/sharded_native_amp_plugin.py @@ -13,16 +13,12 @@ # limitations under the License. from typing import cast -from pytorch_lightning.utilities import _module_available, NATIVE_AMP_AVALAIBLE +from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE, FAIRSCALE_AVAILABLE -if NATIVE_AMP_AVALAIBLE and _module_available('fairscale.optim'): +if NATIVE_AMP_AVALAIBLE and FAIRSCALE_AVAILABLE: from fairscale.optim import OSS from fairscale.optim.grad_scaler import ShardedGradScaler - FAIRSCALE_AMP_AVAILABLE = True -else: - FAIRSCALE_AMP_AVAILABLE = False - from pytorch_lightning.plugins.native_amp import NativeAMPPlugin diff --git a/pytorch_lightning/plugins/sharded_plugin.py b/pytorch_lightning/plugins/sharded_plugin.py index 2eb58093dc..ba7b751fde 100644 --- a/pytorch_lightning/plugins/sharded_plugin.py +++ b/pytorch_lightning/plugins/sharded_plugin.py @@ -11,22 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import platform from typing import List, Optional, Union from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.plugins.ddp_plugin import DDPPlugin -from pytorch_lightning.utilities import rank_zero_only, _module_available +from pytorch_lightning.utilities import rank_zero_only, FAIRSCALE_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException -if platform.system() != "Windows" and _module_available('fairscale.optim'): # Distributed not supported on windows +if FAIRSCALE_AVAILABLE: from fairscale.optim import OSS from pytorch_lightning.overrides.fairscale import LightningShardedDataParallel - FAIRSCALE_AVAILABLE = True -else: - FAIRSCALE_AVAILABLE = False - class DDPShardedPlugin(DDPPlugin): diff --git a/pytorch_lightning/trainer/connectors/precision_connector.py b/pytorch_lightning/trainer/connectors/precision_connector.py index d5c82a396d..5961ac03b4 100644 --- a/pytorch_lightning/trainer/connectors/precision_connector.py +++ b/pytorch_lightning/trainer/connectors/precision_connector.py @@ -16,9 +16,10 @@ from typing import Optional from pytorch_lightning import _logger as log from pytorch_lightning.plugins.apex import ApexPlugin from pytorch_lightning.plugins.native_amp import NativeAMPPlugin -from pytorch_lightning.plugins.sharded_native_amp_plugin import ShardedNativeAMPPlugin, FAIRSCALE_AMP_AVAILABLE +from pytorch_lightning.plugins.sharded_native_amp_plugin import ShardedNativeAMPPlugin from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin -from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVALAIBLE, AMPType, rank_zero_warn +from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVALAIBLE, AMPType, rank_zero_warn, \ + FAIRSCALE_AVAILABLE from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -60,7 +61,7 @@ class PrecisionConnector: else: self.trainer.amp_backend = AMPType.NATIVE if plugins and self._sharded_in_plugins(plugins): - if not FAIRSCALE_AMP_AVAILABLE: + if not FAIRSCALE_AVAILABLE: raise MisconfigurationException('Sharded DDP Plugin requires Fairscale to be installed.') log.info('Using Sharded 16bit plugin.') self.backend = ShardedNativeAMPPlugin(self.trainer) diff --git a/pytorch_lightning/utilities/__init__.py b/pytorch_lightning/utilities/__init__.py index 4d91ec8c4f..83a9144efa 100644 --- a/pytorch_lightning/utilities/__init__.py +++ b/pytorch_lightning/utilities/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. """General utilities""" import importlib +import platform from enum import Enum import numpy @@ -43,6 +44,7 @@ def _module_available(module_path: str) -> bool: APEX_AVAILABLE = _module_available("apex.amp") NATIVE_AMP_AVALAIBLE = hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast") +FAIRSCALE_AVAILABLE = platform.system() != 'Windows' and _module_available('fairscale.nn.data_parallel') FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps diff --git a/tests/plugins/test_sharded_plugin.py b/tests/plugins/test_sharded_plugin.py index 76db75a21f..11ec6cbe73 100644 --- a/tests/plugins/test_sharded_plugin.py +++ b/tests/plugins/test_sharded_plugin.py @@ -1,18 +1,16 @@ import os import platform -import time from unittest import mock import pytest import torch -from torch.utils.data.distributed import DistributedSampler -from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning import Trainer from pytorch_lightning.callbacks import Callback from pytorch_lightning.plugins.sharded_native_amp_plugin import ShardedNativeAMPPlugin from pytorch_lightning.plugins.sharded_plugin import DDPShardedPlugin, FAIRSCALE_AVAILABLE from pytorch_lightning.utilities import NATIVE_AMP_AVALAIBLE -from tests.base.boring_model import BoringModel, RandomDataset +from tests.base.boring_model import BoringModel @mock.patch.dict( @@ -280,246 +278,3 @@ def test_ddp_sharded_plugin_resume_from_checkpoint_gpu_to_cpu(tmpdir): trainer.fit(model) return 1 - - -@pytest.mark.skipif(platform.system() == "Windows", - reason="Distributed training is not supported on Windows") -@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") -def test_ddp_sharded_plugin_correctness_one_device(): - # Allow slightly slower speed due to one CPU doing additional sequential memory saving calls - run_sharded_correctness(accelerator='ddp_cpu', max_percent_speed_diff=0.5) - - -@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") -@pytest.mark.skipif(platform.system() == "Windows", - reason="Distributed training is not supported on Windows") -@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") -def test_ddp_sharded_plugin_correctness_one_gpu(): - run_sharded_correctness(gpus=1, accelerator='ddp_spawn') - - -@pytest.mark.skipif(not NATIVE_AMP_AVALAIBLE, reason="Requires native AMP") -@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") -@pytest.mark.skipif(platform.system() == "Windows", - reason="Distributed training is not supported on Windows") -@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") -def test_ddp_sharded_plugin_correctness_amp_one_gpu(): - run_sharded_correctness(gpus=1, precision=16, accelerator='ddp_spawn') - - -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -@pytest.mark.skipif(platform.system() == "Windows", - reason="Distributed training is not supported on Windows") -@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, - reason="Fairscale is not available") -def test_ddp_sharded_plugin_correctness_multi_gpu(): - run_sharded_correctness(gpus=2, accelerator='ddp_spawn') - - -@pytest.mark.skipif(not NATIVE_AMP_AVALAIBLE, reason="Requires native AMP") -@pytest.mark.skipif(platform.system() == "Windows", - reason="Distributed training is not supported on Windows") -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, reason="Fairscale is not available") -def test_ddp_sharded_plugin_correctness_amp_multi_gpu(): - run_sharded_correctness(gpus=2, precision=16, accelerator='ddp_spawn') - - -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -@pytest.mark.skipif(platform.system() == "Windows", - reason="Distributed training is not supported on Windows") -@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, - reason="Fairscale is not available") -def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim(): - """ - Ensures same results using multiple optimizers across multiple GPUs - """ - run_sharded_correctness( - gpus=2, - accelerator='ddp_spawn', - model_cls=TestMultipleOptimizersModel, - max_percent_speed_diff=0.3 # Increase speed diff since only 2 GPUs sharding 2 optimizers - ) - - -@pytest.mark.skip(reason="Currently DDP manual optimization is broken due to no reduce within training step.") -@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") -@pytest.mark.skipif(platform.system() == "Windows", - reason="Distributed training is not supported on Windows") -@pytest.mark.skipif(not FAIRSCALE_AVAILABLE, - reason="Fairscale is not available") -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) -def test_ddp_sharded_plugin_correctness_multi_gpu_multi_optim_manual(tmpdir): - """ - Ensures using multiple optimizers across multiple GPUs with manual optimization - """ - run_sharded_correctness( - gpus=2, - accelerator='ddp_spawn', - model_cls=TestManualModel, - ) - - -class TestModel(BoringModel): - """ - Overrides training loader to ensure we enforce the same seed for all DDP processes. - """ - - def train_dataloader(self): - seed_everything(42) - return torch.utils.data.DataLoader(RandomDataset(32, 64)) - - -class TestManualModel(TestModel): - def training_step(self, batch, batch_idx, optimizer_idx): - # manual - (opt_a, opt_b) = self.optimizers() - loss_1 = self.step(batch) - - self.manual_backward(loss_1, opt_a) - self.manual_optimizer_step(opt_a) - - # fake discriminator - loss_2 = self.step(batch[0]) - - # ensure we forward the correct params to the optimizer - # without retain_graph we can't do multiple backward passes - self.manual_backward(loss_2, opt_b, retain_graph=True) - self.manual_backward(loss_2, opt_a, retain_graph=True) - self.manual_optimizer_step(opt_b) - - assert self.layer.weight.grad is None or torch.all(self.layer.weight.grad == 0) - - def training_epoch_end(self, outputs) -> None: - # outputs should be an array with an entry per optimizer - assert len(outputs) == 2 - - def configure_optimizers(self): - optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) - optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) - return optimizer, optimizer_2 - - @property - def automatic_optimization(self) -> bool: - return False - - -class TestMultipleOptimizersModel(TestModel): - def training_step(self, batch, batch_idx, optimizer_idx): - output = self.layer(batch) - loss = self.loss(batch, output) - return {"loss": loss} - - def training_epoch_end(self, outputs) -> None: - # outputs should be an array with an entry per optimizer - assert len(outputs) == 2 - - def configure_optimizers(self): - optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) - optimizer_2 = torch.optim.SGD(self.layer.parameters(), lr=0.1) - return optimizer, optimizer_2 - - -def record_ddp_fit_model_stats(trainer, model, gpus): - """ - Helper to calculate wall clock time for fit + max allocated memory. - - Args: - trainer: The trainer object. - model: The LightningModule. - gpus: Number of GPUs in test. - - Returns: - Max Memory if using GPUs, and total wall clock time. - - """ - max_memory = None - - time_start = time.perf_counter() - if gpus > 0: - torch.cuda.reset_peak_memory_stats() - torch.cuda.synchronize() - - trainer.fit(model) - - if gpus > 0: - torch.cuda.synchronize() - max_memory = torch.cuda.max_memory_allocated() / 2 ** 20 - - total_time = time.perf_counter() - time_start - - return max_memory, total_time - - -def run_sharded_correctness( - accelerator='ddp_spawn', - gpus=0, - precision=32, - max_percent_speed_diff=0.25, - model_cls=TestModel): - """ - Ensures that the trained model is identical to the standard DDP implementation. - Also checks for speed/memory regressions, we should expect always less memory but performance to fluctuate. - - Args: - accelerator: Accelerator type for test. - gpus: Number of GPUS to enable. - precision: Whether to use AMP or normal FP32 training. - max_percent_speed_diff: The maximum speed difference compared to normal DDP training. - This is more a safety net for variability in CI which can vary in speed, not for benchmarking. - model_cls: Model class to use for test. - - """ - - # Train normal DDP - seed_everything(42) - ddp_model = model_cls() - - trainer = Trainer( - fast_dev_run=True, - max_epochs=1, - gpus=gpus, - precision=precision, - accelerator=accelerator, - ) - - max_ddp_memory, ddp_time = record_ddp_fit_model_stats( - trainer=trainer, - model=ddp_model, - gpus=gpus - ) - - # Reset and train sharded DDP - seed_everything(42) - sharded_model = model_cls() - - trainer = Trainer( - fast_dev_run=True, - max_epochs=1, - gpus=gpus, - precision=precision, - accelerator=accelerator, - plugins=[DDPShardedPlugin()], - ) - - max_sharded_memory, sharded_time = record_ddp_fit_model_stats( - trainer=trainer, - model=sharded_model, - gpus=gpus - ) - - # Assert model parameters are identical after fit - for ddp_param, shard_param in zip(ddp_model.parameters(), sharded_model.parameters()): - assert torch.equal(ddp_param, shard_param), 'Model parameters are different between DDP and Sharded plugin' - - # Assert speed parity by ensuring percentage difference between sharded/ddp is below threshold - percent_diff = (sharded_time - ddp_time) / sharded_time - - assert percent_diff <= max_percent_speed_diff, \ - f'Sharded plugin was too slow compared to DDP, Sharded Time: {sharded_time}, DDP Time: {ddp_time}' - - if gpus > 0: - # Assert CUDA memory parity - assert max_sharded_memory <= max_ddp_memory, \ - f'Sharded plugin used too much memory compared to DDP,' \ - f'Sharded Mem: {max_sharded_memory}, DDP Mem: {max_ddp_memory}'