Removes need to unsqueeze from dp (#1319)
* removes need to unsqueeze from dp * removes need to unsqueeze from dp * fixed examples * added auto unsqueeze * added auto unsqueeze * added auto unsqueeze * added auto unsqueeze * Update pytorch_lightning/overrides/data_parallel.py Co-Authored-By: Adrian Wälchli <adrian.waelchli@students.unibe.ch> * fixed dp parse * fixed dp parse Co-authored-by: Adrian Wälchli <adrian.waelchli@students.unibe.ch>
This commit is contained in:
parent
6b41b5c589
commit
3cb149f4f4
|
@ -25,6 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Changed
|
||||
|
||||
- On DP and DDP2 unsqueeze is automated now ([#1319](https://github.com/PyTorchLightning/pytorch-lightning/pull/1319))
|
||||
- Does not interfere with a default sampler ([#1318](https://github.com/PyTorchLightning/pytorch-lightning/pull/1318))
|
||||
- Enhanced load_from_checkpoint to also forward params to the model ([#1307](https://github.com/PyTorchLightning/pytorch-lightning/pull/1307))
|
||||
- Made `evalaute` method private >> `Trainer._evaluate(...)`. ([#1260](https://github.com/PyTorchLightning/pytorch-lightning/pull/1260))
|
||||
|
|
|
@ -111,10 +111,6 @@ class LightningTemplateModel(LightningModule):
|
|||
# calculate loss
|
||||
loss_val = self.loss(y, y_hat)
|
||||
|
||||
# in DP mode (default) make sure if result is scalar, there's another dim in the beginning
|
||||
if self.trainer.use_dp or self.trainer.use_ddp2:
|
||||
loss_val = loss_val.unsqueeze(0)
|
||||
|
||||
tqdm_dict = {'train_loss': loss_val}
|
||||
output = OrderedDict({
|
||||
'loss': loss_val,
|
||||
|
@ -145,11 +141,6 @@ class LightningTemplateModel(LightningModule):
|
|||
if self.on_gpu:
|
||||
val_acc = val_acc.cuda(loss_val.device.index)
|
||||
|
||||
# in DP mode (default) make sure if result is scalar, there's another dim in the beginning
|
||||
if self.trainer.use_dp or self.trainer.use_ddp2:
|
||||
loss_val = loss_val.unsqueeze(0)
|
||||
val_acc = val_acc.unsqueeze(0)
|
||||
|
||||
output = OrderedDict({
|
||||
'val_loss': loss_val,
|
||||
'val_acc': val_acc,
|
||||
|
|
|
@ -99,10 +99,7 @@ class GAN(LightningModule):
|
|||
if optimizer_idx == 0:
|
||||
# sample noise
|
||||
z = torch.randn(imgs.shape[0], self.hparams.latent_dim)
|
||||
|
||||
# match gpu device (or keep as cpu)
|
||||
if self.on_gpu:
|
||||
z = z.cuda(imgs.device.index)
|
||||
z = z.type_as(imgs)
|
||||
|
||||
# generate images
|
||||
self.generated_imgs = self(z)
|
||||
|
@ -115,8 +112,7 @@ class GAN(LightningModule):
|
|||
# ground truth result (ie: all fake)
|
||||
# put on GPU because we created this tensor inside training_loop
|
||||
valid = torch.ones(imgs.size(0), 1)
|
||||
if self.on_gpu:
|
||||
valid = valid.cuda(imgs.device.index)
|
||||
valid = valid.type_as(imgs)
|
||||
|
||||
# adversarial loss is binary cross-entropy
|
||||
g_loss = self.adversarial_loss(self.discriminator(self.generated_imgs), valid)
|
||||
|
@ -134,15 +130,13 @@ class GAN(LightningModule):
|
|||
|
||||
# how well can it label as real?
|
||||
valid = torch.ones(imgs.size(0), 1)
|
||||
if self.on_gpu:
|
||||
valid = valid.cuda(imgs.device.index)
|
||||
valid = valid.type_as(imgs)
|
||||
|
||||
real_loss = self.adversarial_loss(self.discriminator(imgs), valid)
|
||||
|
||||
# how well can it label as fake?
|
||||
fake = torch.zeros(imgs.size(0), 1)
|
||||
if self.on_gpu:
|
||||
fake = fake.cuda(imgs.device.index)
|
||||
fake = fake.type_as(fake)
|
||||
|
||||
fake_loss = self.adversarial_loss(
|
||||
self.discriminator(self.generated_imgs.detach()), fake)
|
||||
|
@ -174,9 +168,7 @@ class GAN(LightningModule):
|
|||
|
||||
def on_epoch_end(self):
|
||||
z = torch.randn(8, self.hparams.latent_dim)
|
||||
# match gpu device (or keep as cpu)
|
||||
if self.on_gpu:
|
||||
z = z.cuda(self.last_imgs.device.index)
|
||||
z = z.type_as(self.last_imgs)
|
||||
|
||||
# log sampled images
|
||||
sample_imgs = self(z)
|
||||
|
|
|
@ -277,9 +277,6 @@ class DQNLightning(pl.LightningModule):
|
|||
# calculates training loss
|
||||
loss = self.dqn_mse_loss(batch)
|
||||
|
||||
if self.trainer.use_dp or self.trainer.use_ddp2:
|
||||
loss = loss.unsqueeze(0)
|
||||
|
||||
if done:
|
||||
self.total_reward = self.episode_reward
|
||||
self.episode_reward = 0
|
||||
|
|
|
@ -46,12 +46,6 @@ class ImageNetLightningModel(LightningModule):
|
|||
loss_val = F.cross_entropy(output, target)
|
||||
acc1, acc5 = self.__accuracy(output, target, topk=(1, 5))
|
||||
|
||||
# in DP mode (default) make sure if result is scalar, there's another dim in the beginning
|
||||
if self.trainer.use_dp or self.trainer.use_ddp2:
|
||||
loss_val = loss_val.unsqueeze(0)
|
||||
acc1 = acc1.unsqueeze(0)
|
||||
acc5 = acc5.unsqueeze(0)
|
||||
|
||||
tqdm_dict = {'train_loss': loss_val}
|
||||
output = OrderedDict({
|
||||
'loss': loss_val,
|
||||
|
@ -69,12 +63,6 @@ class ImageNetLightningModel(LightningModule):
|
|||
loss_val = F.cross_entropy(output, target)
|
||||
acc1, acc5 = self.__accuracy(output, target, topk=(1, 5))
|
||||
|
||||
# in DP mode (default) make sure if result is scalar, there's another dim in the beginning
|
||||
if self.trainer.use_dp or self.trainer.use_ddp2:
|
||||
loss_val = loss_val.unsqueeze(0)
|
||||
acc1 = acc1.unsqueeze(0)
|
||||
acc5 = acc5.unsqueeze(0)
|
||||
|
||||
output = OrderedDict({
|
||||
'val_loss': loss_val,
|
||||
'val_acc1': acc1,
|
||||
|
|
|
@ -163,6 +163,9 @@ def parallel_apply(modules, inputs, kwargs_tup=None, devices=None): # pragma: n
|
|||
|
||||
else:
|
||||
output = module.validation_step(*input, **kwargs)
|
||||
|
||||
if module.use_dp or module.use_ddp2:
|
||||
auto_squeeze_dim_zeros(output)
|
||||
# ---------------
|
||||
|
||||
with lock:
|
||||
|
@ -199,3 +202,18 @@ def parallel_apply(modules, inputs, kwargs_tup=None, devices=None): # pragma: n
|
|||
raise output
|
||||
outputs.append(output)
|
||||
return outputs
|
||||
|
||||
|
||||
def auto_squeeze_dim_zeros(output):
|
||||
"""
|
||||
In DP or DDP2 we need to unsqueeze dim 0
|
||||
:param output:
|
||||
:return:
|
||||
"""
|
||||
for k, v in output.items():
|
||||
if not isinstance(v, torch.Tensor):
|
||||
continue
|
||||
|
||||
is_scalar = v.dim() == 0
|
||||
if is_scalar:
|
||||
output[k] = output[k].unsqueeze(0)
|
||||
|
|
Loading…
Reference in New Issue