Trainer.test should return only test metrics (#5214)
* resolve bug * merge tests
This commit is contained in:
parent
d1e97a4f11
commit
9ebbfece5e
|
@ -399,7 +399,7 @@ class EpochResultStore:
|
|||
callback_metrics.update(epoch_log_metrics)
|
||||
callback_metrics.update(forked_metrics)
|
||||
|
||||
if not is_train:
|
||||
if not is_train and self.trainer.testing:
|
||||
logger_connector.evaluation_callback_metrics.update(callback_metrics)
|
||||
|
||||
# update callback_metrics
|
||||
|
|
|
@ -11,8 +11,8 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from copy import deepcopy
|
||||
import os
|
||||
from pprint import pprint
|
||||
from typing import Iterable, Union
|
||||
|
||||
|
@ -273,10 +273,13 @@ class LoggerConnector:
|
|||
if isinstance(eval_results, list):
|
||||
for eval_result in eval_results:
|
||||
self.trainer.logger_connector.callback_metrics.update(eval_result.callback_metrics)
|
||||
self.trainer.logger_connector.evaluation_callback_metrics.update(eval_result.callback_metrics)
|
||||
if self.trainer.testing:
|
||||
self.trainer.logger_connector.evaluation_callback_metrics.update(
|
||||
eval_result.callback_metrics)
|
||||
else:
|
||||
self.trainer.logger_connector.callback_metrics.update(eval_results.callback_metrics)
|
||||
self.trainer.logger_connector.evaluation_callback_metrics.update(eval_results.callback_metrics)
|
||||
if self.trainer.testing:
|
||||
self.trainer.logger_connector.evaluation_callback_metrics.update(eval_results.callback_metrics)
|
||||
else:
|
||||
flat = {}
|
||||
if isinstance(eval_results, list):
|
||||
|
@ -292,7 +295,8 @@ class LoggerConnector:
|
|||
flat['checkpoint_on'] = flat['val_loss']
|
||||
flat['early_stop_on'] = flat['val_loss']
|
||||
self.trainer.logger_connector.callback_metrics.update(flat)
|
||||
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)
|
||||
if self.trainer.testing:
|
||||
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)
|
||||
else:
|
||||
# with a scalar return, auto set it to "val_loss" for callbacks
|
||||
if isinstance(eval_results, torch.Tensor):
|
||||
|
@ -305,7 +309,8 @@ class LoggerConnector:
|
|||
flat['checkpoint_on'] = flat['val_loss']
|
||||
flat['early_stop_on'] = flat['val_loss']
|
||||
self.trainer.logger_connector.callback_metrics.update(flat)
|
||||
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)
|
||||
if self.trainer.testing:
|
||||
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)
|
||||
|
||||
def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metrics, log_metrics, callback_metrics):
|
||||
# eval loop returns all metrics
|
||||
|
@ -322,7 +327,8 @@ class LoggerConnector:
|
|||
callback_metrics.update(log_metrics)
|
||||
callback_metrics.update(prog_bar_metrics)
|
||||
self.trainer.logger_connector.callback_metrics.update(callback_metrics)
|
||||
self.trainer.logger_connector.evaluation_callback_metrics.update(callback_metrics)
|
||||
if self.trainer.testing:
|
||||
self.trainer.logger_connector.evaluation_callback_metrics.update(callback_metrics)
|
||||
|
||||
if len(dataloader_result_metrics) > 0:
|
||||
self.eval_loop_results.append(dataloader_result_metrics)
|
||||
|
|
|
@ -25,7 +25,7 @@ import pytest
|
|||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from pytorch_lightning import Trainer, callbacks, seed_everything
|
||||
from pytorch_lightning import callbacks, seed_everything, Trainer
|
||||
from pytorch_lightning.callbacks import ModelCheckpoint
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.loggers import TensorBoardLogger
|
||||
|
@ -813,7 +813,7 @@ def test_validation_step_log_with_tensorboard(mock_log_metrics, tmpdir):
|
|||
def test_step(self, batch, batch_idx):
|
||||
output = self.layer(batch)
|
||||
loss = self.loss(batch, output)
|
||||
self.log('fake_test_acc', loss)
|
||||
self.log('test_loss', loss)
|
||||
return {"y": loss}
|
||||
|
||||
model = ExtendedModel()
|
||||
|
@ -825,7 +825,7 @@ def test_validation_step_log_with_tensorboard(mock_log_metrics, tmpdir):
|
|||
logger=TensorBoardLogger(tmpdir),
|
||||
limit_train_batches=2,
|
||||
limit_val_batches=2,
|
||||
limit_test_batches=0,
|
||||
limit_test_batches=2,
|
||||
max_epochs=2,
|
||||
progress_bar_refresh_rate=1,
|
||||
)
|
||||
|
@ -877,33 +877,15 @@ def test_validation_step_log_with_tensorboard(mock_log_metrics, tmpdir):
|
|||
expected = torch.stack(model.val_losses[4:]).mean()
|
||||
assert get_metrics_at_idx(6)["valid_loss_1"] == expected
|
||||
|
||||
|
||||
def test_progress_bar_dict_contains_values_on_test_epoch_end(tmpdir):
|
||||
class TestModel(BoringModel):
|
||||
def test_step(self, *args):
|
||||
self.log("foo", torch.tensor(self.current_epoch), on_step=False, on_epoch=True, prog_bar=True)
|
||||
|
||||
def test_epoch_end(self, *_):
|
||||
self.epoch_end_called = True
|
||||
self.log('foo_2', torch.tensor(self.current_epoch), prog_bar=True,
|
||||
on_epoch=True, sync_dist=True, sync_dist_op='sum')
|
||||
|
||||
def on_test_epoch_end(self, *_):
|
||||
self.on_test_epoch_end_called = True
|
||||
assert self.trainer.progress_bar_dict["foo"] == self.current_epoch
|
||||
assert self.trainer.progress_bar_dict["foo_2"] == self.current_epoch
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
max_epochs=2,
|
||||
limit_train_batches=1,
|
||||
num_sanity_val_steps=2,
|
||||
checkpoint_callback=False,
|
||||
logger=False,
|
||||
weights_summary=None,
|
||||
progress_bar_refresh_rate=0,
|
||||
)
|
||||
model = TestModel()
|
||||
trainer.test(model)
|
||||
assert model.epoch_end_called
|
||||
assert model.on_test_epoch_end_called
|
||||
results = trainer.test(model)
|
||||
expected_callback_metrics = {
|
||||
'train_loss',
|
||||
'valid_loss_0_epoch',
|
||||
'valid_loss_0',
|
||||
'debug_epoch',
|
||||
'valid_loss_1',
|
||||
'test_loss',
|
||||
'val_loss'
|
||||
}
|
||||
assert set(trainer.callback_metrics) == expected_callback_metrics
|
||||
assert set(results[0]) == {'test_loss', 'debug_epoch'}
|
||||
|
|
Loading…
Reference in New Issue