lightning/tests/tests_pytorch/strategies/test_sharded_strategy.py

377 lines
14 KiB
Python

import os
from copy import deepcopy
from typing import Mapping
from unittest import mock
from unittest.mock import Mock
import pytest
import torch
from torch import Tensor
from lightning_fabric.strategies.fairscale import _FAIRSCALE_AVAILABLE
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.demos.boring_classes import BoringModel, ManualOptimBoringModel
from pytorch_lightning.plugins import MixedPrecisionPlugin
from pytorch_lightning.strategies import DDPShardedStrategy, DDPSpawnShardedStrategy
from pytorch_lightning.trainer.states import TrainerFn
from tests_pytorch.helpers.runif import RunIf
if _FAIRSCALE_AVAILABLE:
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
from fairscale.optim import OSS
class ModelWithAdamOptimizer(BoringModel):
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.layer.parameters(), lr=0.1)
return optimizer
class CheckModelRestore(ModelWithAdamOptimizer):
def __init__(self, old_model_state_dict, old_optimizer_states):
super().__init__()
self.old_model_state_dict = old_model_state_dict
self.old_optimizer_states = old_optimizer_states
def on_train_start(self):
assert all(
self._is_equal(actual, expected) for actual, expected in zip(self.state_dict(), self.old_model_state_dict)
)
for optimizer, state in zip(self.trainer.optimizers, self.old_optimizer_states):
optimizer_state = self.trainer.strategy.optimizer_state(optimizer)
self._is_equal(optimizer_state, state)
def _is_equal(self, a, b):
if isinstance(a, Tensor):
return torch.allclose(a, b)
if isinstance(a, Mapping):
return all(self._is_equal(a.get(k, None), b.get(k, None)) for k in b.keys())
return a == b
@pytest.mark.parametrize("clip_val", [0, 10])
@RunIf(min_cuda_gpus=1, fairscale=True)
@mock.patch("fairscale.optim.oss.OSS.clip_grad_norm")
def test_ddp_sharded_precision_16_clip_gradients(mock_oss_clip_grad_norm, clip_val, tmpdir):
"""Ensure that clip gradients is only called if the value is greater than 0."""
model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
strategy="ddp_sharded",
accelerator="gpu",
devices=1,
precision=16,
fast_dev_run=True,
gradient_clip_val=clip_val,
)
trainer.fit(model)
if clip_val > 0:
mock_oss_clip_grad_norm.assert_called()
else:
mock_oss_clip_grad_norm.assert_not_called()
@RunIf(fairscale=True)
@pytest.mark.parametrize(
"strategy,expected", [("ddp_sharded", DDPShardedStrategy), ("ddp_sharded_spawn", DDPSpawnShardedStrategy)]
)
def test_sharded_ddp_choice(strategy, expected):
"""Test to ensure that strategy is correctly chosen."""
trainer = Trainer(fast_dev_run=True, strategy=strategy)
assert isinstance(trainer.strategy, expected)
@RunIf(min_cuda_gpus=1, fairscale=True)
@pytest.mark.parametrize(
"strategy,expected", [("ddp_sharded", DDPShardedStrategy), ("ddp_sharded_spawn", DDPSpawnShardedStrategy)]
)
def test_ddp_choice_sharded_amp(strategy, expected):
"""Test to ensure that plugin native amp plugin is correctly chosen when using sharded."""
trainer = Trainer(fast_dev_run=True, accelerator="gpu", devices=1, precision=16, strategy=strategy)
assert isinstance(trainer.strategy, expected)
assert isinstance(trainer.precision_plugin, MixedPrecisionPlugin)
@RunIf(fairscale=True)
def test_ddp_sharded_strategy_checkpoint_cpu(tmpdir):
"""Test to ensure that checkpoint is saved correctly."""
model = BoringModel()
trainer = Trainer(strategy="ddp_sharded_spawn", accelerator="cpu", devices=2, fast_dev_run=True)
trainer.fit(model)
checkpoint_path = os.path.join(tmpdir, "model.pt")
trainer.save_checkpoint(checkpoint_path)
saved_model = BoringModel.load_from_checkpoint(checkpoint_path)
# Assert model parameters are identical after loading
for trained_param, loaded_param in zip(model.parameters(), saved_model.parameters()):
assert torch.equal(trained_param.to("cpu"), loaded_param)
@RunIf(min_cuda_gpus=2, fairscale=True)
def test_ddp_sharded_strategy_checkpoint_multi_gpu(tmpdir):
"""Test to ensure that checkpoint is saved correctly when using multiple GPUs."""
model = BoringModel()
trainer = Trainer(accelerator="gpu", devices=2, strategy="ddp_sharded_spawn", fast_dev_run=True)
trainer.fit(model)
checkpoint_path = os.path.join(tmpdir, "model.pt")
trainer.save_checkpoint(checkpoint_path)
saved_model = BoringModel.load_from_checkpoint(checkpoint_path)
# Assert model parameters are identical after loading
for trained_param, loaded_param in zip(model.parameters(), saved_model.parameters()):
assert torch.equal(trained_param.to("cpu"), loaded_param)
@RunIf(min_cuda_gpus=2, fairscale=True)
def test_ddp_sharded_strategy_finetune(tmpdir):
"""Test to ensure that we can save and restart training (simulate fine-tuning)"""
model = BoringModel()
trainer = Trainer(accelerator="gpu", devices=2, strategy="ddp_sharded_spawn", fast_dev_run=True)
trainer.fit(model)
checkpoint_path = os.path.join(tmpdir, "model.pt")
trainer.save_checkpoint(checkpoint_path)
saved_model = BoringModel.load_from_checkpoint(checkpoint_path)
trainer = Trainer(fast_dev_run=True)
trainer.fit(saved_model)
@RunIf(fairscale=True)
def test_ddp_sharded_strategy_fit_ckpt_path(tmpdir):
"""Test to ensure that resuming from checkpoint works."""
model = BoringModel()
trainer = Trainer(strategy="ddp_sharded_spawn", accelerator="cpu", devices=2, fast_dev_run=True)
trainer.fit(model)
checkpoint_path = os.path.join(tmpdir, "model.pt")
trainer.save_checkpoint(checkpoint_path)
model = BoringModel()
trainer = Trainer(strategy="ddp_sharded_spawn", accelerator="cpu", devices=2, fast_dev_run=True)
trainer.fit(model, ckpt_path=checkpoint_path)
@RunIf(min_cuda_gpus=1, fairscale=True)
def test_ddp_sharded_strategy_fit_ckpt_path_gpu_to_cpu(tmpdir):
"""Test to ensure that resuming from checkpoint works when going from GPUs- > CPU."""
model = BoringModel()
trainer = Trainer(strategy="ddp_sharded_spawn", accelerator="gpu", devices=1, fast_dev_run=True)
trainer.fit(model)
checkpoint_path = os.path.join(tmpdir, "model.pt")
trainer.save_checkpoint(checkpoint_path)
model = BoringModel()
trainer = Trainer(strategy="ddp_sharded_spawn", accelerator="cpu", devices=2, fast_dev_run=True)
trainer.fit(model, ckpt_path=checkpoint_path)
@RunIf(standalone=True, fairscale=True)
@pytest.mark.parametrize(
"trainer_kwargs",
(
dict(accelerator="cpu", devices=2),
pytest.param(dict(accelerator="gpu", devices=2), marks=RunIf(min_cuda_gpus=2)),
),
)
def test_ddp_sharded_strategy_test_multigpu(trainer_kwargs):
"""Test to ensure we can use validate and test without fit."""
model = BoringModel()
trainer = Trainer(
strategy="ddp_sharded_spawn",
fast_dev_run=True,
enable_progress_bar=False,
enable_model_summary=False,
**trainer_kwargs,
)
trainer.validate(model)
trainer.test(model)
@RunIf(min_cuda_gpus=2, standalone=True, fairscale=True)
@pytest.mark.parametrize("strategy", ("ddp_sharded", "ddp_sharded_spawn"))
def test_ddp_sharded_strategy_manual_optimization(tmpdir, strategy):
model = ManualOptimBoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
strategy=strategy,
fast_dev_run=2,
accelerator="gpu",
devices=2,
enable_progress_bar=False,
enable_model_summary=False,
)
trainer.fit(model)
class BoringModelSharded(BoringModel):
def on_train_start(self) -> None:
"""Check if trainer module is wrapped as ShardedDataParallel during training stage."""
assert isinstance(self.trainer.model, ShardedDataParallel)
def on_test_start(self) -> None:
"""Check if trainer module remains as LightningModule during test stage."""
assert isinstance(self.trainer.model, LightningModule)
def on_validation_start(self) -> None:
"""Check if trainer module remains as LightningModule during test stage."""
if self.trainer.state.fn == TrainerFn.FITTING:
assert isinstance(self.trainer.model, ShardedDataParallel)
else:
assert isinstance(self.trainer.model, LightningModule)
def on_predict_start(self) -> None:
"""Check if trainer module remains as LightningModule during prediction stage."""
assert isinstance(self.trainer.model, LightningModule)
@RunIf(fairscale=True)
def test_configure_ddp(tmpdir):
"""Tests with ddp sharded strategy."""
trainer = Trainer(default_root_dir=tmpdir, strategy="ddp_sharded", fast_dev_run=True)
model = BoringModelSharded()
trainer.fit(model)
trainer.test(model, dataloaders=model.test_dataloader())
trainer.validate(model, dataloaders=model.val_dataloader())
trainer.predict(model, dataloaders=model.predict_dataloader())
@RunIf(fairscale=True)
@mock.patch("pytorch_lightning.strategies.DDPShardedStrategy._wrap_optimizers", autospec=True)
@pytest.mark.parametrize("cls", [DDPShardedStrategy, DDPSpawnShardedStrategy])
def test_custom_kwargs_sharded(_, cls):
"""Tests to ensure that if custom kwargs are passed, they are set correctly."""
strategy = cls(reduce_fp16=True)
strategy._lightning_module = Mock(spec=LightningModule)
strategy._lightning_module.trainer = Mock()
strategy.parallel_devices = [Mock()]
class_name = "sharded" if isinstance(strategy, DDPShardedStrategy) else "sharded_spawn"
with mock.patch(f"pytorch_lightning.strategies.{class_name}.ShardedDataParallel", autospec=True) as mock_sharded:
strategy.configure_ddp()
args, kwargs = mock_sharded.call_args
assert "reduce_fp16" in kwargs
assert kwargs["reduce_fp16"]
@RunIf(fairscale=True)
@mock.patch("pytorch_lightning.strategies.DDPShardedStrategy._wrap_optimizers", autospec=True)
@pytest.mark.parametrize(["params", "expected_buffer_size"], [(dict(), 0), (dict(reduce_buffer_size=128), 128)])
@pytest.mark.parametrize("num_nodes", [1, 2])
def test_custom_kwargs_sharded_reduce_buffer_size(_, params, expected_buffer_size, num_nodes):
"""Tests to ensure that ``reduce_buffer_size`` is correctly set based on user kwargs."""
strategy = DDPShardedStrategy(**params)
strategy.num_nodes = num_nodes
strategy._lightning_module = Mock(spec=LightningModule)
strategy._lightning_module.trainer = Mock()
strategy.parallel_devices = [Mock()]
with mock.patch("pytorch_lightning.strategies.sharded.ShardedDataParallel", autospec=True) as mock_sharded:
strategy.configure_ddp()
args, kwargs = mock_sharded.call_args
assert "reduce_buffer_size" in kwargs
if num_nodes > 1 and len(params) == 0:
# If user has not specified a buffer size and we're using multiple nodes, check to see if default is set
assert kwargs["reduce_buffer_size"] == DDPShardedStrategy._REDUCE_BUFFER_SIZE_DEFAULT
else:
assert kwargs["reduce_buffer_size"] == expected_buffer_size
@RunIf(fairscale=True)
def test_block_backward_sync():
strategy = DDPShardedStrategy()
model = mock.MagicMock(spec=ShardedDataParallel)
with mock.patch.object(strategy, "_model", model):
with strategy.block_backward_sync():
pass
model.no_sync.assert_called_once()
@pytest.mark.parametrize(
"strategy_name,expected_ddp_kwargs",
[
("ddp_sharded", {}),
("ddp_sharded_find_unused_parameters_false", {"find_unused_parameters": False}),
("ddp_sharded_spawn", {}),
("ddp_sharded_spawn_find_unused_parameters_false", {"find_unused_parameters": False}),
],
)
def test_ddp_kwargs_from_registry(strategy_name, expected_ddp_kwargs):
trainer = Trainer(strategy=strategy_name)
assert trainer.strategy._ddp_kwargs == expected_ddp_kwargs
class BoringFairScaleOptimizerModel(BoringModel):
def configure_optimizers(self):
base_optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return OSS(params=base_optimizer.param_groups, optim=type(base_optimizer), **base_optimizer.defaults)
@RunIf(min_cuda_gpus=2, fairscale=True)
@pytest.mark.parametrize("strategy", (pytest.param("ddp_sharded", marks=RunIf(standalone=True)), "ddp_sharded_spawn"))
def test_ddp_sharded_strategy_checkpoint_multi_gpu_fairscale_optimizer(tmpdir, strategy):
"""Test to ensure that checkpoint is saved correctly when using fairscale optimizers."""
model = BoringFairScaleOptimizerModel()
trainer = Trainer(accelerator="gpu", devices=2, strategy=strategy, max_steps=1)
trainer.fit(model)
checkpoint_path = os.path.join(tmpdir, "model.pt")
# need to broadcast because tmpdir is different on each process
checkpoint_path = trainer.strategy.broadcast(checkpoint_path)
trainer.save_checkpoint(checkpoint_path)
trainer.strategy.barrier() # ensure the checkpoint is saved before load
saved_model = BoringModel.load_from_checkpoint(checkpoint_path)
# Assert model parameters are identical after loading
for trained_param, loaded_param in zip(model.parameters(), saved_model.parameters()):
assert torch.equal(trained_param.to("cpu"), loaded_param)
@RunIf(min_cuda_gpus=2, fairscale=True)
def test_ddp_sharded_strategy_fit_ckpt_path_downsize_gpus(tmpdir):
model = ModelWithAdamOptimizer()
trainer = Trainer(
strategy="ddp_sharded_spawn",
max_epochs=1,
limit_train_batches=1,
limit_val_batches=0,
accelerator="gpu",
devices=2,
)
trainer.fit(model)
checkpoint_path = trainer.checkpoint_callback.best_model_path
ckpt = torch.load(checkpoint_path)
old_model_state_dict = deepcopy(ckpt["state_dict"])
old_optimizer_states = deepcopy(ckpt["optimizer_states"])
model = CheckModelRestore(old_model_state_dict, old_optimizer_states)
trainer = Trainer(
strategy="ddp_sharded_spawn",
max_epochs=2,
limit_train_batches=1,
limit_val_batches=0,
accelerator="gpu",
devices=1,
)
trainer.fit(model, ckpt_path=checkpoint_path)