updated args
This commit is contained in:
parent
8df13035eb
commit
cf57be9dca
|
@ -63,6 +63,7 @@ class ExampleModel(RootModule):
|
|||
:param data_batch:
|
||||
:return:
|
||||
"""
|
||||
pdb.set_trace()
|
||||
# forward pass
|
||||
x, y = data_batch
|
||||
x = x.view(x.size(0), -1)
|
||||
|
@ -80,6 +81,7 @@ class ExampleModel(RootModule):
|
|||
:param data_batch:
|
||||
:return:
|
||||
"""
|
||||
pdb.set_trace()
|
||||
x, y = data_batch
|
||||
x = x.view(x.size(0), -1)
|
||||
y_hat = self.forward(x)
|
||||
|
|
|
@ -5,7 +5,7 @@ from pytorch_lightning.root_module.memory import get_gpu_memory_map
|
|||
import traceback
|
||||
from pytorch_lightning.root_module.model_saving import TrainerIO
|
||||
from torch.optim.lr_scheduler import MultiStepLR
|
||||
from torch.nn import DataParallel
|
||||
from pytorch_lightning.pt_overrides.override_data_parallel import LightningDataParallel
|
||||
import pdb
|
||||
|
||||
try:
|
||||
|
@ -250,7 +250,7 @@ class Trainer(TrainerIO):
|
|||
|
||||
# put on gpu if needed
|
||||
if self.on_gpu:
|
||||
model = DataParallel(model, device_ids=self.data_parallel_device_ids)
|
||||
model = LightningDataParallel(model, device_ids=self.data_parallel_device_ids)
|
||||
|
||||
# run tiny validation to make sure program won't crash during val
|
||||
_ = self.validate(model, self.val_dataloader, max_batches=self.nb_sanity_val_steps)
|
||||
|
|
Loading…
Reference in New Issue