[bugfix] Logging only on `not should_accumulate()` during training (#5417)
* resolve bug * resolve tests * update * Update tests/loggers/test_tensorboard.py Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com> Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
This commit is contained in:
parent
f2e99d617f
commit
a053d758d0
|
@ -11,8 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from copy import deepcopy
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from pprint import pprint
|
||||
from typing import Iterable, Union
|
||||
|
||||
|
@ -158,7 +158,7 @@ class LoggerConnector:
|
|||
self.logged_metrics.update(logged_metrics_tmp)
|
||||
self.cached_results.legacy_batch_log_metrics.update(logged_metrics_tmp)
|
||||
|
||||
def log_metrics(self, metrics, grad_norm_dic, step=None, log_train_step_metrics=False):
|
||||
def log_metrics(self, metrics, grad_norm_dic, step=None):
|
||||
"""Logs the metric dict passed in.
|
||||
If `step` parameter is None and `step` key is presented is metrics,
|
||||
uses metrics["step"] as a step
|
||||
|
@ -186,11 +186,8 @@ class LoggerConnector:
|
|||
|
||||
elif step is None:
|
||||
# added metrics by Lightning for convenience
|
||||
if log_train_step_metrics:
|
||||
step = self.trainer.total_batch_idx
|
||||
else:
|
||||
scalar_metrics['epoch'] = self.trainer.current_epoch
|
||||
step = self.trainer.global_step
|
||||
scalar_metrics['epoch'] = self.trainer.current_epoch
|
||||
step = self.trainer.global_step
|
||||
|
||||
# log actual metrics
|
||||
if self.trainer.logger is not None:
|
||||
|
@ -593,6 +590,8 @@ class LoggerConnector:
|
|||
return gathered_epoch_outputs
|
||||
|
||||
def log_train_step_metrics(self, batch_output):
|
||||
if self.trainer.train_loop.should_accumulate() and self.trainer.train_loop.automatic_optimization:
|
||||
return
|
||||
_, batch_log_metrics = self.cached_results.update_logger_connector()
|
||||
# when metrics should be logged
|
||||
if self.should_update_logs or self.trainer.fast_dev_run is True:
|
||||
|
@ -601,5 +600,5 @@ class LoggerConnector:
|
|||
if grad_norm_dic is None:
|
||||
grad_norm_dic = {}
|
||||
if len(batch_log_metrics) > 0 or len(grad_norm_dic) > 0:
|
||||
self.log_metrics(batch_log_metrics, grad_norm_dic, log_train_step_metrics=True)
|
||||
self.log_metrics(batch_log_metrics, grad_norm_dic)
|
||||
self.callback_metrics.update(batch_log_metrics)
|
||||
|
|
|
@ -126,7 +126,7 @@ def _test_loggers_fit_test(tmpdir, logger_class):
|
|||
if logger_class == TensorBoardLogger:
|
||||
expected = [
|
||||
(0, ['hp_metric']),
|
||||
(0, ['train_some_val']),
|
||||
(0, ['epoch', 'train_some_val']),
|
||||
(0, ['early_stop_on', 'epoch', 'val_acc']),
|
||||
(0, ['hp_metric']),
|
||||
(1, ['epoch', 'test_acc', 'test_loss'])
|
||||
|
@ -134,7 +134,7 @@ def _test_loggers_fit_test(tmpdir, logger_class):
|
|||
assert log_metric_names == expected
|
||||
else:
|
||||
expected = [
|
||||
(0, ['train_some_val']),
|
||||
(0, ['epoch', 'train_some_val']),
|
||||
(0, ['early_stop_on', 'epoch', 'val_acc']),
|
||||
(1, ['epoch', 'test_acc', 'test_loss'])
|
||||
]
|
||||
|
|
|
@ -213,8 +213,11 @@ def test_tensorboard_with_accummulated_gradients(mock_log_metrics, expected, tmp
|
|||
Tests to ensure that tensorboard log properly when accumulated_gradients > 1
|
||||
"""
|
||||
class TestModel(BoringModel):
|
||||
_count = 0
|
||||
_indexes = []
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._count = 0
|
||||
self._indexes = []
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
output = self.layer(batch)
|
||||
|
@ -222,10 +225,10 @@ def test_tensorboard_with_accummulated_gradients(mock_log_metrics, expected, tmp
|
|||
self.log('count', self._count, on_step=True, on_epoch=True)
|
||||
self.log('loss', loss, on_step=True, on_epoch=True)
|
||||
|
||||
if self.trainer.logger_connector.should_update_logs:
|
||||
self._indexes.append(self._count)
|
||||
if not self.trainer.train_loop.should_accumulate():
|
||||
if self.trainer.logger_connector.should_update_logs:
|
||||
self._indexes.append(self.trainer.global_step)
|
||||
|
||||
self._count += 1
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
|
@ -245,14 +248,13 @@ def test_tensorboard_with_accummulated_gradients(mock_log_metrics, expected, tmp
|
|||
|
||||
logger_0 = TensorBoardLogger(tmpdir, default_hp_metric=False)
|
||||
|
||||
accumulate_grad_batches = 2
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=12,
|
||||
limit_val_batches=12,
|
||||
limit_val_batches=0,
|
||||
max_epochs=3,
|
||||
gpus=0,
|
||||
accumulate_grad_batches=accumulate_grad_batches,
|
||||
accumulate_grad_batches=2,
|
||||
logger=[logger_0],
|
||||
log_every_n_steps=3,
|
||||
)
|
||||
|
@ -260,5 +262,6 @@ def test_tensorboard_with_accummulated_gradients(mock_log_metrics, expected, tmp
|
|||
|
||||
mock_count_epochs = [m[2]["step"] for m in mock_log_metrics.mock_calls if "count_epoch" in m[2]["metrics"]]
|
||||
assert mock_count_epochs == expected
|
||||
|
||||
mock_count_steps = [m[2]["step"] for m in mock_log_metrics.mock_calls if "count_step" in m[2]["metrics"]]
|
||||
assert model._indexes == mock_count_steps
|
||||
|
|
|
@ -24,12 +24,16 @@ from unittest import mock
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from torch.nn import functional as F
|
||||
from torch.utils.data import DataLoader, Dataset, random_split
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets.mnist import MNIST
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning import callbacks, Trainer
|
||||
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
from tests.base.boring_model import BoringModel, RandomDictDataset, RandomDictStringDataset
|
||||
from tests.base.deterministic_model import DeterministicModel
|
||||
|
||||
|
|
Loading…
Reference in New Issue