# Copyright The PyTorch Lightning team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC import torch class ValidationEpochEndVariations(ABC): """ Houses all variations of validation_epoch_end steps """ def validation_epoch_end(self, outputs): """ Called at the end of validation to aggregate outputs Args: outputs: list of individual outputs of each validation step """ # if returned a scalar from validation_step, outputs is a list of tensor scalars # we return just the average in this case (if we want) def _mean(res, key): # recursive mean for multilevel dicts return torch.stack([x[key] if isinstance(x, dict) else _mean(x, key) for x in res]).mean() val_loss_mean = _mean(outputs, "val_loss") val_acc_mean = _mean(outputs, "val_acc") # alternate between tensor and scalar if self.current_epoch % 2 == 0: val_loss_mean = val_loss_mean.item() val_acc_mean = val_acc_mean.item() self.log("early_stop_on", val_loss_mean, prog_bar=True) self.log("val_acc", val_acc_mean, prog_bar=True) def validation_epoch_end__multiple_dataloaders(self, outputs): """ Called at the end of validation to aggregate outputs Args: outputs: list of individual outputs of each validation step """ # if returned a scalar from validation_step, outputs is a list of tensor scalars # we return just the average in this case (if we want) def _mean(res, key): return torch.stack([x[key] for x in res]).mean() pbar = {} logs = {} for dl_output_list in outputs: output_keys = dl_output_list[0].keys() output_keys = [x for x in output_keys if "val_" in x] for key in output_keys: metric_out = _mean(dl_output_list, key) pbar[key] = metric_out logs[key] = metric_out results = { "val_loss": torch.stack([v for k, v in pbar.items() if k.startswith("val_loss")]).mean(), "progress_bar": pbar, "log": logs, } return results