Update to latest logging format and modify the accuracy method. (#4816)
* Update to latest logging format and modify the accuracy method. * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: chaton <thomas@grid.ai>
This commit is contained in:
parent
d3626b7f50
commit
b807c3278d
|
@ -19,7 +19,6 @@ or show all options you can change:
|
|||
"""
|
||||
import os
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from collections import OrderedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
@ -37,7 +36,6 @@ from pytorch_lightning.core import LightningModule
|
|||
|
||||
|
||||
class ImageNetLightningModel(LightningModule):
|
||||
|
||||
# pull out resnet names from torchvision models
|
||||
MODEL_NAMES = sorted(
|
||||
name for name in models.__dict__
|
||||
|
@ -45,16 +43,16 @@ class ImageNetLightningModel(LightningModule):
|
|||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
arch: str,
|
||||
pretrained: bool,
|
||||
lr: float,
|
||||
momentum: float,
|
||||
weight_decay: int,
|
||||
data_path: str,
|
||||
batch_size: int,
|
||||
workers: int,
|
||||
**kwargs,
|
||||
self,
|
||||
arch: str,
|
||||
pretrained: bool,
|
||||
lr: float,
|
||||
momentum: float,
|
||||
weight_decay: int,
|
||||
data_path: str,
|
||||
batch_size: int,
|
||||
workers: int,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.save_hyperparameters()
|
||||
|
@ -74,39 +72,21 @@ class ImageNetLightningModel(LightningModule):
|
|||
def training_step(self, batch, batch_idx):
|
||||
images, target = batch
|
||||
output = self(images)
|
||||
loss_val = F.cross_entropy(output, target)
|
||||
loss_train= F.cross_entropy(output, target)
|
||||
acc1, acc5 = self.__accuracy(output, target, topk=(1, 5))
|
||||
|
||||
tqdm_dict = {'train_loss': loss_val}
|
||||
output = OrderedDict({
|
||||
'loss': loss_val,
|
||||
'acc1': acc1,
|
||||
'acc5': acc5,
|
||||
'progress_bar': tqdm_dict,
|
||||
'log': tqdm_dict
|
||||
})
|
||||
return output
|
||||
self.log('train_loss', loss_train, on_step=True, on_epoch=True, logger=True)
|
||||
self.log('train_acc1', acc1, on_step=True, prog_bar=True, on_epoch=True, logger=True)
|
||||
self.log('train_acc5', acc5, on_step=True, on_epoch=True, logger=True)
|
||||
return loss_val
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
images, target = batch
|
||||
output = self(images)
|
||||
loss_val = F.cross_entropy(output, target)
|
||||
acc1, acc5 = self.__accuracy(output, target, topk=(1, 5))
|
||||
|
||||
output = OrderedDict({
|
||||
'val_loss': loss_val,
|
||||
'val_acc1': acc1,
|
||||
'val_acc5': acc5,
|
||||
})
|
||||
return output
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
tqdm_dict = {}
|
||||
for metric_name in ["val_loss", "val_acc1", "val_acc5"]:
|
||||
tqdm_dict[metric_name] = torch.stack([output[metric_name] for output in outputs]).mean()
|
||||
|
||||
result = {'progress_bar': tqdm_dict, 'log': tqdm_dict, 'val_loss': tqdm_dict["val_loss"]}
|
||||
return result
|
||||
self.log('val_loss', loss_val, on_step=True, on_epoch=True)
|
||||
self.log('val_acc1', acc1, on_step=True, prog_bar=True, on_epoch=True)
|
||||
self.log('val_acc5', acc5, on_step=True, on_epoch=True)
|
||||
|
||||
@staticmethod
|
||||
def __accuracy(output, target, topk=(1,)):
|
||||
|
@ -121,7 +101,7 @@ class ImageNetLightningModel(LightningModule):
|
|||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
||||
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
|
Loading…
Reference in New Issue