diff --git a/docs/source/examples/example_model.py b/docs/source/examples/example_model.py index f608a02dbe..9f48e6f363 100644 --- a/docs/source/examples/example_model.py +++ b/docs/source/examples/example_model.py @@ -63,6 +63,7 @@ class ExampleModel(RootModule): :param data_batch: :return: """ + pdb.set_trace() # forward pass x, y = data_batch x = x.view(x.size(0), -1) @@ -80,6 +81,7 @@ class ExampleModel(RootModule): :param data_batch: :return: """ + pdb.set_trace() x, y = data_batch x = x.view(x.size(0), -1) y_hat = self.forward(x) diff --git a/pytorch_lightning/models/trainer.py b/pytorch_lightning/models/trainer.py index 1ec1d47fd9..a9c7462b59 100644 --- a/pytorch_lightning/models/trainer.py +++ b/pytorch_lightning/models/trainer.py @@ -5,7 +5,7 @@ from pytorch_lightning.root_module.memory import get_gpu_memory_map import traceback from pytorch_lightning.root_module.model_saving import TrainerIO from torch.optim.lr_scheduler import MultiStepLR -from torch.nn import DataParallel +from pytorch_lightning.pt_overrides.override_data_parallel import LightningDataParallel import pdb try: @@ -250,7 +250,7 @@ class Trainer(TrainerIO): # put on gpu if needed if self.on_gpu: - model = DataParallel(model, device_ids=self.data_parallel_device_ids) + model = LightningDataParallel(model, device_ids=self.data_parallel_device_ids) # run tiny validation to make sure program won't crash during val _ = self.validate(model, self.val_dataloader, max_batches=self.nb_sanity_val_steps)