ref: (results 1/n) enable tracking original metric when step and epoch are both true (#3685)
* enable tracking original metric when step and epoch are both true
This commit is contained in:
parent
931995b55b
commit
ff2bab0996
|
@ -152,6 +152,8 @@ class LightningDataModule(DataHooks, metaclass=_DataModuleWrapper):
|
|||
self._has_setup_fit = False
|
||||
self._has_setup_test = False
|
||||
|
||||
self.trainer = None
|
||||
|
||||
@property
|
||||
def train_transforms(self):
|
||||
"""
|
||||
|
|
|
@ -161,21 +161,22 @@ class Result(Dict):
|
|||
tbptt_pad_token=tbptt_pad_token,
|
||||
)
|
||||
self.__setitem__(epoch_name, value)
|
||||
else:
|
||||
self.__set_meta(
|
||||
name,
|
||||
value,
|
||||
prog_bar,
|
||||
logger,
|
||||
on_step,
|
||||
on_epoch,
|
||||
reduce_fx,
|
||||
tbptt_reduce_fx=tbptt_reduce_fx,
|
||||
tbptt_pad_token=tbptt_pad_token,
|
||||
)
|
||||
|
||||
# set the value
|
||||
self.__setitem__(name, value)
|
||||
# always log the original metric
|
||||
self.__set_meta(
|
||||
name,
|
||||
value,
|
||||
prog_bar,
|
||||
logger,
|
||||
on_step,
|
||||
on_epoch,
|
||||
reduce_fx,
|
||||
tbptt_reduce_fx=tbptt_reduce_fx,
|
||||
tbptt_pad_token=tbptt_pad_token,
|
||||
)
|
||||
|
||||
# set the value
|
||||
self.__setitem__(name, value)
|
||||
|
||||
def __set_meta(
|
||||
self,
|
||||
|
@ -378,12 +379,17 @@ class Result(Dict):
|
|||
def reduce_across_time(cls, time_outputs):
|
||||
# auto-reduce across time for tbptt
|
||||
meta = time_outputs[0]['meta']
|
||||
|
||||
# in 1.0 the results have 'extra'. Once we deprecate 0.10.0 we may not need this
|
||||
if 'extra' in time_outputs[0]:
|
||||
[x.pop('extra', None) for x in time_outputs]
|
||||
|
||||
result = cls()
|
||||
result = recursive_gather(time_outputs, result)
|
||||
recursive_stack(result)
|
||||
|
||||
for k, value in result.items():
|
||||
if k == 'meta':
|
||||
if k in ['meta', 'extra']:
|
||||
continue
|
||||
|
||||
# pick the reduce fx
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.core.datamodule import LightningDataModule
|
||||
from typing import Union
|
||||
|
||||
|
||||
def is_overridden(method_name: str, model: LightningModule) -> bool:
|
||||
def is_overridden(method_name: str, model: Union[LightningModule, LightningDataModule]) -> bool:
|
||||
# if you pass DataModule instead of None or a LightningModule, we use LightningDataModule as super
|
||||
# TODO - refector this function to accept model_name, instance, parent so it makes more sense
|
||||
super_object = LightningModule if not isinstance(model, LightningDataModule) else LightningDataModule
|
||||
|
|
|
@ -188,6 +188,12 @@ class DeterministicModel(LightningModule):
|
|||
# only saw 4 batches
|
||||
assert isinstance(result, TrainResult)
|
||||
|
||||
result.step_epoch_log_acc2 = result.step_step_epoch_log_acc2.prod()
|
||||
result.step_epoch_pbar_acc3 = result.step_step_epoch_pbar_acc3.prod()
|
||||
result.step_epoch_log_and_pbar_acc1 = result.step_epoch_log_and_pbar_acc1.prod()
|
||||
result.minimize = result.minimize.mean()
|
||||
result.checkpoint_on = result.checkpoint_on.mean()
|
||||
|
||||
result.step_step_epoch_log_and_pbar_acc1 = result.step_step_epoch_log_and_pbar_acc1.prod()
|
||||
result.epoch_step_epoch_log_and_pbar_acc1 = result.epoch_step_epoch_log_and_pbar_acc1.prod()
|
||||
result.step_step_epoch_log_acc2 = result.step_step_epoch_log_acc2.prod()
|
||||
|
|
|
@ -137,6 +137,7 @@ class TrainingStepVariations(ABC):
|
|||
result.log(f'{eval_name}_step_metric', loss_val + 1, on_step=True)
|
||||
|
||||
setattr(self, f'{eval_name}_step_called', True)
|
||||
|
||||
return result
|
||||
|
||||
def eval_step_end_full_loop_result_obj_dp(self, result):
|
||||
|
@ -150,10 +151,14 @@ class TrainingStepVariations(ABC):
|
|||
reduced = getattr(result, f'epoch_{eval_name}_step_metric').mean()
|
||||
setattr(result, f'epoch_{eval_name}_step_metric', reduced)
|
||||
|
||||
reduced = getattr(result, f'{eval_name}_step_metric').mean()
|
||||
setattr(result, f'{eval_name}_step_metric', reduced)
|
||||
|
||||
result.checkpoint_on = result.checkpoint_on.mean()
|
||||
result.early_stop_on = result.early_stop_on.mean()
|
||||
result.log(f'{eval_name}_step_end_metric', torch.tensor(1).type_as(result.checkpoint_on))
|
||||
setattr(self, f'{eval_name}_step_end_called', True)
|
||||
|
||||
return result
|
||||
|
||||
def eval_epoch_end_full_loop_result_obj_dp(self, result):
|
||||
|
@ -176,6 +181,9 @@ class TrainingStepVariations(ABC):
|
|||
reduced = getattr(result, f'{eval_name}_step_end_metric').mean()
|
||||
setattr(result, f'{eval_name}_step_end_metric', reduced)
|
||||
|
||||
reduced = getattr(result, f'{eval_name}_step_metric').mean()
|
||||
setattr(result, f'{eval_name}_step_metric', reduced)
|
||||
|
||||
return result
|
||||
|
||||
def training_step__using_metrics(self, batch, batch_idx, optimizer_idx=None):
|
||||
|
|
|
@ -201,7 +201,7 @@ def test_training_step_result_log_step_and_epoch(tmpdir):
|
|||
assert not model.training_step_end_called
|
||||
assert not model.training_epoch_end_called
|
||||
|
||||
assert len(trainer.logger_connector.callback_metrics) == 8
|
||||
assert len(trainer.logger_connector.callback_metrics) == 11
|
||||
|
||||
# make sure correct metrics are logged (one per batch step as requested)
|
||||
assert len(trainer.dev_debugger.logged_metrics) == (epochs * batches) + epochs
|
||||
|
@ -227,7 +227,7 @@ def test_training_step_result_log_step_and_epoch(tmpdir):
|
|||
assert logged_metrics['step_step_epoch_log_and_pbar_acc1'] == expected_val_1
|
||||
assert logged_metrics['step_step_epoch_log_acc2'] == expected_val_2
|
||||
assert 'step_epoch_pbar_acc3' not in logged_metrics
|
||||
assert len(logged_metrics) == 4
|
||||
assert len(logged_metrics) == 6
|
||||
|
||||
# make sure the metrics for the epoch end are actual means (the default reduce fx) or all the batches
|
||||
epoch_end_metrics = epoch_outputs[-1]
|
||||
|
@ -236,7 +236,7 @@ def test_training_step_result_log_step_and_epoch(tmpdir):
|
|||
assert epoch_end_metrics['epoch_step_epoch_log_and_pbar_acc1'] == eval_1
|
||||
assert epoch_end_metrics['epoch_step_epoch_log_acc2'] == eval_2
|
||||
assert 'step_epoch_pbar_acc3' not in epoch_end_metrics
|
||||
assert len(logged_metrics) == 4
|
||||
assert len(logged_metrics) == 6
|
||||
|
||||
# make sure we are using the correct metrics for callbacks
|
||||
assert trainer.logger_connector.callback_metrics['checkpoint_on'] == 171
|
||||
|
@ -268,7 +268,7 @@ def test_training_step_result_log_step_and_epoch(tmpdir):
|
|||
assert logged_metrics['step_step_epoch_log_and_pbar_acc1'] == expected_val_1
|
||||
assert logged_metrics['step_step_epoch_pbar_acc3'] == expected_val_2
|
||||
assert 'step_epoch_log_acc2' not in logged_metrics
|
||||
assert len(logged_metrics) == 3
|
||||
assert len(logged_metrics) == 5
|
||||
|
||||
# make sure the metrics for the epoch end are actual means (the default reduce fx) or all the batches
|
||||
epoch_end_metrics = epoch_outputs[-1]
|
||||
|
@ -277,7 +277,7 @@ def test_training_step_result_log_step_and_epoch(tmpdir):
|
|||
assert epoch_end_metrics['epoch_step_epoch_log_and_pbar_acc1'] == eval_1
|
||||
assert epoch_end_metrics['epoch_step_epoch_pbar_acc3'] == eval_2
|
||||
assert 'step_epoch_log_acc2' not in epoch_end_metrics
|
||||
assert len(logged_metrics) == 3
|
||||
assert len(logged_metrics) == 5
|
||||
|
||||
# -----------------------------------------
|
||||
# make sure training outputs what is expected
|
||||
|
@ -287,7 +287,7 @@ def test_training_step_result_log_step_and_epoch(tmpdir):
|
|||
|
||||
out = trainer.train_loop.run_training_batch(batch, batch_idx, 0)
|
||||
assert out.signal == 0
|
||||
assert len(out.batch_log_metrics) == 2
|
||||
assert len(out.batch_log_metrics) == 4
|
||||
|
||||
train_step_out = out.training_step_output_for_epoch_end
|
||||
assert len(train_step_out) == 1
|
||||
|
@ -328,7 +328,7 @@ def test_training_step_epoch_end_result(tmpdir):
|
|||
)
|
||||
trainer.fit(model)
|
||||
|
||||
assert len(trainer.logger_connector.callback_metrics) == 11
|
||||
assert len(trainer.logger_connector.callback_metrics) == 17
|
||||
|
||||
# make sure correct steps were called
|
||||
assert model.training_step_called
|
||||
|
@ -369,7 +369,7 @@ def test_training_step_epoch_end_result(tmpdir):
|
|||
|
||||
out = trainer.train_loop.run_training_batch(batch, batch_idx, 0)
|
||||
assert out.signal == 0
|
||||
assert len(out.batch_log_metrics) == 2
|
||||
assert len(out.batch_log_metrics) == 4
|
||||
|
||||
train_step_out = out.training_step_output_for_epoch_end
|
||||
assert len(train_step_out) == 1
|
||||
|
|
|
@ -278,7 +278,7 @@ def test_val_step_epoch_step_metrics(tmpdir):
|
|||
)
|
||||
trainer.fit(model)
|
||||
|
||||
assert len(trainer.logger_connector.callback_metrics) == 7
|
||||
assert len(trainer.logger_connector.callback_metrics) == 11
|
||||
|
||||
# make sure correct steps were called
|
||||
assert model.validation_step_called
|
||||
|
|
Loading…
Reference in New Issue