From e2b1967b3846dd5966ab6809b00cf0d40eeec7bf Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Fri, 29 Oct 2021 14:05:05 +0200 Subject: [PATCH] Fix Imagenet example (#10179) --- CHANGELOG.md | 1 + pl_examples/domain_templates/imagenet.py | 28 ++++++++---------------- 2 files changed, 10 insertions(+), 19 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 89fcc8ace1..286d64850f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -575,6 +575,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed imagenet example evaluation ([#10179](https://github.com/PyTorchLightning/pytorch-lightning/pull/10179)) - Fixed an issue with logger outputs not being finalized correctly after prediction runs ([#8685](https://github.com/PyTorchLightning/pytorch-lightning/pull/8685)) diff --git a/pl_examples/domain_templates/imagenet.py b/pl_examples/domain_templates/imagenet.py index baefc7c944..7f57d6ff02 100644 --- a/pl_examples/domain_templates/imagenet.py +++ b/pl_examples/domain_templates/imagenet.py @@ -99,14 +99,17 @@ class ImageNetLightningModel(LightningModule): self.log("train_acc5", acc5, on_step=True, on_epoch=True, logger=True) return loss_train - def validation_step(self, batch, batch_idx): + def eval_step(self, batch, batch_idx, prefix: str): images, target = batch output = self(images) loss_val = F.cross_entropy(output, target) acc1, acc5 = self.__accuracy(output, target, topk=(1, 5)) - self.log("val_loss", loss_val, on_step=True, on_epoch=True) - self.log("val_acc1", acc1, on_step=True, prog_bar=True, on_epoch=True) - self.log("val_acc5", acc5, on_step=True, on_epoch=True) + self.log(f"{prefix}_loss", loss_val, on_step=True, on_epoch=True) + self.log(f"{prefix}_acc1", acc1, on_step=True, prog_bar=True, on_epoch=True) + self.log(f"{prefix}_acc5", acc5, on_step=True, on_epoch=True) + + def validation_step(self, batch, batch_idx): + return self.eval_step(batch, batch_idx, "val") @staticmethod def __accuracy(output, target, topk=(1,)): @@ -165,21 +168,8 @@ class ImageNetLightningModel(LightningModule): def test_dataloader(self): return self.val_dataloader() - def test_step(self, *args, **kwargs): - return self.validation_step(*args, **kwargs) - - def test_epoch_end(self, *args, **kwargs): - outputs = self.validation_epoch_end(*args, **kwargs) - - def substitute_val_keys(out): - return {k.replace("val", "test"): v for k, v in out.items()} - - outputs = { - "test_loss": outputs["val_loss"], - "progress_bar": substitute_val_keys(outputs["progress_bar"]), - "log": substitute_val_keys(outputs["log"]), - } - return outputs + def test_step(self, batch, batch_idx): + return self.eval_step(batch, batch_idx, "test") @staticmethod def add_model_specific_args(parent_parser): # pragma: no-cover