""" Validation loop =============== The lightning validation loop handles everything except the actual computations of your model. To decide what will happen in your validation loop, define the `validation_step` function. Below are all the things lightning automates for you in the validation loop. .. note:: Lightning will run 5 steps of validation in the beginning of training as a sanity check so you don't have to wait until a full epoch to catch possible validation issues. Check validation every n epochs ------------------------------- If you have a small dataset you might want to check validation every n epochs .. code-block:: python # DEFAULT trainer = Trainer(check_val_every_n_epoch=1) Set how much of the validation set to check ------------------------------------------- If you don't want to check 100% of the validation set (for debugging or if it's huge), set this flag. limit_val_batches will be overwritten by overfit_batches if `overfit_batches > 0` .. code-block:: python # DEFAULT trainer = Trainer(limit_val_batches=1.0) # check 10% only trainer = Trainer(limit_val_batches=0.1) Set how much of the test set to check ------------------------------------- If you don't want to check 100% of the test set (for debugging or if it's huge), set this flag. limit_test_batches will be overwritten by overfit_batches if `overfit_batches > 0` .. code-block:: python # DEFAULT trainer = Trainer(limit_test_batches=1.0) # check 10% only trainer = Trainer(limit_test_batches=0.1) Set validation check frequency within 1 training epoch ------------------------------------------------------ For large datasets it's often desirable to check validation multiple times within a training loop. Pass in a float to check that often within 1 training epoch. Pass in an int k to check every k training batches. Must use an int if using an IterableDataset. .. code-block:: python # DEFAULT trainer = Trainer(val_check_interval=0.95) # check every .25 of an epoch trainer = Trainer(val_check_interval=0.25) # check every 100 train batches (ie: for IterableDatasets or fixed frequency) trainer = Trainer(val_check_interval=100) Set the number of validation sanity steps ----------------------------------------- Lightning runs a few steps of validation in the beginning of training. This avoids crashing in the validation loop sometime deep into a lengthy training loop. .. code-block:: python # DEFAULT trainer = Trainer(num_sanity_val_steps=2) You can use `Trainer(num_sanity_val_steps=0)` to skip the sanity check or `Trainer(num_sanity_val_steps=-1)` to check all the validation data. # Testing loop To ensure you don't accidentally use test data to guide training decisions Lightning makes running the test set deliberate. **test** You have two options to run the test set. First case is where you test right after a full training routine. .. code-block:: python # run full training trainer.fit(model) # run test set trainer.test() Second case is where you load a model and run the test set .. code-block:: python model = MyLightningModule.load_from_checkpoint( checkpoint_path='/path/to/pytorch_checkpoint.ckpt', hparams_file='/path/to/test_tube/experiment/version/hparams.yaml', map_location=None ) # init trainer with whatever options trainer = Trainer(...) # test (pass in the model) trainer.test(model) In this second case, the options you pass to trainer will be used when running the test set (ie: 16-bit, dp, ddp, etc...) """ from abc import ABC, abstractmethod from pprint import pprint from typing import Callable, List, Union import torch from torch.utils.data import DataLoader from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import rank_zero_warn, flatten_dict, AMPType from pytorch_lightning.core.step_result import Result, EvalResult from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.trainer.supporters import PredictionCollection try: import torch_xla.distributed.parallel_loader as xla_pl import torch_xla.core.xla_model as xm except ImportError: XLA_AVAILABLE = False else: XLA_AVAILABLE = True try: import horovod.torch as hvd except (ModuleNotFoundError, ImportError): HOROVOD_AVAILABLE = False else: HOROVOD_AVAILABLE = True class TrainerEvaluationLoopMixin(ABC): # this is just a summary on variables used in this abstract class, # the proper values/initialisation should be done in child class on_gpu: bool use_ddp: bool use_dp: bool use_ddp2: bool use_horovod: bool use_single_gpu: bool data_parallel_device_ids: ... model: LightningModule num_test_batches: List[int] num_val_batches: int world_size: int fast_dev_run: ... process_output: ... progress_bar_dict: ... global_rank: int current_epoch: int callback_metrics: ... test_dataloaders: DataLoader val_dataloaders: DataLoader use_tpu: bool reload_dataloaders_every_epoch: ... tpu_id: int verbose_test: bool running_sanity_check: bool amp_backend: AMPType # Callback system on_validation_batch_start: Callable on_validation_batch_end: Callable on_test_batch_start: Callable on_test_batch_end: Callable on_validation_start: Callable on_validation_end: Callable on_test_start: Callable on_test_end: Callable @abstractmethod def copy_trainer_model_properties(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def get_model(self) -> LightningModule: """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def is_overridden(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def transfer_batch_to_tpu(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def transfer_batch_to_gpu(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def add_progress_bar_metrics(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def log_metrics(self, *args, **kwargs): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def reset_test_dataloader(self, *args): """Warning: this is just empty shell for code implemented in other class.""" @abstractmethod def reset_val_dataloader(self, *args): """Warning: this is just empty shell for code implemented in other class.""" def __call_eval_loop_hook_start(self, test_mode): """on_validation/test_epoch_start""" self.__call_eval_loop_hook_evt(test_mode, 'start') def __call_eval_loop_hook_end(self, test_mode): """on_validation/test_epoch_end""" self.__call_eval_loop_hook_evt(test_mode, 'end') def __call_eval_loop_hook_evt(self, test_mode, epoch_event): model = self.get_model() # on_[train/validation]_epoch_start hook hook_root_name = 'test' if test_mode else 'validation' hook_name = f'on_{hook_root_name}_epoch_{epoch_event}' with self.profiler.profile(hook_name): # call hook getattr(self, hook_name)() # model hooks if self.is_function_implemented(hook_name): getattr(model, hook_name)() def _evaluate( self, model: LightningModule, dataloaders: List[DataLoader], max_batches: Union[int, List[int]], test_mode: bool = False ): """Run evaluation code. Args: model: The model to evaluate. dataloaders: A list of PyTorch dataloaders. max_batches: An integer or list of integers with length of the number of dataloaders. Each entry is the number of batches to process in the corresponding dataloader. test_mode: """ # enable eval mode model.zero_grad() model.eval() # copy properties for forward overrides self.copy_trainer_model_properties(model) # disable gradients to save memory torch.set_grad_enabled(False) # bookkeeping outputs = [] predictions = PredictionCollection(self.global_rank, self.world_size) # convert max_batches to list if isinstance(max_batches, int): max_batches = [max_batches] * len(dataloaders) # -------------------------- # ON_EVAL_EPOCH_START hook # -------------------------- self.__call_eval_loop_hook_start(test_mode) # run validation for dataloader_idx, dataloader in enumerate(dataloaders): dl_outputs = [] # on TPU we have to wrap it under the ParallelLoader if self.use_tpu: device = xm.xla_device(self.tpu_id) dataloader = xla_pl.ParallelLoader(dataloader, [device]) dataloader = dataloader.per_device_loader(device) # each dataloader has a max num batches dl_max_batches = max_batches[dataloader_idx] for batch_idx, batch in enumerate(dataloader): if batch is None: continue # stop short when running on limited batches if batch_idx >= dl_max_batches: break # callbacks if test_mode: self.on_test_batch_start(batch, batch_idx, dataloader_idx) if self.is_overridden('on_test_batch_start'): model_ref = self.get_model() with self.profiler.profile('on_test_batch_start'): model_ref.on_test_batch_start(output) else: self.on_validation_batch_start(batch, batch_idx, dataloader_idx) if self.is_overridden('on_validation_batch_start'): model_ref = self.get_model() with self.profiler.profile('on_validation_batch_start'): model_ref.on_validation_batch_start(output) # ----------------- # RUN EVALUATION STEP # ----------------- if self.amp_backend == AMPType.NATIVE and not self.use_tpu: with torch.cuda.amp.autocast(): output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode) else: output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode) is_result_obj = isinstance(output, Result) # track batch size for weighted average if is_result_obj: output.track_batch_size(len(batch)) # allow only EvalResult when using structured results (from val_step) if is_result_obj and not isinstance(output, EvalResult): m = 'only EvalResults or dicts are allowed from validation_step' raise MisconfigurationException(m) # ------------------ # EVAL STEP END # ------------------ # on dp / ddp2 might still want to do something with the batch parts eval_step_end_hook_name = 'test_step_end' if test_mode else 'validation_step_end' if self.is_overridden(eval_step_end_hook_name): model_ref = self.get_model() with self.profiler.profile(eval_step_end_hook_name): eval_step_end = getattr(model_ref, eval_step_end_hook_name) output = eval_step_end(output) elif is_result_obj and (self.use_dp or self.use_ddp2): # result auto reduce output.dp_reduce() # callbacks (on __batch_end) if test_mode: self.on_test_batch_end(batch, batch_idx, dataloader_idx) if self.is_overridden('on_test_batch_end'): model_ref = self.get_model() with self.profiler.profile('on_test_batch_end'): model_ref.on_test_batch_end(output) else: self.on_validation_batch_end(batch, batch_idx, dataloader_idx) if self.is_overridden('on_validation_batch_end'): model_ref = self.get_model() with self.profiler.profile('on_validation_batch_end'): model_ref.on_validation_batch_end(output) # track outputs for collation if output is not None: # Add step predictions to prediction collection to write later do_write_predictions = is_result_obj and test_mode if do_write_predictions: predictions.add(output.pop('predictions', None)) dl_outputs.append(output) self.__eval_add_step_metrics(output, batch_idx) # track debug metrics self.dev_debugger.track_eval_loss_history(test_mode, batch_idx, dataloader_idx, output) outputs.append(dl_outputs) # --------------------- # EVAL_EPOCH_END # --------------------- using_eval_result = len(outputs) > 0 and len(outputs[0]) > 0 and isinstance(outputs[0][0], EvalResult) eval_results = self.__run_eval_epoch_end(test_mode, outputs, dataloaders, using_eval_result) # log callback metrics self.__update_callback_metrics(eval_results, using_eval_result) # Write predictions to disk if they're available. predictions.to_disk() # enable train mode again model.train() # enable gradients to save memory torch.set_grad_enabled(True) # -------------------------- # ON_EVAL_EPOCH_END hook # -------------------------- self.__call_eval_loop_hook_end(test_mode) return eval_results def __update_callback_metrics(self, eval_results, using_eval_result): if using_eval_result: if isinstance(eval_results, list): for eval_result in eval_results: self.callback_metrics = eval_result.callback_metrics else: self.callback_metrics = eval_results.callback_metrics else: if isinstance(eval_results, list): for eval_result in eval_results: # with a scalar return, auto set it to "val_loss" for callbacks if isinstance(eval_result, torch.Tensor): flat = {'val_loss': eval_result} else: flat = flatten_dict(eval_result) self.callback_metrics.update(flat) else: # with a scalar return, auto set it to "val_loss" for callbacks if isinstance(eval_results, torch.Tensor): flat = {'val_loss': eval_results} else: flat = flatten_dict(eval_results) self.callback_metrics.update(flat) def __run_eval_epoch_end(self, test_mode, outputs, dataloaders, using_eval_result): model = self.get_model() # with a single dataloader don't pass an array eval_results = outputs if len(dataloaders) == 1: eval_results = outputs[0] user_reduced = False if test_mode: if self.is_overridden('test_end', model=model): # TODO: remove in v1.0.0 if using_eval_result: eval_results = self.__gather_epoch_end_eval_results(outputs) eval_results = model.test_end(eval_results) user_reduced = True rank_zero_warn('Method `test_end` was deprecated in v0.7 and will be removed in v1.0.' ' Use `test_epoch_end` instead.', DeprecationWarning) elif self.is_overridden('test_epoch_end', model=model): if using_eval_result: eval_results = self.__gather_epoch_end_eval_results(outputs) eval_results = model.test_epoch_end(eval_results) user_reduced = True else: if self.is_overridden('validation_end', model=model): # TODO: remove in v1.0.0 if using_eval_result: eval_results = self.__gather_epoch_end_eval_results(outputs) eval_results = model.validation_end(eval_results) user_reduced = True rank_zero_warn('Method `validation_end` was deprecated in v0.7 and will be removed in v1.0.' ' Use `validation_epoch_end` instead.', DeprecationWarning) elif self.is_overridden('validation_epoch_end', model=model): if using_eval_result: eval_results = self.__gather_epoch_end_eval_results(outputs) eval_results = model.validation_epoch_end(eval_results) user_reduced = True if using_eval_result and not user_reduced: eval_results = self.__auto_reduce_result_objs(outputs) if not isinstance(eval_results, list): eval_results = [eval_results] return eval_results def __gather_epoch_end_eval_results(self, outputs): eval_results = [] for epoch_output in outputs: result = epoch_output[0].__class__.gather(epoch_output) if 'checkpoint_on' in result: result.checkpoint_on = result.checkpoint_on.mean() if 'early_stop_on' in result: result.early_stop_on = result.early_stop_on.mean() eval_results.append(result) # with 1 dataloader don't pass in a list if len(eval_results) == 1: eval_results = eval_results[0] return eval_results def __eval_add_step_metrics(self, output, batch_idx): # track step level metrics if isinstance(output, EvalResult) and not self.running_sanity_check: step_log_metrics = output.batch_log_metrics step_pbar_metrics = output.batch_pbar_metrics if len(step_log_metrics) > 0: # make the metrics appear as a different line in the same graph metrics_by_epoch = {} for k, v in step_log_metrics.items(): metrics_by_epoch[f'{k}/epoch_{self.current_epoch}'] = v self.log_metrics(metrics_by_epoch, {}, step=batch_idx) if len(step_pbar_metrics) > 0: self.add_progress_bar_metrics(step_pbar_metrics) def __auto_reduce_result_objs(self, outputs): # outputs has a list of results per dataloader eval_results = [] for dl_output in outputs: result = dl_output[0] result = result.__class__.reduce_on_epoch_end(dl_output) if 'checkpoint_on' in result: result.checkpoint_on = result.checkpoint_on.mean() if 'early_stop_on' in result: result.early_stop_on = result.early_stop_on.mean() eval_results.append(result) return eval_results def run_evaluation(self, test_mode: bool = False): # hook model = self.get_model() model.on_pre_performance_check() # select dataloaders if test_mode: self.reset_test_dataloader(model) dataloaders = self.test_dataloaders max_batches = self.num_test_batches else: # val if self.val_dataloaders is None: self.reset_val_dataloader(model) dataloaders = self.val_dataloaders max_batches = self.num_val_batches if dataloaders is None: return [], [] # Validation/Test begin callbacks if test_mode: self.on_test_start() else: self.on_validation_start() # enable disabling validation step with limit_val_batches = 0 should_skip = sum(max_batches) == 0 if should_skip: return [], [] # run evaluation (val_step + val_step_end + val_epoch_end) eval_results = self._evaluate(self.model, dataloaders, max_batches, test_mode) # log the final eval loop metrics eval_loop_results = self.__log_evaluation_epoch_metrics(eval_results, test_mode) # hook model.on_post_performance_check() # eventual dataset reloading if test_mode: if self.reload_dataloaders_every_epoch: self.reset_test_dataloader(model) else: # val if self.reload_dataloaders_every_epoch: self.reset_val_dataloader(model) # Validation/Test end callbacks if test_mode: self.on_test_end() else: self.on_validation_end() return eval_loop_results, eval_results def __log_evaluation_epoch_metrics(self, eval_results, test_mode): eval_loop_results = [] if eval_results is not None and len(eval_results) > 0: # in eval, the user may return something at every validation step without final reduction if not isinstance(eval_results, list): eval_results = [eval_results] for result_idx, result in enumerate(eval_results): if isinstance(result, EvalResult): prog_bar_metrics = result.epoch_pbar_metrics log_metrics = result.epoch_log_metrics callback_metrics = result.callback_metrics # in testing we don't need the callback metrics if test_mode: callback_metrics = {} else: _, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(result) # eval loop returns all metrics dataloader_result_metrics = {**prog_bar_metrics, **log_metrics, **callback_metrics} # add metrics to prog bar self.add_progress_bar_metrics(prog_bar_metrics) # log metrics self.log_metrics(log_metrics, {}) # track metrics for callbacks self.callback_metrics.update(callback_metrics) if len(dataloader_result_metrics) > 0: eval_loop_results.append(dataloader_result_metrics) # log results of test if test_mode and self.is_global_zero and self.verbose_test: print('-' * 80) for result_idx, results in enumerate(eval_loop_results): print(f'DATALOADER:{result_idx} TEST RESULTS') pprint(results) print('-' * 80) return eval_loop_results def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: bool = False): # make dataloader_idx arg in validation_step optional args = [batch, batch_idx] if (test_mode and len(self.test_dataloaders) > 1) \ or (not test_mode and len(self.val_dataloaders) > 1): args.append(dataloader_idx) # handle DP, DDP forward if self.use_ddp or self.use_dp or self.use_ddp2: output = model(*args) return output # Horovod if self.use_horovod and self.on_gpu: batch = self.transfer_batch_to_gpu(batch, hvd.local_rank()) args[0] = batch # single GPU data transfer if self.use_single_gpu: # for single GPU put inputs on gpu manually root_gpu = 0 if isinstance(self.data_parallel_device_ids, list): root_gpu = self.data_parallel_device_ids[0] batch = self.transfer_batch_to_gpu(batch, root_gpu) args[0] = batch # TPU data transfer if self.use_tpu: batch = self.transfer_batch_to_tpu(batch, self.tpu_id) args[0] = batch # CPU, TPU or gpu step if test_mode: output = model.test_step(*args) else: output = model.validation_step(*args) return output