ref: add eval loop object to streamline eval loop (#3138)

* added eval loop

* added eval loop

* added eval loop

* added eval loop

* added eval loop

* added eval loop
This commit is contained in:
William Falcon 2020-08-24 21:27:11 -04:00 committed by GitHub
parent 82d1128966
commit 229b87655a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 108 additions and 55 deletions

View File

@ -0,0 +1,76 @@
import torch
from pytorch_lightning.trainer.supporters import PredictionCollection
from pytorch_lightning.core.step_result import EvalResult
class EvaluationLoop(object):
def __init__(self, trainer):
self.trainer = trainer
self.testing = False
self.outputs = []
self.predictions = None
self.max_batches = None
def is_using_eval_results(self):
outputs = self.outputs
using_eval_result = len(outputs) > 0 and len(outputs[0]) > 0 and isinstance(outputs[0][0], EvalResult)
return using_eval_result
def setup(self, model, max_batches, dataloaders):
# enable eval mode
model.zero_grad()
model.eval()
# copy properties for forward overrides
self.trainer.copy_trainer_model_properties(model)
# disable gradients to save memory
torch.set_grad_enabled(False)
# bookkeeping
self.outputs = []
self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size)
# convert max_batches to list
if isinstance(max_batches, int):
max_batches = [max_batches] * len(dataloaders)
self.max_batches = max_batches
def on_evaluation_epoch_start(self, *args, **kwargs):
if self.testing:
self.trainer.call_hook('on_test_epoch_start', *args, **kwargs)
else:
self.trainer.call_hook('on_validation_epoch_start', *args, **kwargs)
def evaluation_step(self, *args, **kwargs):
if self.testing:
output = self.trainer.accelerator_backend.test_step(*args, **kwargs)
else:
output = self.trainer.accelerator_backend.validation_step(*args, **kwargs)
return output
def evaluation_step_end(self, *args, **kwargs):
if self.testing:
output = self.trainer.call_hook('test_step_end', *args, **kwargs)
else:
output = self.trainer.call_hook('validation_step_end', *args, **kwargs)
return output
def on_evaluation_batch_start(self, *args, **kwargs):
if self.testing:
self.trainer.call_hook('on_test_batch_start', *args, **kwargs)
else:
self.trainer.call_hook('on_validation_batch_start', *args, **kwargs)
def on_evaluation_batch_end(self, *args, **kwargs):
if self.testing:
self.trainer.call_hook('on_test_batch_end', *args, **kwargs)
else:
self.trainer.call_hook('on_validation_batch_end', *args, **kwargs)
def on_evaluation_epoch_end(self, *args, **kwargs):
if self.testing:
self.trainer.call_hook('on_test_epoch_end', *args, **kwargs)
else:
self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs)

View File

