Fix Imagenet example (#10179)

This commit is contained in:
Justus Schock 2021-10-29 14:05:05 +02:00 committed by GitHub
parent 4bc73b2b76
commit e2b1967b38
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 19 deletions

View File

@ -575,6 +575,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed
- Fixed imagenet example evaluation ([#10179](https://github.com/PyTorchLightning/pytorch-lightning/pull/10179))
- Fixed an issue with logger outputs not being finalized correctly after prediction runs ([#8685](https://github.com/PyTorchLightning/pytorch-lightning/pull/8685))

View File

@ -99,14 +99,17 @@ class ImageNetLightningModel(LightningModule):
self.log("train_acc5", acc5, on_step=True, on_epoch=True, logger=True)
return loss_train
def validation_step(self, batch, batch_idx):
def eval_step(self, batch, batch_idx, prefix: str):
images, target = batch
output = self(images)
loss_val = F.cross_entropy(output, target)
acc1, acc5 = self.__accuracy(output, target, topk=(1, 5))
self.log("val_loss", loss_val, on_step=True, on_epoch=True)
self.log("val_acc1", acc1, on_step=True, prog_bar=True, on_epoch=True)
self.log("val_acc5", acc5, on_step=True, on_epoch=True)
self.log(f"{prefix}_loss", loss_val, on_step=True, on_epoch=True)
self.log(f"{prefix}_acc1", acc1, on_step=True, prog_bar=True, on_epoch=True)
self.log(f"{prefix}_acc5", acc5, on_step=True, on_epoch=True)
def validation_step(self, batch, batch_idx):
return self.eval_step(batch, batch_idx, "val")
@staticmethod
def __accuracy(output, target, topk=(1,)):
@ -165,21 +168,8 @@ class ImageNetLightningModel(LightningModule):
def test_dataloader(self):
return self.val_dataloader()
def test_step(self, *args, **kwargs):
return self.validation_step(*args, **kwargs)
def test_epoch_end(self, *args, **kwargs):
outputs = self.validation_epoch_end(*args, **kwargs)
def substitute_val_keys(out):
return {k.replace("val", "test"): v for k, v in out.items()}
outputs = {
"test_loss": outputs["val_loss"],
"progress_bar": substitute_val_keys(outputs["progress_bar"]),
"log": substitute_val_keys(outputs["log"]),
}
return outputs
def test_step(self, batch, batch_idx):
return self.eval_step(batch, batch_idx, "test")
@staticmethod
def add_model_specific_args(parent_parser): # pragma: no-cover