refactor Fabric tests to use launch method ()

Co-authored-by: bas <bas.krahmer@talentflyxpert.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Bas Krahmer 2023-05-19 19:42:49 +02:00 committed by GitHub
parent 3a68493d0a
commit ca9e006681
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 125 additions and 134 deletions

View File

@ -31,78 +31,78 @@ from tests_fabric.test_fabric import BoringModel
@RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) @RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
def test_deepspeed_multiple_models(): def test_deepspeed_multiple_models():
class RunFabric(Fabric): def run(fabric_obj):
def run(self): model = BoringModel()
model = BoringModel() optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001) model, optimizer = fabric_obj.setup(model, optimizer)
model, optimizer = self.setup(model, optimizer)
for i in range(2): for i in range(2):
optimizer.zero_grad() optimizer.zero_grad()
x = model(torch.randn(1, 32).to(self.device)) x = model(torch.randn(1, 32).to(fabric_obj.device))
loss = x.sum() loss = x.sum()
if i == 0: if i == 0:
# the weights are not initialized with stage 3 until backward is run once # the weights are not initialized with stage 3 until backward is run once
assert all(w.nelement() == 0 for w in model.state_dict().values()) assert all(w.nelement() == 0 for w in model.state_dict().values())
self.backward(loss, model=model) fabric_obj.backward(loss, model=model)
if i == 0: if i == 0:
# save for later to check that the weights were updated # save for later to check that the weights were updated
state_dict = deepcopy(model.state_dict()) state_dict = deepcopy(model.state_dict())
optimizer.step() optimizer.step()
# check that the model trained, the weights from step 1 do not match the weights from step 2 # check that the model trained, the weights from step 1 do not match the weights from step 2
for mw_b, mw_a in zip(state_dict.values(), model.state_dict().values()): for mw_b, mw_a in zip(state_dict.values(), model.state_dict().values()):
assert not torch.allclose(mw_b, mw_a) assert not torch.allclose(mw_b, mw_a)
self.seed_everything(42) fabric_obj.seed_everything(42)
model_1 = BoringModel() model_1 = BoringModel()
optimizer_1 = torch.optim.SGD(model_1.parameters(), lr=0.0001) optimizer_1 = torch.optim.SGD(model_1.parameters(), lr=0.0001)
self.seed_everything(42) fabric_obj.seed_everything(42)
model_2 = BoringModel() model_2 = BoringModel()
optimizer_2 = torch.optim.SGD(model_2.parameters(), lr=0.0001) optimizer_2 = torch.optim.SGD(model_2.parameters(), lr=0.0001)
for mw_1, mw_2 in zip(model_1.state_dict().values(), model_2.state_dict().values()): for mw_1, mw_2 in zip(model_1.state_dict().values(), model_2.state_dict().values()):
assert torch.allclose(mw_1, mw_2) assert torch.allclose(mw_1, mw_2)
model_1, optimizer_1 = self.setup(model_1, optimizer_1) model_1, optimizer_1 = fabric_obj.setup(model_1, optimizer_1)
model_2, optimizer_2 = self.setup(model_2, optimizer_2) model_2, optimizer_2 = fabric_obj.setup(model_2, optimizer_2)
# train model_1 first # train model_1 first
self.seed_everything(42) fabric_obj.seed_everything(42)
data_list = [] data_list = []
for _ in range(2): for _ in range(2):
optimizer_1.zero_grad() optimizer_1.zero_grad()
data = torch.randn(1, 32).to(self.device) data = torch.randn(1, 32).to(fabric_obj.device)
data_list.append(data) data_list.append(data)
x = model_1(data) x = model_1(data)
loss = x.sum() loss = x.sum()
self.backward(loss, model=model_1) fabric_obj.backward(loss, model=model_1)
optimizer_1.step() optimizer_1.step()
# the weights do not match # the weights do not match
assert all(w.nelement() > 1 for w in model_1.state_dict().values()) assert all(w.nelement() > 1 for w in model_1.state_dict().values())
assert all(w.nelement() == 0 for w in model_2.state_dict().values()) assert all(w.nelement() == 0 for w in model_2.state_dict().values())
# now train model_2 with the same data # now train model_2 with the same data
for data in data_list: for data in data_list:
optimizer_2.zero_grad() optimizer_2.zero_grad()
x = model_2(data) x = model_2(data)
loss = x.sum() loss = x.sum()
self.backward(loss, model=model_2) fabric_obj.backward(loss, model=model_2)
optimizer_2.step() optimizer_2.step()
# the weights should match # the weights should match
for mw_1, mw_2 in zip(model_1.state_dict().values(), model_2.state_dict().values()): for mw_1, mw_2 in zip(model_1.state_dict().values(), model_2.state_dict().values()):
assert torch.allclose(mw_1, mw_2) assert torch.allclose(mw_1, mw_2)
# Verify collectives works as expected # Verify collectives works as expected
ranks = self.all_gather(torch.tensor([self.local_rank]).to(self.device)) ranks = fabric_obj.all_gather(torch.tensor([fabric_obj.local_rank]).to(fabric_obj.device))
assert torch.allclose(ranks.cpu(), torch.tensor([[0], [1]])) assert torch.allclose(ranks.cpu(), torch.tensor([[0], [1]]))
assert self.broadcast(True) assert fabric_obj.broadcast(True)
assert self.is_global_zero == (self.local_rank == 0) assert fabric_obj.is_global_zero == (fabric_obj.local_rank == 0)
RunFabric(strategy=DeepSpeedStrategy(stage=3, logging_batch_size_per_gpu=1), devices=2, accelerator="gpu").run() fabric = Fabric(strategy=DeepSpeedStrategy(stage=3, logging_batch_size_per_gpu=1), devices=2, accelerator="gpu")
fabric.launch(run)
@RunIf(min_cuda_gpus=1, deepspeed=True) @RunIf(min_cuda_gpus=1, deepspeed=True)
@ -118,19 +118,18 @@ def test_deepspeed_multiple_models():
def test_deepspeed_auto_batch_size_config_select(dataset_cls, logging_batch_size_per_gpu, expected_batch_size): def test_deepspeed_auto_batch_size_config_select(dataset_cls, logging_batch_size_per_gpu, expected_batch_size):
"""Test to ensure that the batch size is correctly set as expected for deepspeed logging purposes.""" """Test to ensure that the batch size is correctly set as expected for deepspeed logging purposes."""
class RunFabric(Fabric): def run(fabric_obj):
def run(self): assert isinstance(fabric_obj._strategy, DeepSpeedStrategy)
assert isinstance(self._strategy, DeepSpeedStrategy) _ = fabric_obj.setup_dataloaders(DataLoader(dataset_cls(32, 64)))
_ = self.setup_dataloaders(DataLoader(dataset_cls(32, 64))) config = fabric_obj._strategy.config
config = self._strategy.config assert config["train_micro_batch_size_per_gpu"] == expected_batch_size
assert config["train_micro_batch_size_per_gpu"] == expected_batch_size
fabric = RunFabric( fabric = Fabric(
accelerator="cuda", accelerator="cuda",
devices=1, devices=1,
strategy=DeepSpeedStrategy(logging_batch_size_per_gpu=logging_batch_size_per_gpu, zero_optimization=False), strategy=DeepSpeedStrategy(logging_batch_size_per_gpu=logging_batch_size_per_gpu, zero_optimization=False),
) )
fabric.run() fabric.launch(run)
@RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True) @RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True)
@ -138,21 +137,20 @@ def test_deepspeed_configure_optimizers():
"""Test that the deepspeed strategy with default initialization wraps the optimizer correctly.""" """Test that the deepspeed strategy with default initialization wraps the optimizer correctly."""
from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
class RunFabric(Fabric): def run(fabric_obj):
def run(self): model = nn.Linear(3, 3)
model = nn.Linear(3, 3) optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) model, optimizer = fabric_obj.setup(model, optimizer)
model, optimizer = self.setup(model, optimizer) assert isinstance(optimizer.optimizer, DeepSpeedZeroOptimizer)
assert isinstance(optimizer.optimizer, DeepSpeedZeroOptimizer) assert isinstance(optimizer.optimizer.optimizer, torch.optim.SGD)
assert isinstance(optimizer.optimizer.optimizer, torch.optim.SGD)
fabric = RunFabric( fabric = Fabric(
strategy=DeepSpeedStrategy(), strategy=DeepSpeedStrategy(),
accelerator="cuda", accelerator="cuda",
devices=1, devices=1,
precision="16-mixed", precision="16-mixed",
) )
fabric.run() fabric.launch(run)
@RunIf(min_cuda_gpus=1, deepspeed=True) @RunIf(min_cuda_gpus=1, deepspeed=True)
@ -160,25 +158,24 @@ def test_deepspeed_custom_precision_params():
"""Test that if the FP16 parameters are set via the DeepSpeedStrategy, the deepspeed config contains these """Test that if the FP16 parameters are set via the DeepSpeedStrategy, the deepspeed config contains these
changes.""" changes."""
class RunFabric(Fabric): def run(fabric_obj):
def run(self): assert fabric_obj._strategy._config_initialized
assert self._strategy._config_initialized assert fabric_obj._strategy.config["fp16"]["loss_scale"] == 10
assert self._strategy.config["fp16"]["loss_scale"] == 10 assert fabric_obj._strategy.config["fp16"]["initial_scale_power"] == 11
assert self._strategy.config["fp16"]["initial_scale_power"] == 11 assert fabric_obj._strategy.config["fp16"]["loss_scale_window"] == 12
assert self._strategy.config["fp16"]["loss_scale_window"] == 12 assert fabric_obj._strategy.config["fp16"]["hysteresis"] == 13
assert self._strategy.config["fp16"]["hysteresis"] == 13 assert fabric_obj._strategy.config["fp16"]["min_loss_scale"] == 14
assert self._strategy.config["fp16"]["min_loss_scale"] == 14
strategy = DeepSpeedStrategy( strategy = DeepSpeedStrategy(
loss_scale=10, initial_scale_power=11, loss_scale_window=12, hysteresis=13, min_loss_scale=14 loss_scale=10, initial_scale_power=11, loss_scale_window=12, hysteresis=13, min_loss_scale=14
) )
fabric = RunFabric( fabric = Fabric(
strategy=strategy, strategy=strategy,
precision="16-mixed", precision="16-mixed",
accelerator="cuda", accelerator="cuda",
devices=1, devices=1,
) )
fabric.run() fabric.launch(run)
@RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True) @RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True)
@ -187,21 +184,20 @@ def test_deepspeed_custom_activation_checkpointing_params_forwarded():
correctly.""" correctly."""
import deepspeed import deepspeed
class RunFabric(Fabric): def run(fabric_obj):
def run(self): model = nn.Linear(3, 3)
model = nn.Linear(3, 3) optimizer = torch.optim.Adam(model.parameters())
optimizer = torch.optim.Adam(model.parameters())
with mock.patch("deepspeed.checkpointing.configure", wraps=deepspeed.checkpointing.configure) as configure: with mock.patch("deepspeed.checkpointing.configure", wraps=deepspeed.checkpointing.configure) as configure:
self.setup(model, optimizer) fabric_obj.setup(model, optimizer)
configure.assert_called_with( configure.assert_called_with(
mpu_=None, mpu_=None,
partition_activations=True, partition_activations=True,
contiguous_checkpointing=True, contiguous_checkpointing=True,
checkpoint_in_cpu=True, checkpoint_in_cpu=True,
profile=None, profile=None,
) )
strategy = DeepSpeedStrategy( strategy = DeepSpeedStrategy(
partition_activations=True, partition_activations=True,
@ -209,13 +205,13 @@ def test_deepspeed_custom_activation_checkpointing_params_forwarded():
contiguous_memory_optimization=True, contiguous_memory_optimization=True,
synchronize_checkpoint_boundary=True, synchronize_checkpoint_boundary=True,
) )
fabric = RunFabric( fabric = Fabric(
strategy=strategy, strategy=strategy,
precision="16-mixed", precision="16-mixed",
accelerator="cuda", accelerator="cuda",
devices=1, devices=1,
) )
fabric.run() fabric.launch(run)
class ModelParallelClassification(BoringFabric): class ModelParallelClassification(BoringFabric):

