From ca9e0066817d81780ad0d63c25517a2859312e73 Mon Sep 17 00:00:00 2001
From: Bas Krahmer <baskrahmer@gmail.com>
Date: Fri, 19 May 2023 19:42:49 +0200
Subject: [PATCH] refactor Fabric tests to use launch method (#17648)

Co-authored-by: bas <bas.krahmer@talentflyxpert.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
---
 .../strategies/test_deepspeed_integration.py  | 200 +++++++++---------
 tests/tests_fabric/test_fabric.py             |  59 +++---
 2 files changed, 125 insertions(+), 134 deletions(-)

diff --git a/tests/tests_fabric/strategies/test_deepspeed_integration.py b/tests/tests_fabric/strategies/test_deepspeed_integration.py
index 21eb661922..6cca479b65 100644
--- a/tests/tests_fabric/strategies/test_deepspeed_integration.py
+++ b/tests/tests_fabric/strategies/test_deepspeed_integration.py
@@ -31,78 +31,78 @@ from tests_fabric.test_fabric import BoringModel
 
 @RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True)
 def test_deepspeed_multiple_models():
-    class RunFabric(Fabric):
-        def run(self):
-            model = BoringModel()
-            optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
-            model, optimizer = self.setup(model, optimizer)
+    def run(fabric_obj):
+        model = BoringModel()
+        optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
+        model, optimizer = fabric_obj.setup(model, optimizer)
 
-            for i in range(2):
-                optimizer.zero_grad()
-                x = model(torch.randn(1, 32).to(self.device))
-                loss = x.sum()
-                if i == 0:
-                    # the weights are not initialized with stage 3 until backward is run once
-                    assert all(w.nelement() == 0 for w in model.state_dict().values())
-                self.backward(loss, model=model)
-                if i == 0:
-                    # save for later to check that the weights were updated
-                    state_dict = deepcopy(model.state_dict())
-                optimizer.step()
+        for i in range(2):
+            optimizer.zero_grad()
+            x = model(torch.randn(1, 32).to(fabric_obj.device))
+            loss = x.sum()
+            if i == 0:
+                # the weights are not initialized with stage 3 until backward is run once
+                assert all(w.nelement() == 0 for w in model.state_dict().values())
+            fabric_obj.backward(loss, model=model)
+            if i == 0:
+                # save for later to check that the weights were updated
+                state_dict = deepcopy(model.state_dict())
+            optimizer.step()
 
-            # check that the model trained, the weights from step 1 do not match the weights from step 2
-            for mw_b, mw_a in zip(state_dict.values(), model.state_dict().values()):
-                assert not torch.allclose(mw_b, mw_a)
+        # check that the model trained, the weights from step 1 do not match the weights from step 2
+        for mw_b, mw_a in zip(state_dict.values(), model.state_dict().values()):
+            assert not torch.allclose(mw_b, mw_a)
 
-            self.seed_everything(42)
-            model_1 = BoringModel()
-            optimizer_1 = torch.optim.SGD(model_1.parameters(), lr=0.0001)
+        fabric_obj.seed_everything(42)
+        model_1 = BoringModel()
+        optimizer_1 = torch.optim.SGD(model_1.parameters(), lr=0.0001)
 
-            self.seed_everything(42)
-            model_2 = BoringModel()
-            optimizer_2 = torch.optim.SGD(model_2.parameters(), lr=0.0001)
+        fabric_obj.seed_everything(42)
+        model_2 = BoringModel()
+        optimizer_2 = torch.optim.SGD(model_2.parameters(), lr=0.0001)
 
-            for mw_1, mw_2 in zip(model_1.state_dict().values(), model_2.state_dict().values()):
-                assert torch.allclose(mw_1, mw_2)
+        for mw_1, mw_2 in zip(model_1.state_dict().values(), model_2.state_dict().values()):
+            assert torch.allclose(mw_1, mw_2)
 
-            model_1, optimizer_1 = self.setup(model_1, optimizer_1)
-            model_2, optimizer_2 = self.setup(model_2, optimizer_2)
+        model_1, optimizer_1 = fabric_obj.setup(model_1, optimizer_1)
+        model_2, optimizer_2 = fabric_obj.setup(model_2, optimizer_2)
 
