Add ColossalAI strategy (#14224)

Co-authored-by: HELSON <c2h214748@gmail.com>
Co-authored-by: rohitgr7 <rohitgr1998@gmail.com>
Co-authored-by: otaj <ota@lightning.ai>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
ver217 2022-10-11 19:59:09 +08:00 committed by GitHub
parent 6f16e46bdb
commit 2fef6d9403
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 933 additions and 8 deletions

View File

@ -97,6 +97,7 @@ jobs:
set -e
python -c "fname = 'requirements/pytorch/strategies.txt' ; lines = [line for line in open(fname).readlines() if 'horovod' not in line] ; open(fname, 'w').writelines(lines)"
python -c "fname = 'requirements/pytorch/strategies.txt' ; lines = [line for line in open(fname).readlines() if 'bagua' not in line] ; open(fname, 'w').writelines(lines)"
python -c "fname = 'requirements/pytorch/strategies.txt' ; lines = [line for line in open(fname).readlines() if 'colossalai' not in line] ; open(fname, 'w').writelines(lines)"
PYTORCH_VERSION=$(python -c "import torch; print(torch.__version__.split('+')[0])")
python ./requirements/pytorch/adjust-versions.py requirements/pytorch/base.txt ${PYTORCH_VERSION}
@ -110,6 +111,11 @@ jobs:
CUDA_VERSION_BAGUA=$(python -c "print([ver for ver in [116,113,111,102] if $CUDA_VERSION_MM >= ver][0])")
pip install "bagua-cuda$CUDA_VERSION_BAGUA"
PYTORCH_VERSION_COLOSSALAI=$(python -c "import torch; print(torch.__version__.split('+')[0][:4])")
CUDA_VERSION_MM_COLOSSALAI=$(python -c "import torch ; print(''.join(map(str, torch.version.cuda)))")
CUDA_VERSION_COLOSSALAI=$(python -c "print([ver for ver in [11.3, 11.1] if $CUDA_VERSION_MM_COLOSSALAI >= ver][0])")
pip install "colossalai==0.1.10+torch${PYTORCH_VERSION_COLOSSALAI}cu${CUDA_VERSION_COLOSSALAI}" --find-links https://release.colossalai.org
pip list
env:
PACKAGE_NAME: pytorch

View File

@ -54,9 +54,10 @@ RUN \
libopenmpi-dev \
openmpi-bin \
ssh \
ninja-build \
libnccl2=$TO_INSTALL_NCCL \
libnccl-dev=$TO_INSTALL_NCCL && \
# Install python
# Install python
add-apt-repository ppa:deadsnakes/ppa && \
apt-get install -y \
python${PYTHON_VERSION} \
@ -65,7 +66,7 @@ RUN \
&& \
update-alternatives --install /usr/bin/python${PYTHON_VERSION%%.*} python${PYTHON_VERSION%%.*} /usr/bin/python${PYTHON_VERSION} 1 && \
update-alternatives --install /usr/bin/python python /usr/bin/python${PYTHON_VERSION} 1 && \
# Cleaning
# Cleaning
apt-get autoremove -y && \
apt-get clean && \
rm -rf /root/.cache && \
@ -82,14 +83,15 @@ RUN \
rm get-pip.py && \
pip install -q fire && \
# Disable cache \
CUDA_VERSION_MM=$(python -c "print(''.join('$CUDA_VERSION'.split('.')[:2]))") && \
export CUDA_VERSION_MM=$(python -c "print(''.join('$CUDA_VERSION'.split('.')[:2]))") && \
pip config set global.cache-dir false && \
# set particular PyTorch version
python ./requirements/pytorch/adjust-versions.py requirements/pytorch/base.txt ${PYTORCH_VERSION} && \
python ./requirements/pytorch/adjust-versions.py requirements/pytorch/extra.txt ${PYTORCH_VERSION} && \
python ./requirements/pytorch/adjust-versions.py requirements/pytorch/examples.txt ${PYTORCH_VERSION} && \
# Install all requirements \
pip install -r requirements/pytorch/devel.txt --no-cache-dir --find-links https://download.pytorch.org/whl/cu${CUDA_VERSION_MM}/torch_stable.html && \
# Install base requirements \
pip install -r requirements/pytorch/base.txt --no-cache-dir --find-links https://download.pytorch.org/whl/cu${CUDA_VERSION_MM}/torch_stable.html && \
rm assistant.py
ENV \
@ -108,7 +110,7 @@ RUN \
export HOROVOD_BUILD_CUDA_CC_LIST=${HOROVOD_BUILD_CUDA_CC_LIST//"."/""} && \
echo $HOROVOD_BUILD_CUDA_CC_LIST && \
cmake --version && \
pip install --no-cache-dir -r ./requirements/pytorch/strategies.txt && \
pip install --no-cache-dir horovod && \
horovodrun --check-build
RUN \
@ -136,6 +138,28 @@ RUN \
if [[ "$CUDA_VERSION_MM" = "$CUDA_VERSION_BAGUA" ]]; then python -c "import bagua_core; bagua_core.install_deps()"; fi && \
python -c "import bagua; print(bagua.__version__)"
RUN \
# install ColossalAI
SHOULD_INSTALL_COLOSSAL=$(python -c "import torch; print(1 if int(torch.__version__.split('.')[1]) > 9 else 0)") && \
if [[ "$SHOULD_INSTALL_COLOSSAL" = "1" ]]; then \
PYTORCH_VERSION_COLOSSALAI=$(python -c "import torch; print(torch.__version__.split('+')[0][:4])") ; \
CUDA_VERSION_MM_COLOSSALAI=$(python -c "import torch ; print(''.join(map(str, torch.version.cuda)))") ; \
CUDA_VERSION_COLOSSALAI=$(python -c "print([ver for ver in [11.3, 11.1] if $CUDA_VERSION_MM_COLOSSALAI >= ver][0])") ; \
pip install "colossalai==0.1.10+torch${PYTORCH_VERSION_COLOSSALAI}cu${CUDA_VERSION_COLOSSALAI}" --find-links https://release.colossalai.org ; \
python -c "import colossalai; print(colossalai.__version__)" ; \
fi
RUN \
# install rest of strategies
# remove colossalai from requirements since they are installed separately
SHOULD_INSTALL_COLOSSAL=$(python -c "import torch; print(1 if int(torch.__version__.split('.')[1]) > 9 else 0)") && \
if [[ "$SHOULD_INSTALL_COLOSSAL" = "0" ]]; then \
python -c "fname = 'requirements/pytorch/strategies.txt' ; lines = [line for line in open(fname).readlines() if 'colossalai' not in line] ; open(fname, 'w').writelines(lines)" ; \
fi && \
echo "$SHOULD_INSTALL_COLOSSAL" && \
cat requirements/pytorch/strategies.txt && \
pip install -r requirements/pytorch/devel.txt -r requirements/pytorch/strategies.txt --no-cache-dir --find-links https://download.pytorch.org/whl/cu${CUDA_VERSION_MM}/torch_stable.html
COPY requirements/pytorch/check-avail-extras.py check-avail-extras.py
COPY requirements/pytorch/check-avail-strategies.py check-avail-strategies.py

View File

@ -41,7 +41,11 @@ RUN \
fi && \
# otherwise there is collision with folder name ans pkg name on Pypi
cd lightning && \
pip install .["extra","loggers","strategies"] --no-cache-dir && \
SHOULD_INSTALL_COLOSSAL=$(python -c "import torch; print(1 if int(torch.__version__.split('.')[1]) > 9 else 0)") && \
if [[ "$SHOULD_INSTALL_COLOSSAL" = "0" ]]; then \
python -c "fname = 'requirements/pytorch/strategies.txt' ; lines = [line for line in open(fname).readlines() if 'colossalai' not in line] ; open(fname, 'w').writelines(lines)" ; \
fi && \
pip install .["extra","loggers","strategies"] --no-cache-dir --find-links https://release.colossalai.org && \
cd .. && \
rm -rf lightning

View File

@ -185,6 +185,7 @@ precision
:template: classtemplate.rst
ApexMixedPrecisionPlugin
ColossalAIPrecisionPlugin
DeepSpeedPrecisionPlugin
DoublePrecisionPlugin
FullyShardedNativeMixedPrecisionPlugin
@ -285,7 +286,7 @@ strategies
:template: classtemplate.rst
BaguaStrategy
HivemindStrategy
ColossalAIStrategy
DDPFullyShardedNativeStrategy
DDPFullyShardedStrategy
DDPShardedStrategy
@ -294,6 +295,7 @@ strategies
DDPStrategy
DataParallelStrategy
DeepSpeedStrategy
HivemindStrategy
HorovodStrategy
HPUParallelStrategy
IPUStrategy

View File

@ -53,6 +53,7 @@ The full list of built-in precision plugins is listed below.
:template: classtemplate.rst
ApexMixedPrecisionPlugin
ColossalAIPrecisionPlugin
DeepSpeedPrecisionPlugin
DoublePrecisionPlugin
FullyShardedNativeMixedPrecisionPlugin

View File

@ -75,6 +75,9 @@ The below table lists all relevant strategies available in Lightning with their
* - collaborative
- :class:`~pytorch_lightning.strategies.HivemindStrategy`
- Strategy for training collaboratively on local machines or unreliable GPUs across the internet. :ref:`Learn more. <strategies/hivemind:Training on unreliable mixed GPUs across the internet>`
* - colossalai
- :class:`~pytorch_lightning.strategies.ColossalAIStrategy`
- Colossal-AI provides a collection of parallel components for you. It aims to support you to write your distributed deep learning models just like how you write your model on your laptop. `Learn more. <https://www.colossalai.or/>`__
* - fsdp_native
- :class:`~pytorch_lightning.strategies.DDPFullyShardedNativeStrategy`
- Strategy for Fully Sharded Data Parallel provided by PyTorch. :ref:`Learn more. <advanced/model_parallel:PyTorch Fully Sharded Training>`

View File

@ -1,6 +1,7 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment
colossalai>=0.1.10
fairscale>=0.4.5, <=0.4.6
deepspeed>=0.6.0, <=0.7.0
# no need to install with [pytorch] as pytorch is already installed

View File

@ -5,6 +5,7 @@ from pytorch_lightning.plugins.io.async_plugin import AsyncCheckpointIO
from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO
from pytorch_lightning.plugins.layer_sync import LayerSync, NativeSyncBatchNorm
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin
from pytorch_lightning.plugins.precision.colossalai import ColossalAIPrecisionPlugin
from pytorch_lightning.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin
from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin
from pytorch_lightning.plugins.precision.fsdp_native_native_amp import FullyShardedNativeNativeMixedPrecisionPlugin
@ -27,6 +28,7 @@ __all__ = [
"XLACheckpointIO",
"HPUCheckpointIO",
"ApexMixedPrecisionPlugin",
"ColossalAIPrecisionPlugin",
"DeepSpeedPrecisionPlugin",
"DoublePrecisionPlugin",
"IPUPrecisionPlugin",

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin
from pytorch_lightning.plugins.precision.colossalai import ColossalAIPrecisionPlugin
from pytorch_lightning.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin
from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin
from pytorch_lightning.plugins.precision.fsdp_native_native_amp import FullyShardedNativeNativeMixedPrecisionPlugin
@ -26,6 +27,7 @@ from pytorch_lightning.plugins.precision.tpu_bf16 import TPUBf16PrecisionPlugin
__all__ = [
"ApexMixedPrecisionPlugin",
"ColossalAIPrecisionPlugin",
"DeepSpeedPrecisionPlugin",
"DoublePrecisionPlugin",
"FullyShardedNativeNativeMixedPrecisionPlugin",

View File

@ -0,0 +1,90 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, Union
from lightning_utilities.core.rank_zero import WarningCache
from torch import Tensor
from torch.optim import Optimizer
import pytorch_lightning as pl
from lightning_lite.utilities.types import Steppable
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities.enums import PrecisionType
warning_cache = WarningCache()
class ColossalAIPrecisionPlugin(PrecisionPlugin):
"""Precision plugin for ColossalAI integration.
Args:
precision: Half precision (16).
Raises:
ValueError:
If precison is not 16.
"""
def __init__(self, precision: Union[str, int] = 16) -> None:
if not (precision == PrecisionType.HALF):
raise ValueError(
f"`Trainer(strategy='colossalai', precision={precision!r})` is not supported."
" Consider setting `precision=16`."
)
super().__init__()
self.precision = precision
def backward( # type: ignore[override]
self,
tensor: Tensor,
model: "pl.LightningModule",
optimizer: Optional[Steppable],
optimizer_idx: Optional[int],
*args: Any,
**kwargs: Any,
) -> None:
assert optimizer is not None
optimizer.backward(tensor)
def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
optimizer.clip_grad_norm(None, clip_val)
def clip_grad_by_value(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
raise NotImplementedError("`clip_grad_by_value` is not supported by `ColossalAI`")
def optimizer_step( # type: ignore[override]
self,
optimizer: Steppable,
model: "pl.LightningModule",
optimizer_idx: int,
closure: Callable[[], Any],
**kwargs: Any,
) -> Any:
closure_result = closure()
self._after_closure(model, optimizer, optimizer_idx)
skipped_backward = closure_result is None
if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward:
raise ValueError(
"Skipping backward by returning `None` from your `training_step` is not supported by `ColossalAI`."
)
optimizer.step()
def _track_grad_norm(self, trainer: "pl.Trainer") -> None:
if trainer.track_grad_norm == -1:
return
# the gradients are not available in the model due to gradient partitioning in zero stage >= 2
warning_cache.warn(
f"You set `Trainer(track_grad_norm={trainer.track_grad_norm!r})' but this is not supported for ColossalAI."
" The setting will be ignored."
)

View File

@ -13,6 +13,7 @@
# limitations under the License.
from lightning_lite.strategies.registry import _StrategyRegistry
from pytorch_lightning.strategies.bagua import BaguaStrategy # noqa: F401
from pytorch_lightning.strategies.colossalai import ColossalAIStrategy # noqa: F401
from pytorch_lightning.strategies.ddp import DDPStrategy # noqa: F401
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy # noqa: F401
from pytorch_lightning.strategies.deepspeed import DeepSpeedStrategy # noqa: F401

View File

@ -1,3 +1,16 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from typing import Any, Dict, List, Optional, Union

View File

@ -0,0 +1,476 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Dict, List, Mapping, Optional, OrderedDict, TYPE_CHECKING, Union
import torch
from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.rank_zero import rank_zero_warn
from torch import Tensor
from torch.nn import Module
from torch.optim.optimizer import Optimizer
import pytorch_lightning as pl
from lightning_lite.accelerators.cuda import _patch_cuda_is_available
from lightning_lite.plugins.environments.cluster_environment import ClusterEnvironment
from lightning_lite.utilities.distributed import ReduceOp
from pytorch_lightning.accelerators.cuda import CUDAAccelerator
from pytorch_lightning.overrides.base import _LightningModuleWrapperBase, _LightningPrecisionModuleWrapperBase
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import ColossalAIPrecisionPlugin
from pytorch_lightning.strategies.ddp import DDPStrategy
from pytorch_lightning.strategies.strategy import TBroadcast
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.enums import PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import STEP_OUTPUT
_COLOSSALAI_AVAILABLE = RequirementCache("colossalai")
if TYPE_CHECKING and _COLOSSALAI_AVAILABLE:
with _patch_cuda_is_available():
from colossalai.utils.model.colo_init_context import ColoInitContext
else:
ColoInitContext = Any
class ColossalAIStrategy(DDPStrategy):
"""ColossalAI strategy. It only supports a single optimizer, which must be
:class:`colossalai.nn.optimizer.CPUAdam` or :class:`colossalai.nn.optimizer.HybridAdam` now. Your model must
be created in the function ``LightningModule.configure_sharded_model()``. Thus, you should overwrite this function.
More details can be found in the below example.
It configures accelerator and precision, and you should not configure them when initializing ``Trainer``.
CUDA is essential for this strategy. Please make sure CUDA is available.
Example::
class GLUETransformer(LightningModule):
...
def configure_sharded_model(self) -> None:
self.model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
trainer = Trainer(..., accelerator="gpu", precision=16, strategy="colossalai")
Args:
use_chunk: Whether to use chunk-based memory management.
It can speed up training, but slightly more memory will be used.
chunk_size: The size of a chunk.
It will be ignored when ``use_chunk=False``.
If it's None, a best chunk size will be searched out based on ``chunk_search_range``,
``chunk_search_n_grids`` and ``min_chunk_size``.
enable_distributed_storage: Whether to storage model in a distributed manner.
It reduces memory from 1 to 1/N, but it may slow down training.
placement_policy: It can be "cpu", "cuda" and "auto".
* If it's "cpu", parameters, gradients and optimizer states will be offloaded to CPU,
which means min CUDA memory will be used.
* If it's "cuda", they won't be offloaded, which means max CUDA memory will be used. It's the fastest.
* If it's "auto", they are moving dynamically based on CPU and CUDA memory usage.
It will utilize heterogeneous memory space evenly and well.
Note that "auto" policy can only work well when no other processes use CUDA during your training.
force_outputs_fp32: Whether to cast outputs to fp32.
gpu_margin_mem_ratio: The ratio of GPU remaining memory (after the first forward-backward)
which will be used by optimizer.
This argument will be ignored when ``placement_policy`` is not "auto".
chunk_search_range: The range of chunk size to search.
The actual search range will be from
``max(min_chunk_size, max_param_size)`` to ``max(min_chunk_size, max_param_size) + chunk_search_range``.
chunk_search_n_grids: The number of intervals in the search range.
min_chunk_size: The minimum size for a chunk.
initial_scale: The initial dynamic loss scale value.
min_scale: The minimum dynamic loss scaling value.
growth_factor: The multiplication factor for increasing loss scale.
backoff_factor: The multiplication factor for decreasing loss scale.
growth_interval: The number of steps to increase loss scale when no overflow occurs.
hysteresis: The number of overflows before decreasing loss scale.
max_scale: The maximum dynamic loss scaling value.
.. _colossalai.nn.optimizer.CPUAdam:
https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.nn.optimizer.cpu_adam.html
.. _colossalai.nn.optimizer.HybridAdam:
https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.nn.optimizer.hybrid_adam.html
"""
strategy_name = "colossalai"
def __init__(
self,
use_chunk: bool = True,
chunk_size: Optional[int] = None,
enable_distributed_storage: bool = True,
placement_policy: str = "auto",
force_outputs_fp32: bool = False,
gpu_margin_mem_ratio: float = 0.0,
chunk_search_range: int = 64 * 1024**2,
chunk_search_n_grids: int = 1024,
min_chunk_size: Optional[int] = None,
initial_scale: float = 2**32,
min_scale: float = 1,
growth_factor: float = 2,
backoff_factor: float = 0.5,
growth_interval: int = 1000,
hysteresis: int = 2,
max_scale: float = 2**32,
accelerator: Optional["pl.accelerators.Accelerator"] = None,
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
checkpoint_io: Optional[CheckpointIO] = None,
precision_plugin: Optional[ColossalAIPrecisionPlugin] = None,
) -> None:
if not _COLOSSALAI_AVAILABLE:
raise MisconfigurationException(
"To use the `ColossalAIStrategy`, please install `colossalai` first. "
"Download `colossalai` by consulting `https://colossalai.org/download`."
)
with _patch_cuda_is_available():
from colossalai.logging import get_dist_logger
super().__init__(
accelerator=accelerator,
parallel_devices=parallel_devices,
cluster_environment=cluster_environment,
checkpoint_io=checkpoint_io,
precision_plugin=precision_plugin,
)
self.use_chunk = use_chunk
self.chunk_size = chunk_size
self.enable_distributed_storage = enable_distributed_storage
self.placement_policy = placement_policy
self.force_outputs_fp32 = force_outputs_fp32
self.gpu_margin_mem_ratio = gpu_margin_mem_ratio
self.chunk_size_search_kwargs = {
"search_range": chunk_search_range,
"n_grids": chunk_search_n_grids,
"min_chunk_size": min_chunk_size,
}
self.amp_kwargs = {
"initial_scale": initial_scale,
"min_scale": min_scale,
"growth_factor": growth_factor,
"backoff_factor": backoff_factor,
"growth_interval": growth_interval,
"hysteresis": hysteresis,
"max_scale": max_scale,
}
self._num_nodes = 1
self._logger = get_dist_logger()
@property
def root_device(self) -> torch.device:
with _patch_cuda_is_available():
from colossalai.utils import get_current_device
if self.parallel_devices is not None:
return self.parallel_devices[self.local_rank]
return get_current_device()
@property
def handles_gradient_accumulation(self) -> bool:
"""Whether the plugin handles gradient accumulation internally."""
return True
@property
def restore_checkpoint_after_setup(self) -> bool:
"""Override to delay restoring from checkpoint till after pre-dispatch."""
return True
def setup_distributed(self) -> None:
with _patch_cuda_is_available():
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers
assert self.cluster_environment is not None
self.set_world_ranks()
if not gpc.is_initialized(ParallelMode.GLOBAL):
disable_existing_loggers()
gpc.init_global_dist(
rank=self.global_rank,
world_size=self.world_size,
backend="nccl",
host=self.cluster_environment.main_address,
port=self.cluster_environment.main_port,
)
gpc.set_device(self.local_rank)
def model_sharded_context(self) -> "ColoInitContext":
"""Provide hook to create modules in a distributed aware context. This is useful for when we'd like to
shard the model instantly, which is useful for extremely large models which can save memory and
initialization time.
Returns: Model parallel context.
"""
with _patch_cuda_is_available():
from colossalai.utils.model.colo_init_context import ColoInitContext
class ModelShardedContext(ColoInitContext):
def _post_init_method(self, module: torch.nn.Module, *args: Any, **kwargs: Any) -> None:
if getattr(module, "_colossalai_module", False) is True:
return
super()._post_init_method(module, *args, **kwargs)
module._colossalai_module = True # type: ignore[assignment]
return ModelShardedContext()
def setup_precision_plugin(self) -> None:
with _patch_cuda_is_available():
from colossalai.gemini import ChunkManager, GeminiManager
from colossalai.nn.optimizer import CPUAdam, HybridAdam
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ProcessGroup
from colossalai.zero import ZeroOptimizer
super().setup_precision_plugin()
assert self.lightning_module is not None
is_training = self.lightning_module.trainer and self.lightning_module.trainer.training
if is_training:
if len(self.optimizers) > 1:
raise ValueError("`ColossalAIStrategy` only supports single Optimizer now.")
optimizer = self.optimizers[0]
if not isinstance(optimizer, (CPUAdam, HybridAdam)):
raise ValueError(
"`ColossalAIStrategy` only supports `colossalai.nn.optimizer.CPUAdam` "
"and `colossalai.nn.optimizer.HybridAdam` as its optimizer."
)
assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
pl_module = self.model
process_group = ProcessGroup()
if not hasattr(pl_module, "_colossalai_zero"):
if self.use_chunk:
chunk_size = self.chunk_size or ChunkManager.search_chunk_size(
self.model, **self.chunk_size_search_kwargs
)
else:
chunk_size = None
chunk_manager = ChunkManager(
chunk_size,
process_group,
self.enable_distributed_storage,
GeminiManager.get_default_device(self.placement_policy),
)
gemini_manager = GeminiManager(self.placement_policy, chunk_manager)
model = _LightningModuleWrapperBase(self.model)
self.model = ZeroDDP(model, gemini_manager, self.force_outputs_fp32)
assert self.model is not None
pl_module._colossalai_zero = [self.model] # type: ignore[assignment]
else:
self.model = pl_module._colossalai_zero[0] # type: ignore[index, assignment]
if is_training:
self.optimizers = [
ZeroOptimizer(optimizer, self.model, gpu_margin_mem_ratio=self.gpu_margin_mem_ratio, **self.amp_kwargs)
]
def setup(self, trainer: "pl.Trainer") -> None:
precision = self.precision_plugin.precision
if not (precision == PrecisionType.HALF):
raise ValueError(
f"`Trainer(strategy='colossalai', precision={precision!r})` is not supported."
" Consider setting `precision=16`."
)
if not isinstance(self.accelerator, CUDAAccelerator):
raise ValueError(
"`ColossalAIStrategy` is only supported on `CUDAAccelerator`, "
f"but `{self.accelerator.__class__.__name__}` is used."
)
if trainer.state.fn == TrainerFn.FITTING:
if is_overridden("backward", trainer.lightning_module):
rank_zero_warn(
"You have overridden the `LightningModule.backward` hook"
" but it will be ignored since ColossalAI handles"
" the backward logic internally."
)
if trainer.accumulate_grad_batches > 1:
raise ValueError(
"ColossalAI does not support gradient accumulation now. Please set `accumulate_grad_batches` to 1."
)
accumulation_scheduler = trainer.accumulation_scheduler
if accumulation_scheduler.epochs != [0]:
raise ValueError(
"ColossalAI currently does not support different `accumulate_grad_batches` at different epochs."
)
if not isinstance(self.precision_plugin, ColossalAIPrecisionPlugin):
raise ValueError("`ColossalAIStrategy` is only compatible with `ColossalAIPrecisionPlugin`.")
self.accelerator.setup(trainer)
assert self.lightning_module is not None
self.lightning_module._device = self.root_device
self.setup_optimizers(trainer)
self.setup_precision_plugin()
self.model_to_device()
def model_to_device(self) -> None:
assert self.lightning_module is not None
pl_module = self.lightning_module
for child in pl_module.modules():
if child is not pl_module and not getattr(child, "_colossalai_module", False):
child.to(self.root_device)
def teardown(self) -> None:
optimizers = self.optimizers
self.optimizers = list()
zero_model = self.model
self.model = None
pl_module = self._lightning_module
self._lightning_module = None
super().teardown()
self.optimizers = optimizers
self.model = zero_model
self._lightning_module = pl_module
def optimizer_step(
self,
optimizer: Optimizer,
opt_idx: int,
closure: Callable[[], Any],
model: Optional[Union["pl.LightningModule", Module]] = None,
**kwargs: Any,
) -> Any:
model = model or self.lightning_module
# TODO(lite): remove assertion once strategy's optimizer_step typing is fixed
assert isinstance(model, pl.LightningModule)
return self.precision_plugin.optimizer_step(
optimizer, model=model, optimizer_idx=opt_idx, closure=closure, **kwargs
)
def lightning_module_state_dict(self, rank_zero_only: bool = False) -> Dict[str, Any]:
"""Returns a dictionary containing a whole state of the module. But all the tensors in the dictionary are
detached from their parameters and located in cpu memory.
Args:
rank_zero_only: If True, only process rank 0 gets the correct dictionary.
Otherwise, all processes get the same dictionary.
"""
with _patch_cuda_is_available():
from colossalai.nn.parallel import ZeroDDP
assert isinstance(self.model, ZeroDDP)
org_dict = self.model.state_dict(only_rank_0=rank_zero_only)
children = list(self.model.named_children())
assert len(children) == 1
prefix, child = children[0]
prefix += "."
assert child is self.lightning_module
mapping_dict = dict()
for key in org_dict.keys():
mapping_dict[key] = key.replace(prefix, "") # remove "_forward_module." from the key
return {mapping_dict[key]: value for key, value in org_dict.items()}
def load_model_state_dict(self, checkpoint: Mapping[str, Any]) -> None:
orig_dict = checkpoint["state_dict"]
assert self.model is not None
children = list(self.model.named_children())
assert len(children) == 1
prefix, child = children[0]
prefix += "."
assert child is self.lightning_module
mapping_dict = dict()
for key in orig_dict.keys():
mapping_dict[key] = prefix + key # add "_forward_module." to the key
load_dict = OrderedDict({mapping_dict[key]: value for key, value in orig_dict.items()})
self.model.load_state_dict(load_dict)
def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
assert self.model is not None
with self.precision_plugin.val_step_context():
return self.model(*args, **kwargs)
def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]:
assert self.model is not None
with self.precision_plugin.test_step_context():
return self.model(*args, **kwargs)
def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
assert self.model is not None
with self.precision_plugin.predict_step_context():
return self.model(*args, **kwargs)
@classmethod
def register_strategies(cls, strategy_registry: Dict) -> None:
strategy_registry.register("colossalai", cls, description="Default ColossalAI Strategy")
def reduce(
self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = ReduceOp.SUM
) -> Tensor:
with _patch_cuda_is_available():
from colossalai.communication.collective import reduce
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
if not isinstance(tensor, Tensor):
return tensor
if isinstance(reduce_op, str):
if reduce_op.lower() in ("avg", "mean"):
reduce_op = ReduceOp.SUM
div_factor = gpc.get_world_size(parallel_mode=ParallelMode.GLOBAL)
with torch.no_grad():
tensor = tensor / div_factor
else:
reduce_op = getattr(ReduceOp, reduce_op.upper())
tensor = reduce(tensor, dst=0, parallel_mode=ParallelMode.GLOBAL, op=reduce_op)
return tensor
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
"""Broadcasts an object to all processes.
Args:
obj: the object to broadcast
src: source rank
"""
with _patch_cuda_is_available():
from colossalai.communication.collective import broadcast
from colossalai.context import ParallelMode
return broadcast(obj, src=src, parallel_mode=ParallelMode.GLOBAL)
def all_gather(self, tensor: Tensor, group: Optional[Any] = None, sync_grads: bool = False) -> Tensor:
"""Perform a all_gather on all processes."""
with _patch_cuda_is_available():
from colossalai.communication.collective import all_gather
from colossalai.context import ParallelMode
assert sync_grads is False
return all_gather(tensor, dim=0, parallel_mode=ParallelMode.GLOBAL)

View File

@ -41,6 +41,7 @@ from pytorch_lightning.accelerators.tpu import TPUAccelerator
from pytorch_lightning.plugins import (
ApexMixedPrecisionPlugin,
CheckpointIO,
ColossalAIPrecisionPlugin,
DeepSpeedPrecisionPlugin,
DoublePrecisionPlugin,
FullyShardedNativeMixedPrecisionPlugin,
@ -57,6 +58,7 @@ from pytorch_lightning.plugins.environments import BaguaEnvironment
from pytorch_lightning.plugins.layer_sync import LayerSync, NativeSyncBatchNorm
from pytorch_lightning.plugins.precision.fsdp_native_native_amp import FullyShardedNativeNativeMixedPrecisionPlugin
from pytorch_lightning.strategies import (
ColossalAIStrategy,
DDPFullyShardedNativeStrategy,
DDPFullyShardedStrategy,
DDPShardedStrategy,
@ -687,6 +689,9 @@ class AcceleratorConnector:
" is not supported with TPUs. Using `precision='bf16'` instead."
)
return TPUBf16PrecisionPlugin()
if isinstance(self.strategy, ColossalAIStrategy):
return ColossalAIPrecisionPlugin(self._precision_flag)
if isinstance(self.strategy, DeepSpeedStrategy):
return DeepSpeedPrecisionPlugin(self._precision_flag, self._amp_type_flag, self._amp_level_flag)

View File

@ -26,6 +26,7 @@ from pytorch_lightning.accelerators.mps import MPSAccelerator
from pytorch_lightning.accelerators.tpu import TPUAccelerator
from pytorch_lightning.callbacks.progress.rich_progress import _RICH_AVAILABLE
from pytorch_lightning.strategies.bagua import _BAGUA_AVAILABLE
from pytorch_lightning.strategies.colossalai import _COLOSSALAI_AVAILABLE
from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE
from pytorch_lightning.utilities.imports import (
_APEX_AVAILABLE,
@ -86,6 +87,7 @@ class RunIf:
omegaconf: bool = False,
slow: bool = False,
bagua: bool = False,
colossalai: bool = False,
psutil: bool = False,
hivemind: bool = False,
**kwargs,
@ -242,6 +244,10 @@ class RunIf:
conditions.append(not _BAGUA_AVAILABLE or sys.platform in ("win32", "darwin"))
reasons.append("Bagua")
if colossalai:
conditions.append(not _COLOSSALAI_AVAILABLE)
reasons.append("ColossalAI")
if psutil:
conditions.append(not _PSUTIL_AVAILABLE)
reasons.append("psutil")

View File

@ -0,0 +1,289 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pytest
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from torch.optim import Optimizer
from torchmetrics import Accuracy
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.demos.boring_classes import BoringModel
from pytorch_lightning.plugins.precision import ColossalAIPrecisionPlugin
from pytorch_lightning.strategies import ColossalAIStrategy
from pytorch_lightning.strategies.colossalai import _COLOSSALAI_AVAILABLE
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests_pytorch.helpers.datamodules import ClassifDataModule
from tests_pytorch.helpers.runif import RunIf
if _COLOSSALAI_AVAILABLE:
from colossalai.nn.optimizer import HybridAdam
def test_invalid_colosalai(monkeypatch):
import pytorch_lightning.strategies.colossalai as colossal_strategy
monkeypatch.setattr(colossal_strategy, "_COLOSSALAI_AVAILABLE", False)
with pytest.raises(
MisconfigurationException,
match="To use the `ColossalAIStrategy`, please install `colossalai` first. "
"Download `colossalai` by consulting `https://colossalai.org/download`.",
):
ColossalAIStrategy()
@RunIf(colossalai=True)
def test_colossalai_strategy_with_trainer_by_instance():
trainer = Trainer(precision=16, strategy=ColossalAIStrategy())
assert isinstance(trainer.strategy, ColossalAIStrategy)
assert isinstance(trainer.strategy.precision_plugin, ColossalAIPrecisionPlugin)
@RunIf(colossalai=True)
def test_colossalai_strategy_with_trainer_by_string():
trainer = Trainer(precision=16, strategy="colossalai")
assert isinstance(trainer.strategy, ColossalAIStrategy)
assert isinstance(trainer.strategy.precision_plugin, ColossalAIPrecisionPlugin)
class ModelParallelBoringModel(BoringModel):
def __init__(self):
super().__init__()
self.layer = None
def configure_sharded_model(self) -> None:
self.layer = torch.nn.Linear(32, 2)
def configure_optimizers(self):
optimizer = HybridAdam(self.layer.parameters(), lr=1e-3)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]
class ModelParallelBoringModelNoSchedulers(ModelParallelBoringModel):
def configure_optimizers(self):
return HybridAdam(self.layer.parameters(), lr=1e-3)
@RunIf(min_cuda_gpus=1, colossalai=True)
def test_gradient_clip_algorithm_error(tmpdir):
model = ModelParallelBoringModel()
trainer = Trainer(
fast_dev_run=True,
default_root_dir=tmpdir,
accelerator="gpu",
devices=1,
precision=16,
strategy="colossalai",
enable_progress_bar=False,
enable_model_summary=False,
gradient_clip_val=1.0,
gradient_clip_algorithm="value",
)
with pytest.raises(NotImplementedError, match="`clip_grad_by_value` is not supported by `ColossalAI`"):
trainer.fit(model)
@RunIf(min_cuda_gpus=1, colossalai=True)
def test_gradient_accumulation_error(tmpdir):
model = ModelParallelBoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
accelerator="gpu",
devices=1,
precision=16,
strategy="colossalai",
max_epochs=1,
accumulate_grad_batches={0: 1, 4: 2, 8: 3},
)
with pytest.raises(
ValueError,
match="ColossalAI currently does not support different `accumulate_grad_batches` at different epochs.",
):
trainer.fit(model)
@RunIf(min_cuda_gpus=1, colossalai=True)
def test_colossalai_optimizer(tmpdir):
model = BoringModel()
trainer = Trainer(
fast_dev_run=True,
default_root_dir=tmpdir,
accelerator="gpu",
devices=1,
precision=16,
strategy="colossalai",
enable_progress_bar=False,
enable_model_summary=False,
)
with pytest.raises(
ValueError,
match="`ColossalAIStrategy` only supports `colossalai.nn.optimizer.CPUAdam` "
"and `colossalai.nn.optimizer.HybridAdam` as its optimizer.",
):
trainer.fit(model)
@RunIf(min_cuda_gpus=1, standalone=True, colossalai=True)
def test_warn_colossalai_ignored(tmpdir):
class TestModel(ModelParallelBoringModel):
def backward(self, loss: Tensor, optimizer: Optimizer, optimizer_idx: int, *args, **kwargs) -> None:
return loss.backward()
model = TestModel()
trainer = Trainer(
fast_dev_run=True,
default_root_dir=tmpdir,
accelerator="gpu",
devices=1,
precision=16,
strategy="colossalai",
track_grad_norm=2,
enable_progress_bar=False,
enable_model_summary=False,
)
from pytorch_lightning.plugins.precision.colossalai import warning_cache
with pytest.warns(UserWarning, match="will be ignored since ColossalAI handles the backward"):
trainer.fit(model)
assert any("track_grad_norm=2.0)' but this is not supported" in w for w in warning_cache)
def _assert_save_model_is_equal(model, tmpdir, trainer):
checkpoint_path = os.path.join(tmpdir, "model.pt")
checkpoint_path = trainer.strategy.broadcast(checkpoint_path)
trainer.save_checkpoint(checkpoint_path)
trainer.strategy.barrier()
# carry out the check only on rank 0
if trainer.is_global_zero:
state_dict = torch.load(checkpoint_path)
# Assert model parameters are identical after loading
for orig_param, saved_model_param in zip(model.parameters(), state_dict.values()):
saved_model_param = saved_model_param.to(dtype=orig_param.dtype, device=orig_param.device)
assert torch.equal(orig_param, saved_model_param)
class ModelParallelClassificationModel(LightningModule):
def __init__(self, lr=0.01):
super().__init__()
self.lr = lr
self.layers = None
self.train_acc = Accuracy()
self.valid_acc = Accuracy()
self.test_acc = Accuracy()
def build_layers(self) -> nn.Module:
layers = []
for _ in range(3):
layers.append(nn.Linear(32, 32))
layers.append(nn.ReLU())
layers.append(nn.Linear(32, 3))
return nn.Sequential(*layers)
def configure_sharded_model(self) -> None:
if self.layers is None:
self.layers = self.build_layers()
def forward(self, x):
x = self.layers(x)
logits = F.softmax(x, dim=1)
return logits
def configure_optimizers(self):
optimizer = HybridAdam(self.parameters(), lr=self.lr)
return [optimizer], []
def training_step(self, batch, batch_idx):
x, y = batch
logits = self.forward(x)
loss = F.cross_entropy(logits, y)
self.log("train_loss", loss, prog_bar=True, sync_dist=True)
self.log("train_acc", self.train_acc(logits, y), prog_bar=True, sync_dist=True)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self.forward(x)
self.log("val_loss", F.cross_entropy(logits, y), prog_bar=False, sync_dist=True)
self.log("val_acc", self.valid_acc(logits, y), prog_bar=True, sync_dist=True)
def test_step(self, batch, batch_idx):
x, y = batch
logits = self.forward(x)
self.log("test_loss", F.cross_entropy(logits, y), prog_bar=False, sync_dist=True)
self.log("test_acc", self.test_acc(logits, y), prog_bar=True, sync_dist=True)
def predict_step(self, batch, batch_idx):
x, _ = batch
return self.forward(x)
@RunIf(min_cuda_gpus=2, standalone=True, colossalai=True)
def test_multi_gpu_checkpointing(tmpdir):
dm = ClassifDataModule()
model = ModelParallelClassificationModel()
ck = ModelCheckpoint(monitor="val_acc", mode="max", save_last=True, save_top_k=-1)
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
accelerator="gpu",
devices=2,
precision=16,
strategy="colossalai",
callbacks=[ck],
)
trainer.fit(model, datamodule=dm)
results = trainer.test(datamodule=dm)
saved_results = trainer.test(ckpt_path=ck.best_model_path, datamodule=dm)
assert saved_results == results
# here, we test whether restore_checkpoint_after_setup is worked
model = ModelParallelClassificationModel()
trainer = Trainer(default_root_dir=tmpdir, accelerator="gpu", devices=2, precision=16, strategy="colossalai")
saved_results = trainer.test(model, datamodule=dm, ckpt_path=ck.best_model_path)
assert saved_results == results
@RunIf(min_cuda_gpus=2, standalone=True, colossalai=True)
def test_multi_gpu_model_colossalai_fit_test(tmpdir):
dm = ClassifDataModule()
model = ModelParallelClassificationModel()
trainer = Trainer(
default_root_dir=tmpdir,
accelerator="gpu",
devices=2,
precision=16,
strategy=ColossalAIStrategy(initial_scale=32),
max_epochs=1,
)
trainer.fit(model, datamodule=dm)
out_metrics = trainer.callback_metrics
assert out_metrics["train_acc"] > 0.7
assert out_metrics["val_acc"] > 0.7
result = trainer.test(model, datamodule=dm)
for out in result:
assert out["test_acc"] > 0.7