Update to latest logging format and modify the accuracy method. (#4816)

* Update to latest logging format and modify the accuracy method.

* Apply suggestions from code review

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: chaton <thomas@grid.ai>
This commit is contained in:
Limber Cheng 2020-12-04 22:30:51 +08:00 committed by GitHub
parent d3626b7f50
commit b807c3278d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 19 additions and 39 deletions

View File

@ -19,7 +19,6 @@ or show all options you can change:
"""
import os
from argparse import ArgumentParser, Namespace
from collections import OrderedDict
import torch
import torch.nn.functional as F
@ -37,7 +36,6 @@ from pytorch_lightning.core import LightningModule
class ImageNetLightningModel(LightningModule):
# pull out resnet names from torchvision models
MODEL_NAMES = sorted(
name for name in models.__dict__
@ -45,16 +43,16 @@ class ImageNetLightningModel(LightningModule):
)
def __init__(
self,
arch: str,
pretrained: bool,
lr: float,
momentum: float,
weight_decay: int,
data_path: str,
batch_size: int,
workers: int,
**kwargs,
self,
arch: str,
pretrained: bool,
lr: float,
momentum: float,
weight_decay: int,
data_path: str,
batch_size: int,
workers: int,
**kwargs,
):
super().__init__()
self.save_hyperparameters()
@ -74,39 +72,21 @@ class ImageNetLightningModel(LightningModule):
def training_step(self, batch, batch_idx):
images, target = batch
output = self(images)
loss_val = F.cross_entropy(output, target)
loss_train= F.cross_entropy(output, target)
acc1, acc5 = self.__accuracy(output, target, topk=(1, 5))
tqdm_dict = {'train_loss': loss_val}
output = OrderedDict({
'loss': loss_val,
'acc1': acc1,
'acc5': acc5,
'progress_bar': tqdm_dict,
'log': tqdm_dict
})
return output
self.log('train_loss', loss_train, on_step=True, on_epoch=True, logger=True)
self.log('train_acc1', acc1, on_step=True, prog_bar=True, on_epoch=True, logger=True)
self.log('train_acc5', acc5, on_step=True, on_epoch=True, logger=True)
return loss_val
def validation_step(self, batch, batch_idx):
images, target = batch
output = self(images)
loss_val = F.cross_entropy(output, target)
acc1, acc5 = self.__accuracy(output, target, topk=(1, 5))
output = OrderedDict({
'val_loss': loss_val,
'val_acc1': acc1,
'val_acc5': acc5,
})
return output
def validation_epoch_end(self, outputs):
tqdm_dict = {}
for metric_name in ["val_loss", "val_acc1", "val_acc5"]:
tqdm_dict[metric_name] = torch.stack([output[metric_name] for output in outputs]).mean()
result = {'progress_bar': tqdm_dict, 'log': tqdm_dict, 'val_loss': tqdm_dict["val_loss"]}
return result
self.log('val_loss', loss_val, on_step=True, on_epoch=True)
self.log('val_acc1', acc1, on_step=True, prog_bar=True, on_epoch=True)
self.log('val_acc5', acc5, on_step=True, on_epoch=True)
@staticmethod
def __accuracy(output, target, topk=(1,)):
@ -121,7 +101,7 @@ class ImageNetLightningModel(LightningModule):
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res