expand eval loop out (#3165)
This commit is contained in:
parent
9adf7dfed0
commit
a7705c8677
|
@ -325,8 +325,70 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
# TODO: deprecate
|
||||
self.evaluation_loop.on_evaluation_start()
|
||||
|
||||
# run evaluation (val_step + val_step_end + val_epoch_end)
|
||||
eval_results = self._evaluate(self.model, dataloaders, max_batches, test_mode)
|
||||
# ------------------------------
|
||||
# ------------------------------
|
||||
# ------------------------------
|
||||
# enable eval mode + no grads
|
||||
model.zero_grad()
|
||||
model.eval()
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
# set up the eval loop
|
||||
self.evaluation_loop.setup(model, max_batches, dataloaders)
|
||||
|
||||
# hook
|
||||
self.evaluation_loop.on_evaluation_epoch_start()
|
||||
|
||||
# run validation/testing
|
||||
for dataloader_idx, dataloader in enumerate(dataloaders):
|
||||
dl_outputs = []
|
||||
|
||||
# certain accelerators need to process the dataloader
|
||||
dataloader = self.accelerator_backend.process_dataloader(dataloader)
|
||||
|
||||
# each dataloader has a max num batches
|
||||
dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx]
|
||||
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
if batch is None:
|
||||
continue
|
||||
|
||||
# stop short when running on limited batches
|
||||
if batch_idx >= dl_max_batches:
|
||||
break
|
||||
|
||||
# hook
|
||||
self.evaluation_loop.on_evaluation_batch_start(batch, batch_idx, dataloader_idx)
|
||||
|
||||
# lightning module methods
|
||||
output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx)
|
||||
output = self.evaluation_loop.evaluation_step_end(output)
|
||||
|
||||
# hook
|
||||
self.evaluation_loop.on_evaluation_batch_end(batch, batch_idx, dataloader_idx)
|
||||
|
||||
# clean up
|
||||
self.evaluation_loop.evaluation_batch_end_cleanup(output, batch_idx, dataloader_idx)
|
||||
self.evaluation_loop.log_step_metrics(output, batch_idx)
|
||||
|
||||
# track epoch level metrics
|
||||
if output is not None:
|
||||
dl_outputs.append(output)
|
||||
|
||||
self.evaluation_loop.outputs.append(dl_outputs)
|
||||
|
||||
# lightning module method
|
||||
eval_results = self.evaluation_loop.evaluation_epoch_end(num_dataloaders=len(dataloaders))
|
||||
|
||||
# hook
|
||||
self.evaluation_loop.on_evaluation_epoch_end(eval_results)
|
||||
|
||||
# enable train mode again
|
||||
model.train()
|
||||
torch.set_grad_enabled(True)
|
||||
# ------------------------------
|
||||
# ------------------------------
|
||||
# ------------------------------
|
||||
|
||||
# log the final eval loop metrics
|
||||
eval_loop_results = self.__log_evaluation_epoch_metrics(eval_results, test_mode)
|
||||
|
|
Loading…
Reference in New Issue