prune Results usage in notebooks (#3911)

* notebooks

* notebooks
This commit is contained in:
Jirka Borovec 2020-10-06 22:57:56 +02:00 committed by GitHub
parent c510a7f900
commit 4722cc0bf0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 21 additions and 27 deletions

View File

@ -123,7 +123,7 @@
" def training_step(self, batch, batch_nb):\n", " def training_step(self, batch, batch_nb):\n",
" x, y = batch\n", " x, y = batch\n",
" loss = F.cross_entropy(self(x), y)\n", " loss = F.cross_entropy(self(x), y)\n",
" return pl.TrainResult(loss)\n", " return loss\n",
"\n", "\n",
" def configure_optimizers(self):\n", " def configure_optimizers(self):\n",
" return torch.optim.Adam(self.parameters(), lr=0.02)" " return torch.optim.Adam(self.parameters(), lr=0.02)"
@ -250,7 +250,7 @@
" x, y = batch\n", " x, y = batch\n",
" logits = self(x)\n", " logits = self(x)\n",
" loss = F.nll_loss(logits, y)\n", " loss = F.nll_loss(logits, y)\n",
" return pl.TrainResult(loss)\n", " return loss\n",
"\n", "\n",
" def validation_step(self, batch, batch_idx):\n", " def validation_step(self, batch, batch_idx):\n",
" x, y = batch\n", " x, y = batch\n",
@ -258,12 +258,11 @@
" loss = F.nll_loss(logits, y)\n", " loss = F.nll_loss(logits, y)\n",
" preds = torch.argmax(logits, dim=1)\n", " preds = torch.argmax(logits, dim=1)\n",
" acc = accuracy(preds, y)\n", " acc = accuracy(preds, y)\n",
" result = pl.EvalResult(checkpoint_on=loss)\n",
"\n", "\n",
" # Calling result.log will surface up scalars for you in TensorBoard\n", " # Calling self.log will surface up scalars for you in TensorBoard\n",
" result.log('val_loss', loss, prog_bar=True)\n", " self.log('val_loss', loss, prog_bar=True)\n",
" result.log('val_acc', acc, prog_bar=True)\n", " self.log('val_acc', acc, prog_bar=True)\n",
" return result\n", " return loss\n",
"\n", "\n",
" def test_step(self, batch, batch_idx):\n", " def test_step(self, batch, batch_idx):\n",
" # Here we just reuse the validation_step for testing\n", " # Here we just reuse the validation_step for testing\n",

View File

@ -169,7 +169,7 @@
" x, y = batch\n", " x, y = batch\n",
" logits = self(x)\n", " logits = self(x)\n",
" loss = F.nll_loss(logits, y)\n", " loss = F.nll_loss(logits, y)\n",
" return pl.TrainResult(loss)\n", " return loss\n",
"\n", "\n",
" def validation_step(self, batch, batch_idx):\n", " def validation_step(self, batch, batch_idx):\n",
" x, y = batch\n", " x, y = batch\n",
@ -177,10 +177,9 @@
" loss = F.nll_loss(logits, y)\n", " loss = F.nll_loss(logits, y)\n",
" preds = torch.argmax(logits, dim=1)\n", " preds = torch.argmax(logits, dim=1)\n",
" acc = accuracy(preds, y)\n", " acc = accuracy(preds, y)\n",
" result = pl.EvalResult(checkpoint_on=loss)\n", " self.log('val_loss', loss, prog_bar=True)\n",
" result.log('val_loss', loss, prog_bar=True)\n", " self.log('val_acc', acc, prog_bar=True)\n",
" result.log('val_acc', acc, prog_bar=True)\n", " return loss\n",
" return result\n",
"\n", "\n",
" def configure_optimizers(self):\n", " def configure_optimizers(self):\n",
" optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",
@ -394,7 +393,7 @@
" x, y = batch\n", " x, y = batch\n",
" logits = self(x)\n", " logits = self(x)\n",
" loss = F.nll_loss(logits, y)\n", " loss = F.nll_loss(logits, y)\n",
" return pl.TrainResult(loss)\n", " return loss\n",
"\n", "\n",
" def validation_step(self, batch, batch_idx):\n", " def validation_step(self, batch, batch_idx):\n",
"\n", "\n",
@ -403,10 +402,9 @@
" loss = F.nll_loss(logits, y)\n", " loss = F.nll_loss(logits, y)\n",
" preds = torch.argmax(logits, dim=1)\n", " preds = torch.argmax(logits, dim=1)\n",
" acc = accuracy(preds, y)\n", " acc = accuracy(preds, y)\n",
" result = pl.EvalResult(checkpoint_on=loss)\n", " self.log('val_loss', loss, prog_bar=True)\n",
" result.log('val_loss', loss, prog_bar=True)\n", " self.log('val_acc', acc, prog_bar=True)\n",
" result.log('val_acc', acc, prog_bar=True)\n", " return loss\n",
" return result\n",
"\n", "\n",
" def configure_optimizers(self):\n", " def configure_optimizers(self):\n",
" optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n", " optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)\n",

View File

@ -299,7 +299,7 @@
" def training_step(self, batch, batch_idx):\n", " def training_step(self, batch, batch_idx):\n",
" outputs = self(**batch)\n", " outputs = self(**batch)\n",
" loss = outputs[0]\n", " loss = outputs[0]\n",
" return pl.TrainResult(loss)\n", " return loss\n",
"\n", "\n",
" def validation_step(self, batch, batch_idx, dataloader_idx=0):\n", " def validation_step(self, batch, batch_idx, dataloader_idx=0):\n",
" outputs = self(**batch)\n", " outputs = self(**batch)\n",
@ -322,20 +322,17 @@
" preds = torch.cat([x['preds'] for x in output]).detach().cpu().numpy()\n", " preds = torch.cat([x['preds'] for x in output]).detach().cpu().numpy()\n",
" labels = torch.cat([x['labels'] for x in output]).detach().cpu().numpy()\n", " labels = torch.cat([x['labels'] for x in output]).detach().cpu().numpy()\n",
" loss = torch.stack([x['loss'] for x in output]).mean()\n", " loss = torch.stack([x['loss'] for x in output]).mean()\n",
" if i == 0:\n", " self.log(f'val_loss_{split}', loss, prog_bar=True)\n",
" result = pl.EvalResult(checkpoint_on=loss)\n",
" result.log(f'val_loss_{split}', loss, prog_bar=True)\n",
" split_metrics = {f\"{k}_{split}\": v for k, v in self.metric.compute(predictions=preds, references=labels).items()}\n", " split_metrics = {f\"{k}_{split}\": v for k, v in self.metric.compute(predictions=preds, references=labels).items()}\n",
" result.log_dict(split_metrics, prog_bar=True)\n", " self.log_dict(split_metrics, prog_bar=True)\n",
" return result\n", " return loss\n",
"\n", "\n",
" preds = torch.cat([x['preds'] for x in outputs]).detach().cpu().numpy()\n", " preds = torch.cat([x['preds'] for x in outputs]).detach().cpu().numpy()\n",
" labels = torch.cat([x['labels'] for x in outputs]).detach().cpu().numpy()\n", " labels = torch.cat([x['labels'] for x in outputs]).detach().cpu().numpy()\n",
" loss = torch.stack([x['loss'] for x in outputs]).mean()\n", " loss = torch.stack([x['loss'] for x in outputs]).mean()\n",
" result = pl.EvalResult(checkpoint_on=loss)\n", " self.log('val_loss', loss, prog_bar=True)\n",
" result.log('val_loss', loss, prog_bar=True)\n", " self.log_dict(self.metric.compute(predictions=preds, references=labels), prog_bar=True)\n",
" result.log_dict(self.metric.compute(predictions=preds, references=labels), prog_bar=True)\n", " return loss\n",
" return result\n",
"\n", "\n",
" def setup(self, stage):\n", " def setup(self, stage):\n",
" if stage == 'fit':\n", " if stage == 'fit':\n",