diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index a62a051da9..22a06624a4 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -32,6 +32,7 @@ from pytorch_lightning.core.step_result import EvalResult, TrainResult from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.core.step_result import Result from pytorch_lightning.utilities.parsing import ( AttributeDict, @@ -216,7 +217,15 @@ class LightningModule( if self._results is not None: # in any epoch end can't log step metrics (only epoch metric) if 'epoch_end' in self._current_fx_name and on_step: - on_step = False + m = f'on_step=True cannot be used on {self._current_fx_name} method' + raise MisconfigurationException(m) + + if 'epoch_end' in self._current_fx_name and on_epoch == False: + m = f'on_epoch cannot be False when called from the {self._current_fx_name} method' + raise MisconfigurationException(m) + + # add log_dict + # TODO: if logged twice fail with crash # set the default depending on the fx_name on_step = self.__auto_choose_log_on_step(on_step) @@ -238,6 +247,60 @@ class LightningModule( sync_dist_group ) + def log_dict( + self, + dictionary: dict, + prog_bar: bool = False, + logger: bool = True, + on_step: Union[None, bool] = None, + on_epoch: Union[None, bool] = None, + reduce_fx: Callable = torch.mean, + tbptt_reduce_fx: Callable = torch.mean, + tbptt_pad_token: int = 0, + enable_graph: bool = False, + sync_dist: bool = False, + sync_dist_op: Union[Any, str] = 'mean', + sync_dist_group: Optional[Any] = None, + ): + """ + Log a dictonary of values at once + + Example:: + + values = {'loss': loss, 'acc': acc, ..., 'metric_n': metric_n} + self.log_dict(values) + + Args: + dictionary: key value pairs (str, tensors) + prog_bar: if True logs to the progress base + logger: if True logs to the logger + on_step: if True logs at this step. None auto-logs for training_step but not validation/test_step + on_epoch: if True logs epoch accumulated metrics. None auto-logs for val/test step but not training_step + reduce_fx: Torch.mean by default + tbptt_reduce_fx: function to reduce on truncated back prop + tbptt_pad_token: token to use for padding + enable_graph: if True, will not auto detach the graph + sync_dist: if True, reduces the metric across GPUs/TPUs + sync_dist_op: the op to sync across + sync_dist_group: the ddp group: + """ + for k, v in dictionary.items(): + self.log( + name=k, + value=v, + prog_bar=prog_bar, + logger=logger, + on_step=on_step, + on_epoch=on_epoch, + reduce_fx=reduce_fx, + enable_graph=enable_graph, + sync_dist=sync_dist, + sync_dist_group=sync_dist_group, + sync_dist_op=sync_dist_op, + tbptt_pad_token=tbptt_pad_token, + tbptt_reduce_fx=tbptt_reduce_fx, + ) + def __auto_choose_log_on_step(self, on_step): if on_step is None: if self._current_fx_name in {'training_step', 'training_step_end'}: diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 7ef23aaee4..28f0e91f87 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -131,7 +131,10 @@ class Result(Dict): # if user requests both step and epoch, then we split the metric in two automatically # one will be logged per step. the other per epoch + was_forked = False if on_step and on_epoch: + was_forked = True + # set step version step_name = f'step_{name}' self.__set_meta( @@ -144,6 +147,7 @@ class Result(Dict): reduce_fx=reduce_fx, tbptt_reduce_fx=tbptt_reduce_fx, tbptt_pad_token=tbptt_pad_token, + forked=False ) self.__setitem__(step_name, value) @@ -159,6 +163,7 @@ class Result(Dict): reduce_fx=reduce_fx, tbptt_reduce_fx=tbptt_reduce_fx, tbptt_pad_token=tbptt_pad_token, + forked=False ) self.__setitem__(epoch_name, value) @@ -173,6 +178,7 @@ class Result(Dict): reduce_fx, tbptt_reduce_fx=tbptt_reduce_fx, tbptt_pad_token=tbptt_pad_token, + forked=was_forked ) # set the value @@ -189,6 +195,7 @@ class Result(Dict): reduce_fx: Callable, tbptt_pad_token: int, tbptt_reduce_fx: Callable, + forked: bool ): # set the meta for the item meta_value = value @@ -201,6 +208,7 @@ class Result(Dict): value=meta_value, tbptt_reduce_fx=tbptt_reduce_fx, tbptt_pad_token=tbptt_pad_token, + forked=forked ) self['meta'][name] = meta @@ -222,9 +230,10 @@ class Result(Dict): return result - def get_batch_log_metrics(self) -> dict: + def get_batch_log_metrics(self, include_forked_originals=True) -> dict: """ Gets the metrics to log at the end of the batch step + """ result = {} @@ -232,6 +241,10 @@ class Result(Dict): for k, options in meta.items(): if k == '_internal': continue + + if options['forked'] and not include_forked_originals: + continue + if options['logger'] and options['on_step']: result[k] = self[k] return result @@ -264,7 +277,7 @@ class Result(Dict): result[k] = self[k] return result - def get_batch_pbar_metrics(self): + def get_batch_pbar_metrics(self, include_forked_originals=True): """ Gets the metrics to log at the end of the batch step """ @@ -274,6 +287,9 @@ class Result(Dict): for k, options in meta.items(): if k == '_internal': continue + if options['forked'] and not include_forked_originals: + continue + if options['prog_bar'] and options['on_step']: result[k] = self[k] return result diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index de4e49eba7..b1a2b21b03 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -202,6 +202,7 @@ class EvaluationLoop(object): if self.testing: if is_overridden('test_epoch_end', model=model): + model._current_fx_name = 'test_epoch_end' if using_eval_result: eval_results = self.__gather_epoch_end_eval_results(outputs) @@ -210,6 +211,7 @@ class EvaluationLoop(object): else: if is_overridden('validation_epoch_end', model=model): + model._current_fx_name = 'validation_epoch_end' if using_eval_result: eval_results = self.__gather_epoch_end_eval_results(outputs) @@ -314,8 +316,8 @@ class EvaluationLoop(object): self.__log_result_step_metrics(output, batch_idx) def __log_result_step_metrics(self, output, batch_idx): - step_log_metrics = output.batch_log_metrics - step_pbar_metrics = output.batch_pbar_metrics + step_log_metrics = output.get_batch_log_metrics(include_forked_originals=False) + step_pbar_metrics = output.get_batch_pbar_metrics(include_forked_originals=False) if len(step_log_metrics) > 0: # make the metrics appear as a different line in the same graph diff --git a/tests/trainer/test_eval_loop_logging_1_0.py b/tests/trainer/test_eval_loop_logging_1_0.py index 0bdf6afa54..aa7ff6580e 100644 --- a/tests/trainer/test_eval_loop_logging_1_0.py +++ b/tests/trainer/test_eval_loop_logging_1_0.py @@ -52,8 +52,6 @@ def test__validation_step__log(tmpdir): 'b', 'step_b/epoch_0', 'step_b/epoch_1', - 'b/epoch_0', - 'b/epoch_1', 'epoch_b', 'epoch', } @@ -67,7 +65,7 @@ def test__validation_step__log(tmpdir): assert expected_cb_metrics == callback_metrics -def test__validation_step__epoch_end__log(tmpdir): +def test__validation_step__step_end__epoch_end__log(tmpdir): """ Tests that validation_step can log """ @@ -88,16 +86,22 @@ def test__validation_step__epoch_end__log(tmpdir): self.log('c', acc) self.log('d', acc, on_step=True, on_epoch=True) self.validation_step_called = True + return acc + + def validation_step_end(self, acc): + self.validation_step_end_called = True + self.log('e', acc) + self.log('f', acc, on_step=True, on_epoch=True) + return ['random_thing'] def validation_epoch_end(self, outputs): - self.log('e', torch.tensor(2, device=self.device), on_step=True, on_epoch=True) + self.log('g', torch.tensor(2, device=self.device), on_epoch=True) self.validation_epoch_end_called = True def backward(self, trainer, loss, optimizer, optimizer_idx): loss.backward() model = TestModel() - model.validation_step_end = None trainer = Trainer( default_root_dir=tmpdir, @@ -110,38 +114,32 @@ def test__validation_step__epoch_end__log(tmpdir): trainer.fit(model) # make sure all the metrics are available for callbacks + logged_metrics = set(trainer.logged_metrics.keys()) expected_logged_metrics = { + 'epoch', 'a', 'b', 'step_b', 'epoch_b', 'c', 'd', - 'd/epoch_0', - 'd/epoch_1', 'step_d/epoch_0', 'step_d/epoch_1', 'epoch_d', 'e', - 'epoch_e', - 'epoch', + 'f', + 'step_f/epoch_0', + 'step_f/epoch_1', + 'epoch_f', + 'g', } - - logged_metrics = set(trainer.logged_metrics.keys()) assert expected_logged_metrics == logged_metrics - # we don't want to enable val metrics during steps because it is not something that users should do - expected_cb_metrics = { - 'a', - 'b', - 'step_b', - 'epoch_b', - 'c', - 'd', - 'epoch_d', - 'e', - 'epoch_e', - } + progress_bar_metrics = set(trainer.progress_bar_metrics.keys()) + expected_pbar_metrics = set() + assert expected_pbar_metrics == progress_bar_metrics + # we don't want to enable val metrics during steps because it is not something that users should do callback_metrics = set(trainer.callback_metrics.keys()) + expected_cb_metrics = {'a', 'b', 'c', 'd', 'e', 'epoch_b', 'epoch_d', 'epoch_f', 'f', 'g', 'step_b'} assert expected_cb_metrics == callback_metrics diff --git a/tests/trainer/test_train_loop_logging_1_0.py b/tests/trainer/test_train_loop_logging_1_0.py index 4936bf525f..ff969f324f 100644 --- a/tests/trainer/test_train_loop_logging_1_0.py +++ b/tests/trainer/test_train_loop_logging_1_0.py @@ -17,12 +17,35 @@ def test__training_step__log(tmpdir): def training_step(self, batch, batch_idx): acc = self.step(batch, batch_idx) acc = acc + batch_idx - self.log('step_acc', acc, on_step=True, on_epoch=False) - self.log('epoch_acc', acc, on_step=False, on_epoch=True) - self.log('no_prefix_step_epoch_acc', acc, on_step=True, on_epoch=True) - self.log('pbar_step_acc', acc, on_step=True, prog_bar=True, on_epoch=False, logger=False) - self.log('pbar_epoch_acc', acc, on_step=False, on_epoch=True, prog_bar=True, logger=False) - self.log('pbar_step_epoch_acc', acc, on_step=True, on_epoch=True, prog_bar=True, logger=False) + + # ----------- + # default + # ----------- + self.log('default', acc) + + # ----------- + # logger + # ----------- + # on_step T on_epoch F + self.log('l_s', acc, on_step=True, on_epoch=False, prog_bar=False, logger=True) + + # on_step F on_epoch T + self.log('l_e', acc, on_step=False, on_epoch=True, prog_bar=False, logger=True) + + # on_step T on_epoch T + self.log('l_se', acc, on_step=True, on_epoch=True, prog_bar=False, logger=True) + + # ----------- + # pbar + # ----------- + # on_step T on_epoch F + self.log('p_s', acc, on_step=True, on_epoch=False, prog_bar=True, logger=False) + + # on_step F on_epoch T + self.log('p_e', acc, on_step=False, on_epoch=True, prog_bar=True, logger=False) + + # on_step T on_epoch T + self.log('p_se', acc, on_step=True, on_epoch=True, prog_bar=True, logger=False) self.training_step_called = True return acc @@ -46,19 +69,38 @@ def test__training_step__log(tmpdir): # make sure correct steps were called assert model.training_step_called assert not model.training_step_end_called + assert not model.training_epoch_end_called # make sure all the metrics are available for callbacks - metrics = [ - 'step_acc', - 'epoch_acc', - 'no_prefix_step_epoch_acc', 'step_no_prefix_step_epoch_acc', 'epoch_no_prefix_step_epoch_acc', - 'pbar_step_acc', - 'pbar_epoch_acc', - 'pbar_step_epoch_acc', 'step_pbar_step_epoch_acc', 'epoch_pbar_step_epoch_acc', - ] - expected_metrics = set(metrics + ['debug_epoch']) + logged_metrics = set(trainer.logged_metrics.keys()) + expected_logged_metrics = { + 'epoch', + 'default', + 'l_e', + 'l_s', + 'l_se', + 'step_l_se', + 'epoch_l_se', + } + assert logged_metrics == expected_logged_metrics + + pbar_metrics = set(trainer.progress_bar_metrics.keys()) + expected_pbar_metrics = { + 'p_e', + 'p_s', + 'p_se', + 'step_p_se', + 'epoch_p_se', + } + assert pbar_metrics == expected_pbar_metrics + callback_metrics = set(trainer.callback_metrics.keys()) - assert expected_metrics == callback_metrics + callback_metrics.remove('debug_epoch') + expected_callback_metrics = set() + expected_callback_metrics = expected_callback_metrics.union(logged_metrics) + expected_callback_metrics = expected_callback_metrics.union(pbar_metrics) + expected_callback_metrics.remove('epoch') + assert callback_metrics == expected_callback_metrics def test__training_step__epoch_end__log(tmpdir): @@ -69,22 +111,17 @@ def test__training_step__epoch_end__log(tmpdir): class TestModel(DeterministicModel): def training_step(self, batch, batch_idx): + self.training_step_called = True acc = self.step(batch, batch_idx) acc = acc + batch_idx - self.log('step_acc', acc, on_step=True, on_epoch=False) - self.log('epoch_acc', acc, on_step=False, on_epoch=True) - self.log('no_prefix_step_epoch_acc', acc, on_step=True, on_epoch=True) - self.log('pbar_step_acc', acc, on_step=True, prog_bar=True, on_epoch=False, logger=False) - self.log('pbar_epoch_acc', acc, on_step=False, on_epoch=True, prog_bar=True, logger=False) - self.log('pbar_step_epoch_acc', acc, on_step=True, on_epoch=True, prog_bar=True, logger=False) - - self.training_step_called = True + self.log('a', acc, on_step=True, on_epoch=True) + self.log_dict({'a1': acc, 'a2': acc}) return acc def training_epoch_end(self, outputs): self.training_epoch_end_called = True - # logging is independent of epoch_end loops - self.log('custom_epoch_end_metric', torch.tensor(37.2)) + self.log('b1', outputs[0]['loss']) + self.log('b', outputs[0]['loss'], on_epoch=True, prog_bar=True, logger=True) def backward(self, trainer, loss, optimizer, optimizer_idx): loss.backward() @@ -108,27 +145,100 @@ def test__training_step__epoch_end__log(tmpdir): assert model.training_epoch_end_called # make sure all the metrics are available for callbacks - metrics = [ - 'step_acc', - 'epoch_acc', - 'no_prefix_step_epoch_acc', 'step_no_prefix_step_epoch_acc', 'epoch_no_prefix_step_epoch_acc', - 'pbar_step_acc', - 'pbar_epoch_acc', - 'pbar_step_epoch_acc', 'step_pbar_step_epoch_acc', 'epoch_pbar_step_epoch_acc', - 'custom_epoch_end_metric' - ] - expected_metrics = set(metrics + ['debug_epoch']) + logged_metrics = set(trainer.logged_metrics.keys()) + expected_logged_metrics = { + 'epoch', + 'a', + 'step_a', + 'epoch_a', + 'b', + 'b1', + 'a1', + 'a2' + } + assert logged_metrics == expected_logged_metrics + + pbar_metrics = set(trainer.progress_bar_metrics.keys()) + expected_pbar_metrics = { + 'b', + } + assert pbar_metrics == expected_pbar_metrics + callback_metrics = set(trainer.callback_metrics.keys()) - assert expected_metrics == callback_metrics + callback_metrics.remove('debug_epoch') + expected_callback_metrics = set() + expected_callback_metrics = expected_callback_metrics.union(logged_metrics) + expected_callback_metrics = expected_callback_metrics.union(pbar_metrics) + expected_callback_metrics.remove('epoch') + assert callback_metrics == expected_callback_metrics - # verify global steps were correctly called - # epoch 0 - assert trainer.dev_debugger.logged_metrics[0]['global_step'] == 0 - assert trainer.dev_debugger.logged_metrics[1]['global_step'] == 1 - assert trainer.dev_debugger.logged_metrics[2]['global_step'] == 1 +def test__training_step__step_end__epoch_end__log(tmpdir): + """ + Tests that only training_step can be used + """ + os.environ['PL_DEV_DEBUG'] = '1' - # epoch 1 - assert trainer.dev_debugger.logged_metrics[3]['global_step'] == 2 - assert trainer.dev_debugger.logged_metrics[4]['global_step'] == 3 - assert trainer.dev_debugger.logged_metrics[5]['global_step'] == 3 + class TestModel(DeterministicModel): + def training_step(self, batch, batch_idx): + self.training_step_called = True + acc = self.step(batch, batch_idx) + acc = acc + batch_idx + self.log('a', acc, on_step=True, on_epoch=True) + return acc + + def training_step_end(self, out): + self.training_step_end_called = True + self.log('b', out, on_step=True, on_epoch=True, prog_bar=True, logger=True) + return out + + def training_epoch_end(self, outputs): + self.training_epoch_end_called = True + self.log('c', outputs[0]['loss'], on_epoch=True, prog_bar=True, logger=True) + + def backward(self, trainer, loss, optimizer, optimizer_idx): + loss.backward() + + model = TestModel() + model.val_dataloader = None + + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=2, + limit_val_batches=2, + max_epochs=2, + row_log_interval=1, + weights_summary=None, + ) + trainer.fit(model) + + # make sure correct steps were called + assert model.training_step_called + assert model.training_step_end_called + assert model.training_epoch_end_called + + # make sure all the metrics are available for callbacks + logged_metrics = set(trainer.logged_metrics.keys()) + expected_logged_metrics = { + 'a', + 'step_a', + 'epoch_a', + 'b', + 'step_b', + 'epoch_b', + 'c', + 'epoch', + } + assert logged_metrics == expected_logged_metrics + + pbar_metrics = set(trainer.progress_bar_metrics.keys()) + expected_pbar_metrics = {'b', 'c', 'epoch_b', 'step_b'} + assert pbar_metrics == expected_pbar_metrics + + callback_metrics = set(trainer.callback_metrics.keys()) + callback_metrics.remove('debug_epoch') + expected_callback_metrics = set() + expected_callback_metrics = expected_callback_metrics.union(logged_metrics) + expected_callback_metrics = expected_callback_metrics.union(pbar_metrics) + expected_callback_metrics.remove('epoch') + assert callback_metrics == expected_callback_metrics