[HOTFIX] Logging for evaluation (#4684)

* resolve bugs

* add should_flush_logs

* remove should_flush

* should work

* update test

* use something else

* Update pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

* log mock_log_metrics.mock_calls

* typo

* don't use keys

* convert to list

* typo

* check kwargs

* resolve bug

* resolve flake8

Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
This commit is contained in:
chaton 2020-11-15 15:41:33 +00:00 committed by GitHub
parent 53f14391da
commit 867eef0e4c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 124 additions and 21 deletions

View File

@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections import ChainMap, defaultdict
from copy import deepcopy
from collections import defaultdict, ChainMap
from enum import Enum
from typing import Union, Tuple, Any, Dict, Optional, List
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from typing import Any, Dict, List, Optional, Tuple, Union
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.utilities.exceptions import MisconfigurationException
# used to map boolean to right LoggerStage values
@ -445,9 +446,6 @@ class EpochResultStore:
epoch_log_metrics = self.get_epoch_log_metrics()
logger_connector.logged_metrics.update(epoch_log_metrics)
logger_connector.logged_metrics.update(epoch_dict)
if not self.trainer.running_sanity_check and not is_train:
if len(epoch_log_metrics) > 0:
self.trainer.dev_debugger.track_logged_metrics_history(deepcopy(epoch_log_metrics))
# get forked_metrics
forked_metrics = self.get_forked_metrics()

View File

@ -12,23 +12,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections import ChainMap
from copy import deepcopy
from pprint import pprint
from typing import Iterable, Union, cast
from copy import deepcopy
from collections import ChainMap
import torch
from pytorch_lightning.core import memory
from pytorch_lightning.loggers import TensorBoardLogger, LoggerCollection
from pytorch_lightning.utilities import flatten_dict
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.core.step_result import EvalResult, Result
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator
from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import (
LOOKUP_TABLE,
EpochResultStore,
LoggerStages,
LOOKUP_TABLE
)
from pytorch_lightning.utilities import flatten_dict
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_utils import is_overridden
class LoggerConnector:
@ -258,6 +260,11 @@ class LoggerConnector:
self.add_to_eval_loop_results(dl_idx, has_been_initialized)
def get_evaluate_epoch_results(self, test_mode):
if not self.trainer.running_sanity_check:
# log all the metrics as a single dict
metrics_to_log = self.cached_results.get_epoch_log_metrics()
if len(metrics_to_log) > 0:
self.log_metrics(metrics_to_log, {})
self.prepare_eval_loop_results()

View File

@ -14,19 +14,22 @@
"""
Tests to ensure that the training loop works with a dict (1.0)
"""
import collections
import itertools
import os
from unittest import mock
from unittest.mock import call, patch
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import Trainer
from pytorch_lightning import callbacks, seed_everything
from tests.base.deterministic_model import DeterministicModel
from tests.base import SimpleModule, BoringModel, RandomDataset
import numpy as np
import itertools
import collections
import torch
import pytest
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import Trainer, callbacks, seed_everything
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import TensorBoardLogger
from tests.base import BoringModel, RandomDataset, SimpleModule
from tests.base.deterministic_model import DeterministicModel
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
@ -780,3 +783,98 @@ def test_log_works_in_test_callback(tmpdir):
assert func_name in trainer.logger_connector.progress_bar_metrics
else:
assert func_name not in trainer.logger_connector.progress_bar_metrics
@mock.patch("pytorch_lightning.loggers.TensorBoardLogger.log_metrics")
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
def test_validation_step_log_with_tensorboard(mock_log_metrics, tmpdir):
"""
This tests make sure we properly log_metrics to loggers
"""
class ExtendedModel(BoringModel):
val_losses = []
def training_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
self.log('train_loss', loss)
return {"loss": loss}
def validation_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
self.val_losses.append(loss)
self.log('valid_loss_0', loss, on_step=True, on_epoch=True)
self.log('valid_loss_1', loss, on_step=False, on_epoch=True)
self.log('valid_loss_2', loss, on_step=True, on_epoch=False)
self.log('valid_loss_3', loss, on_step=False, on_epoch=False)
return {"val_loss": loss}
def test_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
self.log('fake_test_acc', loss)
return {"y": loss}
model = ExtendedModel()
model.validation_epoch_end = None
# Initialize a trainer
trainer = Trainer(
default_root_dir=tmpdir,
logger=TensorBoardLogger(tmpdir),
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=0,
max_epochs=2,
progress_bar_refresh_rate=1,
)
# Train the model ⚡
trainer.fit(model)
# hp_metric + 2 steps + epoch + 2 steps + epoch
expected_num_calls = 1 + 2 + 1 + 2 + 1
assert len(mock_log_metrics.mock_calls) == expected_num_calls
assert mock_log_metrics.mock_calls[0] == call({'hp_metric': -1}, 0)
def get_metrics_at_idx(idx):
mock_calls = list(mock_log_metrics.mock_calls)
if isinstance(mock_calls[idx].kwargs, dict):
return mock_calls[idx].kwargs["metrics"]
else:
return mock_calls[idx][2]["metrics"]
expected = ['valid_loss_0_step/epoch_0', 'valid_loss_2/epoch_0', 'global_step']
assert sorted(get_metrics_at_idx(1)) == sorted(expected)
assert sorted(get_metrics_at_idx(2)) == sorted(expected)
expected = model.val_losses[2]
assert get_metrics_at_idx(1)["valid_loss_0_step/epoch_0"] == expected
expected = model.val_losses[3]
assert get_metrics_at_idx(2)["valid_loss_0_step/epoch_0"] == expected
expected = ['valid_loss_0_epoch', 'valid_loss_1', 'epoch', 'global_step']
assert sorted(get_metrics_at_idx(3)) == sorted(expected)
expected = torch.stack(model.val_losses[2:4]).mean()
assert get_metrics_at_idx(3)["valid_loss_1"] == expected
expected = ['valid_loss_0_step/epoch_1', 'valid_loss_2/epoch_1', 'global_step']
assert sorted(get_metrics_at_idx(4)) == sorted(expected)
assert sorted(get_metrics_at_idx(5)) == sorted(expected)
expected = model.val_losses[4]
assert get_metrics_at_idx(4)["valid_loss_0_step/epoch_1"] == expected
expected = model.val_losses[5]
assert get_metrics_at_idx(5)["valid_loss_0_step/epoch_1"] == expected
expected = ['valid_loss_0_epoch', 'valid_loss_1', 'epoch', 'global_step']
assert sorted(get_metrics_at_idx(6)) == sorted(expected)
expected = torch.stack(model.val_losses[4:]).mean()
assert get_metrics_at_idx(6)["valid_loss_1"] == expected