[HOTFIX] Logging for evaluation (#4684)
* resolve bugs * add should_flush_logs * remove should_flush * should work * update test * use something else * Update pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py * log mock_log_metrics.mock_calls * typo * don't use keys * convert to list * typo * check kwargs * resolve bug * resolve flake8 Co-authored-by: Sean Naren <sean.narenthiran@gmail.com>
This commit is contained in:
parent
53f14391da
commit
867eef0e4c
|
@ -12,12 +12,13 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from collections import ChainMap, defaultdict
|
||||
from copy import deepcopy
|
||||
from collections import defaultdict, ChainMap
|
||||
from enum import Enum
|
||||
from typing import Union, Tuple, Any, Dict, Optional, List
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from pytorch_lightning.core.step_result import Result
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
# used to map boolean to right LoggerStage values
|
||||
|
@ -445,9 +446,6 @@ class EpochResultStore:
|
|||
epoch_log_metrics = self.get_epoch_log_metrics()
|
||||
logger_connector.logged_metrics.update(epoch_log_metrics)
|
||||
logger_connector.logged_metrics.update(epoch_dict)
|
||||
if not self.trainer.running_sanity_check and not is_train:
|
||||
if len(epoch_log_metrics) > 0:
|
||||
self.trainer.dev_debugger.track_logged_metrics_history(deepcopy(epoch_log_metrics))
|
||||
|
||||
# get forked_metrics
|
||||
forked_metrics = self.get_forked_metrics()
|
||||
|
|
|
@ -12,23 +12,25 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from collections import ChainMap
|
||||
from copy import deepcopy
|
||||
from pprint import pprint
|
||||
from typing import Iterable, Union, cast
|
||||
from copy import deepcopy
|
||||
from collections import ChainMap
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_lightning.core import memory
|
||||
from pytorch_lightning.loggers import TensorBoardLogger, LoggerCollection
|
||||
from pytorch_lightning.utilities import flatten_dict
|
||||
from pytorch_lightning.utilities.model_utils import is_overridden
|
||||
from pytorch_lightning.core.step_result import EvalResult, Result
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator
|
||||
from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import (
|
||||
LOOKUP_TABLE,
|
||||
EpochResultStore,
|
||||
LoggerStages,
|
||||
LOOKUP_TABLE
|
||||
)
|
||||
from pytorch_lightning.utilities import flatten_dict
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.model_utils import is_overridden
|
||||
|
||||
|
||||
class LoggerConnector:
|
||||
|
@ -258,6 +260,11 @@ class LoggerConnector:
|
|||
self.add_to_eval_loop_results(dl_idx, has_been_initialized)
|
||||
|
||||
def get_evaluate_epoch_results(self, test_mode):
|
||||
if not self.trainer.running_sanity_check:
|
||||
# log all the metrics as a single dict
|
||||
metrics_to_log = self.cached_results.get_epoch_log_metrics()
|
||||
if len(metrics_to_log) > 0:
|
||||
self.log_metrics(metrics_to_log, {})
|
||||
|
||||
self.prepare_eval_loop_results()
|
||||
|
||||
|
|
|
@ -14,19 +14,22 @@
|
|||
"""
|
||||
Tests to ensure that the training loop works with a dict (1.0)
|
||||
"""
|
||||
import collections
|
||||
import itertools
|
||||
import os
|
||||
from unittest import mock
|
||||
from unittest.mock import call, patch
|
||||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning import Trainer
|
||||
from pytorch_lightning import callbacks, seed_everything
|
||||
from tests.base.deterministic_model import DeterministicModel
|
||||
from tests.base import SimpleModule, BoringModel, RandomDataset
|
||||
import numpy as np
|
||||
import itertools
|
||||
import collections
|
||||
import torch
|
||||
import pytest
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from pytorch_lightning import Trainer, callbacks, seed_everything
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
from tests.base import BoringModel, RandomDataset, SimpleModule
|
||||
from tests.base.deterministic_model import DeterministicModel
|
||||
|
||||
|
||||
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
|
||||
|
@ -780,3 +783,98 @@ def test_log_works_in_test_callback(tmpdir):
|
|||
assert func_name in trainer.logger_connector.progress_bar_metrics
|
||||
else:
|
||||
assert func_name not in trainer.logger_connector.progress_bar_metrics
|
||||
|
||||
|
||||
@mock.patch("pytorch_lightning.loggers.TensorBoardLogger.log_metrics")
|
||||
@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
|
||||
def test_validation_step_log_with_tensorboard(mock_log_metrics, tmpdir):
|
||||
"""
|
||||
This tests make sure we properly log_metrics to loggers
|
||||
"""
|
||||
|
||||
class ExtendedModel(BoringModel):
|
||||
|
||||
val_losses = []
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
output = self.layer(batch)
|
||||
loss = self.loss(batch, output)
|
||||
self.log('train_loss', loss)
|
||||
return {"loss": loss}
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
output = self.layer(batch)
|
||||
loss = self.loss(batch, output)
|
||||
self.val_losses.append(loss)
|
||||
self.log('valid_loss_0', loss, on_step=True, on_epoch=True)
|
||||
self.log('valid_loss_1', loss, on_step=False, on_epoch=True)
|
||||
self.log('valid_loss_2', loss, on_step=True, on_epoch=False)
|
||||
self.log('valid_loss_3', loss, on_step=False, on_epoch=False)
|
||||
return {"val_loss": loss}
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
output = self.layer(batch)
|
||||
loss = self.loss(batch, output)
|
||||
self.log('fake_test_acc', loss)
|
||||
return {"y": loss}
|
||||
|
||||
model = ExtendedModel()
|
||||
model.validation_epoch_end = None
|
||||
|
||||
# Initialize a trainer
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
logger=TensorBoardLogger(tmpdir),
|
||||
limit_train_batches=2,
|
||||
limit_val_batches=2,
|
||||
limit_test_batches=0,
|
||||
max_epochs=2,
|
||||
progress_bar_refresh_rate=1,
|
||||
)
|
||||
|
||||
# Train the model ⚡
|
||||
trainer.fit(model)
|
||||
|
||||
# hp_metric + 2 steps + epoch + 2 steps + epoch
|
||||
expected_num_calls = 1 + 2 + 1 + 2 + 1
|
||||
|
||||
assert len(mock_log_metrics.mock_calls) == expected_num_calls
|
||||
|
||||
assert mock_log_metrics.mock_calls[0] == call({'hp_metric': -1}, 0)
|
||||
|
||||
def get_metrics_at_idx(idx):
|
||||
mock_calls = list(mock_log_metrics.mock_calls)
|
||||
if isinstance(mock_calls[idx].kwargs, dict):
|
||||
return mock_calls[idx].kwargs["metrics"]
|
||||
else:
|
||||
return mock_calls[idx][2]["metrics"]
|
||||
|
||||
expected = ['valid_loss_0_step/epoch_0', 'valid_loss_2/epoch_0', 'global_step']
|
||||
assert sorted(get_metrics_at_idx(1)) == sorted(expected)
|
||||
assert sorted(get_metrics_at_idx(2)) == sorted(expected)
|
||||
|
||||
expected = model.val_losses[2]
|
||||
assert get_metrics_at_idx(1)["valid_loss_0_step/epoch_0"] == expected
|
||||
expected = model.val_losses[3]
|
||||
assert get_metrics_at_idx(2)["valid_loss_0_step/epoch_0"] == expected
|
||||
|
||||
expected = ['valid_loss_0_epoch', 'valid_loss_1', 'epoch', 'global_step']
|
||||
assert sorted(get_metrics_at_idx(3)) == sorted(expected)
|
||||
|
||||
expected = torch.stack(model.val_losses[2:4]).mean()
|
||||
assert get_metrics_at_idx(3)["valid_loss_1"] == expected
|
||||
expected = ['valid_loss_0_step/epoch_1', 'valid_loss_2/epoch_1', 'global_step']
|
||||
|
||||
assert sorted(get_metrics_at_idx(4)) == sorted(expected)
|
||||
assert sorted(get_metrics_at_idx(5)) == sorted(expected)
|
||||
|
||||
expected = model.val_losses[4]
|
||||
assert get_metrics_at_idx(4)["valid_loss_0_step/epoch_1"] == expected
|
||||
expected = model.val_losses[5]
|
||||
assert get_metrics_at_idx(5)["valid_loss_0_step/epoch_1"] == expected
|
||||
|
||||
expected = ['valid_loss_0_epoch', 'valid_loss_1', 'epoch', 'global_step']
|
||||
assert sorted(get_metrics_at_idx(6)) == sorted(expected)
|
||||
|
||||
expected = torch.stack(model.val_losses[4:]).mean()
|
||||
assert get_metrics_at_idx(6)["valid_loss_1"] == expected
|
||||
|
|
Loading…
Reference in New Issue