Trainer.test should return only test metrics (#5214)

* resolve bug

* merge tests
This commit is contained in:
chaton 2020-12-28 15:34:18 +01:00 committed by GitHub
parent d1e97a4f11
commit 9ebbfece5e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 28 additions and 40 deletions

View File

@ -399,7 +399,7 @@ class EpochResultStore:
callback_metrics.update(epoch_log_metrics)
callback_metrics.update(forked_metrics)
if not is_train:
if not is_train and self.trainer.testing:
logger_connector.evaluation_callback_metrics.update(callback_metrics)
# update callback_metrics

View File

@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from copy import deepcopy
import os
from pprint import pprint
from typing import Iterable, Union
@ -273,10 +273,13 @@ class LoggerConnector:
if isinstance(eval_results, list):
for eval_result in eval_results:
self.trainer.logger_connector.callback_metrics.update(eval_result.callback_metrics)
self.trainer.logger_connector.evaluation_callback_metrics.update(eval_result.callback_metrics)
if self.trainer.testing:
self.trainer.logger_connector.evaluation_callback_metrics.update(
eval_result.callback_metrics)
else:
self.trainer.logger_connector.callback_metrics.update(eval_results.callback_metrics)
self.trainer.logger_connector.evaluation_callback_metrics.update(eval_results.callback_metrics)
if self.trainer.testing:
self.trainer.logger_connector.evaluation_callback_metrics.update(eval_results.callback_metrics)
else:
flat = {}
if isinstance(eval_results, list):
@ -292,7 +295,8 @@ class LoggerConnector:
flat['checkpoint_on'] = flat['val_loss']
flat['early_stop_on'] = flat['val_loss']
self.trainer.logger_connector.callback_metrics.update(flat)
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)
if self.trainer.testing:
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)
else:
# with a scalar return, auto set it to "val_loss" for callbacks
if isinstance(eval_results, torch.Tensor):
@ -305,7 +309,8 @@ class LoggerConnector:
flat['checkpoint_on'] = flat['val_loss']
flat['early_stop_on'] = flat['val_loss']
self.trainer.logger_connector.callback_metrics.update(flat)
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)
if self.trainer.testing:
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)
def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metrics, log_metrics, callback_metrics):
# eval loop returns all metrics
@ -322,7 +327,8 @@ class LoggerConnector:
callback_metrics.update(log_metrics)
callback_metrics.update(prog_bar_metrics)
self.trainer.logger_connector.callback_metrics.update(callback_metrics)
self.trainer.logger_connector.evaluation_callback_metrics.update(callback_metrics)
if self.trainer.testing:
self.trainer.logger_connector.evaluation_callback_metrics.update(callback_metrics)
if len(dataloader_result_metrics) > 0:
self.eval_loop_results.append(dataloader_result_metrics)

View File

@ -25,7 +25,7 @@ import pytest
import torch
from torch.utils.data import DataLoader, Dataset
from pytorch_lightning import Trainer, callbacks, seed_everything
from pytorch_lightning import callbacks, seed_everything, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.loggers import TensorBoardLogger
@ -813,7 +813,7 @@ def test_validation_step_log_with_tensorboard(mock_log_metrics, tmpdir):
def test_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
self.log('fake_test_acc', loss)
self.log('test_loss', loss)
return {"y": loss}
model = ExtendedModel()
@ -825,7 +825,7 @@ def test_validation_step_log_with_tensorboard(mock_log_metrics, tmpdir):
logger=TensorBoardLogger(tmpdir),
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=0,
limit_test_batches=2,
max_epochs=2,
progress_bar_refresh_rate=1,
)
@ -877,33 +877,15 @@ def test_validation_step_log_with_tensorboard(mock_log_metrics, tmpdir):
expected = torch.stack(model.val_losses[4:]).mean()
assert get_metrics_at_idx(6)["valid_loss_1"] == expected
def test_progress_bar_dict_contains_values_on_test_epoch_end(tmpdir):
class TestModel(BoringModel):
def test_step(self, *args):
self.log("foo", torch.tensor(self.current_epoch), on_step=False, on_epoch=True, prog_bar=True)
def test_epoch_end(self, *_):
self.epoch_end_called = True
self.log('foo_2', torch.tensor(self.current_epoch), prog_bar=True,
on_epoch=True, sync_dist=True, sync_dist_op='sum')
def on_test_epoch_end(self, *_):
self.on_test_epoch_end_called = True
assert self.trainer.progress_bar_dict["foo"] == self.current_epoch
assert self.trainer.progress_bar_dict["foo_2"] == self.current_epoch
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=2,
limit_train_batches=1,
num_sanity_val_steps=2,
checkpoint_callback=False,
logger=False,
weights_summary=None,
progress_bar_refresh_rate=0,
)
model = TestModel()
trainer.test(model)
assert model.epoch_end_called
assert model.on_test_epoch_end_called
results = trainer.test(model)
expected_callback_metrics = {
'train_loss',
'valid_loss_0_epoch',
'valid_loss_0',
'debug_epoch',
'valid_loss_1',
'test_loss',
'val_loss'
}
assert set(trainer.callback_metrics) == expected_callback_metrics
assert set(results[0]) == {'test_loss', 'debug_epoch'}