From 7e874d70d6dd57caa1d6be12995a647062951cfc Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 7 Sep 2020 11:55:14 -0400 Subject: [PATCH] ref: inner train loop (intermediate step) 19/n (#3385) * ref: inner train loop (intermediate step) 19/n * Update debugging.py * ref: inner train loop (intermediate step) 19/n --- .../callbacks/model_checkpoint.py | 2 +- pytorch_lightning/trainer/evaluate_loop.py | 2 +- pytorch_lightning/trainer/evaluation_loop.py | 6 +-- pytorch_lightning/trainer/logger_connector.py | 40 +++++++++++++++++++ pytorch_lightning/trainer/logging.py | 38 ------------------ pytorch_lightning/trainer/trainer.py | 1 - pytorch_lightning/trainer/training_loop.py | 8 +--- 7 files changed, 45 insertions(+), 52 deletions(-) diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 74167691c1..c6b96fe307 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -345,7 +345,7 @@ class ModelCheckpoint(Callback): self.epoch_last_check = epoch - ckpt_name_metrics = trainer.logged_metrics + ckpt_name_metrics = trainer.logger_connector.logged_metrics filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics) version_cnt = 0 while self._fs.exists(filepath): diff --git a/pytorch_lightning/trainer/evaluate_loop.py b/pytorch_lightning/trainer/evaluate_loop.py index e0f5ada384..565cb2785c 100644 --- a/pytorch_lightning/trainer/evaluate_loop.py +++ b/pytorch_lightning/trainer/evaluate_loop.py @@ -276,7 +276,7 @@ class EvaluationLoop(object): for k, v in step_log_metrics.items(): metrics_by_epoch[f'{k}/epoch_{self.trainer.current_epoch}'] = v - self.trainer.log_metrics(metrics_by_epoch, {}, step=batch_idx) + self.trainer.logger_connector.log_metrics(metrics_by_epoch, {}, step=batch_idx) if len(step_pbar_metrics) > 0: self.trainer.add_progress_bar_metrics(step_pbar_metrics) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 57641b3a74..7ea12a74e6 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -211,10 +211,6 @@ class TrainerEvaluationLoopMixin(ABC): def add_progress_bar_metrics(self, *args): """Warning: this is just empty shell for code implemented in other class.""" - @abstractmethod - def log_metrics(self, *args, **kwargs): - """Warning: this is just empty shell for code implemented in other class.""" - @abstractmethod def reset_test_dataloader(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @@ -337,7 +333,7 @@ class TrainerEvaluationLoopMixin(ABC): self.add_progress_bar_metrics(prog_bar_metrics) # log metrics - self.log_metrics(log_metrics, {}) + self.logger_connector.log_metrics(log_metrics, {}) # track metrics for callbacks self.logger_connector.callback_metrics.update(callback_metrics) diff --git a/pytorch_lightning/trainer/logger_connector.py b/pytorch_lightning/trainer/logger_connector.py index cd6f32dcdc..0d8dd44275 100644 --- a/pytorch_lightning/trainer/logger_connector.py +++ b/pytorch_lightning/trainer/logger_connector.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from pytorch_lightning.core import memory class LoggerConnector: @@ -18,3 +19,42 @@ class LoggerConnector: def __init__(self, trainer): self.trainer = trainer self.callback_metrics = {} + self.logged_metrics = {} + + def log_metrics(self, metrics, grad_norm_dic, step=None): + """Logs the metric dict passed in. + If `step` parameter is None and `step` key is presented is metrics, + uses metrics["step"] as a step + + Args: + metrics (dict): Metric values + grad_norm_dic (dict): Gradient norms + step (int): Step for which metrics should be logged. Default value corresponds to `self.global_step` + """ + # add gpu memory + if self.trainer.on_gpu and self.trainer.log_gpu_memory: + mem_map = memory.get_memory_profile(self.trainer.log_gpu_memory) + metrics.update(mem_map) + + # add norms + metrics.update(grad_norm_dic) + + # turn all tensors to scalars + scalar_metrics = self.trainer.metrics_to_scalars(metrics) + + if "step" in scalar_metrics and step is None: + step = scalar_metrics.pop("step") + + elif step is None: + # added metrics by Lightning for convenience + scalar_metrics['epoch'] = self.trainer.current_epoch + step = step if step is not None else self.trainer.global_step + + # log actual metrics + if self.trainer.is_global_zero and self.trainer.logger is not None: + self.trainer.logger.agg_and_log_metrics(scalar_metrics, step=step) + self.trainer.logger.save() + + # track the logged metrics + self.logged_metrics = scalar_metrics + self.trainer.dev_debugger.track_logged_metrics_history(scalar_metrics) diff --git a/pytorch_lightning/trainer/logging.py b/pytorch_lightning/trainer/logging.py index 41346ac951..9811e5952a 100644 --- a/pytorch_lightning/trainer/logging.py +++ b/pytorch_lightning/trainer/logging.py @@ -56,44 +56,6 @@ class TrainerLoggingMixin(ABC): else: self.logger = logger - def log_metrics(self, metrics, grad_norm_dic, step=None): - """Logs the metric dict passed in. - If `step` parameter is None and `step` key is presented is metrics, - uses metrics["step"] as a step - - Args: - metrics (dict): Metric values - grad_norm_dic (dict): Gradient norms - step (int): Step for which metrics should be logged. Default value corresponds to `self.global_step` - """ - # add gpu memory - if self.on_gpu and self.log_gpu_memory: - mem_map = memory.get_memory_profile(self.log_gpu_memory) - metrics.update(mem_map) - - # add norms - metrics.update(grad_norm_dic) - - # turn all tensors to scalars - scalar_metrics = self.metrics_to_scalars(metrics) - - if "step" in scalar_metrics and step is None: - step = scalar_metrics.pop("step") - - elif step is None: - # added metrics by Lightning for convenience - scalar_metrics['epoch'] = self.current_epoch - step = step if step is not None else self.global_step - - # log actual metrics - if self.is_global_zero and self.logger is not None: - self.logger.agg_and_log_metrics(scalar_metrics, step=step) - self.logger.save() - - # track the logged metrics - self.logged_metrics = scalar_metrics - self.dev_debugger.track_logged_metrics_history(scalar_metrics) - def add_progress_bar_metrics(self, metrics): for k, v in metrics.items(): if isinstance(v, torch.Tensor): diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index e61161918e..a78ef5fd2b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -380,7 +380,6 @@ class Trainer( self.running_loss = TensorRunningAccum(window_length=20) self.batch_idx = 0 self.progress_bar_metrics = {} - self.logged_metrics = {} self.num_training_batches = 0 self.num_val_batches = [] self.num_sanity_val_batches = [] diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index 8c5d29e2b2..ce2a6d30f3 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -76,10 +76,6 @@ class TrainerTrainLoopMixin(ABC): def add_progress_bar_metrics(self, *args): """Warning: this is just empty shell for code implemented in other class.""" - @abstractmethod - def log_metrics(self, *args): - """Warning: this is just empty shell for code implemented in other class.""" - @abstractmethod def process_output(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @@ -186,7 +182,7 @@ class TrainerTrainLoopMixin(ABC): # -------------------------- # add the metrics to the loggers if epoch_log_metrics and len(epoch_log_metrics) > 0: - self.log_metrics(epoch_log_metrics, {}) + self.logger_connector.log_metrics(epoch_log_metrics, {}) # add metrics to callbacks self.logger_connector.callback_metrics.update(epoch_callback_metrics) @@ -246,7 +242,7 @@ class TrainerTrainLoopMixin(ABC): metrics = batch_output.batch_log_metrics grad_norm_dic = batch_output.grad_norm_dic if len(metrics) > 0 or len(grad_norm_dic) > 0: - self.log_metrics(metrics, grad_norm_dic) + self.logger_connector.log_metrics(metrics, grad_norm_dic) def save_loggers_in_training_loop(self, batch_idx): # when loggers should save to disk