diff --git a/tests/tests_fabric/strategies/test_deepspeed_integration.py b/tests/tests_fabric/strategies/test_deepspeed_integration.py index 21eb661922..6cca479b65 100644 --- a/tests/tests_fabric/strategies/test_deepspeed_integration.py +++ b/tests/tests_fabric/strategies/test_deepspeed_integration.py @@ -31,78 +31,78 @@ from tests_fabric.test_fabric import BoringModel @RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) def test_deepspeed_multiple_models(): - class RunFabric(Fabric): - def run(self): - model = BoringModel() - optimizer = torch.optim.SGD(model.parameters(), lr=0.0001) - model, optimizer = self.setup(model, optimizer) + def run(fabric_obj): + model = BoringModel() + optimizer = torch.optim.SGD(model.parameters(), lr=0.0001) + model, optimizer = fabric_obj.setup(model, optimizer) - for i in range(2): - optimizer.zero_grad() - x = model(torch.randn(1, 32).to(self.device)) - loss = x.sum() - if i == 0: - # the weights are not initialized with stage 3 until backward is run once - assert all(w.nelement() == 0 for w in model.state_dict().values()) - self.backward(loss, model=model) - if i == 0: - # save for later to check that the weights were updated - state_dict = deepcopy(model.state_dict()) - optimizer.step() + for i in range(2): + optimizer.zero_grad() + x = model(torch.randn(1, 32).to(fabric_obj.device)) + loss = x.sum() + if i == 0: + # the weights are not initialized with stage 3 until backward is run once + assert all(w.nelement() == 0 for w in model.state_dict().values()) + fabric_obj.backward(loss, model=model) + if i == 0: + # save for later to check that the weights were updated + state_dict = deepcopy(model.state_dict()) + optimizer.step() - # check that the model trained, the weights from step 1 do not match the weights from step 2 - for mw_b, mw_a in zip(state_dict.values(), model.state_dict().values()): - assert not torch.allclose(mw_b, mw_a) + # check that the model trained, the weights from step 1 do not match the weights from step 2 + for mw_b, mw_a in zip(state_dict.values(), model.state_dict().values()): + assert not torch.allclose(mw_b, mw_a) - self.seed_everything(42) - model_1 = BoringModel() - optimizer_1 = torch.optim.SGD(model_1.parameters(), lr=0.0001) + fabric_obj.seed_everything(42) + model_1 = BoringModel() + optimizer_1 = torch.optim.SGD(model_1.parameters(), lr=0.0001) - self.seed_everything(42) - model_2 = BoringModel() - optimizer_2 = torch.optim.SGD(model_2.parameters(), lr=0.0001) + fabric_obj.seed_everything(42) + model_2 = BoringModel() + optimizer_2 = torch.optim.SGD(model_2.parameters(), lr=0.0001) - for mw_1, mw_2 in zip(model_1.state_dict().values(), model_2.state_dict().values()): - assert torch.allclose(mw_1, mw_2) + for mw_1, mw_2 in zip(model_1.state_dict().values(), model_2.state_dict().values()): + assert torch.allclose(mw_1, mw_2) - model_1, optimizer_1 = self.setup(model_1, optimizer_1) - model_2, optimizer_2 = self.setup(model_2, optimizer_2) + model_1, optimizer_1 = fabric_obj.setup(model_1, optimizer_1) + model_2, optimizer_2 = fabric_obj.setup(model_2, optimizer_2) - # train model_1 first - self.seed_everything(42) - data_list = [] - for _ in range(2): - optimizer_1.zero_grad() - data = torch.randn(1, 32).to(self.device) - data_list.append(data) - x = model_1(data) - loss = x.sum() - self.backward(loss, model=model_1) - optimizer_1.step() + # train model_1 first + fabric_obj.seed_everything(42) + data_list = [] + for _ in range(2): + optimizer_1.zero_grad() + data = torch.randn(1, 32).to(fabric_obj.device) + data_list.append(data) + x = model_1(data) + loss = x.sum() + fabric_obj.backward(loss, model=model_1) + optimizer_1.step() - # the weights do not match - assert all(w.nelement() > 1 for w in model_1.state_dict().values()) - assert all(w.nelement() == 0 for w in model_2.state_dict().values()) + # the weights do not match + assert all(w.nelement() > 1 for w in model_1.state_dict().values()) + assert all(w.nelement() == 0 for w in model_2.state_dict().values()) - # now train model_2 with the same data - for data in data_list: - optimizer_2.zero_grad() - x = model_2(data) - loss = x.sum() - self.backward(loss, model=model_2) - optimizer_2.step() + # now train model_2 with the same data + for data in data_list: + optimizer_2.zero_grad() + x = model_2(data) + loss = x.sum() + fabric_obj.backward(loss, model=model_2) + optimizer_2.step() - # the weights should match - for mw_1, mw_2 in zip(model_1.state_dict().values(), model_2.state_dict().values()): - assert torch.allclose(mw_1, mw_2) + # the weights should match + for mw_1, mw_2 in zip(model_1.state_dict().values(), model_2.state_dict().values()): + assert torch.allclose(mw_1, mw_2) - # Verify collectives works as expected - ranks = self.all_gather(torch.tensor([self.local_rank]).to(self.device)) - assert torch.allclose(ranks.cpu(), torch.tensor([[0], [1]])) - assert self.broadcast(True) - assert self.is_global_zero == (self.local_rank == 0) + # Verify collectives works as expected + ranks = fabric_obj.all_gather(torch.tensor([fabric_obj.local_rank]).to(fabric_obj.device)) + assert torch.allclose(ranks.cpu(), torch.tensor([[0], [1]])) + assert fabric_obj.broadcast(True) + assert fabric_obj.is_global_zero == (fabric_obj.local_rank == 0) - RunFabric(strategy=DeepSpeedStrategy(stage=3, logging_batch_size_per_gpu=1), devices=2, accelerator="gpu").run() + fabric = Fabric(strategy=DeepSpeedStrategy(stage=3, logging_batch_size_per_gpu=1), devices=2, accelerator="gpu") + fabric.launch(run) @RunIf(min_cuda_gpus=1, deepspeed=True) @@ -118,19 +118,18 @@ def test_deepspeed_multiple_models(): def test_deepspeed_auto_batch_size_config_select(dataset_cls, logging_batch_size_per_gpu, expected_batch_size): """Test to ensure that the batch size is correctly set as expected for deepspeed logging purposes.""" - class RunFabric(Fabric): - def run(self): - assert isinstance(self._strategy, DeepSpeedStrategy) - _ = self.setup_dataloaders(DataLoader(dataset_cls(32, 64))) - config = self._strategy.config - assert config["train_micro_batch_size_per_gpu"] == expected_batch_size + def run(fabric_obj): + assert isinstance(fabric_obj._strategy, DeepSpeedStrategy) + _ = fabric_obj.setup_dataloaders(DataLoader(dataset_cls(32, 64))) + config = fabric_obj._strategy.config + assert config["train_micro_batch_size_per_gpu"] == expected_batch_size - fabric = RunFabric( + fabric = Fabric( accelerator="cuda", devices=1, strategy=DeepSpeedStrategy(logging_batch_size_per_gpu=logging_batch_size_per_gpu, zero_optimization=False), ) - fabric.run() + fabric.launch(run) @RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True) @@ -138,21 +137,20 @@ def test_deepspeed_configure_optimizers(): """Test that the deepspeed strategy with default initialization wraps the optimizer correctly.""" from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer - class RunFabric(Fabric): - def run(self): - model = nn.Linear(3, 3) - optimizer = torch.optim.SGD(model.parameters(), lr=0.01) - model, optimizer = self.setup(model, optimizer) - assert isinstance(optimizer.optimizer, DeepSpeedZeroOptimizer) - assert isinstance(optimizer.optimizer.optimizer, torch.optim.SGD) + def run(fabric_obj): + model = nn.Linear(3, 3) + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + model, optimizer = fabric_obj.setup(model, optimizer) + assert isinstance(optimizer.optimizer, DeepSpeedZeroOptimizer) + assert isinstance(optimizer.optimizer.optimizer, torch.optim.SGD) - fabric = RunFabric( + fabric = Fabric( strategy=DeepSpeedStrategy(), accelerator="cuda", devices=1, precision="16-mixed", ) - fabric.run() + fabric.launch(run) @RunIf(min_cuda_gpus=1, deepspeed=True) @@ -160,25 +158,24 @@ def test_deepspeed_custom_precision_params(): """Test that if the FP16 parameters are set via the DeepSpeedStrategy, the deepspeed config contains these changes.""" - class RunFabric(Fabric): - def run(self): - assert self._strategy._config_initialized - assert self._strategy.config["fp16"]["loss_scale"] == 10 - assert self._strategy.config["fp16"]["initial_scale_power"] == 11 - assert self._strategy.config["fp16"]["loss_scale_window"] == 12 - assert self._strategy.config["fp16"]["hysteresis"] == 13 - assert self._strategy.config["fp16"]["min_loss_scale"] == 14 + def run(fabric_obj): + assert fabric_obj._strategy._config_initialized + assert fabric_obj._strategy.config["fp16"]["loss_scale"] == 10 + assert fabric_obj._strategy.config["fp16"]["initial_scale_power"] == 11 + assert fabric_obj._strategy.config["fp16"]["loss_scale_window"] == 12 + assert fabric_obj._strategy.config["fp16"]["hysteresis"] == 13 + assert fabric_obj._strategy.config["fp16"]["min_loss_scale"] == 14 strategy = DeepSpeedStrategy( loss_scale=10, initial_scale_power=11, loss_scale_window=12, hysteresis=13, min_loss_scale=14 ) - fabric = RunFabric( + fabric = Fabric( strategy=strategy, precision="16-mixed", accelerator="cuda", devices=1, ) - fabric.run() + fabric.launch(run) @RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True) @@ -187,21 +184,20 @@ def test_deepspeed_custom_activation_checkpointing_params_forwarded(): correctly.""" import deepspeed - class RunFabric(Fabric): - def run(self): - model = nn.Linear(3, 3) - optimizer = torch.optim.Adam(model.parameters()) + def run(fabric_obj): + model = nn.Linear(3, 3) + optimizer = torch.optim.Adam(model.parameters()) - with mock.patch("deepspeed.checkpointing.configure", wraps=deepspeed.checkpointing.configure) as configure: - self.setup(model, optimizer) + with mock.patch("deepspeed.checkpointing.configure", wraps=deepspeed.checkpointing.configure) as configure: + fabric_obj.setup(model, optimizer) - configure.assert_called_with( - mpu_=None, - partition_activations=True, - contiguous_checkpointing=True, - checkpoint_in_cpu=True, - profile=None, - ) + configure.assert_called_with( + mpu_=None, + partition_activations=True, + contiguous_checkpointing=True, + checkpoint_in_cpu=True, + profile=None, + ) strategy = DeepSpeedStrategy( partition_activations=True, @@ -209,13 +205,13 @@ def test_deepspeed_custom_activation_checkpointing_params_forwarded(): contiguous_memory_optimization=True, synchronize_checkpoint_boundary=True, ) - fabric = RunFabric( + fabric = Fabric( strategy=strategy, precision="16-mixed", accelerator="cuda", devices=1, ) - fabric.run() + fabric.launch(run) class ModelParallelClassification(BoringFabric): diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py index 7e2ceeefe7..71ab649d5e 100644 --- a/tests/tests_fabric/test_fabric.py +++ b/tests/tests_fabric/test_fabric.py @@ -55,17 +55,13 @@ class BoringModel(nn.Module): def test_run_input_output(): """Test that the dynamically patched run() method receives the input arguments and returns the result.""" - class RunFabric(Fabric): - run_args = () - run_kwargs = {} + def run(fabric_obj, *args, **kwargs): + fabric_obj.run_args = args + fabric_obj.run_kwargs = kwargs + return "result" - def run(self, *args, **kwargs): - self.run_args = args - self.run_kwargs = kwargs - return "result" - - fabric = RunFabric() - result = fabric.run(1, 2, three=3) + fabric = Fabric() + result = fabric.launch(run, 1, 2, three=3) assert result == "result" assert fabric.run_args == (1, 2) assert fabric.run_kwargs == {"three": 3} @@ -322,12 +318,12 @@ def test_setup_dataloaders_captures_dataloader_arguments(ctx_manager): """Test that Fabric intercepts the DataLoader constructor arguments with a context manager in its run method.""" - class RunFabric(Fabric): - def run(self): - # One for BatchSampler, another for DataLoader - assert ctx_manager().__enter__.call_count == 2 + def run(_): + # One for BatchSampler, another for DataLoader + assert ctx_manager().__enter__.call_count == 2 - RunFabric().run() + fabric = Fabric() + fabric.launch(run) assert ctx_manager().__exit__.call_count == 2 @@ -538,27 +534,26 @@ def test_to_device(accelerator, expected): if not pjrt.using_pjrt(): expected = "xla:1" - class RunFabric(Fabric): - def run(self): - expected_device = torch.device(expected) + def run(_): + expected_device = torch.device(expected) - # module - module = torch.nn.Linear(2, 3) - module = fabric.to_device(module) - assert all(param.device == expected_device for param in module.parameters()) + # module + module = torch.nn.Linear(2, 3) + module = fabric.to_device(module) + assert all(param.device == expected_device for param in module.parameters()) - # tensor - tensor = torch.rand(2, 2) - tensor = fabric.to_device(tensor) - assert tensor.device == expected_device + # tensor + tensor = torch.rand(2, 2) + tensor = fabric.to_device(tensor) + assert tensor.device == expected_device - # collection - collection = {"data": torch.rand(2, 2), "int": 1} - collection = fabric.to_device(collection) - assert collection["data"].device == expected_device + # collection + collection = {"data": torch.rand(2, 2), "int": 1} + collection = fabric.to_device(collection) + assert collection["data"].device == expected_device - fabric = RunFabric(accelerator=accelerator, devices=1) - fabric.run() + fabric = Fabric(accelerator=accelerator, devices=1) + fabric.launch(run) def test_rank_properties():