From dd33528e00c17289e1865746bb3743a7d6b3663f Mon Sep 17 00:00:00 2001 From: HELSON Date: Fri, 14 Oct 2022 00:14:03 +0800 Subject: [PATCH] [docs] Docs for ColossalaiStrategy (#15093) --- .../advanced/model_parallel.rst | 96 +++++++++++++++++++ docs/source-pytorch/extensions/strategy.rst | 2 +- .../strategies/colossalai.py | 10 +- .../strategies/test_colossalai.py | 4 +- 4 files changed, 107 insertions(+), 5 deletions(-) diff --git a/docs/source-pytorch/advanced/model_parallel.rst b/docs/source-pytorch/advanced/model_parallel.rst index 757b7dffa4..3a57f7b949 100644 --- a/docs/source-pytorch/advanced/model_parallel.rst +++ b/docs/source-pytorch/advanced/model_parallel.rst @@ -55,6 +55,102 @@ Sharding techniques help when model sizes are fairly large; roughly 500M+ parame ---------- +.. _colossalai: + +*********** +Colossal-AI +*********** + +:class:`~pytorch_lightning.strategies.colossalai.ColossalAIStrategy` implements ZeRO-DP with chunk-based memory management. +With this chunk mechanism, really large models can be trained with a small number of GPUs. +It supports larger trainable model size and batch size than usual heterogeneous training by reducing CUDA memory fragments and CPU memory consumption. +Also, it speeds up this kind of heterogeneous training by fully utilizing all kinds of resources. + +When enabling chunk mechanism, a set of consecutive parameters are stored in a chunk, and then the chunk is sharded across different processes. +This can reduce communication and data transmission frequency and fully utilize communication and PCI-E bandwidth, which makes training faster. + +Unlike traditional implementations, which adopt static memory partition, we implemented a dynamic heterogeneous memory management system named Gemini. +During the first training step, the warmup phase will sample the maximum non-model data memory (memory usage expect parameters, gradients, and optimizer states). +In later training, it will use the collected memory usage information to evict chunks dynamically. +Gemini allows you to fit much larger models with limited GPU memory. + +According to our benchmark results, we can train models with up to 24 billion parameters in 1 GPU. +You can install colossalai by consulting `how to download colossalai `_. +Then, run this benchmark in `Colossalai-PL/gpt `_. + +Here is an example showing how to use ColossalAI: + +.. code-block:: python + + from colossalai.nn.optimizer import HybridAdam + + + class MyBert(LightningModule): + ... + + def configure_sharded_model(self) -> None: + # create your model here + self.model = BertForSequenceClassification.from_pretrained("bert-base-uncased") + + def configure_optimizers(self): + # use the specified optimizer + optimizer = HybridAdam(self.model.parameters(), self.lr) + + ... + + + model = MyBert() + trainer = Trainer(accelerator="gpu", devices=1, precision=16, strategy="colossalai") + trainer.fit(model) + +You can find more examples in the `Colossalai-PL `_ repository. + +.. note:: + + * The only accelerator which ColossalAI supports is ``"gpu"``. But CPU resources will be used when the placement policy is set to "auto" or "cpu". + + * The only precision which ColossalAI allows is 16 (FP16). + + * It only supports a single optimizer, which must be ``colossalai.nn.optimizer.CPUAdam`` or ``colossalai.nn.optimizer. + HybridAdam`` now. You can set ``adamw_mode`` to False to use normal Adam. Noticing that ``HybridAdam`` is highly optimized, it uses fused CUDA kernel and parallel CPU kernel. + It is recomended to use ``HybridAdam``, since it updates parameters in GPU and CPU both. + + * Your model must be created using the :meth:`~pytorch_lightning.core.module.LightningModule.configure_sharded_model` method. + + * ``ColossalaiStrategy`` doesn't support gradient accumulation as of now. + +.. _colossal_placement_policy: + +Placement Policy +================ + +Placement policies can help users fully exploit their GPU-CPU heterogeneous memory space for better training efficiency. +There are three options for the placement policy. +They are "cpu", "cuda" and "auto" respectively. + +When the placement policy is set to "cpu", all participated parameters will be offloaded into CPU memory immediately at the end of every auto-grad operation. +In this way, "cpu" placement policy uses the least CUDA memory. +It is the best choice for users who want to exceptionally enlarge their model size or training batch size. + +When using "cuda" option, all parameters are placed in the CUDA memory, no CPU resources will be used during the training. +It is for users who get plenty of CUDA memory. + +The third option, "auto", enables Gemini. +It monitors the consumption of CUDA memory during the warmup phase and collects CUDA memory usage of all auto-grad operations. +In later training steps, Gemini automatically manages the data transmission between GPU and CPU according to collected CUDA memory usage information. +It is the fastest option when CUDA memory is enough. + +Here's an example of changing the placement policy to "cpu". + +.. code-block:: python + + from pytorch_lightning.strategies import ColossalAIStrategy + + model = MyModel() + my_strategy = ColossalAIStrategy(placement_policy="cpu") + trainer = Trainer(accelerator="gpu", devices=4, precision=16, strategy=my_strategy) + trainer.fit(model) + .. _sharded-training: ************************** diff --git a/docs/source-pytorch/extensions/strategy.rst b/docs/source-pytorch/extensions/strategy.rst index 3f752a28ab..807de1b02e 100644 --- a/docs/source-pytorch/extensions/strategy.rst +++ b/docs/source-pytorch/extensions/strategy.rst @@ -77,7 +77,7 @@ The below table lists all relevant strategies available in Lightning with their - Strategy for training collaboratively on local machines or unreliable GPUs across the internet. :ref:`Learn more. ` * - colossalai - :class:`~pytorch_lightning.strategies.ColossalAIStrategy` - - Colossal-AI provides a collection of parallel components for you. It aims to support you to write your distributed deep learning models just like how you write your model on your laptop. `Learn more. `__ + - Colossal-AI provides a collection of parallel components for you. It aims to support you to write your distributed deep learning models just like how you write your model on your laptop. `Learn more. `__ * - fsdp_native - :class:`~pytorch_lightning.strategies.DDPFullyShardedNativeStrategy` - Strategy for Fully Sharded Data Parallel provided by PyTorch. :ref:`Learn more. ` diff --git a/src/pytorch_lightning/strategies/colossalai.py b/src/pytorch_lightning/strategies/colossalai.py index 1f96023aa3..c32c042b56 100644 --- a/src/pytorch_lightning/strategies/colossalai.py +++ b/src/pytorch_lightning/strategies/colossalai.py @@ -131,7 +131,7 @@ class ColossalAIStrategy(DDPStrategy): chunk_search_range: int = 64 * 1024**2, chunk_search_n_grids: int = 1024, min_chunk_size: Optional[int] = None, - initial_scale: float = 2**32, + initial_scale: float = 2**16, min_scale: float = 1, growth_factor: float = 2, backoff_factor: float = 0.5, @@ -463,8 +463,14 @@ class ColossalAIStrategy(DDPStrategy): with _patch_cuda_is_available(): from colossalai.communication.collective import broadcast from colossalai.context import ParallelMode + from colossalai.core import global_context as gpc - return broadcast(obj, src=src, parallel_mode=ParallelMode.GLOBAL) + if isinstance(obj, Tensor): + return broadcast(obj, src=src, parallel_mode=ParallelMode.GLOBAL) + else: + obj_list = [obj] + torch.distributed.broadcast_object_list(obj_list, src, group=gpc.get_group(ParallelMode.GLOBAL)) + return obj_list[0] def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor: """Perform a all_gather on all processes.""" diff --git a/tests/tests_pytorch/strategies/test_colossalai.py b/tests/tests_pytorch/strategies/test_colossalai.py index 7f49ba69fe..cb8a3f3b4a 100644 --- a/tests/tests_pytorch/strategies/test_colossalai.py +++ b/tests/tests_pytorch/strategies/test_colossalai.py @@ -100,7 +100,7 @@ def test_gradient_clip_algorithm_error(tmpdir): trainer.fit(model) -@RunIf(min_cuda_gpus=1, colossalai=True) +@RunIf(min_cuda_gpus=1, standalone=True, colossalai=True) def test_gradient_accumulation_error(tmpdir): model = ModelParallelBoringModel() trainer = Trainer( @@ -120,7 +120,7 @@ def test_gradient_accumulation_error(tmpdir): trainer.fit(model) -@RunIf(min_cuda_gpus=1, colossalai=True) +@RunIf(min_cuda_gpus=1, standalone=True, colossalai=True) def test_colossalai_optimizer(tmpdir): model = BoringModel() trainer = Trainer(