updated args

This commit is contained in:
William Falcon 2019-06-25 19:43:25 -04:00
parent 8df13035eb
commit cf57be9dca
2 changed files with 4 additions and 2 deletions

View File

@ -63,6 +63,7 @@ class ExampleModel(RootModule):
:param data_batch:
:return:
"""
pdb.set_trace()
# forward pass
x, y = data_batch
x = x.view(x.size(0), -1)
@ -80,6 +81,7 @@ class ExampleModel(RootModule):
:param data_batch:
:return:
"""
pdb.set_trace()
x, y = data_batch
x = x.view(x.size(0), -1)
y_hat = self.forward(x)

View File

@ -5,7 +5,7 @@ from pytorch_lightning.root_module.memory import get_gpu_memory_map
import traceback
from pytorch_lightning.root_module.model_saving import TrainerIO
from torch.optim.lr_scheduler import MultiStepLR
from torch.nn import DataParallel
from pytorch_lightning.pt_overrides.override_data_parallel import LightningDataParallel
import pdb
try:
@ -250,7 +250,7 @@ class Trainer(TrainerIO):
# put on gpu if needed
if self.on_gpu:
model = DataParallel(model, device_ids=self.data_parallel_device_ids)
model = LightningDataParallel(model, device_ids=self.data_parallel_device_ids)
# run tiny validation to make sure program won't crash during val
_ = self.validate(model, self.val_dataloader, max_batches=self.nb_sanity_val_steps)