Fix DeepSpeedPlugin with IterableDataset (#7362)
* deepspeed add train_micro_batch_size_per_gpu argument * Update naming and doc * Modify to use auto naming convention, add test * Add iterable tests * Fix tests, attempt by mocking * Import correct package * Fix comparison * Set as special test * Remove import * Add Changelog Co-authored-by: SeanNaren <sean@grid.ai>
This commit is contained in:
parent
28103c67c2
commit
98b94b810c
|
@ -22,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
### Fixed
|
||||
|
||||
|
||||
- Fixed DeepSpeed with IterableDatasets ([#7362](https://github.com/PyTorchLightning/pytorch-lightning/pull/7362))
|
||||
|
||||
|
||||
## [1.3.0] - 2021-05-06
|
||||
|
||||
|
|
|
@ -88,6 +88,7 @@ class DeepSpeedPlugin(DDPPlugin):
|
|||
allgather_bucket_size: int = 2e8,
|
||||
reduce_bucket_size: int = 2e8,
|
||||
zero_allow_untested_optimizer: bool = True,
|
||||
logging_batch_size_per_gpu: Union[str, int] = "auto",
|
||||
config: Optional[Union[Path, str, dict]] = None,
|
||||
logging_level: int = logging.WARN,
|
||||
num_nodes: int = 1,
|
||||
|
@ -148,6 +149,13 @@ class DeepSpeedPlugin(DDPPlugin):
|
|||
zero_allow_untested_optimizer: Allow untested optimizers to be used with ZeRO. Currently only Adam is a
|
||||
DeepSpeed supported optimizer when using ZeRO (default: True)
|
||||
|
||||
logging_batch_size_per_gpu: Config used in DeepSpeed to calculate verbose timing for logging
|
||||
on a per sample per second basis (only displayed if logging=logging.INFO).
|
||||
If set to "auto", the plugin tries to infer this from
|
||||
the train DataLoader's BatchSampler, else defaults to 1.
|
||||
To obtain accurate logs when using datasets that do not support batch samplers,
|
||||
set this to the actual per gpu batch size (trainer.batch_size).
|
||||
|
||||
config: Pass in a deepspeed formatted config dict,
|
||||
or path to a deepspeed config: https://www.deepspeed.ai/docs/config-json.
|
||||
All defaults will be ignored if a config is passed in. (Default: ``None``)
|
||||
|
@ -182,6 +190,7 @@ class DeepSpeedPlugin(DDPPlugin):
|
|||
when using ZeRO Stage 3. This allows a single weight file to contain the entire model,
|
||||
rather than individual sharded weight files.
|
||||
Disable to save sharded states individually. (Default: True)
|
||||
|
||||
"""
|
||||
if not _DEEPSPEED_AVAILABLE:
|
||||
raise MisconfigurationException(
|
||||
|
@ -197,6 +206,7 @@ class DeepSpeedPlugin(DDPPlugin):
|
|||
self.config = self._create_default_config(
|
||||
zero_optimization,
|
||||
zero_allow_untested_optimizer,
|
||||
logging_batch_size_per_gpu,
|
||||
partition_activations=partition_activations,
|
||||
cpu_checkpointing=cpu_checkpointing,
|
||||
contiguous_memory_optimization=contiguous_memory_optimization,
|
||||
|
@ -409,14 +419,22 @@ class DeepSpeedPlugin(DDPPlugin):
|
|||
" as this will be set via accumulate_grad_batches=x argument passed via the Lightning Trainer."
|
||||
)
|
||||
if "train_micro_batch_size_per_gpu" not in self.config:
|
||||
# train_micro_batch_size_per_gpu is used for throughput logging purposes
|
||||
# by default we use the batch size of the loader which may be incorrect if a batch sampler is passed
|
||||
batch_size = self.lightning_module.train_dataloader().batch_sampler.batch_size
|
||||
batch_size = self._auto_select_batch_size()
|
||||
self.config["train_micro_batch_size_per_gpu"] = batch_size
|
||||
self.config["gradient_accumulation_steps"] = self.lightning_module.trainer.accumulate_grad_batches
|
||||
if "gradient_clipping" not in self.config:
|
||||
self.config["gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val
|
||||
|
||||
def _auto_select_batch_size(self):
|
||||
# train_micro_batch_size_per_gpu is used for throughput logging purposes
|
||||
# by default we try to use the batch size of the loader
|
||||
batch_size = 1
|
||||
if hasattr(self.lightning_module, 'train_dataloader'):
|
||||
train_dataloader = self.lightning_module.train_dataloader()
|
||||
if hasattr(train_dataloader, 'batch_sampler'):
|
||||
batch_size = train_dataloader.batch_sampler.batch_size
|
||||
return batch_size
|
||||
|
||||
def _format_precision_config(self):
|
||||
amp_type = self.lightning_module.trainer.accelerator_connector.amp_type
|
||||
amp_level = self.lightning_module.trainer.accelerator_connector.amp_level
|
||||
|
@ -446,6 +464,7 @@ class DeepSpeedPlugin(DDPPlugin):
|
|||
self,
|
||||
zero_optimization: bool,
|
||||
zero_allow_untested_optimizer: bool,
|
||||
logging_batch_size_per_gpu: Union[str, int],
|
||||
partition_activations: bool,
|
||||
cpu_checkpointing: bool,
|
||||
contiguous_memory_optimization: bool,
|
||||
|
@ -466,6 +485,8 @@ class DeepSpeedPlugin(DDPPlugin):
|
|||
"zero_optimization": zero_kwargs,
|
||||
**cfg
|
||||
}
|
||||
if logging_batch_size_per_gpu != 'auto':
|
||||
cfg = {"train_micro_batch_size_per_gpu": logging_batch_size_per_gpu, **cfg}
|
||||
return cfg
|
||||
|
||||
def _filepath_to_dir(self, filepath: str) -> str:
|
||||
|
|
|
@ -7,6 +7,7 @@ import torch
|
|||
import torch.nn.functional as F
|
||||
from torch import nn, Tensor
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_lightning import LightningModule, seed_everything, Trainer
|
||||
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
|
||||
|
@ -14,7 +15,7 @@ from pytorch_lightning.metrics import Accuracy
|
|||
from pytorch_lightning.plugins import DeepSpeedPlugin, DeepSpeedPrecisionPlugin
|
||||
from pytorch_lightning.plugins.training_type.deepspeed import LightningDeepSpeedModule
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from tests.helpers.boring_model import BoringModel
|
||||
from tests.helpers.boring_model import BoringModel, RandomDataset, RandomIterableDataset
|
||||
from tests.helpers.datamodules import ClassifDataModule
|
||||
from tests.helpers.runif import RunIf
|
||||
|
||||
|
@ -234,6 +235,44 @@ def test_warn_deepspeed_override_backward(tmpdir):
|
|||
trainer.fit(model)
|
||||
|
||||
|
||||
@RunIf(min_gpus=1, deepspeed=True, special=True)
|
||||
@pytest.mark.parametrize(['dataset_cls', 'value'], [(RandomDataset, "auto"), (RandomDataset, 10),
|
||||
(RandomIterableDataset, "auto"), (RandomIterableDataset, 10)])
|
||||
def test_deepspeed_auto_batch_size_config_select(tmpdir, dataset_cls, value):
|
||||
"""Test to ensure that the batch size is correctly set as expected for deepspeed logging purposes."""
|
||||
|
||||
class TestModel(BoringModel):
|
||||
|
||||
def train_dataloader(self):
|
||||
return DataLoader(dataset_cls(32, 64))
|
||||
|
||||
class AssertCallback(Callback):
|
||||
|
||||
def on_train_start(self, trainer, pl_module) -> None:
|
||||
assert isinstance(trainer.accelerator.training_type_plugin, DeepSpeedPlugin)
|
||||
config = trainer.accelerator.training_type_plugin.config
|
||||
|
||||
# int value overrides auto mode
|
||||
expected_value = value if isinstance(value, int) else 1
|
||||
if dataset_cls == RandomDataset:
|
||||
expected_value = pl_module.train_dataloader().batch_size if value == "auto" else value
|
||||
|
||||
assert config['train_micro_batch_size_per_gpu'] == expected_value
|
||||
raise SystemExit
|
||||
|
||||
ck = AssertCallback()
|
||||
model = TestModel()
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
fast_dev_run=True,
|
||||
callbacks=ck,
|
||||
gpus=1,
|
||||
plugins=DeepSpeedPlugin(logging_batch_size_per_gpu=value, zero_optimization=False),
|
||||
)
|
||||
with pytest.raises(SystemExit):
|
||||
trainer.fit(model)
|
||||
|
||||
|
||||
@RunIf(min_gpus=1, deepspeed=True, special=True)
|
||||
def test_deepspeed_run_configure_optimizers(tmpdir):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue