parent
36f0b5bbd0
commit
73a7cf3c99
|
@ -35,7 +35,7 @@ class ModelSummary(object):
|
|||
out_sizes = []
|
||||
input_ = self.model.example_input_array
|
||||
|
||||
if self.model.use_ddp:
|
||||
if self.model.use_ddp or self.model.use_dp:
|
||||
input_ = input_.cuda(0)
|
||||
|
||||
if self.model.trainer.use_amp:
|
||||
|
|
|
@ -128,8 +128,8 @@ def test_dp_resume():
|
|||
dp_model = new_trainer.model
|
||||
dp_model.eval()
|
||||
|
||||
for dataloader in trainer.get_train_dataloader():
|
||||
run_prediction(dataloader, dp_model, dp=True)
|
||||
dataloader = trainer.get_train_dataloader()
|
||||
run_prediction(dataloader, dp_model, dp=True)
|
||||
|
||||
# new model
|
||||
model = LightningTestModel(hparams)
|
||||
|
|
Loading…
Reference in New Issue