diff --git a/pytorch_lightning/trainer/evaluate_loop.py b/pytorch_lightning/trainer/evaluate_loop.py new file mode 100644 index 0000000000..d1d8977843 --- /dev/null +++ b/pytorch_lightning/trainer/evaluate_loop.py @@ -0,0 +1,76 @@ +import torch +from pytorch_lightning.trainer.supporters import PredictionCollection +from pytorch_lightning.core.step_result import EvalResult + + +class EvaluationLoop(object): + def __init__(self, trainer): + self.trainer = trainer + self.testing = False + self.outputs = [] + self.predictions = None + self.max_batches = None + + def is_using_eval_results(self): + outputs = self.outputs + using_eval_result = len(outputs) > 0 and len(outputs[0]) > 0 and isinstance(outputs[0][0], EvalResult) + return using_eval_result + + def setup(self, model, max_batches, dataloaders): + # enable eval mode + model.zero_grad() + model.eval() + + # copy properties for forward overrides + self.trainer.copy_trainer_model_properties(model) + + # disable gradients to save memory + torch.set_grad_enabled(False) + + # bookkeeping + self.outputs = [] + self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size) + + # convert max_batches to list + if isinstance(max_batches, int): + max_batches = [max_batches] * len(dataloaders) + + self.max_batches = max_batches + + def on_evaluation_epoch_start(self, *args, **kwargs): + if self.testing: + self.trainer.call_hook('on_test_epoch_start', *args, **kwargs) + else: + self.trainer.call_hook('on_validation_epoch_start', *args, **kwargs) + + def evaluation_step(self, *args, **kwargs): + if self.testing: + output = self.trainer.accelerator_backend.test_step(*args, **kwargs) + else: + output = self.trainer.accelerator_backend.validation_step(*args, **kwargs) + return output + + def evaluation_step_end(self, *args, **kwargs): + if self.testing: + output = self.trainer.call_hook('test_step_end', *args, **kwargs) + else: + output = self.trainer.call_hook('validation_step_end', *args, **kwargs) + return output + + def on_evaluation_batch_start(self, *args, **kwargs): + if self.testing: + self.trainer.call_hook('on_test_batch_start', *args, **kwargs) + else: + self.trainer.call_hook('on_validation_batch_start', *args, **kwargs) + + def on_evaluation_batch_end(self, *args, **kwargs): + if self.testing: + self.trainer.call_hook('on_test_batch_end', *args, **kwargs) + else: + self.trainer.call_hook('on_validation_batch_end', *args, **kwargs) + + def on_evaluation_epoch_end(self, *args, **kwargs): + if self.testing: + self.trainer.call_hook('on_test_epoch_end', *args, **kwargs) + else: + self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 8843e03917..e7dca236b7 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -134,7 +134,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import rank_zero_warn, flatten_dict, AMPType from pytorch_lightning.core.step_result import Result, EvalResult from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.trainer.supporters import PredictionCollection +from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop try: import torch_xla.distributed.parallel_loader as xla_pl @@ -192,6 +192,7 @@ class TrainerEvaluationLoopMixin(ABC): on_test_start: Callable on_test_end: Callable accelerator_backend: ... + evaluation_loop: EvaluationLoop @abstractmethod def copy_trainer_model_properties(self, *args): @@ -245,31 +246,14 @@ class TrainerEvaluationLoopMixin(ABC): entry is the number of batches to process in the corresponding dataloader. test_mode: """ - # enable eval mode - model.zero_grad() - model.eval() + # set up the loop for val/test + self.evaluation_loop.testing = test_mode - # copy properties for forward overrides - self.copy_trainer_model_properties(model) + # set up the eval loop + self.evaluation_loop.setup(model, max_batches, dataloaders) - # disable gradients to save memory - torch.set_grad_enabled(False) - - # bookkeeping - outputs = [] - predictions = PredictionCollection(self.global_rank, self.world_size) - - # convert max_batches to list - if isinstance(max_batches, int): - max_batches = [max_batches] * len(dataloaders) - - # -------------------------- - # ON_EVAL_EPOCH_START hook - # -------------------------- - if test_mode: - self.call_hook('on_test_epoch_start') - else: - self.call_hook('on_validation_epoch_start') + # hook + self.evaluation_loop.on_evaluation_epoch_start() # run validation for dataloader_idx, dataloader in enumerate(dataloaders): @@ -282,7 +266,7 @@ class TrainerEvaluationLoopMixin(ABC): dataloader = dataloader.per_device_loader(device) # each dataloader has a max num batches - dl_max_batches = max_batches[dataloader_idx] + dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx] for batch_idx, batch in enumerate(dataloader): if batch is None: @@ -292,25 +276,19 @@ class TrainerEvaluationLoopMixin(ABC): if batch_idx >= dl_max_batches: break - # callbacks - if test_mode: - self.call_hook('on_test_batch_start', batch, batch_idx, dataloader_idx) - else: - self.call_hook('on_validation_batch_start', batch, batch_idx, dataloader_idx) + # ----------------- + # eval_batch_start + # ----------------- + self.evaluation_loop.on_evaluation_batch_start(batch, batch_idx, dataloader_idx) # ----------------- # RUN EVALUATION STEP # ----------------- args = self.build_args(test_mode, batch, batch_idx, dataloader_idx) - - if test_mode: - output = self.accelerator_backend.test_step(args) - else: - output = self.accelerator_backend.validation_step(args) - - is_result_obj = isinstance(output, Result) + output = self.evaluation_loop.evaluation_step(args) # track batch size for weighted average + is_result_obj = isinstance(output, Result) if is_result_obj: output.track_batch_size(len(batch)) @@ -322,19 +300,12 @@ class TrainerEvaluationLoopMixin(ABC): # ------------------ # EVAL STEP END # ------------------ - if test_mode: - output = self.call_hook('test_step_end', output) - else: - output = self.call_hook('validation_step_end', output) + output = self.evaluation_loop.evaluation_step_end(output) # ------------------ # Hook: on_eval_batch_end # ------------------ - # callbacks (on __batch_end) - if test_mode: - self.call_hook('on_test_batch_end', batch, batch_idx, dataloader_idx) - else: - self.call_hook('on_validation_batch_end', batch, batch_idx, dataloader_idx) + self.evaluation_loop.on_evaluation_batch_end(batch, batch_idx, dataloader_idx) # ---------------------- # Post processing @@ -345,7 +316,7 @@ class TrainerEvaluationLoopMixin(ABC): # Add step predictions to prediction collection to write later do_write_predictions = is_result_obj and test_mode if do_write_predictions: - predictions.add(output.pop('predictions', None)) + self.evaluation_loop.predictions.add(output.pop('predictions', None)) dl_outputs.append(output) @@ -354,19 +325,24 @@ class TrainerEvaluationLoopMixin(ABC): # track debug metrics self.dev_debugger.track_eval_loss_history(test_mode, batch_idx, dataloader_idx, output) - outputs.append(dl_outputs) + self.evaluation_loop.outputs.append(dl_outputs) # --------------------- # EVAL_EPOCH_END # --------------------- - using_eval_result = len(outputs) > 0 and len(outputs[0]) > 0 and isinstance(outputs[0][0], EvalResult) - eval_results = self.__run_eval_epoch_end(test_mode, outputs, dataloaders, using_eval_result) + using_eval_result = self.evaluation_loop.is_using_eval_results() + eval_results = self.__run_eval_epoch_end( + test_mode, + self.evaluation_loop.outputs, + dataloaders, + using_eval_result + ) # log callback metrics self.__update_callback_metrics(eval_results, using_eval_result) # Write predictions to disk if they're available. - predictions.to_disk() + self.evaluation_loop.predictions.to_disk() # enable train mode again model.train() @@ -377,10 +353,7 @@ class TrainerEvaluationLoopMixin(ABC): # -------------------------- # ON_EVAL_EPOCH_END hook # -------------------------- - if test_mode: - self.call_hook('on_test_epoch_end') - else: - self.call_hook('on_validation_epoch_end') + self.evaluation_loop.on_evaluation_epoch_end() return eval_results diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 75bafe6cf5..b0796f9c18 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -54,6 +54,7 @@ from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only, from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.cloud_io import is_remote_path +from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop # warnings to ignore in trainer warnings.filterwarnings( @@ -608,6 +609,9 @@ class Trainer( self.config_validator = ConfigValidator(self) self.accelerator_backend = None + # loops + self.evaluation_loop = EvaluationLoop(self) + # Callback system self.on_init_end()