From 98b94b810c89d5e51a0ad0a2e6a87747aee6fbe9 Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Fri, 7 May 2021 09:46:03 +0000 Subject: [PATCH] Fix DeepSpeedPlugin with IterableDataset (#7362) * deepspeed add train_micro_batch_size_per_gpu argument * Update naming and doc * Modify to use auto naming convention, add test * Add iterable tests * Fix tests, attempt by mocking * Import correct package * Fix comparison * Set as special test * Remove import * Add Changelog Co-authored-by: SeanNaren --- CHANGELOG.md | 2 + .../plugins/training_type/deepspeed.py | 27 ++++++++++-- tests/plugins/test_deepspeed_plugin.py | 41 ++++++++++++++++++- 3 files changed, 66 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 193ca633f2..af142fdba3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed DeepSpeed with IterableDatasets ([#7362](https://github.com/PyTorchLightning/pytorch-lightning/pull/7362)) + ## [1.3.0] - 2021-05-06 diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 54974739c1..fe3f51fa99 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -88,6 +88,7 @@ class DeepSpeedPlugin(DDPPlugin): allgather_bucket_size: int = 2e8, reduce_bucket_size: int = 2e8, zero_allow_untested_optimizer: bool = True, + logging_batch_size_per_gpu: Union[str, int] = "auto", config: Optional[Union[Path, str, dict]] = None, logging_level: int = logging.WARN, num_nodes: int = 1, @@ -148,6 +149,13 @@ class DeepSpeedPlugin(DDPPlugin): zero_allow_untested_optimizer: Allow untested optimizers to be used with ZeRO. Currently only Adam is a DeepSpeed supported optimizer when using ZeRO (default: True) + logging_batch_size_per_gpu: Config used in DeepSpeed to calculate verbose timing for logging + on a per sample per second basis (only displayed if logging=logging.INFO). + If set to "auto", the plugin tries to infer this from + the train DataLoader's BatchSampler, else defaults to 1. + To obtain accurate logs when using datasets that do not support batch samplers, + set this to the actual per gpu batch size (trainer.batch_size). + config: Pass in a deepspeed formatted config dict, or path to a deepspeed config: https://www.deepspeed.ai/docs/config-json. All defaults will be ignored if a config is passed in. (Default: ``None``) @@ -182,6 +190,7 @@ class DeepSpeedPlugin(DDPPlugin): when using ZeRO Stage 3. This allows a single weight file to contain the entire model, rather than individual sharded weight files. Disable to save sharded states individually. (Default: True) + """ if not _DEEPSPEED_AVAILABLE: raise MisconfigurationException( @@ -197,6 +206,7 @@ class DeepSpeedPlugin(DDPPlugin): self.config = self._create_default_config( zero_optimization, zero_allow_untested_optimizer, + logging_batch_size_per_gpu, partition_activations=partition_activations, cpu_checkpointing=cpu_checkpointing, contiguous_memory_optimization=contiguous_memory_optimization, @@ -409,14 +419,22 @@ class DeepSpeedPlugin(DDPPlugin): " as this will be set via accumulate_grad_batches=x argument passed via the Lightning Trainer." ) if "train_micro_batch_size_per_gpu" not in self.config: - # train_micro_batch_size_per_gpu is used for throughput logging purposes - # by default we use the batch size of the loader which may be incorrect if a batch sampler is passed - batch_size = self.lightning_module.train_dataloader().batch_sampler.batch_size + batch_size = self._auto_select_batch_size() self.config["train_micro_batch_size_per_gpu"] = batch_size self.config["gradient_accumulation_steps"] = self.lightning_module.trainer.accumulate_grad_batches if "gradient_clipping" not in self.config: self.config["gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val + def _auto_select_batch_size(self): + # train_micro_batch_size_per_gpu is used for throughput logging purposes + # by default we try to use the batch size of the loader + batch_size = 1 + if hasattr(self.lightning_module, 'train_dataloader'): + train_dataloader = self.lightning_module.train_dataloader() + if hasattr(train_dataloader, 'batch_sampler'): + batch_size = train_dataloader.batch_sampler.batch_size + return batch_size + def _format_precision_config(self): amp_type = self.lightning_module.trainer.accelerator_connector.amp_type amp_level = self.lightning_module.trainer.accelerator_connector.amp_level @@ -446,6 +464,7 @@ class DeepSpeedPlugin(DDPPlugin): self, zero_optimization: bool, zero_allow_untested_optimizer: bool, + logging_batch_size_per_gpu: Union[str, int], partition_activations: bool, cpu_checkpointing: bool, contiguous_memory_optimization: bool, @@ -466,6 +485,8 @@ class DeepSpeedPlugin(DDPPlugin): "zero_optimization": zero_kwargs, **cfg } + if logging_batch_size_per_gpu != 'auto': + cfg = {"train_micro_batch_size_per_gpu": logging_batch_size_per_gpu, **cfg} return cfg def _filepath_to_dir(self, filepath: str) -> str: diff --git a/tests/plugins/test_deepspeed_plugin.py b/tests/plugins/test_deepspeed_plugin.py index c768a9aabf..056c28ffa2 100644 --- a/tests/plugins/test_deepspeed_plugin.py +++ b/tests/plugins/test_deepspeed_plugin.py @@ -7,6 +7,7 @@ import torch import torch.nn.functional as F from torch import nn, Tensor from torch.optim import Optimizer +from torch.utils.data import DataLoader from pytorch_lightning import LightningModule, seed_everything, Trainer from pytorch_lightning.callbacks import Callback, ModelCheckpoint @@ -14,7 +15,7 @@ from pytorch_lightning.metrics import Accuracy from pytorch_lightning.plugins import DeepSpeedPlugin, DeepSpeedPrecisionPlugin from pytorch_lightning.plugins.training_type.deepspeed import LightningDeepSpeedModule from pytorch_lightning.utilities.exceptions import MisconfigurationException -from tests.helpers.boring_model import BoringModel +from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset from tests.helpers.datamodules import ClassifDataModule from tests.helpers.runif import RunIf @@ -234,6 +235,44 @@ def test_warn_deepspeed_override_backward(tmpdir): trainer.fit(model) +@RunIf(min_gpus=1, deepspeed=True, special=True) +@pytest.mark.parametrize(['dataset_cls', 'value'], [(RandomDataset, "auto"), (RandomDataset, 10), + (RandomIterableDataset, "auto"), (RandomIterableDataset, 10)]) +def test_deepspeed_auto_batch_size_config_select(tmpdir, dataset_cls, value): + """Test to ensure that the batch size is correctly set as expected for deepspeed logging purposes.""" + + class TestModel(BoringModel): + + def train_dataloader(self): + return DataLoader(dataset_cls(32, 64)) + + class AssertCallback(Callback): + + def on_train_start(self, trainer, pl_module) -> None: + assert isinstance(trainer.accelerator.training_type_plugin, DeepSpeedPlugin) + config = trainer.accelerator.training_type_plugin.config + + # int value overrides auto mode + expected_value = value if isinstance(value, int) else 1 + if dataset_cls == RandomDataset: + expected_value = pl_module.train_dataloader().batch_size if value == "auto" else value + + assert config['train_micro_batch_size_per_gpu'] == expected_value + raise SystemExit + + ck = AssertCallback() + model = TestModel() + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + callbacks=ck, + gpus=1, + plugins=DeepSpeedPlugin(logging_batch_size_per_gpu=value, zero_optimization=False), + ) + with pytest.raises(SystemExit): + trainer.fit(model) + + @RunIf(min_gpus=1, deepspeed=True, special=True) def test_deepspeed_run_configure_optimizers(tmpdir): """