lightning/tests/strategies/test_bagua_strategy.py

122 lines
4.1 KiB
Python

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
import pytest
import torch
from pytorch_lightning import Trainer
from pytorch_lightning.strategies import BaguaStrategy
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf
class BoringModel4QAdam(BoringModel):
def configure_optimizers(self):
from bagua.torch_api.algorithms.q_adam import QAdamOptimizer
optimizer = QAdamOptimizer(self.layer.parameters(), lr=0.05, warmup_steps=20)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
return [optimizer], [lr_scheduler]
@RunIf(bagua=True, min_gpus=1)
def test_bagua_default(tmpdir):
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=1,
strategy="bagua",
accelerator="gpu",
devices=1,
)
assert isinstance(trainer.strategy, BaguaStrategy)
@RunIf(bagua=True, min_gpus=2, standalone=True)
def test_async_algorithm(tmpdir):
model = BoringModel()
bagua_strategy = BaguaStrategy(algorithm="async")
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=1,
strategy=bagua_strategy,
accelerator="gpu",
devices=2,
)
trainer.fit(model)
for param in model.parameters():
assert torch.norm(param) < 3
@RunIf(bagua=True, min_gpus=1)
@pytest.mark.parametrize(
"algorithm", ["gradient_allreduce", "bytegrad", "qadam", "decentralized", "low_precision_decentralized"]
)
def test_configuration(algorithm, tmpdir):
model = BoringModel()
bagua_strategy = BaguaStrategy(algorithm=algorithm)
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=1,
strategy=bagua_strategy,
accelerator="gpu",
devices=1,
)
trainer.state.fn = TrainerFn.FITTING
trainer.strategy.connect(model)
trainer.lightning_module.trainer = trainer
with mock.patch(
"bagua.torch_api.data_parallel.bagua_distributed.BaguaDistributedDataParallel.__init__", return_value=None
), mock.patch("bagua.torch_api.communication.is_initialized", return_value=True):
if algorithm == "qadam":
with pytest.raises(MisconfigurationException, match="Bagua QAdam can only accept one QAdamOptimizer"):
trainer.strategy.configure_ddp()
else:
trainer.strategy.configure_ddp()
@RunIf(bagua=True, min_gpus=1)
def test_qadam_configuration(tmpdir):
model = BoringModel4QAdam()
bagua_strategy = BaguaStrategy(algorithm="qadam")
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=1,
strategy=bagua_strategy,
accelerator="gpu",
devices=1,
)
trainer.state.fn = TrainerFn.FITTING
trainer.strategy.connect(model)
trainer.lightning_module.trainer = trainer
trainer.strategy.setup_optimizers(trainer)
with mock.patch(
"bagua.torch_api.data_parallel.bagua_distributed.BaguaDistributedDataParallel.__init__", return_value=None
), mock.patch("bagua.torch_api.communication.is_initialized", return_value=True):
trainer.strategy.configure_ddp()
def test_bagua_not_available(monkeypatch):
import pytorch_lightning.strategies.bagua as imports
monkeypatch.setattr(imports, "_BAGUA_AVAILABLE", False)
with mock.patch("torch.cuda.device_count", return_value=1):
with pytest.raises(MisconfigurationException, match="you must have `Bagua` installed"):
Trainer(strategy="bagua", gpus=1)