tests for val step flow and logging (#3731)
* ref: test val epoch end * ref: test val epoch end * ref: test val epoch end * ref: test log dict * ref: test log dict * ref: test log dict * ref: test log dict
This commit is contained in:
parent
3dcf7130c5
commit
b3be8022bd
|
@ -32,6 +32,7 @@ from pytorch_lightning.core.step_result import EvalResult, TrainResult
|
|||
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.core.step_result import Result
|
||||
from pytorch_lightning.utilities.parsing import (
|
||||
AttributeDict,
|
||||
|
@ -216,7 +217,15 @@ class LightningModule(
|
|||
if self._results is not None:
|
||||
# in any epoch end can't log step metrics (only epoch metric)
|
||||
if 'epoch_end' in self._current_fx_name and on_step:
|
||||
on_step = False
|
||||
m = f'on_step=True cannot be used on {self._current_fx_name} method'
|
||||
raise MisconfigurationException(m)
|
||||
|
||||
if 'epoch_end' in self._current_fx_name and on_epoch == False:
|
||||
m = f'on_epoch cannot be False when called from the {self._current_fx_name} method'
|
||||
raise MisconfigurationException(m)
|
||||
|
||||
# add log_dict
|
||||
# TODO: if logged twice fail with crash
|
||||
|
||||
# set the default depending on the fx_name
|
||||
on_step = self.__auto_choose_log_on_step(on_step)
|
||||
|
@ -238,6 +247,60 @@ class LightningModule(
|
|||
sync_dist_group
|
||||
)
|
||||
|
||||
def log_dict(
|
||||
self,
|
||||
dictionary: dict,
|
||||
prog_bar: bool = False,
|
||||
logger: bool = True,
|
||||
on_step: Union[None, bool] = None,
|
||||
on_epoch: Union[None, bool] = None,
|
||||
reduce_fx: Callable = torch.mean,
|
||||
tbptt_reduce_fx: Callable = torch.mean,
|
||||
tbptt_pad_token: int = 0,
|
||||
enable_graph: bool = False,
|
||||
sync_dist: bool = False,
|
||||
sync_dist_op: Union[Any, str] = 'mean',
|
||||
sync_dist_group: Optional[Any] = None,
|
||||
):
|
||||
"""
|
||||
Log a dictonary of values at once
|
||||
|
||||
Example::
|
||||
|
||||
values = {'loss': loss, 'acc': acc, ..., 'metric_n': metric_n}
|
||||
self.log_dict(values)
|
||||
|
||||
Args:
|
||||
dictionary: key value pairs (str, tensors)
|
||||
prog_bar: if True logs to the progress base
|
||||
logger: if True logs to the logger
|
||||
on_step: if True logs at this step. None auto-logs for training_step but not validation/test_step
|
||||
on_epoch: if True logs epoch accumulated metrics. None auto-logs for val/test step but not training_step
|
||||
reduce_fx: Torch.mean by default
|
||||
tbptt_reduce_fx: function to reduce on truncated back prop
|
||||
tbptt_pad_token: token to use for padding
|
||||
enable_graph: if True, will not auto detach the graph
|
||||
sync_dist: if True, reduces the metric across GPUs/TPUs
|
||||
sync_dist_op: the op to sync across
|
||||
sync_dist_group: the ddp group:
|
||||
"""
|
||||
for k, v in dictionary.items():
|
||||
self.log(
|
||||
name=k,
|
||||
value=v,
|
||||
prog_bar=prog_bar,
|
||||
logger=logger,
|
||||
on_step=on_step,
|
||||
on_epoch=on_epoch,
|
||||
reduce_fx=reduce_fx,
|
||||
enable_graph=enable_graph,
|
||||
sync_dist=sync_dist,
|
||||
sync_dist_group=sync_dist_group,
|
||||
sync_dist_op=sync_dist_op,
|
||||
tbptt_pad_token=tbptt_pad_token,
|
||||
tbptt_reduce_fx=tbptt_reduce_fx,
|
||||
)
|
||||
|
||||
def __auto_choose_log_on_step(self, on_step):
|
||||
if on_step is None:
|
||||
if self._current_fx_name in {'training_step', 'training_step_end'}:
|
||||
|
|
|
@ -131,7 +131,10 @@ class Result(Dict):
|
|||
|
||||
# if user requests both step and epoch, then we split the metric in two automatically
|
||||
# one will be logged per step. the other per epoch
|
||||
was_forked = False
|
||||
if on_step and on_epoch:
|
||||
was_forked = True
|
||||
|
||||
# set step version
|
||||
step_name = f'step_{name}'
|
||||
self.__set_meta(
|
||||
|
@ -144,6 +147,7 @@ class Result(Dict):
|
|||
reduce_fx=reduce_fx,
|
||||
tbptt_reduce_fx=tbptt_reduce_fx,
|
||||
tbptt_pad_token=tbptt_pad_token,
|
||||
forked=False
|
||||
)
|
||||
self.__setitem__(step_name, value)
|
||||
|
||||
|
@ -159,6 +163,7 @@ class Result(Dict):
|
|||
reduce_fx=reduce_fx,
|
||||
tbptt_reduce_fx=tbptt_reduce_fx,
|
||||
tbptt_pad_token=tbptt_pad_token,
|
||||
forked=False
|
||||
)
|
||||
self.__setitem__(epoch_name, value)
|
||||
|
||||
|
@ -173,6 +178,7 @@ class Result(Dict):
|
|||
reduce_fx,
|
||||
tbptt_reduce_fx=tbptt_reduce_fx,
|
||||
tbptt_pad_token=tbptt_pad_token,
|
||||
forked=was_forked
|
||||
)
|
||||
|
||||
# set the value
|
||||
|
@ -189,6 +195,7 @@ class Result(Dict):
|
|||
reduce_fx: Callable,
|
||||
tbptt_pad_token: int,
|
||||
tbptt_reduce_fx: Callable,
|
||||
forked: bool
|
||||
):
|
||||
# set the meta for the item
|
||||
meta_value = value
|
||||
|
@ -201,6 +208,7 @@ class Result(Dict):
|
|||
value=meta_value,
|
||||
tbptt_reduce_fx=tbptt_reduce_fx,
|
||||
tbptt_pad_token=tbptt_pad_token,
|
||||
forked=forked
|
||||
)
|
||||
|
||||
self['meta'][name] = meta
|
||||
|
@ -222,9 +230,10 @@ class Result(Dict):
|
|||
|
||||
return result
|
||||
|
||||
def get_batch_log_metrics(self) -> dict:
|
||||
def get_batch_log_metrics(self, include_forked_originals=True) -> dict:
|
||||
"""
|
||||
Gets the metrics to log at the end of the batch step
|
||||
|
||||
"""
|
||||
result = {}
|
||||
|
||||
|
@ -232,6 +241,10 @@ class Result(Dict):
|
|||
for k, options in meta.items():
|
||||
if k == '_internal':
|
||||
continue
|
||||
|
||||
if options['forked'] and not include_forked_originals:
|
||||
continue
|
||||
|
||||
if options['logger'] and options['on_step']:
|
||||
result[k] = self[k]
|
||||
return result
|
||||
|
@ -264,7 +277,7 @@ class Result(Dict):
|
|||
result[k] = self[k]
|
||||
return result
|
||||
|
||||
def get_batch_pbar_metrics(self):
|
||||
def get_batch_pbar_metrics(self, include_forked_originals=True):
|
||||
"""
|
||||
Gets the metrics to log at the end of the batch step
|
||||
"""
|
||||
|
@ -274,6 +287,9 @@ class Result(Dict):
|
|||
for k, options in meta.items():
|
||||
if k == '_internal':
|
||||
continue
|
||||
if options['forked'] and not include_forked_originals:
|
||||
continue
|
||||
|
||||
if options['prog_bar'] and options['on_step']:
|
||||
result[k] = self[k]
|
||||
return result
|
||||
|
|
|
@ -202,6 +202,7 @@ class EvaluationLoop(object):
|
|||
|
||||
if self.testing:
|
||||
if is_overridden('test_epoch_end', model=model):
|
||||
model._current_fx_name = 'test_epoch_end'
|
||||
if using_eval_result:
|
||||
eval_results = self.__gather_epoch_end_eval_results(outputs)
|
||||
|
||||
|
@ -210,6 +211,7 @@ class EvaluationLoop(object):
|
|||
|
||||
else:
|
||||
if is_overridden('validation_epoch_end', model=model):
|
||||
model._current_fx_name = 'validation_epoch_end'
|
||||
if using_eval_result:
|
||||
eval_results = self.__gather_epoch_end_eval_results(outputs)
|
||||
|
||||
|
@ -314,8 +316,8 @@ class EvaluationLoop(object):
|
|||
self.__log_result_step_metrics(output, batch_idx)
|
||||
|
||||
def __log_result_step_metrics(self, output, batch_idx):
|
||||
step_log_metrics = output.batch_log_metrics
|
||||
step_pbar_metrics = output.batch_pbar_metrics
|
||||
step_log_metrics = output.get_batch_log_metrics(include_forked_originals=False)
|
||||
step_pbar_metrics = output.get_batch_pbar_metrics(include_forked_originals=False)
|
||||
|
||||
if len(step_log_metrics) > 0:
|
||||
# make the metrics appear as a different line in the same graph
|
||||
|
|
|
@ -52,8 +52,6 @@ def test__validation_step__log(tmpdir):
|
|||
'b',
|
||||
'step_b/epoch_0',
|
||||
'step_b/epoch_1',
|
||||
'b/epoch_0',
|
||||
'b/epoch_1',
|
||||
'epoch_b',
|
||||
'epoch',
|
||||
}
|
||||
|
@ -67,7 +65,7 @@ def test__validation_step__log(tmpdir):
|
|||
assert expected_cb_metrics == callback_metrics
|
||||
|
||||
|
||||
def test__validation_step__epoch_end__log(tmpdir):
|
||||
def test__validation_step__step_end__epoch_end__log(tmpdir):
|
||||
"""
|
||||
Tests that validation_step can log
|
||||
"""
|
||||
|
@ -88,16 +86,22 @@ def test__validation_step__epoch_end__log(tmpdir):
|
|||
self.log('c', acc)
|
||||
self.log('d', acc, on_step=True, on_epoch=True)
|
||||
self.validation_step_called = True
|
||||
return acc
|
||||
|
||||
def validation_step_end(self, acc):
|
||||
self.validation_step_end_called = True
|
||||
self.log('e', acc)
|
||||
self.log('f', acc, on_step=True, on_epoch=True)
|
||||
return ['random_thing']
|
||||
|
||||
def validation_epoch_end(self, outputs):
|
||||
self.log('e', torch.tensor(2, device=self.device), on_step=True, on_epoch=True)
|
||||
self.log('g', torch.tensor(2, device=self.device), on_epoch=True)
|
||||
self.validation_epoch_end_called = True
|
||||
|
||||
def backward(self, trainer, loss, optimizer, optimizer_idx):
|
||||
loss.backward()
|
||||
|
||||
model = TestModel()
|
||||
model.validation_step_end = None
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
|
@ -110,38 +114,32 @@ def test__validation_step__epoch_end__log(tmpdir):
|
|||
trainer.fit(model)
|
||||
|
||||
# make sure all the metrics are available for callbacks
|
||||
logged_metrics = set(trainer.logged_metrics.keys())
|
||||
expected_logged_metrics = {
|
||||
'epoch',
|
||||
'a',
|
||||
'b',
|
||||
'step_b',
|
||||
'epoch_b',
|
||||
'c',
|
||||
'd',
|
||||
'd/epoch_0',
|
||||
'd/epoch_1',
|
||||
'step_d/epoch_0',
|
||||
'step_d/epoch_1',
|
||||
'epoch_d',
|
||||
'e',
|
||||
'epoch_e',
|
||||
'epoch',
|
||||
'f',
|
||||
'step_f/epoch_0',
|
||||
'step_f/epoch_1',
|
||||
'epoch_f',
|
||||
'g',
|
||||
}
|
||||
|
||||
logged_metrics = set(trainer.logged_metrics.keys())
|
||||
assert expected_logged_metrics == logged_metrics
|
||||
|
||||
# we don't want to enable val metrics during steps because it is not something that users should do
|
||||
expected_cb_metrics = {
|
||||
'a',
|
||||
'b',
|
||||
'step_b',
|
||||
'epoch_b',
|
||||
'c',
|
||||
'd',
|
||||
'epoch_d',
|
||||
'e',
|
||||
'epoch_e',
|
||||
}
|
||||
progress_bar_metrics = set(trainer.progress_bar_metrics.keys())
|
||||
expected_pbar_metrics = set()
|
||||
assert expected_pbar_metrics == progress_bar_metrics
|
||||
|
||||
# we don't want to enable val metrics during steps because it is not something that users should do
|
||||
callback_metrics = set(trainer.callback_metrics.keys())
|
||||
expected_cb_metrics = {'a', 'b', 'c', 'd', 'e', 'epoch_b', 'epoch_d', 'epoch_f', 'f', 'g', 'step_b'}
|
||||
assert expected_cb_metrics == callback_metrics
|
||||
|
|
|
@ -17,12 +17,35 @@ def test__training_step__log(tmpdir):
|
|||
def training_step(self, batch, batch_idx):
|
||||
acc = self.step(batch, batch_idx)
|
||||
acc = acc + batch_idx
|
||||
self.log('step_acc', acc, on_step=True, on_epoch=False)
|
||||
self.log('epoch_acc', acc, on_step=False, on_epoch=True)
|
||||
self.log('no_prefix_step_epoch_acc', acc, on_step=True, on_epoch=True)
|
||||
self.log('pbar_step_acc', acc, on_step=True, prog_bar=True, on_epoch=False, logger=False)
|
||||
self.log('pbar_epoch_acc', acc, on_step=False, on_epoch=True, prog_bar=True, logger=False)
|
||||
self.log('pbar_step_epoch_acc', acc, on_step=True, on_epoch=True, prog_bar=True, logger=False)
|
||||
|
||||
# -----------
|
||||
# default
|
||||
# -----------
|
||||
self.log('default', acc)
|
||||
|
||||
# -----------
|
||||
# logger
|
||||
# -----------
|
||||
# on_step T on_epoch F
|
||||
self.log('l_s', acc, on_step=True, on_epoch=False, prog_bar=False, logger=True)
|
||||
|
||||
# on_step F on_epoch T
|
||||
self.log('l_e', acc, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
||||
|
||||
# on_step T on_epoch T
|
||||
self.log('l_se', acc, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
||||
|
||||
# -----------
|
||||
# pbar
|
||||
# -----------
|
||||
# on_step T on_epoch F
|
||||
self.log('p_s', acc, on_step=True, on_epoch=False, prog_bar=True, logger=False)
|
||||
|
||||
# on_step F on_epoch T
|
||||
self.log('p_e', acc, on_step=False, on_epoch=True, prog_bar=True, logger=False)
|
||||
|
||||
# on_step T on_epoch T
|
||||
self.log('p_se', acc, on_step=True, on_epoch=True, prog_bar=True, logger=False)
|
||||
|
||||
self.training_step_called = True
|
||||
return acc
|
||||
|
@ -46,19 +69,38 @@ def test__training_step__log(tmpdir):
|
|||
# make sure correct steps were called
|
||||
assert model.training_step_called
|
||||
assert not model.training_step_end_called
|
||||
assert not model.training_epoch_end_called
|
||||
|
||||
# make sure all the metrics are available for callbacks
|
||||
metrics = [
|
||||
'step_acc',
|
||||
'epoch_acc',
|
||||
'no_prefix_step_epoch_acc', 'step_no_prefix_step_epoch_acc', 'epoch_no_prefix_step_epoch_acc',
|
||||
'pbar_step_acc',
|
||||
'pbar_epoch_acc',
|
||||
'pbar_step_epoch_acc', 'step_pbar_step_epoch_acc', 'epoch_pbar_step_epoch_acc',
|
||||
]
|
||||
expected_metrics = set(metrics + ['debug_epoch'])
|
||||
logged_metrics = set(trainer.logged_metrics.keys())
|
||||
expected_logged_metrics = {
|
||||
'epoch',
|
||||
'default',
|
||||
'l_e',
|
||||
'l_s',
|
||||
'l_se',
|
||||
'step_l_se',
|
||||
'epoch_l_se',
|
||||
}
|
||||
assert logged_metrics == expected_logged_metrics
|
||||
|
||||
pbar_metrics = set(trainer.progress_bar_metrics.keys())
|
||||
expected_pbar_metrics = {
|
||||
'p_e',
|
||||
'p_s',
|
||||
'p_se',
|
||||
'step_p_se',
|
||||
'epoch_p_se',
|
||||
}
|
||||
assert pbar_metrics == expected_pbar_metrics
|
||||
|
||||
callback_metrics = set(trainer.callback_metrics.keys())
|
||||
assert expected_metrics == callback_metrics
|
||||
callback_metrics.remove('debug_epoch')
|
||||
expected_callback_metrics = set()
|
||||
expected_callback_metrics = expected_callback_metrics.union(logged_metrics)
|
||||
expected_callback_metrics = expected_callback_metrics.union(pbar_metrics)
|
||||
expected_callback_metrics.remove('epoch')
|
||||
assert callback_metrics == expected_callback_metrics
|
||||
|
||||
|
||||
def test__training_step__epoch_end__log(tmpdir):
|
||||
|
@ -69,22 +111,17 @@ def test__training_step__epoch_end__log(tmpdir):
|
|||
|
||||
class TestModel(DeterministicModel):
|
||||
def training_step(self, batch, batch_idx):
|
||||
self.training_step_called = True
|
||||
acc = self.step(batch, batch_idx)
|
||||
acc = acc + batch_idx
|
||||
self.log('step_acc', acc, on_step=True, on_epoch=False)
|
||||
self.log('epoch_acc', acc, on_step=False, on_epoch=True)
|
||||
self.log('no_prefix_step_epoch_acc', acc, on_step=True, on_epoch=True)
|
||||
self.log('pbar_step_acc', acc, on_step=True, prog_bar=True, on_epoch=False, logger=False)
|
||||
self.log('pbar_epoch_acc', acc, on_step=False, on_epoch=True, prog_bar=True, logger=False)
|
||||
self.log('pbar_step_epoch_acc', acc, on_step=True, on_epoch=True, prog_bar=True, logger=False)
|
||||
|
||||
self.training_step_called = True
|
||||
self.log('a', acc, on_step=True, on_epoch=True)
|
||||
self.log_dict({'a1': acc, 'a2': acc})
|
||||
return acc
|
||||
|
||||
def training_epoch_end(self, outputs):
|
||||
self.training_epoch_end_called = True
|
||||
# logging is independent of epoch_end loops
|
||||
self.log('custom_epoch_end_metric', torch.tensor(37.2))
|
||||
self.log('b1', outputs[0]['loss'])
|
||||
self.log('b', outputs[0]['loss'], on_epoch=True, prog_bar=True, logger=True)
|
||||
|
||||
def backward(self, trainer, loss, optimizer, optimizer_idx):
|
||||
loss.backward()
|
||||
|
@ -108,27 +145,100 @@ def test__training_step__epoch_end__log(tmpdir):
|
|||
assert model.training_epoch_end_called
|
||||
|
||||
# make sure all the metrics are available for callbacks
|
||||
metrics = [
|
||||
'step_acc',
|
||||
'epoch_acc',
|
||||
'no_prefix_step_epoch_acc', 'step_no_prefix_step_epoch_acc', 'epoch_no_prefix_step_epoch_acc',
|
||||
'pbar_step_acc',
|
||||
'pbar_epoch_acc',
|
||||
'pbar_step_epoch_acc', 'step_pbar_step_epoch_acc', 'epoch_pbar_step_epoch_acc',
|
||||
'custom_epoch_end_metric'
|
||||
]
|
||||
expected_metrics = set(metrics + ['debug_epoch'])
|
||||
logged_metrics = set(trainer.logged_metrics.keys())
|
||||
expected_logged_metrics = {
|
||||
'epoch',
|
||||
'a',
|
||||
'step_a',
|
||||
'epoch_a',
|
||||
'b',
|
||||
'b1',
|
||||
'a1',
|
||||
'a2'
|
||||
}
|
||||
assert logged_metrics == expected_logged_metrics
|
||||
|
||||
pbar_metrics = set(trainer.progress_bar_metrics.keys())
|
||||
expected_pbar_metrics = {
|
||||
'b',
|
||||
}
|
||||
assert pbar_metrics == expected_pbar_metrics
|
||||
|
||||
callback_metrics = set(trainer.callback_metrics.keys())
|
||||
assert expected_metrics == callback_metrics
|
||||
callback_metrics.remove('debug_epoch')
|
||||
expected_callback_metrics = set()
|
||||
expected_callback_metrics = expected_callback_metrics.union(logged_metrics)
|
||||
expected_callback_metrics = expected_callback_metrics.union(pbar_metrics)
|
||||
expected_callback_metrics.remove('epoch')
|
||||
assert callback_metrics == expected_callback_metrics
|
||||
|
||||
# verify global steps were correctly called
|
||||
|
||||
# epoch 0
|
||||
assert trainer.dev_debugger.logged_metrics[0]['global_step'] == 0
|
||||
assert trainer.dev_debugger.logged_metrics[1]['global_step'] == 1
|
||||
assert trainer.dev_debugger.logged_metrics[2]['global_step'] == 1
|
||||
def test__training_step__step_end__epoch_end__log(tmpdir):
|
||||
"""
|
||||
Tests that only training_step can be used
|
||||
"""
|
||||
os.environ['PL_DEV_DEBUG'] = '1'
|
||||
|
||||
# epoch 1
|
||||
assert trainer.dev_debugger.logged_metrics[3]['global_step'] == 2
|
||||
assert trainer.dev_debugger.logged_metrics[4]['global_step'] == 3
|
||||
assert trainer.dev_debugger.logged_metrics[5]['global_step'] == 3
|
||||
class TestModel(DeterministicModel):
|
||||
def training_step(self, batch, batch_idx):
|
||||
self.training_step_called = True
|
||||
acc = self.step(batch, batch_idx)
|
||||
acc = acc + batch_idx
|
||||
self.log('a', acc, on_step=True, on_epoch=True)
|
||||
return acc
|
||||
|
||||
def training_step_end(self, out):
|
||||
self.training_step_end_called = True
|
||||
self.log('b', out, on_step=True, on_epoch=True, prog_bar=True, logger=True)
|
||||
return out
|
||||
|
||||
def training_epoch_end(self, outputs):
|
||||
self.training_epoch_end_called = True
|
||||
self.log('c', outputs[0]['loss'], on_epoch=True, prog_bar=True, logger=True)
|
||||
|
||||
def backward(self, trainer, loss, optimizer, optimizer_idx):
|
||||
loss.backward()
|
||||
|
||||
model = TestModel()
|
||||
model.val_dataloader = None
|
||||
|
||||
trainer = Trainer(
|
||||
default_root_dir=tmpdir,
|
||||
limit_train_batches=2,
|
||||
limit_val_batches=2,
|
||||
max_epochs=2,
|
||||
row_log_interval=1,
|
||||
weights_summary=None,
|
||||
)
|
||||
trainer.fit(model)
|
||||
|
||||
# make sure correct steps were called
|
||||
assert model.training_step_called
|
||||
assert model.training_step_end_called
|
||||
assert model.training_epoch_end_called
|
||||
|
||||
# make sure all the metrics are available for callbacks
|
||||
logged_metrics = set(trainer.logged_metrics.keys())
|
||||
expected_logged_metrics = {
|
||||
'a',
|
||||
'step_a',
|
||||
'epoch_a',
|
||||
'b',
|
||||
'step_b',
|
||||
'epoch_b',
|
||||
'c',
|
||||
'epoch',
|
||||
}
|
||||
assert logged_metrics == expected_logged_metrics
|
||||
|
||||
pbar_metrics = set(trainer.progress_bar_metrics.keys())
|
||||
expected_pbar_metrics = {'b', 'c', 'epoch_b', 'step_b'}
|
||||
assert pbar_metrics == expected_pbar_metrics
|
||||
|
||||
callback_metrics = set(trainer.callback_metrics.keys())
|
||||
callback_metrics.remove('debug_epoch')
|
||||
expected_callback_metrics = set()
|
||||
expected_callback_metrics = expected_callback_metrics.union(logged_metrics)
|
||||
expected_callback_metrics = expected_callback_metrics.union(pbar_metrics)
|
||||
expected_callback_metrics.remove('epoch')
|
||||
assert callback_metrics == expected_callback_metrics
|
||||
|
|
Loading…
Reference in New Issue