-            # train model_1 first
-            self.seed_everything(42)
-            data_list = []
-            for _ in range(2):
-                optimizer_1.zero_grad()
-                data = torch.randn(1, 32).to(self.device)
-                data_list.append(data)
-                x = model_1(data)
-                loss = x.sum()
-                self.backward(loss, model=model_1)
-                optimizer_1.step()
+        # train model_1 first
+        fabric_obj.seed_everything(42)
+        data_list = []
+        for _ in range(2):
+            optimizer_1.zero_grad()
+            data = torch.randn(1, 32).to(fabric_obj.device)
+            data_list.append(data)
+            x = model_1(data)
+            loss = x.sum()
+            fabric_obj.backward(loss, model=model_1)
+            optimizer_1.step()
 
-            # the weights do not match
-            assert all(w.nelement() > 1 for w in model_1.state_dict().values())
-            assert all(w.nelement() == 0 for w in model_2.state_dict().values())
+        # the weights do not match
+        assert all(w.nelement() > 1 for w in model_1.state_dict().values())
+        assert all(w.nelement() == 0 for w in model_2.state_dict().values())
 
-            # now train model_2 with the same data
-            for data in data_list:
-                optimizer_2.zero_grad()
-                x = model_2(data)
-                loss = x.sum()
-                self.backward(loss, model=model_2)
-                optimizer_2.step()
+        # now train model_2 with the same data
+        for data in data_list:
+            optimizer_2.zero_grad()
+            x = model_2(data)
+            loss = x.sum()
+            fabric_obj.backward(loss, model=model_2)
+            optimizer_2.step()
 
-            # the weights should match
-            for mw_1, mw_2 in zip(model_1.state_dict().values(), model_2.state_dict().values()):
-                assert torch.allclose(mw_1, mw_2)
+        # the weights should match
+        for mw_1, mw_2 in zip(model_1.state_dict().values(), model_2.state_dict().values()):
+            assert torch.allclose(mw_1, mw_2)
 
-            # Verify collectives works as expected
-            ranks = self.all_gather(torch.tensor([self.local_rank]).to(self.device))
-            assert torch.allclose(ranks.cpu(), torch.tensor([[0], [1]]))
-            assert self.broadcast(True)
-            assert self.is_global_zero == (self.local_rank == 0)
+        # Verify collectives works as expected
+        ranks = fabric_obj.all_gather(torch.tensor([fabric_obj.local_rank]).to(fabric_obj.device))
+        assert torch.allclose(ranks.cpu(), torch.tensor([[0], [1]]))
+        assert fabric_obj.broadcast(True)
+        assert fabric_obj.is_global_zero == (fabric_obj.local_rank == 0)
 
-    RunFabric(strategy=DeepSpeedStrategy(stage=3, logging_batch_size_per_gpu=1), devices=2, accelerator="gpu").run()
+    fabric = Fabric(strategy=DeepSpeedStrategy(stage=3, logging_batch_size_per_gpu=1), devices=2, accelerator="gpu")
+    fabric.launch(run)
 
 
 @RunIf(min_cuda_gpus=1, deepspeed=True)
@@ -118,19 +118,18 @@ def test_deepspeed_multiple_models():
 def test_deepspeed_auto_batch_size_config_select(dataset_cls, logging_batch_size_per_gpu, expected_batch_size):
     """Test to ensure that the batch size is correctly set as expected for deepspeed logging purposes."""
 
-    class RunFabric(Fabric):
-        def run(self):
-            assert isinstance(self._strategy, DeepSpeedStrategy)
-            _ = self.setup_dataloaders(DataLoader(dataset_cls(32, 64)))
-            config = self._strategy.config
-            assert config["train_micro_batch_size_per_gpu"] == expected_batch_size
+    def run(fabric_obj):
+        assert isinstance(fabric_obj._strategy, DeepSpeedStrategy)
+        _ = fabric_obj.setup_dataloaders(DataLoader(dataset_cls(32, 64)))
+        config = fabric_obj._strategy.config
+        assert config["train_micro_batch_size_per_gpu"] == expected_batch_size
 
