Make `configure_sharded_model` implementation in test models idempotent (#17625)

This commit is contained in:
Adrian Wälchli 2023-05-17 23:51:27 +02:00 committed by GitHub
parent a37f5a546c
commit ccdd563bd5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 10 additions and 4 deletions

View File

@ -49,7 +49,8 @@ class ModelParallelBoringModel(BoringModel):
self.layer = None
def configure_sharded_model(self) -> None:
self.layer = torch.nn.Linear(32, 2)
if self.layer is None:
self.layer = torch.nn.Linear(32, 2)
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
self.configure_sharded_model()
@ -73,7 +74,8 @@ class ModelParallelBoringModelManualOptim(BoringModel):
opt.step()
def configure_sharded_model(self) -> None:
self.layer = torch.nn.Linear(32, 2)
if self.layer is None:
self.layer = torch.nn.Linear(32, 2)
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
self.configure_sharded_model()
@ -571,11 +573,14 @@ class ModelParallelClassificationModel(LightningModule):
self.valid_acc = metric.clone()
self.test_acc = metric.clone()
self.model = None
def make_block(self):
return nn.Sequential(nn.Linear(32, 32, bias=False), nn.ReLU())
def configure_sharded_model(self) -> None:
self.model = nn.Sequential(*(self.make_block() for x in range(self.num_blocks)), nn.Linear(32, 3))
if self.model is None:
self.model = nn.Sequential(*(self.make_block() for x in range(self.num_blocks)), nn.Linear(32, 3))
def forward(self, x):
x = self.model(x)
@ -892,7 +897,8 @@ def test_deepspeed_multigpu_partial_partition_parameters(tmpdir):
self.layer_2 = torch.nn.Linear(32, 32)
def configure_sharded_model(self) -> None:
self.layer = torch.nn.Linear(32, 2)
if self.layer is None:
self.layer = torch.nn.Linear(32, 2)
def forward(self, x):
x = self.layer_2(x)