Fix Imagenet example (#10179)
This commit is contained in:
parent
4bc73b2b76
commit
e2b1967b38
|
@ -575,6 +575,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
### Fixed
|
||||
|
||||
- Fixed imagenet example evaluation ([#10179](https://github.com/PyTorchLightning/pytorch-lightning/pull/10179))
|
||||
|
||||
- Fixed an issue with logger outputs not being finalized correctly after prediction runs ([#8685](https://github.com/PyTorchLightning/pytorch-lightning/pull/8685))
|
||||
|
||||
|
|
|
@ -99,14 +99,17 @@ class ImageNetLightningModel(LightningModule):
|
|||
self.log("train_acc5", acc5, on_step=True, on_epoch=True, logger=True)
|
||||
return loss_train
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
def eval_step(self, batch, batch_idx, prefix: str):
|
||||
images, target = batch
|
||||
output = self(images)
|
||||
loss_val = F.cross_entropy(output, target)
|
||||
acc1, acc5 = self.__accuracy(output, target, topk=(1, 5))
|
||||
self.log("val_loss", loss_val, on_step=True, on_epoch=True)
|
||||
self.log("val_acc1", acc1, on_step=True, prog_bar=True, on_epoch=True)
|
||||
self.log("val_acc5", acc5, on_step=True, on_epoch=True)
|
||||
self.log(f"{prefix}_loss", loss_val, on_step=True, on_epoch=True)
|
||||
self.log(f"{prefix}_acc1", acc1, on_step=True, prog_bar=True, on_epoch=True)
|
||||
self.log(f"{prefix}_acc5", acc5, on_step=True, on_epoch=True)
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
return self.eval_step(batch, batch_idx, "val")
|
||||
|
||||
@staticmethod
|
||||
def __accuracy(output, target, topk=(1,)):
|
||||
|
@ -165,21 +168,8 @@ class ImageNetLightningModel(LightningModule):
|
|||
def test_dataloader(self):
|
||||
return self.val_dataloader()
|
||||
|
||||
def test_step(self, *args, **kwargs):
|
||||
return self.validation_step(*args, **kwargs)
|
||||
|
||||
def test_epoch_end(self, *args, **kwargs):
|
||||
outputs = self.validation_epoch_end(*args, **kwargs)
|
||||
|
||||
def substitute_val_keys(out):
|
||||
return {k.replace("val", "test"): v for k, v in out.items()}
|
||||
|
||||
outputs = {
|
||||
"test_loss": outputs["val_loss"],
|
||||
"progress_bar": substitute_val_keys(outputs["progress_bar"]),
|
||||
"log": substitute_val_keys(outputs["log"]),
|
||||
}
|
||||
return outputs
|
||||
def test_step(self, batch, batch_idx):
|
||||
return self.eval_step(batch, batch_idx, "test")
|
||||
|
||||
@staticmethod
|
||||
def add_model_specific_args(parent_parser): # pragma: no-cover
|
||||
|
|
Loading…
Reference in New Issue