122 lines
4.1 KiB
Python
122 lines
4.1 KiB
Python
# Copyright The PyTorch Lightning team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from unittest import mock
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from pytorch_lightning import Trainer
|
|
from pytorch_lightning.strategies import BaguaStrategy
|
|
from pytorch_lightning.trainer.states import TrainerFn
|
|
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
|
from tests.helpers.boring_model import BoringModel
|
|
from tests.helpers.runif import RunIf
|
|
|
|
|
|
class BoringModel4QAdam(BoringModel):
|
|
def configure_optimizers(self):
|
|
from bagua.torch_api.algorithms.q_adam import QAdamOptimizer
|
|
|
|
optimizer = QAdamOptimizer(self.layer.parameters(), lr=0.05, warmup_steps=20)
|
|
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
|
|
return [optimizer], [lr_scheduler]
|
|
|
|
|
|
@RunIf(bagua=True, min_gpus=1)
|
|
def test_bagua_default(tmpdir):
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
fast_dev_run=1,
|
|
strategy="bagua",
|
|
accelerator="gpu",
|
|
devices=1,
|
|
)
|
|
assert isinstance(trainer.strategy, BaguaStrategy)
|
|
|
|
|
|
@RunIf(bagua=True, min_gpus=2, standalone=True)
|
|
def test_async_algorithm(tmpdir):
|
|
model = BoringModel()
|
|
bagua_strategy = BaguaStrategy(algorithm="async")
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
fast_dev_run=1,
|
|
strategy=bagua_strategy,
|
|
accelerator="gpu",
|
|
devices=2,
|
|
)
|
|
trainer.fit(model)
|
|
|
|
for param in model.parameters():
|
|
assert torch.norm(param) < 3
|
|
|
|
|
|
@RunIf(bagua=True, min_gpus=1)
|
|
@pytest.mark.parametrize(
|
|
"algorithm", ["gradient_allreduce", "bytegrad", "qadam", "decentralized", "low_precision_decentralized"]
|
|
)
|
|
def test_configuration(algorithm, tmpdir):
|
|
model = BoringModel()
|
|
bagua_strategy = BaguaStrategy(algorithm=algorithm)
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
fast_dev_run=1,
|
|
strategy=bagua_strategy,
|
|
accelerator="gpu",
|
|
devices=1,
|
|
)
|
|
trainer.state.fn = TrainerFn.FITTING
|
|
trainer.strategy.connect(model)
|
|
trainer.lightning_module.trainer = trainer
|
|
|
|
with mock.patch(
|
|
"bagua.torch_api.data_parallel.bagua_distributed.BaguaDistributedDataParallel.__init__", return_value=None
|
|
), mock.patch("bagua.torch_api.communication.is_initialized", return_value=True):
|
|
if algorithm == "qadam":
|
|
with pytest.raises(MisconfigurationException, match="Bagua QAdam can only accept one QAdamOptimizer"):
|
|
trainer.strategy.configure_ddp()
|
|
else:
|
|
trainer.strategy.configure_ddp()
|
|
|
|
|
|
@RunIf(bagua=True, min_gpus=1)
|
|
def test_qadam_configuration(tmpdir):
|
|
model = BoringModel4QAdam()
|
|
bagua_strategy = BaguaStrategy(algorithm="qadam")
|
|
trainer = Trainer(
|
|
default_root_dir=tmpdir,
|
|
fast_dev_run=1,
|
|
strategy=bagua_strategy,
|
|
accelerator="gpu",
|
|
devices=1,
|
|
)
|
|
trainer.state.fn = TrainerFn.FITTING
|
|
trainer.strategy.connect(model)
|
|
trainer.lightning_module.trainer = trainer
|
|
trainer.strategy.setup_optimizers(trainer)
|
|
|
|
with mock.patch(
|
|
"bagua.torch_api.data_parallel.bagua_distributed.BaguaDistributedDataParallel.__init__", return_value=None
|
|
), mock.patch("bagua.torch_api.communication.is_initialized", return_value=True):
|
|
trainer.strategy.configure_ddp()
|
|
|
|
|
|
def test_bagua_not_available(monkeypatch):
|
|
import pytorch_lightning.strategies.bagua as imports
|
|
|
|
monkeypatch.setattr(imports, "_BAGUA_AVAILABLE", False)
|
|
with mock.patch("torch.cuda.device_count", return_value=1):
|
|
with pytest.raises(MisconfigurationException, match="you must have `Bagua` installed"):
|
|
Trainer(strategy="bagua", gpus=1)
|