diff --git a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py index d9d7e9d848..6f5e888bad 100644 --- a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py +++ b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py @@ -49,7 +49,8 @@ class ModelParallelBoringModel(BoringModel): self.layer = None def configure_sharded_model(self) -> None: - self.layer = torch.nn.Linear(32, 2) + if self.layer is None: + self.layer = torch.nn.Linear(32, 2) def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: self.configure_sharded_model() @@ -73,7 +74,8 @@ class ModelParallelBoringModelManualOptim(BoringModel): opt.step() def configure_sharded_model(self) -> None: - self.layer = torch.nn.Linear(32, 2) + if self.layer is None: + self.layer = torch.nn.Linear(32, 2) def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: self.configure_sharded_model() @@ -571,11 +573,14 @@ class ModelParallelClassificationModel(LightningModule): self.valid_acc = metric.clone() self.test_acc = metric.clone() + self.model = None + def make_block(self): return nn.Sequential(nn.Linear(32, 32, bias=False), nn.ReLU()) def configure_sharded_model(self) -> None: - self.model = nn.Sequential(*(self.make_block() for x in range(self.num_blocks)), nn.Linear(32, 3)) + if self.model is None: + self.model = nn.Sequential(*(self.make_block() for x in range(self.num_blocks)), nn.Linear(32, 3)) def forward(self, x): x = self.model(x) @@ -892,7 +897,8 @@ def test_deepspeed_multigpu_partial_partition_parameters(tmpdir): self.layer_2 = torch.nn.Linear(32, 32) def configure_sharded_model(self) -> None: - self.layer = torch.nn.Linear(32, 2) + if self.layer is None: + self.layer = torch.nn.Linear(32, 2) def forward(self, x): x = self.layer_2(x)