View File

@ -55,17 +55,13 @@ class BoringModel(nn.Module):
def test_run_input_output(): def test_run_input_output():
"""Test that the dynamically patched run() method receives the input arguments and returns the result.""" """Test that the dynamically patched run() method receives the input arguments and returns the result."""
class RunFabric(Fabric): def run(fabric_obj, *args, **kwargs):
run_args = () fabric_obj.run_args = args
run_kwargs = {} fabric_obj.run_kwargs = kwargs
return "result"
def run(self, *args, **kwargs): fabric = Fabric()
self.run_args = args result = fabric.launch(run, 1, 2, three=3)
self.run_kwargs = kwargs
return "result"
fabric = RunFabric()
result = fabric.run(1, 2, three=3)
assert result == "result" assert result == "result"
assert fabric.run_args == (1, 2) assert fabric.run_args == (1, 2)
assert fabric.run_kwargs == {"three": 3} assert fabric.run_kwargs == {"three": 3}
@ -322,12 +318,12 @@ def test_setup_dataloaders_captures_dataloader_arguments(ctx_manager):
"""Test that Fabric intercepts the DataLoader constructor arguments with a context manager in its run """Test that Fabric intercepts the DataLoader constructor arguments with a context manager in its run
method.""" method."""
class RunFabric(Fabric): def run(_):
def run(self): # One for BatchSampler, another for DataLoader
# One for BatchSampler, another for DataLoader assert ctx_manager().__enter__.call_count == 2
assert ctx_manager().__enter__.call_count == 2
RunFabric().run() fabric = Fabric()
fabric.launch(run)
assert ctx_manager().__exit__.call_count == 2 assert ctx_manager().__exit__.call_count == 2
@ -538,27 +534,26 @@ def test_to_device(accelerator, expected):
if not pjrt.using_pjrt(): if not pjrt.using_pjrt():
expected = "xla:1" expected = "xla:1"
class RunFabric(Fabric): def run(_):
def run(self): expected_device = torch.device(expected)
expected_device = torch.device(expected)
# module # module
module = torch.nn.Linear(2, 3) module = torch.nn.Linear(2, 3)
module = fabric.to_device(module) module = fabric.to_device(module)
assert all(param.device == expected_device for param in module.parameters()) assert all(param.device == expected_device for param in module.parameters())
# tensor # tensor
tensor = torch.rand(2, 2) tensor = torch.rand(2, 2)
tensor = fabric.to_device(tensor) tensor = fabric.to_device(tensor)
assert tensor.device == expected_device assert tensor.device == expected_device
# collection # collection
collection = {"data": torch.rand(2, 2), "int": 1} collection = {"data": torch.rand(2, 2), "int": 1}
collection = fabric.to_device(collection) collection = fabric.to_device(collection)
assert collection["data"].device == expected_device assert collection["data"].device == expected_device
fabric = RunFabric(accelerator=accelerator, devices=1) fabric = Fabric(accelerator=accelerator, devices=1)
fabric.run() fabric.launch(run)
def test_rank_properties(): def test_rank_properties():