diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index 7ea2e8208a..1d03f3e885 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from collections import ChainMap, defaultdict from copy import deepcopy -from collections import defaultdict, ChainMap from enum import Enum -from typing import Union, Tuple, Any, Dict, Optional, List -from pytorch_lightning.utilities.exceptions import MisconfigurationException +from typing import Any, Dict, List, Optional, Tuple, Union + from pytorch_lightning.core.step_result import Result +from pytorch_lightning.utilities.exceptions import MisconfigurationException # used to map boolean to right LoggerStage values @@ -445,9 +446,6 @@ class EpochResultStore: epoch_log_metrics = self.get_epoch_log_metrics() logger_connector.logged_metrics.update(epoch_log_metrics) logger_connector.logged_metrics.update(epoch_dict) - if not self.trainer.running_sanity_check and not is_train: - if len(epoch_log_metrics) > 0: - self.trainer.dev_debugger.track_logged_metrics_history(deepcopy(epoch_log_metrics)) # get forked_metrics forked_metrics = self.get_forked_metrics() diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 7116739c8a..9eacdb6dc3 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -12,23 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from collections import ChainMap +from copy import deepcopy from pprint import pprint from typing import Iterable, Union, cast -from copy import deepcopy -from collections import ChainMap + import torch + from pytorch_lightning.core import memory -from pytorch_lightning.loggers import TensorBoardLogger, LoggerCollection -from pytorch_lightning.utilities import flatten_dict -from pytorch_lightning.utilities.model_utils import is_overridden from pytorch_lightning.core.step_result import EvalResult, Result -from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import ( + LOOKUP_TABLE, EpochResultStore, LoggerStages, - LOOKUP_TABLE ) +from pytorch_lightning.utilities import flatten_dict +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.model_utils import is_overridden class LoggerConnector: @@ -258,6 +260,11 @@ class LoggerConnector: self.add_to_eval_loop_results(dl_idx, has_been_initialized) def get_evaluate_epoch_results(self, test_mode): + if not self.trainer.running_sanity_check: + # log all the metrics as a single dict + metrics_to_log = self.cached_results.get_epoch_log_metrics() + if len(metrics_to_log) > 0: + self.log_metrics(metrics_to_log, {}) self.prepare_eval_loop_results() diff --git a/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py b/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py index 552c472e12..b6d4b107ee 100644 --- a/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py +++ b/tests/trainer/logging_tests/test_eval_loop_logging_1_0.py @@ -14,19 +14,22 @@ """ Tests to ensure that the training loop works with a dict (1.0) """ +import collections +import itertools import os from unittest import mock +from unittest.mock import call, patch -from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning import Trainer -from pytorch_lightning import callbacks, seed_everything -from tests.base.deterministic_model import DeterministicModel -from tests.base import SimpleModule, BoringModel, RandomDataset import numpy as np -import itertools -import collections -import torch import pytest +import torch +from torch.utils.data import DataLoader, Dataset + +from pytorch_lightning import Trainer, callbacks, seed_everything +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.loggers import TensorBoardLogger +from tests.base import BoringModel, RandomDataset, SimpleModule +from tests.base.deterministic_model import DeterministicModel @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) @@ -780,3 +783,98 @@ def test_log_works_in_test_callback(tmpdir): assert func_name in trainer.logger_connector.progress_bar_metrics else: assert func_name not in trainer.logger_connector.progress_bar_metrics + + +@mock.patch("pytorch_lightning.loggers.TensorBoardLogger.log_metrics") +@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) +def test_validation_step_log_with_tensorboard(mock_log_metrics, tmpdir): + """ + This tests make sure we properly log_metrics to loggers + """ + + class ExtendedModel(BoringModel): + + val_losses = [] + + def training_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.log('train_loss', loss) + return {"loss": loss} + + def validation_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.val_losses.append(loss) + self.log('valid_loss_0', loss, on_step=True, on_epoch=True) + self.log('valid_loss_1', loss, on_step=False, on_epoch=True) + self.log('valid_loss_2', loss, on_step=True, on_epoch=False) + self.log('valid_loss_3', loss, on_step=False, on_epoch=False) + return {"val_loss": loss} + + def test_step(self, batch, batch_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.log('fake_test_acc', loss) + return {"y": loss} + + model = ExtendedModel() + model.validation_epoch_end = None + + # Initialize a trainer + trainer = Trainer( + default_root_dir=tmpdir, + logger=TensorBoardLogger(tmpdir), + limit_train_batches=2, + limit_val_batches=2, + limit_test_batches=0, + max_epochs=2, + progress_bar_refresh_rate=1, + ) + + # Train the model ⚡ + trainer.fit(model) + + # hp_metric + 2 steps + epoch + 2 steps + epoch + expected_num_calls = 1 + 2 + 1 + 2 + 1 + + assert len(mock_log_metrics.mock_calls) == expected_num_calls + + assert mock_log_metrics.mock_calls[0] == call({'hp_metric': -1}, 0) + + def get_metrics_at_idx(idx): + mock_calls = list(mock_log_metrics.mock_calls) + if isinstance(mock_calls[idx].kwargs, dict): + return mock_calls[idx].kwargs["metrics"] + else: + return mock_calls[idx][2]["metrics"] + + expected = ['valid_loss_0_step/epoch_0', 'valid_loss_2/epoch_0', 'global_step'] + assert sorted(get_metrics_at_idx(1)) == sorted(expected) + assert sorted(get_metrics_at_idx(2)) == sorted(expected) + + expected = model.val_losses[2] + assert get_metrics_at_idx(1)["valid_loss_0_step/epoch_0"] == expected + expected = model.val_losses[3] + assert get_metrics_at_idx(2)["valid_loss_0_step/epoch_0"] == expected + + expected = ['valid_loss_0_epoch', 'valid_loss_1', 'epoch', 'global_step'] + assert sorted(get_metrics_at_idx(3)) == sorted(expected) + + expected = torch.stack(model.val_losses[2:4]).mean() + assert get_metrics_at_idx(3)["valid_loss_1"] == expected + expected = ['valid_loss_0_step/epoch_1', 'valid_loss_2/epoch_1', 'global_step'] + + assert sorted(get_metrics_at_idx(4)) == sorted(expected) + assert sorted(get_metrics_at_idx(5)) == sorted(expected) + + expected = model.val_losses[4] + assert get_metrics_at_idx(4)["valid_loss_0_step/epoch_1"] == expected + expected = model.val_losses[5] + assert get_metrics_at_idx(5)["valid_loss_0_step/epoch_1"] == expected + + expected = ['valid_loss_0_epoch', 'valid_loss_1', 'epoch', 'global_step'] + assert sorted(get_metrics_at_idx(6)) == sorted(expected) + + expected = torch.stack(model.val_losses[4:]).mean() + assert get_metrics_at_idx(6)["valid_loss_1"] == expected