From a053d758d03558d2aa5a328b2f6befbc133a0ebc Mon Sep 17 00:00:00 2001 From: chaton Date: Sat, 9 Jan 2021 01:35:47 +0100 Subject: [PATCH] [bugfix] Logging only on `not should_accumulate()` during training (#5417) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * resolve bug * resolve tests * update * Update tests/loggers/test_tensorboard.py Co-authored-by: Carlos MocholĂ­ Co-authored-by: Carlos MocholĂ­ --- .../logger_connector/logger_connector.py | 15 +++++++-------- tests/loggers/test_all.py | 4 ++-- tests/loggers/test_tensorboard.py | 19 +++++++++++-------- .../test_train_loop_logging_1_0.py | 6 +++++- 4 files changed, 25 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 54bf2f9a90..6cf020aa65 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from copy import deepcopy import os +from copy import deepcopy from pprint import pprint from typing import Iterable, Union @@ -158,7 +158,7 @@ class LoggerConnector: self.logged_metrics.update(logged_metrics_tmp) self.cached_results.legacy_batch_log_metrics.update(logged_metrics_tmp) - def log_metrics(self, metrics, grad_norm_dic, step=None, log_train_step_metrics=False): + def log_metrics(self, metrics, grad_norm_dic, step=None): """Logs the metric dict passed in. If `step` parameter is None and `step` key is presented is metrics, uses metrics["step"] as a step @@ -186,11 +186,8 @@ class LoggerConnector: elif step is None: # added metrics by Lightning for convenience - if log_train_step_metrics: - step = self.trainer.total_batch_idx - else: - scalar_metrics['epoch'] = self.trainer.current_epoch - step = self.trainer.global_step + scalar_metrics['epoch'] = self.trainer.current_epoch + step = self.trainer.global_step # log actual metrics if self.trainer.logger is not None: @@ -593,6 +590,8 @@ class LoggerConnector: return gathered_epoch_outputs def log_train_step_metrics(self, batch_output): + if self.trainer.train_loop.should_accumulate() and self.trainer.train_loop.automatic_optimization: + return _, batch_log_metrics = self.cached_results.update_logger_connector() # when metrics should be logged if self.should_update_logs or self.trainer.fast_dev_run is True: @@ -601,5 +600,5 @@ class LoggerConnector: if grad_norm_dic is None: grad_norm_dic = {} if len(batch_log_metrics) > 0 or len(grad_norm_dic) > 0: - self.log_metrics(batch_log_metrics, grad_norm_dic, log_train_step_metrics=True) + self.log_metrics(batch_log_metrics, grad_norm_dic) self.callback_metrics.update(batch_log_metrics) diff --git a/tests/loggers/test_all.py b/tests/loggers/test_all.py index 795b1a91e6..4bf15ff8d9 100644 --- a/tests/loggers/test_all.py +++ b/tests/loggers/test_all.py @@ -126,7 +126,7 @@ def _test_loggers_fit_test(tmpdir, logger_class): if logger_class == TensorBoardLogger: expected = [ (0, ['hp_metric']), - (0, ['train_some_val']), + (0, ['epoch', 'train_some_val']), (0, ['early_stop_on', 'epoch', 'val_acc']), (0, ['hp_metric']), (1, ['epoch', 'test_acc', 'test_loss']) @@ -134,7 +134,7 @@ def _test_loggers_fit_test(tmpdir, logger_class): assert log_metric_names == expected else: expected = [ - (0, ['train_some_val']), + (0, ['epoch', 'train_some_val']), (0, ['early_stop_on', 'epoch', 'val_acc']), (1, ['epoch', 'test_acc', 'test_loss']) ] diff --git a/tests/loggers/test_tensorboard.py b/tests/loggers/test_tensorboard.py index fa5c711357..148ad550e7 100644 --- a/tests/loggers/test_tensorboard.py +++ b/tests/loggers/test_tensorboard.py @@ -213,8 +213,11 @@ def test_tensorboard_with_accummulated_gradients(mock_log_metrics, expected, tmp Tests to ensure that tensorboard log properly when accumulated_gradients > 1 """ class TestModel(BoringModel): - _count = 0 - _indexes = [] + + def __init__(self): + super().__init__() + self._count = 0 + self._indexes = [] def training_step(self, batch, batch_idx): output = self.layer(batch) @@ -222,10 +225,10 @@ def test_tensorboard_with_accummulated_gradients(mock_log_metrics, expected, tmp self.log('count', self._count, on_step=True, on_epoch=True) self.log('loss', loss, on_step=True, on_epoch=True) - if self.trainer.logger_connector.should_update_logs: - self._indexes.append(self._count) + if not self.trainer.train_loop.should_accumulate(): + if self.trainer.logger_connector.should_update_logs: + self._indexes.append(self.trainer.global_step) - self._count += 1 return loss def validation_step(self, batch, batch_idx): @@ -245,14 +248,13 @@ def test_tensorboard_with_accummulated_gradients(mock_log_metrics, expected, tmp logger_0 = TensorBoardLogger(tmpdir, default_hp_metric=False) - accumulate_grad_batches = 2 trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=12, - limit_val_batches=12, + limit_val_batches=0, max_epochs=3, gpus=0, - accumulate_grad_batches=accumulate_grad_batches, + accumulate_grad_batches=2, logger=[logger_0], log_every_n_steps=3, ) @@ -260,5 +262,6 @@ def test_tensorboard_with_accummulated_gradients(mock_log_metrics, expected, tmp mock_count_epochs = [m[2]["step"] for m in mock_log_metrics.mock_calls if "count_epoch" in m[2]["metrics"]] assert mock_count_epochs == expected + mock_count_steps = [m[2]["step"] for m in mock_log_metrics.mock_calls if "count_step" in m[2]["metrics"]] assert model._indexes == mock_count_steps diff --git a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py index 51b9c2ac69..617cd6fa3c 100644 --- a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py +++ b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py @@ -24,12 +24,16 @@ from unittest import mock import numpy as np import pytest import torch -from torch.utils.data import Dataset +from torch.nn import functional as F +from torch.utils.data import DataLoader, Dataset, random_split +from torchvision import transforms +from torchvision.datasets.mnist import MNIST import pytorch_lightning as pl from pytorch_lightning import callbacks, Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.loggers import WandbLogger from tests.base.boring_model import BoringModel, RandomDictDataset, RandomDictStringDataset from tests.base.deterministic_model import DeterministicModel