@ -134,7 +134,7 @@ from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_warn, flatten_dict, AMPType
from pytorch_lightning.core.step_result import Result, EvalResult
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.trainer.supporters import PredictionCollection
from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop
try:
import torch_xla.distributed.parallel_loader as xla_pl
@ -192,6 +192,7 @@ class TrainerEvaluationLoopMixin(ABC):
on_test_start: Callable
on_test_end: Callable
accelerator_backend: ...
evaluation_loop: EvaluationLoop
@abstractmethod
def copy_trainer_model_properties(self, *args):
@ -245,31 +246,14 @@ class TrainerEvaluationLoopMixin(ABC):
entry is the number of batches to process in the corresponding dataloader.
test_mode:
"""
# enable eval mode
model.zero_grad()
model.eval()
# set up the loop for val/test
self.evaluation_loop.testing = test_mode
# copy properties for forward overrides
self.copy_trainer_model_properties(model)
# set up the eval loop
self.evaluation_loop.setup(model, max_batches, dataloaders)
# disable gradients to save memory
torch.set_grad_enabled(False)
# bookkeeping
outputs = []
predictions = PredictionCollection(self.global_rank, self.world_size)
# convert max_batches to list
if isinstance(max_batches, int):
max_batches = [max_batches] * len(dataloaders)
# --------------------------
# ON_EVAL_EPOCH_START hook
# --------------------------
if test_mode:
self.call_hook('on_test_epoch_start')
else:
self.call_hook('on_validation_epoch_start')
# hook
self.evaluation_loop.on_evaluation_epoch_start()
# run validation
for dataloader_idx, dataloader in enumerate(dataloaders):
@ -282,7 +266,7 @@ class TrainerEvaluationLoopMixin(ABC):
dataloader = dataloader.per_device_loader(device)
# each dataloader has a max num batches
dl_max_batches = max_batches[dataloader_idx]
dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx]
for batch_idx, batch in enumerate(dataloader):
if batch is None:
@ -292,25 +276,19 @@ class TrainerEvaluationLoopMixin(ABC):
if batch_idx >= dl_max_batches:
break
# callbacks
if test_mode:
self.call_hook('on_test_batch_start', batch, batch_idx, dataloader_idx)
else:
self.call_hook('on_validation_batch_start', batch, batch_idx, dataloader_idx)
# -----------------
# eval_batch_start
# -----------------
self.evaluation_loop.on_evaluation_batch_start(batch, batch_idx, dataloader_idx)
# -----------------
# RUN EVALUATION STEP
# -----------------
args = self.build_args(test_mode, batch, batch_idx, dataloader_idx)
if test_mode:
output = self.accelerator_backend.test_step(args)
else:
output = self.accelerator_backend.validation_step(args)
is_result_obj = isinstance(output, Result)
output = self.evaluation_loop.evaluation_step(args)
# track batch size for weighted average
is_result_obj = isinstance(output, Result)
if is_result_obj:
output.track_batch_size(len(batch))
@ -322,19 +300,12 @@ class TrainerEvaluationLoopMixin(ABC):
# ------------------
# EVAL STEP END
# ------------------
if test_mode:
output = self.call_hook('test_step_end', output)
else:
output = self.call_hook('validation_step_end', output)
output = self.evaluation_loop.evaluation_step_end(output)
# ------------------
# Hook: on_eval_batch_end
# ------------------
# callbacks (on __batch_end)
if test_mode:
self.call_hook('on_test_batch_end', batch, batch_idx, dataloader_idx)
else:
self.call_hook('on_validation_batch_end', batch, batch_idx, dataloader_idx)
self.evaluation_loop.on_evaluation_batch_end(batch, batch_idx, dataloader_idx)
# ----------------------
# Post processing
@ -345,7 +316,7 @@ class TrainerEvaluationLoopMixin(ABC):
# Add step predictions to prediction collection to write later
do_write_predictions = is_result_obj and test_mode
if do_write_predictions:
predictions.add(output.pop('predictions', None))
self.evaluation_loop.predictions.add(output.pop('predictions', None))
dl_outputs.append(output)
@ -354,19 +325,24 @@ class TrainerEvaluationLoopMixin(ABC):
# track debug metrics
self.dev_debugger.track_eval_loss_history(test_mode, batch_idx, dataloader_idx, output)
outputs.append(dl_outputs)
self.evaluation_loop.outputs.append(dl_outputs)
# ---------------------
# EVAL_EPOCH_END
# ---------------------
using_eval_result = len(outputs) > 0 and len(outputs[0]) > 0 and isinstance(outputs[0][0], EvalResult)
eval_results = self.__run_eval_epoch_end(test_mode, outputs, dataloaders, using_eval_result)
using_eval_result = self.evaluation_loop.is_using_eval_results()
eval_results = self.__run_eval_epoch_end(
test_mode,
self.evaluation_loop.outputs,
dataloaders,
using_eval_result
)
# log callback metrics
self.__update_callback_metrics(eval_results, using_eval_result)
# Write predictions to disk if they're available.
predictions.to_disk()
self.evaluation_loop.predictions.to_disk()
# enable train mode again
model.train()
@ -377,10 +353,7 @@ class TrainerEvaluationLoopMixin(ABC):
# --------------------------
# ON_EVAL_EPOCH_END hook
# --------------------------
if test_mode:
self.call_hook('on_test_epoch_end')
else:
self.call_hook('on_validation_epoch_end')
self.evaluation_loop.on_evaluation_epoch_end()
return eval_results

View File

@ -54,6 +54,7 @@ from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only,
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.cloud_io import is_remote_path
from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop
# warnings to ignore in trainer
warnings.filterwarnings(
@ -608,6 +609,9 @@ class Trainer(
self.config_validator = ConfigValidator(self)
self.accelerator_backend = None
# loops
self.evaluation_loop = EvaluationLoop(self)
# Callback system
self.on_init_end()