refactor Fabric tests to use launch method (#17648)
Co-authored-by: bas <bas.krahmer@talentflyxpert.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
parent
3a68493d0a
commit
ca9e006681
tests/tests_fabric
|
@ -31,78 +31,78 @@ from tests_fabric.test_fabric import BoringModel
|
||||||
|
|
||||||
@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
|
@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
|
||||||
def test_deepspeed_multiple_models():
|
def test_deepspeed_multiple_models():
|
||||||
class RunFabric(Fabric):
|
def run(fabric_obj):
|
||||||
def run(self):
|
model = BoringModel()
|
||||||
model = BoringModel()
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
|
||||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
|
model, optimizer = fabric_obj.setup(model, optimizer)
|
||||||
model, optimizer = self.setup(model, optimizer)
|
|
||||||
|
|
||||||
for i in range(2):
|
for i in range(2):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
x = model(torch.randn(1, 32).to(self.device))
|
x = model(torch.randn(1, 32).to(fabric_obj.device))
|
||||||
loss = x.sum()
|
loss = x.sum()
|
||||||
if i == 0:
|
if i == 0:
|
||||||
# the weights are not initialized with stage 3 until backward is run once
|
# the weights are not initialized with stage 3 until backward is run once
|
||||||
assert all(w.nelement() == 0 for w in model.state_dict().values())
|
assert all(w.nelement() == 0 for w in model.state_dict().values())
|
||||||
self.backward(loss, model=model)
|
fabric_obj.backward(loss, model=model)
|
||||||
if i == 0:
|
if i == 0:
|
||||||
# save for later to check that the weights were updated
|
# save for later to check that the weights were updated
|
||||||
state_dict = deepcopy(model.state_dict())
|
state_dict = deepcopy(model.state_dict())
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
# check that the model trained, the weights from step 1 do not match the weights from step 2
|
# check that the model trained, the weights from step 1 do not match the weights from step 2
|
||||||
for mw_b, mw_a in zip(state_dict.values(), model.state_dict().values()):
|
for mw_b, mw_a in zip(state_dict.values(), model.state_dict().values()):
|
||||||
assert not torch.allclose(mw_b, mw_a)
|
assert not torch.allclose(mw_b, mw_a)
|
||||||
|
|
||||||
self.seed_everything(42)
|
fabric_obj.seed_everything(42)
|
||||||
model_1 = BoringModel()
|
model_1 = BoringModel()
|
||||||
optimizer_1 = torch.optim.SGD(model_1.parameters(), lr=0.0001)
|
optimizer_1 = torch.optim.SGD(model_1.parameters(), lr=0.0001)
|
||||||
|
|
||||||
self.seed_everything(42)
|
fabric_obj.seed_everything(42)
|
||||||
model_2 = BoringModel()
|
model_2 = BoringModel()
|
||||||
optimizer_2 = torch.optim.SGD(model_2.parameters(), lr=0.0001)
|
optimizer_2 = torch.optim.SGD(model_2.parameters(), lr=0.0001)
|
||||||
|
|
||||||
for mw_1, mw_2 in zip(model_1.state_dict().values(), model_2.state_dict().values()):
|
for mw_1, mw_2 in zip(model_1.state_dict().values(), model_2.state_dict().values()):
|
||||||
assert torch.allclose(mw_1, mw_2)
|
assert torch.allclose(mw_1, mw_2)
|
||||||
|
|
||||||
model_1, optimizer_1 = self.setup(model_1, optimizer_1)
|
model_1, optimizer_1 = fabric_obj.setup(model_1, optimizer_1)
|
||||||
model_2, optimizer_2 = self.setup(model_2, optimizer_2)
|
model_2, optimizer_2 = fabric_obj.setup(model_2, optimizer_2)
|
||||||
|
|
||||||
# train model_1 first
|
# train model_1 first
|
||||||
self.seed_everything(42)
|
fabric_obj.seed_everything(42)
|
||||||
data_list = []
|
data_list = []
|
||||||
for _ in range(2):
|
for _ in range(2):
|
||||||
optimizer_1.zero_grad()
|
optimizer_1.zero_grad()
|
||||||
data = torch.randn(1, 32).to(self.device)
|
data = torch.randn(1, 32).to(fabric_obj.device)
|
||||||
data_list.append(data)
|
data_list.append(data)
|
||||||
x = model_1(data)
|
x = model_1(data)
|
||||||
loss = x.sum()
|
loss = x.sum()
|
||||||
self.backward(loss, model=model_1)
|
fabric_obj.backward(loss, model=model_1)
|
||||||
optimizer_1.step()
|
optimizer_1.step()
|
||||||
|
|
||||||
# the weights do not match
|
# the weights do not match
|
||||||
assert all(w.nelement() > 1 for w in model_1.state_dict().values())
|
assert all(w.nelement() > 1 for w in model_1.state_dict().values())
|
||||||
assert all(w.nelement() == 0 for w in model_2.state_dict().values())
|
assert all(w.nelement() == 0 for w in model_2.state_dict().values())
|
||||||
|
|
||||||
# now train model_2 with the same data
|
# now train model_2 with the same data
|
||||||
for data in data_list:
|
for data in data_list:
|
||||||
optimizer_2.zero_grad()
|
optimizer_2.zero_grad()
|
||||||
x = model_2(data)
|
x = model_2(data)
|
||||||
loss = x.sum()
|
loss = x.sum()
|
||||||
self.backward(loss, model=model_2)
|
fabric_obj.backward(loss, model=model_2)
|
||||||
optimizer_2.step()
|
optimizer_2.step()
|
||||||
|
|
||||||
# the weights should match
|
# the weights should match
|
||||||
for mw_1, mw_2 in zip(model_1.state_dict().values(), model_2.state_dict().values()):
|
for mw_1, mw_2 in zip(model_1.state_dict().values(), model_2.state_dict().values()):
|
||||||
assert torch.allclose(mw_1, mw_2)
|
assert torch.allclose(mw_1, mw_2)
|
||||||
|
|
||||||
# Verify collectives works as expected
|
# Verify collectives works as expected
|
||||||
ranks = self.all_gather(torch.tensor([self.local_rank]).to(self.device))
|
ranks = fabric_obj.all_gather(torch.tensor([fabric_obj.local_rank]).to(fabric_obj.device))
|
||||||
assert torch.allclose(ranks.cpu(), torch.tensor([[0], [1]]))
|
assert torch.allclose(ranks.cpu(), torch.tensor([[0], [1]]))
|
||||||
assert self.broadcast(True)
|
assert fabric_obj.broadcast(True)
|
||||||
assert self.is_global_zero == (self.local_rank == 0)
|
assert fabric_obj.is_global_zero == (fabric_obj.local_rank == 0)
|
||||||
|
|
||||||
RunFabric(strategy=DeepSpeedStrategy(stage=3, logging_batch_size_per_gpu=1), devices=2, accelerator="gpu").run()
|
fabric = Fabric(strategy=DeepSpeedStrategy(stage=3, logging_batch_size_per_gpu=1), devices=2, accelerator="gpu")
|
||||||
|
fabric.launch(run)
|
||||||
|
|
||||||
|
|
||||||
@RunIf(min_cuda_gpus=1, deepspeed=True)
|
@RunIf(min_cuda_gpus=1, deepspeed=True)
|
||||||
|
@ -118,19 +118,18 @@ def test_deepspeed_multiple_models():
|
||||||
def test_deepspeed_auto_batch_size_config_select(dataset_cls, logging_batch_size_per_gpu, expected_batch_size):
|
def test_deepspeed_auto_batch_size_config_select(dataset_cls, logging_batch_size_per_gpu, expected_batch_size):
|
||||||
"""Test to ensure that the batch size is correctly set as expected for deepspeed logging purposes."""
|
"""Test to ensure that the batch size is correctly set as expected for deepspeed logging purposes."""
|
||||||
|
|
||||||
class RunFabric(Fabric):
|
def run(fabric_obj):
|
||||||
def run(self):
|
assert isinstance(fabric_obj._strategy, DeepSpeedStrategy)
|
||||||
assert isinstance(self._strategy, DeepSpeedStrategy)
|
_ = fabric_obj.setup_dataloaders(DataLoader(dataset_cls(32, 64)))
|
||||||
_ = self.setup_dataloaders(DataLoader(dataset_cls(32, 64)))
|
config = fabric_obj._strategy.config
|
||||||
config = self._strategy.config
|
assert config["train_micro_batch_size_per_gpu"] == expected_batch_size
|
||||||
assert config["train_micro_batch_size_per_gpu"] == expected_batch_size
|
|
||||||
|
|
||||||
fabric = RunFabric(
|
fabric = Fabric(
|
||||||
accelerator="cuda",
|
accelerator="cuda",
|
||||||
devices=1,
|
devices=1,
|
||||||
strategy=DeepSpeedStrategy(logging_batch_size_per_gpu=logging_batch_size_per_gpu, zero_optimization=False),
|
strategy=DeepSpeedStrategy(logging_batch_size_per_gpu=logging_batch_size_per_gpu, zero_optimization=False),
|
||||||
)
|
)
|
||||||
fabric.run()
|
fabric.launch(run)
|
||||||
|
|
||||||
|
|
||||||
@RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True)
|
@RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True)
|
||||||
|
@ -138,21 +137,20 @@ def test_deepspeed_configure_optimizers():
|
||||||
"""Test that the deepspeed strategy with default initialization wraps the optimizer correctly."""
|
"""Test that the deepspeed strategy with default initialization wraps the optimizer correctly."""
|
||||||
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
|
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
|
||||||
|
|
||||||
class RunFabric(Fabric):
|
def run(fabric_obj):
|
||||||
def run(self):
|
model = nn.Linear(3, 3)
|
||||||
model = nn.Linear(3, 3)
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
||||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
|
model, optimizer = fabric_obj.setup(model, optimizer)
|
||||||
model, optimizer = self.setup(model, optimizer)
|
assert isinstance(optimizer.optimizer, DeepSpeedZeroOptimizer)
|
||||||
assert isinstance(optimizer.optimizer, DeepSpeedZeroOptimizer)
|
assert isinstance(optimizer.optimizer.optimizer, torch.optim.SGD)
|
||||||
assert isinstance(optimizer.optimizer.optimizer, torch.optim.SGD)
|
|
||||||
|
|
||||||
fabric = RunFabric(
|
fabric = Fabric(
|
||||||
strategy=DeepSpeedStrategy(),
|
strategy=DeepSpeedStrategy(),
|
||||||
accelerator="cuda",
|
accelerator="cuda",
|
||||||
devices=1,
|
devices=1,
|
||||||
precision="16-mixed",
|
precision="16-mixed",
|
||||||
)
|
)
|
||||||
fabric.run()
|
fabric.launch(run)
|
||||||
|
|
||||||
|
|
||||||
@RunIf(min_cuda_gpus=1, deepspeed=True)
|
@RunIf(min_cuda_gpus=1, deepspeed=True)
|
||||||
|
@ -160,25 +158,24 @@ def test_deepspeed_custom_precision_params():
|
||||||
"""Test that if the FP16 parameters are set via the DeepSpeedStrategy, the deepspeed config contains these
|
"""Test that if the FP16 parameters are set via the DeepSpeedStrategy, the deepspeed config contains these
|
||||||
changes."""
|
changes."""
|
||||||
|
|
||||||
class RunFabric(Fabric):
|
def run(fabric_obj):
|
||||||
def run(self):
|
assert fabric_obj._strategy._config_initialized
|
||||||
assert self._strategy._config_initialized
|
assert fabric_obj._strategy.config["fp16"]["loss_scale"] == 10
|
||||||
assert self._strategy.config["fp16"]["loss_scale"] == 10
|
assert fabric_obj._strategy.config["fp16"]["initial_scale_power"] == 11
|
||||||
assert self._strategy.config["fp16"]["initial_scale_power"] == 11
|
assert fabric_obj._strategy.config["fp16"]["loss_scale_window"] == 12
|
||||||
assert self._strategy.config["fp16"]["loss_scale_window"] == 12
|
assert fabric_obj._strategy.config["fp16"]["hysteresis"] == 13
|
||||||
assert self._strategy.config["fp16"]["hysteresis"] == 13
|
assert fabric_obj._strategy.config["fp16"]["min_loss_scale"] == 14
|
||||||
assert self._strategy.config["fp16"]["min_loss_scale"] == 14
|
|
||||||
|
|
||||||
strategy = DeepSpeedStrategy(
|
strategy = DeepSpeedStrategy(
|
||||||
loss_scale=10, initial_scale_power=11, loss_scale_window=12, hysteresis=13, min_loss_scale=14
|
loss_scale=10, initial_scale_power=11, loss_scale_window=12, hysteresis=13, min_loss_scale=14
|
||||||
)
|
)
|
||||||
fabric = RunFabric(
|
fabric = Fabric(
|
||||||
strategy=strategy,
|
strategy=strategy,
|
||||||
precision="16-mixed",
|
precision="16-mixed",
|
||||||
accelerator="cuda",
|
accelerator="cuda",
|
||||||
devices=1,
|
devices=1,
|
||||||
)
|
)
|
||||||
fabric.run()
|
fabric.launch(run)
|
||||||
|
|
||||||
|
|
||||||
@RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True)
|
@RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True)
|
||||||
|
@ -187,21 +184,20 @@ def test_deepspeed_custom_activation_checkpointing_params_forwarded():
|
||||||
correctly."""
|
correctly."""
|
||||||
import deepspeed
|
import deepspeed
|
||||||
|
|
||||||
class RunFabric(Fabric):
|
def run(fabric_obj):
|
||||||
def run(self):
|
model = nn.Linear(3, 3)
|
||||||
model = nn.Linear(3, 3)
|
optimizer = torch.optim.Adam(model.parameters())
|
||||||
optimizer = torch.optim.Adam(model.parameters())
|
|
||||||
|
|
||||||
with mock.patch("deepspeed.checkpointing.configure", wraps=deepspeed.checkpointing.configure) as configure:
|
with mock.patch("deepspeed.checkpointing.configure", wraps=deepspeed.checkpointing.configure) as configure:
|
||||||
self.setup(model, optimizer)
|
fabric_obj.setup(model, optimizer)
|
||||||
|
|
||||||
configure.assert_called_with(
|
configure.assert_called_with(
|
||||||
mpu_=None,
|
mpu_=None,
|
||||||
partition_activations=True,
|
partition_activations=True,
|
||||||
contiguous_checkpointing=True,
|
contiguous_checkpointing=True,
|
||||||
checkpoint_in_cpu=True,
|
checkpoint_in_cpu=True,
|
||||||
profile=None,
|
profile=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
strategy = DeepSpeedStrategy(
|
strategy = DeepSpeedStrategy(
|
||||||
partition_activations=True,
|
partition_activations=True,
|
||||||
|
@ -209,13 +205,13 @@ def test_deepspeed_custom_activation_checkpointing_params_forwarded():
|
||||||
contiguous_memory_optimization=True,
|
contiguous_memory_optimization=True,
|
||||||
synchronize_checkpoint_boundary=True,
|
synchronize_checkpoint_boundary=True,
|
||||||
)
|
)
|
||||||
fabric = RunFabric(
|
fabric = Fabric(
|
||||||
strategy=strategy,
|
strategy=strategy,
|
||||||
precision="16-mixed",
|
precision="16-mixed",
|
||||||
accelerator="cuda",
|
accelerator="cuda",
|
||||||
devices=1,
|
devices=1,
|
||||||
)
|
)
|
||||||
fabric.run()
|
fabric.launch(run)
|
||||||
|
|
||||||
|
|
||||||
class ModelParallelClassification(BoringFabric):
|
class ModelParallelClassification(BoringFabric):
|
||||||
|
|
|
@ -55,17 +55,13 @@ class BoringModel(nn.Module):
|
||||||
def test_run_input_output():
|
def test_run_input_output():
|
||||||
"""Test that the dynamically patched run() method receives the input arguments and returns the result."""
|
"""Test that the dynamically patched run() method receives the input arguments and returns the result."""
|
||||||
|
|
||||||
class RunFabric(Fabric):
|
def run(fabric_obj, *args, **kwargs):
|
||||||
run_args = ()
|
fabric_obj.run_args = args
|
||||||
run_kwargs = {}
|
fabric_obj.run_kwargs = kwargs
|
||||||
|
return "result"
|
||||||
|
|
||||||
def run(self, *args, **kwargs):
|
fabric = Fabric()
|
||||||
self.run_args = args
|
result = fabric.launch(run, 1, 2, three=3)
|
||||||
self.run_kwargs = kwargs
|
|
||||||
return "result"
|
|
||||||
|
|
||||||
fabric = RunFabric()
|
|
||||||
result = fabric.run(1, 2, three=3)
|
|
||||||
assert result == "result"
|
assert result == "result"
|
||||||
assert fabric.run_args == (1, 2)
|
assert fabric.run_args == (1, 2)
|
||||||
assert fabric.run_kwargs == {"three": 3}
|
assert fabric.run_kwargs == {"three": 3}
|
||||||
|
@ -322,12 +318,12 @@ def test_setup_dataloaders_captures_dataloader_arguments(ctx_manager):
|
||||||
"""Test that Fabric intercepts the DataLoader constructor arguments with a context manager in its run
|
"""Test that Fabric intercepts the DataLoader constructor arguments with a context manager in its run
|
||||||
method."""
|
method."""
|
||||||
|
|
||||||
class RunFabric(Fabric):
|
def run(_):
|
||||||
def run(self):
|
# One for BatchSampler, another for DataLoader
|
||||||
# One for BatchSampler, another for DataLoader
|
assert ctx_manager().__enter__.call_count == 2
|
||||||
assert ctx_manager().__enter__.call_count == 2
|
|
||||||
|
|
||||||
RunFabric().run()
|
fabric = Fabric()
|
||||||
|
fabric.launch(run)
|
||||||
assert ctx_manager().__exit__.call_count == 2
|
assert ctx_manager().__exit__.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
@ -538,27 +534,26 @@ def test_to_device(accelerator, expected):
|
||||||
if not pjrt.using_pjrt():
|
if not pjrt.using_pjrt():
|
||||||
expected = "xla:1"
|
expected = "xla:1"
|
||||||
|
|
||||||
class RunFabric(Fabric):
|
def run(_):
|
||||||
def run(self):
|
expected_device = torch.device(expected)
|
||||||
expected_device = torch.device(expected)
|
|
||||||
|
|
||||||
# module
|
# module
|
||||||
module = torch.nn.Linear(2, 3)
|
module = torch.nn.Linear(2, 3)
|
||||||
module = fabric.to_device(module)
|
module = fabric.to_device(module)
|
||||||
assert all(param.device == expected_device for param in module.parameters())
|
assert all(param.device == expected_device for param in module.parameters())
|
||||||
|
|
||||||
# tensor
|
# tensor
|
||||||
tensor = torch.rand(2, 2)
|
tensor = torch.rand(2, 2)
|
||||||
tensor = fabric.to_device(tensor)
|
tensor = fabric.to_device(tensor)
|
||||||
assert tensor.device == expected_device
|
assert tensor.device == expected_device
|
||||||
|
|
||||||
# collection
|
# collection
|
||||||
collection = {"data": torch.rand(2, 2), "int": 1}
|
collection = {"data": torch.rand(2, 2), "int": 1}
|
||||||
collection = fabric.to_device(collection)
|
collection = fabric.to_device(collection)
|
||||||
assert collection["data"].device == expected_device
|
assert collection["data"].device == expected_device
|
||||||
|
|
||||||
fabric = RunFabric(accelerator=accelerator, devices=1)
|
fabric = Fabric(accelerator=accelerator, devices=1)
|
||||||
fabric.run()
|
fabric.launch(run)
|
||||||
|
|
||||||
|
|
||||||
def test_rank_properties():
|
def test_rank_properties():
|
||||||
|
|
Loading…
Reference in New Issue