Removes need to unsqueeze from dp (#1319)

* removes need to unsqueeze from dp

* removes need to unsqueeze from dp

* fixed examples

* added auto unsqueeze

* added auto unsqueeze

* added auto unsqueeze

* added auto unsqueeze

* Update pytorch_lightning/overrides/data_parallel.py

Co-Authored-By: Adrian Wälchli <adrian.waelchli@students.unibe.ch>

* fixed dp parse

* fixed dp parse

Co-authored-by: Adrian Wälchli <adrian.waelchli@students.unibe.ch>
This commit is contained in:
William Falcon 2020-04-02 11:46:20 -04:00 committed by GitHub
parent 6b41b5c589
commit 3cb149f4f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 24 additions and 37 deletions

View File

@ -25,6 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Changed
- On DP and DDP2 unsqueeze is automated now ([#1319](https://github.com/PyTorchLightning/pytorch-lightning/pull/1319))
- Does not interfere with a default sampler ([#1318](https://github.com/PyTorchLightning/pytorch-lightning/pull/1318))
- Enhanced load_from_checkpoint to also forward params to the model ([#1307](https://github.com/PyTorchLightning/pytorch-lightning/pull/1307))
- Made `evalaute` method private >> `Trainer._evaluate(...)`. ([#1260](https://github.com/PyTorchLightning/pytorch-lightning/pull/1260))

View File

@ -111,10 +111,6 @@ class LightningTemplateModel(LightningModule):
# calculate loss
loss_val = self.loss(y, y_hat)
# in DP mode (default) make sure if result is scalar, there's another dim in the beginning
if self.trainer.use_dp or self.trainer.use_ddp2:
loss_val = loss_val.unsqueeze(0)
tqdm_dict = {'train_loss': loss_val}
output = OrderedDict({
'loss': loss_val,
@ -145,11 +141,6 @@ class LightningTemplateModel(LightningModule):
if self.on_gpu:
val_acc = val_acc.cuda(loss_val.device.index)
# in DP mode (default) make sure if result is scalar, there's another dim in the beginning
if self.trainer.use_dp or self.trainer.use_ddp2:
loss_val = loss_val.unsqueeze(0)
val_acc = val_acc.unsqueeze(0)
output = OrderedDict({
'val_loss': loss_val,
'val_acc': val_acc,

View File

@ -99,10 +99,7 @@ class GAN(LightningModule):
if optimizer_idx == 0:
# sample noise
z = torch.randn(imgs.shape[0], self.hparams.latent_dim)
# match gpu device (or keep as cpu)
if self.on_gpu:
z = z.cuda(imgs.device.index)
z = z.type_as(imgs)
# generate images
self.generated_imgs = self(z)
@ -115,8 +112,7 @@ class GAN(LightningModule):
# ground truth result (ie: all fake)
# put on GPU because we created this tensor inside training_loop
valid = torch.ones(imgs.size(0), 1)
if self.on_gpu:
valid = valid.cuda(imgs.device.index)
valid = valid.type_as(imgs)
# adversarial loss is binary cross-entropy
g_loss = self.adversarial_loss(self.discriminator(self.generated_imgs), valid)
@ -134,15 +130,13 @@ class GAN(LightningModule):
# how well can it label as real?
valid = torch.ones(imgs.size(0), 1)
if self.on_gpu:
valid = valid.cuda(imgs.device.index)
valid = valid.type_as(imgs)
real_loss = self.adversarial_loss(self.discriminator(imgs), valid)
# how well can it label as fake?
fake = torch.zeros(imgs.size(0), 1)
if self.on_gpu:
fake = fake.cuda(imgs.device.index)
fake = fake.type_as(fake)
fake_loss = self.adversarial_loss(
self.discriminator(self.generated_imgs.detach()), fake)
@ -174,9 +168,7 @@ class GAN(LightningModule):
def on_epoch_end(self):
z = torch.randn(8, self.hparams.latent_dim)
# match gpu device (or keep as cpu)
if self.on_gpu:
z = z.cuda(self.last_imgs.device.index)
z = z.type_as(self.last_imgs)
# log sampled images
sample_imgs = self(z)

View File

@ -277,9 +277,6 @@ class DQNLightning(pl.LightningModule):
# calculates training loss
loss = self.dqn_mse_loss(batch)
if self.trainer.use_dp or self.trainer.use_ddp2:
loss = loss.unsqueeze(0)
if done:
self.total_reward = self.episode_reward
self.episode_reward = 0

View File

@ -46,12 +46,6 @@ class ImageNetLightningModel(LightningModule):
loss_val = F.cross_entropy(output, target)
acc1, acc5 = self.__accuracy(output, target, topk=(1, 5))
# in DP mode (default) make sure if result is scalar, there's another dim in the beginning
if self.trainer.use_dp or self.trainer.use_ddp2:
loss_val = loss_val.unsqueeze(0)
acc1 = acc1.unsqueeze(0)
acc5 = acc5.unsqueeze(0)
tqdm_dict = {'train_loss': loss_val}
output = OrderedDict({
'loss': loss_val,
@ -69,12 +63,6 @@ class ImageNetLightningModel(LightningModule):
loss_val = F.cross_entropy(output, target)
acc1, acc5 = self.__accuracy(output, target, topk=(1, 5))
# in DP mode (default) make sure if result is scalar, there's another dim in the beginning
if self.trainer.use_dp or self.trainer.use_ddp2:
loss_val = loss_val.unsqueeze(0)
acc1 = acc1.unsqueeze(0)
acc5 = acc5.unsqueeze(0)
output = OrderedDict({
'val_loss': loss_val,
'val_acc1': acc1,

View File

@ -163,6 +163,9 @@ def parallel_apply(modules, inputs, kwargs_tup=None, devices=None): # pragma: n
else:
output = module.validation_step(*input, **kwargs)
if module.use_dp or module.use_ddp2:
auto_squeeze_dim_zeros(output)
# ---------------
with lock:
@ -199,3 +202,18 @@ def parallel_apply(modules, inputs, kwargs_tup=None, devices=None): # pragma: n
raise output
outputs.append(output)
return outputs
def auto_squeeze_dim_zeros(output):
"""
In DP or DDP2 we need to unsqueeze dim 0
:param output:
:return:
"""
for k, v in output.items():
if not isinstance(v, torch.Tensor):
continue
is_scalar = v.dim() == 0
if is_scalar:
output[k] = output[k].unsqueeze(0)