ref: (results 1/n) enable tracking original metric when step and epoch are both true (#3685)

* enable tracking original metric when step and epoch are both true
This commit is contained in:
William Falcon 2020-09-27 22:08:31 -04:00 committed by GitHub
parent 931995b55b
commit ff2bab0996
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 48 additions and 25 deletions

View File

@ -152,6 +152,8 @@ class LightningDataModule(DataHooks, metaclass=_DataModuleWrapper):
self._has_setup_fit = False
self._has_setup_test = False
self.trainer = None
@property
def train_transforms(self):
"""

View File

@ -161,21 +161,22 @@ class Result(Dict):
tbptt_pad_token=tbptt_pad_token,
)
self.__setitem__(epoch_name, value)
else:
self.__set_meta(
name,
value,
prog_bar,
logger,
on_step,
on_epoch,
reduce_fx,
tbptt_reduce_fx=tbptt_reduce_fx,
tbptt_pad_token=tbptt_pad_token,
)
# set the value
self.__setitem__(name, value)
# always log the original metric
self.__set_meta(
name,
value,
prog_bar,
logger,
on_step,
on_epoch,
reduce_fx,
tbptt_reduce_fx=tbptt_reduce_fx,
tbptt_pad_token=tbptt_pad_token,
)
# set the value
self.__setitem__(name, value)
def __set_meta(
self,
@ -378,12 +379,17 @@ class Result(Dict):
def reduce_across_time(cls, time_outputs):
# auto-reduce across time for tbptt
meta = time_outputs[0]['meta']
# in 1.0 the results have 'extra'. Once we deprecate 0.10.0 we may not need this
if 'extra' in time_outputs[0]:
[x.pop('extra', None) for x in time_outputs]
result = cls()
result = recursive_gather(time_outputs, result)
recursive_stack(result)
for k, value in result.items():
if k == 'meta':
if k in ['meta', 'extra']:
continue
# pick the reduce fx

View File

@ -1,8 +1,9 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.datamodule import LightningDataModule
from typing import Union
def is_overridden(method_name: str, model: LightningModule) -> bool:
def is_overridden(method_name: str, model: Union[LightningModule, LightningDataModule]) -> bool:
# if you pass DataModule instead of None or a LightningModule, we use LightningDataModule as super
# TODO - refector this function to accept model_name, instance, parent so it makes more sense
super_object = LightningModule if not isinstance(model, LightningDataModule) else LightningDataModule

View File

@ -188,6 +188,12 @@ class DeterministicModel(LightningModule):
# only saw 4 batches
assert isinstance(result, TrainResult)
result.step_epoch_log_acc2 = result.step_step_epoch_log_acc2.prod()
result.step_epoch_pbar_acc3 = result.step_step_epoch_pbar_acc3.prod()
result.step_epoch_log_and_pbar_acc1 = result.step_epoch_log_and_pbar_acc1.prod()
result.minimize = result.minimize.mean()
result.checkpoint_on = result.checkpoint_on.mean()
result.step_step_epoch_log_and_pbar_acc1 = result.step_step_epoch_log_and_pbar_acc1.prod()
result.epoch_step_epoch_log_and_pbar_acc1 = result.epoch_step_epoch_log_and_pbar_acc1.prod()
result.step_step_epoch_log_acc2 = result.step_step_epoch_log_acc2.prod()

View File

@ -137,6 +137,7 @@ class TrainingStepVariations(ABC):
result.log(f'{eval_name}_step_metric', loss_val + 1, on_step=True)
setattr(self, f'{eval_name}_step_called', True)
return result
def eval_step_end_full_loop_result_obj_dp(self, result):
@ -150,10 +151,14 @@ class TrainingStepVariations(ABC):
reduced = getattr(result, f'epoch_{eval_name}_step_metric').mean()
setattr(result, f'epoch_{eval_name}_step_metric', reduced)
reduced = getattr(result, f'{eval_name}_step_metric').mean()
setattr(result, f'{eval_name}_step_metric', reduced)
result.checkpoint_on = result.checkpoint_on.mean()
result.early_stop_on = result.early_stop_on.mean()
result.log(f'{eval_name}_step_end_metric', torch.tensor(1).type_as(result.checkpoint_on))
setattr(self, f'{eval_name}_step_end_called', True)
return result
def eval_epoch_end_full_loop_result_obj_dp(self, result):
@ -176,6 +181,9 @@ class TrainingStepVariations(ABC):
reduced = getattr(result, f'{eval_name}_step_end_metric').mean()
setattr(result, f'{eval_name}_step_end_metric', reduced)
reduced = getattr(result, f'{eval_name}_step_metric').mean()
setattr(result, f'{eval_name}_step_metric', reduced)
return result
def training_step__using_metrics(self, batch, batch_idx, optimizer_idx=None):

View File

@ -201,7 +201,7 @@ def test_training_step_result_log_step_and_epoch(tmpdir):
assert not model.training_step_end_called
assert not model.training_epoch_end_called
assert len(trainer.logger_connector.callback_metrics) == 8
assert len(trainer.logger_connector.callback_metrics) == 11
# make sure correct metrics are logged (one per batch step as requested)
assert len(trainer.dev_debugger.logged_metrics) == (epochs * batches) + epochs
@ -227,7 +227,7 @@ def test_training_step_result_log_step_and_epoch(tmpdir):
assert logged_metrics['step_step_epoch_log_and_pbar_acc1'] == expected_val_1
assert logged_metrics['step_step_epoch_log_acc2'] == expected_val_2
assert 'step_epoch_pbar_acc3' not in logged_metrics
assert len(logged_metrics) == 4
assert len(logged_metrics) == 6
# make sure the metrics for the epoch end are actual means (the default reduce fx) or all the batches
epoch_end_metrics = epoch_outputs[-1]
@ -236,7 +236,7 @@ def test_training_step_result_log_step_and_epoch(tmpdir):
assert epoch_end_metrics['epoch_step_epoch_log_and_pbar_acc1'] == eval_1
assert epoch_end_metrics['epoch_step_epoch_log_acc2'] == eval_2
assert 'step_epoch_pbar_acc3' not in epoch_end_metrics
assert len(logged_metrics) == 4
assert len(logged_metrics) == 6
# make sure we are using the correct metrics for callbacks
assert trainer.logger_connector.callback_metrics['checkpoint_on'] == 171
@ -268,7 +268,7 @@ def test_training_step_result_log_step_and_epoch(tmpdir):
assert logged_metrics['step_step_epoch_log_and_pbar_acc1'] == expected_val_1
assert logged_metrics['step_step_epoch_pbar_acc3'] == expected_val_2
assert 'step_epoch_log_acc2' not in logged_metrics
assert len(logged_metrics) == 3
assert len(logged_metrics) == 5
# make sure the metrics for the epoch end are actual means (the default reduce fx) or all the batches
epoch_end_metrics = epoch_outputs[-1]
@ -277,7 +277,7 @@ def test_training_step_result_log_step_and_epoch(tmpdir):
assert epoch_end_metrics['epoch_step_epoch_log_and_pbar_acc1'] == eval_1
assert epoch_end_metrics['epoch_step_epoch_pbar_acc3'] == eval_2
assert 'step_epoch_log_acc2' not in epoch_end_metrics
assert len(logged_metrics) == 3
assert len(logged_metrics) == 5
# -----------------------------------------
# make sure training outputs what is expected
@ -287,7 +287,7 @@ def test_training_step_result_log_step_and_epoch(tmpdir):
out = trainer.train_loop.run_training_batch(batch, batch_idx, 0)
assert out.signal == 0
assert len(out.batch_log_metrics) == 2
assert len(out.batch_log_metrics) == 4
train_step_out = out.training_step_output_for_epoch_end
assert len(train_step_out) == 1
@ -328,7 +328,7 @@ def test_training_step_epoch_end_result(tmpdir):
)
trainer.fit(model)
assert len(trainer.logger_connector.callback_metrics) == 11
assert len(trainer.logger_connector.callback_metrics) == 17
# make sure correct steps were called
assert model.training_step_called
@ -369,7 +369,7 @@ def test_training_step_epoch_end_result(tmpdir):
out = trainer.train_loop.run_training_batch(batch, batch_idx, 0)
assert out.signal == 0
assert len(out.batch_log_metrics) == 2
assert len(out.batch_log_metrics) == 4
train_step_out = out.training_step_output_for_epoch_end
assert len(train_step_out) == 1

View File

@ -278,7 +278,7 @@ def test_val_step_epoch_step_metrics(tmpdir):
)
trainer.fit(model)
assert len(trainer.logger_connector.callback_metrics) == 7
assert len(trainer.logger_connector.callback_metrics) == 11
# make sure correct steps were called
assert model.validation_step_called