Mem crash (#299)

* fixes memory crash

* fixes memory crash
This commit is contained in:
William Falcon 2019-10-04 15:53:44 -04:00 committed by GitHub
parent 36f0b5bbd0
commit 73a7cf3c99
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 3 deletions

View File

@ -35,7 +35,7 @@ class ModelSummary(object):
out_sizes = []
input_ = self.model.example_input_array
if self.model.use_ddp:
if self.model.use_ddp or self.model.use_dp:
input_ = input_.cuda(0)
if self.model.trainer.use_amp:

View File

@ -128,8 +128,8 @@ def test_dp_resume():
dp_model = new_trainer.model
dp_model.eval()
for dataloader in trainer.get_train_dataloader():
run_prediction(dataloader, dp_model, dp=True)
dataloader = trainer.get_train_dataloader()
run_prediction(dataloader, dp_model, dp=True)
# new model
model = LightningTestModel(hparams)