From 4722cc0bf0fb8c5c902144f0d72dc55c3335e6d8 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 6 Oct 2020 22:57:56 +0200 Subject: [PATCH] prune Results usage in notebooks (#3911) * notebooks * notebooks --- notebooks/01-mnist-hello-world.ipynb | 13 ++++++------- notebooks/02-datamodules.ipynb | 18 ++++++++---------- .../04-transformers-text-classification.ipynb | 17 +++++++---------- 3 files changed, 21 insertions(+), 27 deletions(-) diff --git a/notebooks/01-mnist-hello-world.ipynb b/notebooks/01-mnist-hello-world.ipynb index c9e81cc990..79bc9ebec9 100644 --- a/notebooks/01-mnist-hello-world.ipynb +++ b/notebooks/01-mnist-hello-world.ipynb @@ -123,7 +123,7 @@ " def training_step(self, batch, batch_nb):\n", " x, y = batch\n", " loss = F.cross_entropy(self(x), y)\n", - " return pl.TrainResult(loss)\n", + " return loss\n", "\n", " def configure_optimizers(self):\n", " return torch.optim.Adam(self.parameters(), lr=0.02)" @@ -250,7 +250,7 @@ " x, y = batch\n", " logits = self(x)\n", " loss = F.nll_loss(logits, y)\n", - " return pl.TrainResult(loss)\n", + " return loss\n", "\n", " def validation_step(self, batch, batch_idx):\n", " x, y = batch\n", @@ -258,12 +258,11 @@ " loss = F.nll_loss(logits, y)\n", " preds = torch.argmax(logits, dim=1)\n", " acc = accuracy(preds, y)\n", - " result = pl.EvalResult(checkpoint_on=loss)\n", "\n", - " # Calling result.log will surface up scalars for you in TensorBoard\n", - " result.log('val_loss', loss, prog_bar=True)\n", - " result.log('val_acc', acc, prog_bar=True)\n", - " return result\n", + " # Calling self.log will surface up scalars for you in TensorBoard\n", + " self.log('val_loss', loss, prog_bar=True)\n", + " self.log('val_acc', acc, prog_bar=True)\n", + " return loss\n", "\n", " def test_step(self, batch, batch_idx):\n", " # Here we just reuse the validation_step for testing\n", diff --git a/notebooks/02-datamodules.ipynb b/notebooks/02-datamodules.ipynb index 53468d2c72..3e027cd304 100644 --- a/notebooks/02-datamodules.ipynb +++ b/notebooks/02-datamodules.ipynb @@ -169,7 +169,7 @@ " x, y = batch\n", " logits = self(x)\n", " loss = F.nll_loss(logits, y)\n", - " return pl.TrainResult(loss)\n", + " return loss\n", "\n", " def validation_step(self, batch, batch_idx):\n", " x, y = batch\n", @@ -177,10 +177,9 @@ " loss = F.nll_loss(logits, y)\n", " preds = torch.argmax(logits, dim=1)\n", " acc = accuracy(preds, y)\n", - " result = pl.EvalResult(checkpoint_on=loss)\n", - " result.log('val_loss', loss, prog_bar=True)\n", - " result.log('val_acc', acc, prog_bar=True)\n", - " return result\n", + " self.log('val_loss', loss, prog_bar=True)\n", + " self.log('val_acc', acc, prog_bar=True)\n", + " return loss\n", "\n", " def configure_optimizers(self):\n", " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", @@ -394,7 +393,7 @@ " x, y = batch\n", " logits = self(x)\n", " loss = F.nll_loss(logits, y)\n", - " return pl.TrainResult(loss)\n", + " return loss\n", "\n", " def validation_step(self, batch, batch_idx):\n", "\n", @@ -403,10 +402,9 @@ " loss = F.nll_loss(logits, y)\n", " preds = torch.argmax(logits, dim=1)\n", " acc = accuracy(preds, y)\n", - " result = pl.EvalResult(checkpoint_on=loss)\n", - " result.log('val_loss', loss, prog_bar=True)\n", - " result.log('val_acc', acc, prog_bar=True)\n", - " return result\n", + " self.log('val_loss', loss, prog_bar=True)\n", + " self.log('val_acc', acc, prog_bar=True)\n", + " return loss\n", "\n", " def configure_optimizers(self):\n", " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", diff --git a/notebooks/04-transformers-text-classification.ipynb b/notebooks/04-transformers-text-classification.ipynb index e92673e66d..ae7424c7d4 100644 --- a/notebooks/04-transformers-text-classification.ipynb +++ b/notebooks/04-transformers-text-classification.ipynb @@ -299,7 +299,7 @@ " def training_step(self, batch, batch_idx):\n", " outputs = self(**batch)\n", " loss = outputs[0]\n", - " return pl.TrainResult(loss)\n", + " return loss\n", "\n", " def validation_step(self, batch, batch_idx, dataloader_idx=0):\n", " outputs = self(**batch)\n", @@ -322,20 +322,17 @@ " preds = torch.cat([x['preds'] for x in output]).detach().cpu().numpy()\n", " labels = torch.cat([x['labels'] for x in output]).detach().cpu().numpy()\n", " loss = torch.stack([x['loss'] for x in output]).mean()\n", - " if i == 0:\n", - " result = pl.EvalResult(checkpoint_on=loss)\n", - " result.log(f'val_loss_{split}', loss, prog_bar=True)\n", + " self.log(f'val_loss_{split}', loss, prog_bar=True)\n", " split_metrics = {f\"{k}_{split}\": v for k, v in self.metric.compute(predictions=preds, references=labels).items()}\n", - " result.log_dict(split_metrics, prog_bar=True)\n", - " return result\n", + " self.log_dict(split_metrics, prog_bar=True)\n", + " return loss\n", "\n", " preds = torch.cat([x['preds'] for x in outputs]).detach().cpu().numpy()\n", " labels = torch.cat([x['labels'] for x in outputs]).detach().cpu().numpy()\n", " loss = torch.stack([x['loss'] for x in outputs]).mean()\n", - " result = pl.EvalResult(checkpoint_on=loss)\n", - " result.log('val_loss', loss, prog_bar=True)\n", - " result.log_dict(self.metric.compute(predictions=preds, references=labels), prog_bar=True)\n", - " return result\n", + " self.log('val_loss', loss, prog_bar=True)\n", + " self.log_dict(self.metric.compute(predictions=preds, references=labels), prog_bar=True)\n", + " return loss\n", "\n", " def setup(self, stage):\n", " if stage == 'fit':\n",