ref: add eval loop object to streamline eval loop (#3138)
* added eval loop * added eval loop * added eval loop * added eval loop * added eval loop * added eval loop
This commit is contained in:
parent
82d1128966
commit
229b87655a
|
@ -0,0 +1,76 @@
|
|||
import torch
|
||||
from pytorch_lightning.trainer.supporters import PredictionCollection
|
||||
from pytorch_lightning.core.step_result import EvalResult
|
||||
|
||||
|
||||
class EvaluationLoop(object):
|
||||
def __init__(self, trainer):
|
||||
self.trainer = trainer
|
||||
self.testing = False
|
||||
self.outputs = []
|
||||
self.predictions = None
|
||||
self.max_batches = None
|
||||
|
||||
def is_using_eval_results(self):
|
||||
outputs = self.outputs
|
||||
using_eval_result = len(outputs) > 0 and len(outputs[0]) > 0 and isinstance(outputs[0][0], EvalResult)
|
||||
return using_eval_result
|
||||
|
||||
def setup(self, model, max_batches, dataloaders):
|
||||
# enable eval mode
|
||||
model.zero_grad()
|
||||
model.eval()
|
||||
|
||||
# copy properties for forward overrides
|
||||
self.trainer.copy_trainer_model_properties(model)
|
||||
|
||||
# disable gradients to save memory
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
# bookkeeping
|
||||
self.outputs = []
|
||||
self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size)
|
||||
|
||||
# convert max_batches to list
|
||||
if isinstance(max_batches, int):
|
||||
max_batches = [max_batches] * len(dataloaders)
|
||||
|
||||
self.max_batches = max_batches
|
||||
|
||||
def on_evaluation_epoch_start(self, *args, **kwargs):
|
||||
if self.testing:
|
||||
self.trainer.call_hook('on_test_epoch_start', *args, **kwargs)
|
||||
else:
|
||||
self.trainer.call_hook('on_validation_epoch_start', *args, **kwargs)
|
||||
|
||||
def evaluation_step(self, *args, **kwargs):
|
||||
if self.testing:
|
||||
output = self.trainer.accelerator_backend.test_step(*args, **kwargs)
|
||||
else:
|
||||
output = self.trainer.accelerator_backend.validation_step(*args, **kwargs)
|
||||
return output
|
||||
|
||||
def evaluation_step_end(self, *args, **kwargs):
|
||||
if self.testing:
|
||||
output = self.trainer.call_hook('test_step_end', *args, **kwargs)
|
||||
else:
|
||||
output = self.trainer.call_hook('validation_step_end', *args, **kwargs)
|
||||
return output
|
||||
|
||||
def on_evaluation_batch_start(self, *args, **kwargs):
|
||||
if self.testing:
|
||||
self.trainer.call_hook('on_test_batch_start', *args, **kwargs)
|
||||
else:
|
||||
self.trainer.call_hook('on_validation_batch_start', *args, **kwargs)
|
||||
|
||||
def on_evaluation_batch_end(self, *args, **kwargs):
|
||||
if self.testing:
|
||||
self.trainer.call_hook('on_test_batch_end', *args, **kwargs)
|
||||
else:
|
||||
self.trainer.call_hook('on_validation_batch_end', *args, **kwargs)
|
||||
|
||||
def on_evaluation_epoch_end(self, *args, **kwargs):
|
||||
if self.testing:
|
||||
self.trainer.call_hook('on_test_epoch_end', *args, **kwargs)
|
||||
else:
|
||||
self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs)
|
|
@ -134,7 +134,7 @@ from pytorch_lightning.core.lightning import LightningModule
|
|||
from pytorch_lightning.utilities import rank_zero_warn, flatten_dict, AMPType
|
||||
from pytorch_lightning.core.step_result import Result, EvalResult
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.trainer.supporters import PredictionCollection
|
||||
from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop
|
||||
|
||||
try:
|
||||
import torch_xla.distributed.parallel_loader as xla_pl
|
||||
|
@ -192,6 +192,7 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
on_test_start: Callable
|
||||
on_test_end: Callable
|
||||
accelerator_backend: ...
|
||||
evaluation_loop: EvaluationLoop
|
||||
|
||||
@abstractmethod
|
||||
def copy_trainer_model_properties(self, *args):
|
||||
|
@ -245,31 +246,14 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
entry is the number of batches to process in the corresponding dataloader.
|
||||
test_mode:
|
||||
"""
|
||||
# enable eval mode
|
||||
model.zero_grad()
|
||||
model.eval()
|
||||
# set up the loop for val/test
|
||||
self.evaluation_loop.testing = test_mode
|
||||
|
||||
# copy properties for forward overrides
|
||||
self.copy_trainer_model_properties(model)
|
||||
# set up the eval loop
|
||||
self.evaluation_loop.setup(model, max_batches, dataloaders)
|
||||
|
||||
# disable gradients to save memory
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
# bookkeeping
|
||||
outputs = []
|
||||
predictions = PredictionCollection(self.global_rank, self.world_size)
|
||||
|
||||
# convert max_batches to list
|
||||
if isinstance(max_batches, int):
|
||||
max_batches = [max_batches] * len(dataloaders)
|
||||
|
||||
# --------------------------
|
||||
# ON_EVAL_EPOCH_START hook
|
||||
# --------------------------
|
||||
if test_mode:
|
||||
self.call_hook('on_test_epoch_start')
|
||||
else:
|
||||
self.call_hook('on_validation_epoch_start')
|
||||
# hook
|
||||
self.evaluation_loop.on_evaluation_epoch_start()
|
||||
|
||||
# run validation
|
||||
for dataloader_idx, dataloader in enumerate(dataloaders):
|
||||
|
@ -282,7 +266,7 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
dataloader = dataloader.per_device_loader(device)
|
||||
|
||||
# each dataloader has a max num batches
|
||||
dl_max_batches = max_batches[dataloader_idx]
|
||||
dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx]
|
||||
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
if batch is None:
|
||||
|
@ -292,25 +276,19 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
if batch_idx >= dl_max_batches:
|
||||
break
|
||||
|
||||
# callbacks
|
||||
if test_mode:
|
||||
self.call_hook('on_test_batch_start', batch, batch_idx, dataloader_idx)
|
||||
else:
|
||||
self.call_hook('on_validation_batch_start', batch, batch_idx, dataloader_idx)
|
||||
# -----------------
|
||||
# eval_batch_start
|
||||
# -----------------
|
||||
self.evaluation_loop.on_evaluation_batch_start(batch, batch_idx, dataloader_idx)
|
||||
|
||||
# -----------------
|
||||
# RUN EVALUATION STEP
|
||||
# -----------------
|
||||
args = self.build_args(test_mode, batch, batch_idx, dataloader_idx)
|
||||
|
||||
if test_mode:
|
||||
output = self.accelerator_backend.test_step(args)
|
||||
else:
|
||||
output = self.accelerator_backend.validation_step(args)
|
||||
|
||||
is_result_obj = isinstance(output, Result)
|
||||
output = self.evaluation_loop.evaluation_step(args)
|
||||
|
||||
# track batch size for weighted average
|
||||
is_result_obj = isinstance(output, Result)
|
||||
if is_result_obj:
|
||||
output.track_batch_size(len(batch))
|
||||
|
||||
|
@ -322,19 +300,12 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
# ------------------
|
||||
# EVAL STEP END
|
||||
# ------------------
|
||||
if test_mode:
|
||||
output = self.call_hook('test_step_end', output)
|
||||
else:
|
||||
output = self.call_hook('validation_step_end', output)
|
||||
output = self.evaluation_loop.evaluation_step_end(output)
|
||||
|
||||
# ------------------
|
||||
# Hook: on_eval_batch_end
|
||||
# ------------------
|
||||
# callbacks (on __batch_end)
|
||||
if test_mode:
|
||||
self.call_hook('on_test_batch_end', batch, batch_idx, dataloader_idx)
|
||||
else:
|
||||
self.call_hook('on_validation_batch_end', batch, batch_idx, dataloader_idx)
|
||||
self.evaluation_loop.on_evaluation_batch_end(batch, batch_idx, dataloader_idx)
|
||||
|
||||
# ----------------------
|
||||
# Post processing
|
||||
|
@ -345,7 +316,7 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
# Add step predictions to prediction collection to write later
|
||||
do_write_predictions = is_result_obj and test_mode
|
||||
if do_write_predictions:
|
||||
predictions.add(output.pop('predictions', None))
|
||||
self.evaluation_loop.predictions.add(output.pop('predictions', None))
|
||||
|
||||
dl_outputs.append(output)
|
||||
|
||||
|
@ -354,19 +325,24 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
# track debug metrics
|
||||
self.dev_debugger.track_eval_loss_history(test_mode, batch_idx, dataloader_idx, output)
|
||||
|
||||
outputs.append(dl_outputs)
|
||||
self.evaluation_loop.outputs.append(dl_outputs)
|
||||
|
||||
# ---------------------
|
||||
# EVAL_EPOCH_END
|
||||
# ---------------------
|
||||
using_eval_result = len(outputs) > 0 and len(outputs[0]) > 0 and isinstance(outputs[0][0], EvalResult)
|
||||
eval_results = self.__run_eval_epoch_end(test_mode, outputs, dataloaders, using_eval_result)
|
||||
using_eval_result = self.evaluation_loop.is_using_eval_results()
|
||||
eval_results = self.__run_eval_epoch_end(
|
||||
test_mode,
|
||||
self.evaluation_loop.outputs,
|
||||
dataloaders,
|
||||
using_eval_result
|
||||
)
|
||||
|
||||
# log callback metrics
|
||||
self.__update_callback_metrics(eval_results, using_eval_result)
|
||||
|
||||
# Write predictions to disk if they're available.
|
||||
predictions.to_disk()
|
||||
self.evaluation_loop.predictions.to_disk()
|
||||
|
||||
# enable train mode again
|
||||
model.train()
|
||||
|
@ -377,10 +353,7 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
# --------------------------
|
||||
# ON_EVAL_EPOCH_END hook
|
||||
# --------------------------
|
||||
if test_mode:
|
||||
self.call_hook('on_test_epoch_end')
|
||||
else:
|
||||
self.call_hook('on_validation_epoch_end')
|
||||
self.evaluation_loop.on_evaluation_epoch_end()
|
||||
|
||||
return eval_results
|
||||
|
||||
|
|
|
@ -54,6 +54,7 @@ from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only,
|
|||
from pytorch_lightning.utilities.debugging import InternalDebugger
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.cloud_io import is_remote_path
|
||||
from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop
|
||||
|
||||
# warnings to ignore in trainer
|
||||
warnings.filterwarnings(
|
||||
|
@ -608,6 +609,9 @@ class Trainer(
|
|||
self.config_validator = ConfigValidator(self)
|
||||
self.accelerator_backend = None
|
||||
|
||||
# loops
|
||||
self.evaluation_loop = EvaluationLoop(self)
|
||||
|
||||
# Callback system
|
||||
self.on_init_end()
|
||||
|
||||
|
|
Loading…
Reference in New Issue