-    fabric = RunFabric(
+    fabric = Fabric(
         accelerator="cuda",
         devices=1,
         strategy=DeepSpeedStrategy(logging_batch_size_per_gpu=logging_batch_size_per_gpu, zero_optimization=False),
     )
-    fabric.run()
+    fabric.launch(run)
 
 
 @RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True)
@@ -138,21 +137,20 @@ def test_deepspeed_configure_optimizers():
     """Test that the deepspeed strategy with default initialization wraps the optimizer correctly."""
     from deepspeed.runtime.zero.stage_1_and_2 import DeepSpeedZeroOptimizer
 
-    class RunFabric(Fabric):
-        def run(self):
-            model = nn.Linear(3, 3)
-            optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
-            model, optimizer = self.setup(model, optimizer)
-            assert isinstance(optimizer.optimizer, DeepSpeedZeroOptimizer)
-            assert isinstance(optimizer.optimizer.optimizer, torch.optim.SGD)
+    def run(fabric_obj):
+        model = nn.Linear(3, 3)
+        optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
+        model, optimizer = fabric_obj.setup(model, optimizer)
+        assert isinstance(optimizer.optimizer, DeepSpeedZeroOptimizer)
+        assert isinstance(optimizer.optimizer.optimizer, torch.optim.SGD)
 
