expand eval loop out (#3165)

This commit is contained in:
William Falcon 2020-08-25 12:28:00 -04:00 committed by GitHub
parent 9adf7dfed0
commit a7705c8677
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 64 additions and 2 deletions

View File

@ -325,8 +325,70 @@ class TrainerEvaluationLoopMixin(ABC):
# TODO: deprecate
self.evaluation_loop.on_evaluation_start()
# run evaluation (val_step + val_step_end + val_epoch_end)
eval_results = self._evaluate(self.model, dataloaders, max_batches, test_mode)
# ------------------------------
# ------------------------------
# ------------------------------
# enable eval mode + no grads
model.zero_grad()
model.eval()
torch.set_grad_enabled(False)
# set up the eval loop
self.evaluation_loop.setup(model, max_batches, dataloaders)
# hook
self.evaluation_loop.on_evaluation_epoch_start()
# run validation/testing
for dataloader_idx, dataloader in enumerate(dataloaders):
dl_outputs = []
# certain accelerators need to process the dataloader
dataloader = self.accelerator_backend.process_dataloader(dataloader)
# each dataloader has a max num batches
dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx]
for batch_idx, batch in enumerate(dataloader):
if batch is None:
continue
# stop short when running on limited batches
if batch_idx >= dl_max_batches:
break
# hook
self.evaluation_loop.on_evaluation_batch_start(batch, batch_idx, dataloader_idx)
# lightning module methods
output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx)
output = self.evaluation_loop.evaluation_step_end(output)
# hook
self.evaluation_loop.on_evaluation_batch_end(batch, batch_idx, dataloader_idx)
# clean up
self.evaluation_loop.evaluation_batch_end_cleanup(output, batch_idx, dataloader_idx)
self.evaluation_loop.log_step_metrics(output, batch_idx)
# track epoch level metrics
if output is not None:
dl_outputs.append(output)
self.evaluation_loop.outputs.append(dl_outputs)
# lightning module method
eval_results = self.evaluation_loop.evaluation_epoch_end(num_dataloaders=len(dataloaders))
# hook
self.evaluation_loop.on_evaluation_epoch_end(eval_results)
# enable train mode again
model.train()
torch.set_grad_enabled(True)
# ------------------------------
# ------------------------------
# ------------------------------
# log the final eval loop metrics
eval_loop_results = self.__log_evaluation_epoch_metrics(eval_results, test_mode)