diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e4436b393..4e189b3424 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- On DP and DDP2 unsqueeze is automated now ([#1319](https://github.com/PyTorchLightning/pytorch-lightning/pull/1319)) - Does not interfere with a default sampler ([#1318](https://github.com/PyTorchLightning/pytorch-lightning/pull/1318)) - Enhanced load_from_checkpoint to also forward params to the model ([#1307](https://github.com/PyTorchLightning/pytorch-lightning/pull/1307)) - Made `evalaute` method private >> `Trainer._evaluate(...)`. ([#1260](https://github.com/PyTorchLightning/pytorch-lightning/pull/1260)) diff --git a/pl_examples/basic_examples/lightning_module_template.py b/pl_examples/basic_examples/lightning_module_template.py index f436db2872..9c65bac9ac 100644 --- a/pl_examples/basic_examples/lightning_module_template.py +++ b/pl_examples/basic_examples/lightning_module_template.py @@ -111,10 +111,6 @@ class LightningTemplateModel(LightningModule): # calculate loss loss_val = self.loss(y, y_hat) - # in DP mode (default) make sure if result is scalar, there's another dim in the beginning - if self.trainer.use_dp or self.trainer.use_ddp2: - loss_val = loss_val.unsqueeze(0) - tqdm_dict = {'train_loss': loss_val} output = OrderedDict({ 'loss': loss_val, @@ -145,11 +141,6 @@ class LightningTemplateModel(LightningModule): if self.on_gpu: val_acc = val_acc.cuda(loss_val.device.index) - # in DP mode (default) make sure if result is scalar, there's another dim in the beginning - if self.trainer.use_dp or self.trainer.use_ddp2: - loss_val = loss_val.unsqueeze(0) - val_acc = val_acc.unsqueeze(0) - output = OrderedDict({ 'val_loss': loss_val, 'val_acc': val_acc, diff --git a/pl_examples/domain_templates/gan.py b/pl_examples/domain_templates/gan.py index 68e6053e7e..3661b3637e 100644 --- a/pl_examples/domain_templates/gan.py +++ b/pl_examples/domain_templates/gan.py @@ -99,10 +99,7 @@ class GAN(LightningModule): if optimizer_idx == 0: # sample noise z = torch.randn(imgs.shape[0], self.hparams.latent_dim) - - # match gpu device (or keep as cpu) - if self.on_gpu: - z = z.cuda(imgs.device.index) + z = z.type_as(imgs) # generate images self.generated_imgs = self(z) @@ -115,8 +112,7 @@ class GAN(LightningModule): # ground truth result (ie: all fake) # put on GPU because we created this tensor inside training_loop valid = torch.ones(imgs.size(0), 1) - if self.on_gpu: - valid = valid.cuda(imgs.device.index) + valid = valid.type_as(imgs) # adversarial loss is binary cross-entropy g_loss = self.adversarial_loss(self.discriminator(self.generated_imgs), valid) @@ -134,15 +130,13 @@ class GAN(LightningModule): # how well can it label as real? valid = torch.ones(imgs.size(0), 1) - if self.on_gpu: - valid = valid.cuda(imgs.device.index) + valid = valid.type_as(imgs) real_loss = self.adversarial_loss(self.discriminator(imgs), valid) # how well can it label as fake? fake = torch.zeros(imgs.size(0), 1) - if self.on_gpu: - fake = fake.cuda(imgs.device.index) + fake = fake.type_as(fake) fake_loss = self.adversarial_loss( self.discriminator(self.generated_imgs.detach()), fake) @@ -174,9 +168,7 @@ class GAN(LightningModule): def on_epoch_end(self): z = torch.randn(8, self.hparams.latent_dim) - # match gpu device (or keep as cpu) - if self.on_gpu: - z = z.cuda(self.last_imgs.device.index) + z = z.type_as(self.last_imgs) # log sampled images sample_imgs = self(z) diff --git a/pl_examples/domain_templates/reinforse_learn_Qnet.py b/pl_examples/domain_templates/reinforse_learn_Qnet.py index 4585c108d5..5a797d7d89 100644 --- a/pl_examples/domain_templates/reinforse_learn_Qnet.py +++ b/pl_examples/domain_templates/reinforse_learn_Qnet.py @@ -277,9 +277,6 @@ class DQNLightning(pl.LightningModule): # calculates training loss loss = self.dqn_mse_loss(batch) - if self.trainer.use_dp or self.trainer.use_ddp2: - loss = loss.unsqueeze(0) - if done: self.total_reward = self.episode_reward self.episode_reward = 0 diff --git a/pl_examples/full_examples/imagenet/imagenet_example.py b/pl_examples/full_examples/imagenet/imagenet_example.py index ad8f90f5a1..52c5cf0642 100644 --- a/pl_examples/full_examples/imagenet/imagenet_example.py +++ b/pl_examples/full_examples/imagenet/imagenet_example.py @@ -46,12 +46,6 @@ class ImageNetLightningModel(LightningModule): loss_val = F.cross_entropy(output, target) acc1, acc5 = self.__accuracy(output, target, topk=(1, 5)) - # in DP mode (default) make sure if result is scalar, there's another dim in the beginning - if self.trainer.use_dp or self.trainer.use_ddp2: - loss_val = loss_val.unsqueeze(0) - acc1 = acc1.unsqueeze(0) - acc5 = acc5.unsqueeze(0) - tqdm_dict = {'train_loss': loss_val} output = OrderedDict({ 'loss': loss_val, @@ -69,12 +63,6 @@ class ImageNetLightningModel(LightningModule): loss_val = F.cross_entropy(output, target) acc1, acc5 = self.__accuracy(output, target, topk=(1, 5)) - # in DP mode (default) make sure if result is scalar, there's another dim in the beginning - if self.trainer.use_dp or self.trainer.use_ddp2: - loss_val = loss_val.unsqueeze(0) - acc1 = acc1.unsqueeze(0) - acc5 = acc5.unsqueeze(0) - output = OrderedDict({ 'val_loss': loss_val, 'val_acc1': acc1, diff --git a/pytorch_lightning/overrides/data_parallel.py b/pytorch_lightning/overrides/data_parallel.py index 6b922f65a5..168cdf7e17 100644 --- a/pytorch_lightning/overrides/data_parallel.py +++ b/pytorch_lightning/overrides/data_parallel.py @@ -163,6 +163,9 @@ def parallel_apply(modules, inputs, kwargs_tup=None, devices=None): # pragma: n else: output = module.validation_step(*input, **kwargs) + + if module.use_dp or module.use_ddp2: + auto_squeeze_dim_zeros(output) # --------------- with lock: @@ -199,3 +202,18 @@ def parallel_apply(modules, inputs, kwargs_tup=None, devices=None): # pragma: n raise output outputs.append(output) return outputs + + +def auto_squeeze_dim_zeros(output): + """ + In DP or DDP2 we need to unsqueeze dim 0 + :param output: + :return: + """ + for k, v in output.items(): + if not isinstance(v, torch.Tensor): + continue + + is_scalar = v.dim() == 0 + if is_scalar: + output[k] = output[k].unsqueeze(0)