-    fabric = RunFabric(
+    fabric = Fabric(
         strategy=DeepSpeedStrategy(),
         accelerator="cuda",
         devices=1,
         precision="16-mixed",
     )
-    fabric.run()
+    fabric.launch(run)
 
 
 @RunIf(min_cuda_gpus=1, deepspeed=True)
@@ -160,25 +158,24 @@ def test_deepspeed_custom_precision_params():
     """Test that if the FP16 parameters are set via the DeepSpeedStrategy, the deepspeed config contains these
     changes."""
 
-    class RunFabric(Fabric):
-        def run(self):
-            assert self._strategy._config_initialized
-            assert self._strategy.config["fp16"]["loss_scale"] == 10
-            assert self._strategy.config["fp16"]["initial_scale_power"] == 11
-            assert self._strategy.config["fp16"]["loss_scale_window"] == 12
-            assert self._strategy.config["fp16"]["hysteresis"] == 13
-            assert self._strategy.config["fp16"]["min_loss_scale"] == 14
+    def run(fabric_obj):
+        assert fabric_obj._strategy._config_initialized
+        assert fabric_obj._strategy.config["fp16"]["loss_scale"] == 10
+        assert fabric_obj._strategy.config["fp16"]["initial_scale_power"] == 11
+        assert fabric_obj._strategy.config["fp16"]["loss_scale_window"] == 12
+        assert fabric_obj._strategy.config["fp16"]["hysteresis"] == 13
+        assert fabric_obj._strategy.config["fp16"]["min_loss_scale"] == 14
 
     strategy = DeepSpeedStrategy(
         loss_scale=10, initial_scale_power=11, loss_scale_window=12, hysteresis=13, min_loss_scale=14
     )
-    fabric = RunFabric(
+    fabric = Fabric(
         strategy=strategy,
         precision="16-mixed",
         accelerator="cuda",
         devices=1,
     )
-    fabric.run()
+    fabric.launch(run)
 
 
 @RunIf(min_cuda_gpus=1, standalone=True, deepspeed=True)
@@ -187,21 +184,20 @@ def test_deepspeed_custom_activation_checkpointing_params_forwarded():
     correctly."""
     import deepspeed
 
-    class RunFabric(Fabric):
-        def run(self):
-            model = nn.Linear(3, 3)
-            optimizer = torch.optim.Adam(model.parameters())
+    def run(fabric_obj):
+        model = nn.Linear(3, 3)
+        optimizer = torch.optim.Adam(model.parameters())
 
-            with mock.patch("deepspeed.checkpointing.configure", wraps=deepspeed.checkpointing.configure) as configure:
-                self.setup(model, optimizer)
+        with mock.patch("deepspeed.checkpointing.configure", wraps=deepspeed.checkpointing.configure) as configure:
+            fabric_obj.setup(model, optimizer)
 
-            configure.assert_called_with(
-                mpu_=None,
-                partition_activations=True,
-                contiguous_checkpointing=True,
-                checkpoint_in_cpu=True,
-                profile=None,
-            )
+        configure.assert_called_with(
+            mpu_=None,
+            partition_activations=True,
+            contiguous_checkpointing=True,
+            checkpoint_in_cpu=True,
+            profile=None,
+        )
 
     strategy = DeepSpeedStrategy(
         partition_activations=True,
@@ -209,13 +205,13 @@ def test_deepspeed_custom_activation_checkpointing_params_forwarded():
         contiguous_memory_optimization=True,
         synchronize_checkpoint_boundary=True,
     )
-    fabric = RunFabric(
+    fabric = Fabric(
         strategy=strategy,
         precision="16-mixed",
         accelerator="cuda",
         devices=1,
     )
-    fabric.run()
+    fabric.launch(run)
 
 
 class ModelParallelClassification(BoringFabric):
diff --git a/tests/tests_fabric/test_fabric.py b/tests/tests_fabric/test_fabric.py
index 7e2ceeefe7..71ab649d5e 100644
--- a/tests/tests_fabric/test_fabric.py
+++ b/tests/tests_fabric/test_fabric.py
@@ -55,17 +55,13 @@ class BoringModel(nn.Module):
 def test_run_input_output():
     """Test that the dynamically patched run() method receives the input arguments and returns the result."""
 
-    class RunFabric(Fabric):
-        run_args = ()
-        run_kwargs = {}
+    def run(fabric_obj, *args, **kwargs):
+        fabric_obj.run_args = args
+        fabric_obj.run_kwargs = kwargs
+        return "result"
 
-        def run(self, *args, **kwargs):
-            self.run_args = args
-            self.run_kwargs = kwargs
-            return "result"
-
-    fabric = RunFabric()
-    result = fabric.run(1, 2, three=3)
+    fabric = Fabric()
+    result = fabric.launch(run, 1, 2, three=3)
     assert result == "result"
     assert fabric.run_args == (1, 2)
     assert fabric.run_kwargs == {"three": 3}
@@ -322,12 +318,12 @@ def test_setup_dataloaders_captures_dataloader_arguments(ctx_manager):
     """Test that Fabric intercepts the DataLoader constructor arguments with a context manager in its run
     method."""
 
-    class RunFabric(Fabric):
-        def run(self):
-            # One for BatchSampler, another for DataLoader
-            assert ctx_manager().__enter__.call_count == 2
+    def run(_):
+        # One for BatchSampler, another for DataLoader
+        assert ctx_manager().__enter__.call_count == 2
 
-    RunFabric().run()
+    fabric = Fabric()
+    fabric.launch(run)
     assert ctx_manager().__exit__.call_count == 2
 
 
@@ -538,27 +534,26 @@ def test_to_device(accelerator, expected):
         if not pjrt.using_pjrt():
             expected = "xla:1"
 
-    class RunFabric(Fabric):
-        def run(self):
-            expected_device = torch.device(expected)
+    def run(_):
+        expected_device = torch.device(expected)
 
-            # module
-            module = torch.nn.Linear(2, 3)
-            module = fabric.to_device(module)
-            assert all(param.device == expected_device for param in module.parameters())
+        # module
+        module = torch.nn.Linear(2, 3)
+        module = fabric.to_device(module)
+        assert all(param.device == expected_device for param in module.parameters())
 
-            # tensor
-            tensor = torch.rand(2, 2)
-            tensor = fabric.to_device(tensor)
-            assert tensor.device == expected_device
+        # tensor
+        tensor = torch.rand(2, 2)
+        tensor = fabric.to_device(tensor)
+        assert tensor.device == expected_device
 
-            # collection
-            collection = {"data": torch.rand(2, 2), "int": 1}
-            collection = fabric.to_device(collection)
-            assert collection["data"].device == expected_device
+        # collection
+        collection = {"data": torch.rand(2, 2), "int": 1}
+        collection = fabric.to_device(collection)
+        assert collection["data"].device == expected_device
 
-    fabric = RunFabric(accelerator=accelerator, devices=1)
-    fabric.run()
+    fabric = Fabric(accelerator=accelerator, devices=1)
+    fabric.launch(run)
 
 
 def test_rank_properties():