ref: trainer 1/n (#3412)
* ref: moved eval loop 2/n * ref: moved eval loop 2/n * ref: trainer 1/n * ref: trainer 1/n * ref: trainer 1/n
This commit is contained in:
parent
cd40cb2fad
commit
10cf86b94b
|
@ -1,264 +0,0 @@
|
|||
from pytorch_lightning.trainer.supporters import PredictionCollection
|
||||
from pytorch_lightning.core.step_result import Result, EvalResult
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.model_utils import is_overridden
|
||||
|
||||
|
||||
class EvaluationLoop(object):
|
||||
def __init__(self, trainer):
|
||||
self.trainer = trainer
|
||||
self.testing = False
|
||||
self.outputs = []
|
||||
self.predictions = None
|
||||
self.max_batches = None
|
||||
|
||||
def get_evaluation_dataloaders(self, max_batches):
|
||||
# select dataloaders
|
||||
model = self.trainer.get_model()
|
||||
|
||||
# select dataloaders
|
||||
if self.testing:
|
||||
self.trainer.reset_test_dataloader(model)
|
||||
|
||||
dataloaders = self.trainer.test_dataloaders
|
||||
new_max_batches = self.trainer.num_test_batches
|
||||
else:
|
||||
# val
|
||||
in_sanity_check = self.trainer.running_sanity_check
|
||||
should_reload_every_epoch = self.trainer.reload_dataloaders_every_epoch
|
||||
if (self.trainer.val_dataloaders is None or should_reload_every_epoch) and not in_sanity_check:
|
||||
self.trainer.reset_val_dataloader(model)
|
||||
|
||||
dataloaders = self.trainer.val_dataloaders
|
||||
new_max_batches = self.trainer.num_val_batches
|
||||
|
||||
if max_batches is None:
|
||||
max_batches = new_max_batches
|
||||
|
||||
return dataloaders, max_batches
|
||||
|
||||
def should_skip_evaluation(self, dataloaders, max_batches):
|
||||
# skip when dataloaders aren't defined
|
||||
if dataloaders is None:
|
||||
return True
|
||||
|
||||
# enable disabling validation step with limit_val_batches = 0
|
||||
should_skip = sum(max_batches) == 0
|
||||
if should_skip:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def on_evaluation_start(self, *args, **kwargs):
|
||||
if self.testing:
|
||||
self.trainer.call_hook('on_test_start', *args, **kwargs)
|
||||
else:
|
||||
self.trainer.call_hook('on_validation_start', *args, **kwargs)
|
||||
|
||||
def on_evaluation_end(self, *args, **kwargs):
|
||||
if self.testing:
|
||||
self.trainer.call_hook('on_test_end', *args, **kwargs)
|
||||
else:
|
||||
self.trainer.call_hook('on_validation_end', *args, **kwargs)
|
||||
|
||||
def reload_evaluation_dataloaders(self):
|
||||
model = self.trainer.get_model()
|
||||
if self.testing:
|
||||
self.trainer.reset_test_dataloader(model)
|
||||
else:
|
||||
self.trainer.reset_val_dataloader(model)
|
||||
|
||||
def is_using_eval_results(self):
|
||||
outputs = self.outputs
|
||||
using_eval_result = len(outputs) > 0 and len(outputs[0]) > 0 and isinstance(outputs[0][0], EvalResult)
|
||||
return using_eval_result
|
||||
|
||||
def setup(self, model, max_batches, dataloaders):
|
||||
# copy properties for forward overrides
|
||||
self.trainer.model_connector.copy_trainer_model_properties(model)
|
||||
|
||||
# bookkeeping
|
||||
self.outputs = []
|
||||
self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size)
|
||||
|
||||
# convert max_batches to list
|
||||
if isinstance(max_batches, int):
|
||||
max_batches = [max_batches] * len(dataloaders)
|
||||
|
||||
self.max_batches = max_batches
|
||||
|
||||
def on_evaluation_epoch_start(self, *args, **kwargs):
|
||||
if self.testing:
|
||||
self.trainer.call_hook('on_test_epoch_start', *args, **kwargs)
|
||||
else:
|
||||
self.trainer.call_hook('on_validation_epoch_start', *args, **kwargs)
|
||||
|
||||
def build_args(self, test_mode, batch, batch_idx, dataloader_idx):
|
||||
# make dataloader_idx arg in validation_step optional
|
||||
args = [batch, batch_idx]
|
||||
|
||||
multiple_val_loaders = (not test_mode and len(self.trainer.val_dataloaders) > 1)
|
||||
multiple_test_loaders = (test_mode and len(self.trainer.test_dataloaders) > 1)
|
||||
|
||||
if multiple_test_loaders or multiple_val_loaders:
|
||||
args.append(dataloader_idx)
|
||||
|
||||
return args
|
||||
|
||||
def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx):
|
||||
# configure args
|
||||
args = self.build_args(test_mode, batch, batch_idx, dataloader_idx)
|
||||
|
||||
# run actual test step
|
||||
if self.testing:
|
||||
output = self.trainer.accelerator_backend.test_step(args)
|
||||
else:
|
||||
output = self.trainer.accelerator_backend.validation_step(args)
|
||||
|
||||
# track batch size for weighted average
|
||||
is_result_obj = isinstance(output, Result)
|
||||
if is_result_obj:
|
||||
output.track_batch_size(len(batch))
|
||||
|
||||
# allow only EvalResult when using structured results (from val_step)
|
||||
if is_result_obj and not isinstance(output, EvalResult):
|
||||
m = 'only EvalResults or dicts are allowed from validation_step'
|
||||
raise MisconfigurationException(m)
|
||||
|
||||
return output
|
||||
|
||||
def evaluation_step_end(self, *args, **kwargs):
|
||||
if self.testing:
|
||||
output = self.trainer.call_hook('test_step_end', *args, **kwargs)
|
||||
else:
|
||||
output = self.trainer.call_hook('validation_step_end', *args, **kwargs)
|
||||
return output
|
||||
|
||||
def evaluation_epoch_end(self, num_dataloaders):
|
||||
using_eval_result = self.is_using_eval_results()
|
||||
|
||||
# call the model epoch end
|
||||
eval_results = self.__run_eval_epoch_end(num_dataloaders, using_eval_result)
|
||||
return eval_results
|
||||
|
||||
def log_epoch_metrics(self, eval_results, test_mode):
|
||||
using_eval_result = self.is_using_eval_results()
|
||||
eval_loop_results = self.trainer.logger_connector.on_evaluation_epoch_end(
|
||||
eval_results,
|
||||
using_eval_result,
|
||||
test_mode
|
||||
)
|
||||
return eval_loop_results
|
||||
|
||||
def __run_eval_epoch_end(self, num_dataloaders, using_eval_result):
|
||||
model = self.trainer.get_model()
|
||||
|
||||
# with a single dataloader don't pass an array
|
||||
outputs = self.outputs
|
||||
eval_results = outputs
|
||||
if num_dataloaders == 1:
|
||||
eval_results = outputs[0]
|
||||
|
||||
user_reduced = False
|
||||
|
||||
if self.testing:
|
||||
if is_overridden('test_epoch_end', model=model):
|
||||
if using_eval_result:
|
||||
eval_results = self.__gather_epoch_end_eval_results(outputs)
|
||||
|
||||
eval_results = model.test_epoch_end(eval_results)
|
||||
user_reduced = True
|
||||
|
||||
else:
|
||||
if is_overridden('validation_epoch_end', model=model):
|
||||
if using_eval_result:
|
||||
eval_results = self.__gather_epoch_end_eval_results(outputs)
|
||||
|
||||
eval_results = model.validation_epoch_end(eval_results)
|
||||
user_reduced = True
|
||||
|
||||
if using_eval_result and not user_reduced:
|
||||
eval_results = self.__auto_reduce_result_objs(outputs)
|
||||
|
||||
if not isinstance(eval_results, list):
|
||||
eval_results = [eval_results]
|
||||
|
||||
return eval_results
|
||||
|
||||
def __gather_epoch_end_eval_results(self, outputs):
|
||||
eval_results = []
|
||||
for epoch_output in outputs:
|
||||
result = epoch_output[0].__class__.gather(epoch_output)
|
||||
if 'checkpoint_on' in result:
|
||||
result.checkpoint_on = result.checkpoint_on.mean()
|
||||
if 'early_stop_on' in result:
|
||||
result.early_stop_on = result.early_stop_on.mean()
|
||||
|
||||
eval_results.append(result)
|
||||
|
||||
# with 1 dataloader don't pass in a list
|
||||
if len(eval_results) == 1:
|
||||
eval_results = eval_results[0]
|
||||
return eval_results
|
||||
|
||||
def __auto_reduce_result_objs(self, outputs):
|
||||
# outputs has a list of results per dataloader
|
||||
eval_results = []
|
||||
for dl_output in outputs:
|
||||
result = dl_output[0]
|
||||
result = result.__class__.reduce_on_epoch_end(dl_output)
|
||||
if 'checkpoint_on' in result:
|
||||
result.checkpoint_on = result.checkpoint_on.mean()
|
||||
if 'early_stop_on' in result:
|
||||
result.early_stop_on = result.early_stop_on.mean()
|
||||
eval_results.append(result)
|
||||
|
||||
return eval_results
|
||||
|
||||
def on_evaluation_batch_start(self, *args, **kwargs):
|
||||
if self.testing:
|
||||
self.trainer.call_hook('on_test_batch_start', *args, **kwargs)
|
||||
else:
|
||||
self.trainer.call_hook('on_validation_batch_start', *args, **kwargs)
|
||||
|
||||
def on_evaluation_batch_end(self, *args, **kwargs):
|
||||
if self.testing:
|
||||
self.trainer.call_hook('on_test_batch_end', *args, **kwargs)
|
||||
else:
|
||||
self.trainer.call_hook('on_validation_batch_end', *args, **kwargs)
|
||||
|
||||
def evaluation_batch_end_cleanup(self, output, batch_idx, dataloader_idx):
|
||||
# Add step predictions to prediction collection to write later
|
||||
if output is not None:
|
||||
do_write_predictions = isinstance(output, Result) and self.testing
|
||||
if do_write_predictions:
|
||||
self.predictions.add(output.pop('predictions', None))
|
||||
|
||||
# track debug metrics
|
||||
self.trainer.dev_debugger.track_eval_loss_history(self.testing, batch_idx, dataloader_idx, output)
|
||||
|
||||
def on_evaluation_epoch_end(self, *args, **kwargs):
|
||||
# call the callback hook
|
||||
if self.testing:
|
||||
self.trainer.call_hook('on_test_epoch_end', *args, **kwargs)
|
||||
else:
|
||||
self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs)
|
||||
|
||||
def log_step_metrics(self, output, batch_idx):
|
||||
if self.trainer.running_sanity_check:
|
||||
return
|
||||
|
||||
if isinstance(output, EvalResult):
|
||||
step_log_metrics = output.batch_log_metrics
|
||||
step_pbar_metrics = output.batch_pbar_metrics
|
||||
|
||||
if len(step_log_metrics) > 0:
|
||||
# make the metrics appear as a different line in the same graph
|
||||
metrics_by_epoch = {}
|
||||
for k, v in step_log_metrics.items():
|
||||
metrics_by_epoch[f'{k}/epoch_{self.trainer.current_epoch}'] = v
|
||||
|
||||
self.trainer.logger_connector.log_metrics(metrics_by_epoch, {}, step=batch_idx)
|
||||
|
||||
if len(step_pbar_metrics) > 0:
|
||||
self.trainer.logger_connector.add_progress_bar_metrics(step_pbar_metrics)
|
|
@ -1,264 +1,278 @@
|
|||
"""
|
||||
Validation loop
|
||||
===============
|
||||
|
||||
The lightning validation loop handles everything except the actual computations of your model.
|
||||
To decide what will happen in your validation loop, define the `validation_step` function.
|
||||
Below are all the things lightning automates for you in the validation loop.
|
||||
|
||||
.. note:: Lightning will run 5 steps of validation in the beginning of training as a sanity
|
||||
check so you don't have to wait until a full epoch to catch possible validation issues.
|
||||
|
||||
Check validation every n epochs
|
||||
-------------------------------
|
||||
|
||||
If you have a small dataset you might want to check validation every n epochs
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# DEFAULT
|
||||
trainer = Trainer(check_val_every_n_epoch=1)
|
||||
|
||||
Set how much of the validation set to check
|
||||
-------------------------------------------
|
||||
|
||||
If you don't want to check 100% of the validation set (for debugging or if it's huge), set this flag.
|
||||
|
||||
limit_val_batches will be overwritten by overfit_batches if `overfit_batches > 0`
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# DEFAULT
|
||||
trainer = Trainer(limit_val_batches=1.0)
|
||||
|
||||
# check 10% only
|
||||
trainer = Trainer(limit_val_batches=0.1)
|
||||
|
||||
Set how much of the test set to check
|
||||
-------------------------------------
|
||||
|
||||
If you don't want to check 100% of the test set (for debugging or if it's huge), set this flag.
|
||||
|
||||
limit_test_batches will be overwritten by overfit_batches if `overfit_batches > 0`
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# DEFAULT
|
||||
trainer = Trainer(limit_test_batches=1.0)
|
||||
|
||||
# check 10% only
|
||||
trainer = Trainer(limit_test_batches=0.1)
|
||||
|
||||
Set validation check frequency within 1 training epoch
|
||||
------------------------------------------------------
|
||||
|
||||
For large datasets it's often desirable to check validation multiple times within a training loop.
|
||||
Pass in a float to check that often within 1 training epoch.
|
||||
Pass in an int k to check every k training batches. Must use an int if using an IterableDataset.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# DEFAULT
|
||||
trainer = Trainer(val_check_interval=0.95)
|
||||
|
||||
# check every .25 of an epoch
|
||||
trainer = Trainer(val_check_interval=0.25)
|
||||
|
||||
# check every 100 train batches (ie: for IterableDatasets or fixed frequency)
|
||||
trainer = Trainer(val_check_interval=100)
|
||||
|
||||
|
||||
Set the number of validation sanity steps
|
||||
-----------------------------------------
|
||||
|
||||
Lightning runs a few steps of validation in the beginning of training.
|
||||
This avoids crashing in the validation loop sometime deep into a lengthy training loop.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# DEFAULT
|
||||
trainer = Trainer(num_sanity_val_steps=2)
|
||||
|
||||
|
||||
You can use `Trainer(num_sanity_val_steps=0)` to skip the sanity check or `Trainer(num_sanity_val_steps=-1)`
|
||||
to check all the validation data.
|
||||
|
||||
# Testing loop
|
||||
|
||||
To ensure you don't accidentally use test data to guide training decisions Lightning
|
||||
makes running the test set deliberate.
|
||||
|
||||
**test**
|
||||
|
||||
You have two options to run the test set.
|
||||
First case is where you test right after a full training routine.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
# run full training
|
||||
trainer.fit(model)
|
||||
|
||||
# run test set
|
||||
trainer.test()
|
||||
|
||||
|
||||
Second case is where you load a model and run the test set
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
model = MyLightningModule.load_from_checkpoint(
|
||||
checkpoint_path='/path/to/pytorch_checkpoint.ckpt',
|
||||
hparams_file='/path/to/test_tube/experiment/version/hparams.yaml',
|
||||
map_location=None
|
||||
)
|
||||
|
||||
# init trainer with whatever options
|
||||
trainer = Trainer(...)
|
||||
|
||||
# test (pass in the model)
|
||||
trainer.test(model)
|
||||
|
||||
In this second case, the options you pass to trainer will be used when running
|
||||
the test set (ie: 16-bit, dp, ddp, etc...)
|
||||
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, List
|
||||
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.utilities import AMPType
|
||||
from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop
|
||||
from pytorch_lightning.trainer.logger_connector import LoggerConnector
|
||||
|
||||
|
||||
class TrainerEvaluationLoopMixin(ABC):
|
||||
|
||||
# this is just a summary on variables used in this abstract class,
|
||||
# the proper values/initialisation should be done in child class
|
||||
on_gpu: bool
|
||||
use_ddp: bool
|
||||
use_dp: bool
|
||||
use_ddp2: bool
|
||||
use_horovod: bool
|
||||
use_single_gpu: bool
|
||||
data_parallel_device_ids: ...
|
||||
model: LightningModule
|
||||
num_test_batches: List[int]
|
||||
num_val_batches: int
|
||||
world_size: int
|
||||
fast_dev_run: ...
|
||||
process_output: ...
|
||||
progress_bar_dict: ...
|
||||
global_rank: int
|
||||
current_epoch: int
|
||||
callback_metrics: ...
|
||||
test_dataloaders: DataLoader
|
||||
val_dataloaders: DataLoader
|
||||
use_tpu: bool
|
||||
reload_dataloaders_every_epoch: ...
|
||||
tpu_id: int
|
||||
verbose_test: bool
|
||||
running_sanity_check: bool
|
||||
amp_backend: AMPType
|
||||
logger_connector: LoggerConnector
|
||||
|
||||
# Callback system
|
||||
on_validation_batch_start: Callable
|
||||
on_validation_batch_end: Callable
|
||||
on_test_batch_start: Callable
|
||||
on_test_batch_end: Callable
|
||||
on_validation_start: Callable
|
||||
on_validation_end: Callable
|
||||
on_test_start: Callable
|
||||
on_test_end: Callable
|
||||
accelerator_backend: ...
|
||||
evaluation_loop: EvaluationLoop
|
||||
|
||||
@abstractmethod
|
||||
def get_model(self) -> LightningModule:
|
||||
"""Warning: this is just empty shell for code implemented in other class."""
|
||||
|
||||
@abstractmethod
|
||||
def call_hook(self, hook_name, *args, **kwargs):
|
||||
"""Warning: this is just empty shell for code implemented in other class."""
|
||||
|
||||
def run_evaluation(self, test_mode: bool = False, max_batches=None):
|
||||
# bookkeeping
|
||||
self.evaluation_loop.testing = test_mode
|
||||
dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders(max_batches)
|
||||
if self.evaluation_loop.should_skip_evaluation(dataloaders, max_batches):
|
||||
return [], []
|
||||
|
||||
# enable eval mode + no grads
|
||||
model = self.get_model()
|
||||
model.zero_grad()
|
||||
model.eval()
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
# hook
|
||||
self.evaluation_loop.on_evaluation_start()
|
||||
|
||||
# set up the eval loop
|
||||
self.evaluation_loop.setup(model, max_batches, dataloaders)
|
||||
|
||||
# hook
|
||||
# TODO: should this be insider the dataloader loop?
|
||||
self.evaluation_loop.on_evaluation_epoch_start()
|
||||
|
||||
# run validation/testing
|
||||
for dataloader_idx, dataloader in enumerate(dataloaders):
|
||||
# bookkeeping
|
||||
dl_outputs = []
|
||||
dataloader = self.accelerator_backend.process_dataloader(dataloader)
|
||||
dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx]
|
||||
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
if batch is None:
|
||||
continue
|
||||
|
||||
# stop short when running on limited batches
|
||||
if batch_idx >= dl_max_batches:
|
||||
break
|
||||
|
||||
# hook
|
||||
self.evaluation_loop.on_evaluation_batch_start(batch, batch_idx, dataloader_idx)
|
||||
|
||||
# lightning module methods
|
||||
output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx)
|
||||
output = self.evaluation_loop.evaluation_step_end(output)
|
||||
|
||||
# hook
|
||||
self.evaluation_loop.on_evaluation_batch_end(batch, batch_idx, dataloader_idx)
|
||||
|
||||
# clean up
|
||||
self.evaluation_loop.evaluation_batch_end_cleanup(output, batch_idx, dataloader_idx)
|
||||
self.evaluation_loop.log_step_metrics(output, batch_idx)
|
||||
|
||||
# track epoch level metrics
|
||||
if output is not None:
|
||||
dl_outputs.append(output)
|
||||
|
||||
self.evaluation_loop.outputs.append(dl_outputs)
|
||||
|
||||
# lightning module method
|
||||
eval_results = self.evaluation_loop.evaluation_epoch_end(num_dataloaders=len(dataloaders))
|
||||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pytorch_lightning.trainer.supporters import PredictionCollection
|
||||
from pytorch_lightning.core.step_result import Result, EvalResult
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.model_utils import is_overridden
|
||||
|
||||
|
||||
class EvaluationLoop(object):
|
||||
def __init__(self, trainer):
|
||||
self.trainer = trainer
|
||||
self.testing = False
|
||||
self.outputs = []
|
||||
self.predictions = None
|
||||
self.max_batches = None
|
||||
|
||||
def get_evaluation_dataloaders(self, max_batches):
|
||||
# select dataloaders
|
||||
model = self.trainer.get_model()
|
||||
|
||||
# select dataloaders
|
||||
if self.testing:
|
||||
self.trainer.reset_test_dataloader(model)
|
||||
|
||||
dataloaders = self.trainer.test_dataloaders
|
||||
new_max_batches = self.trainer.num_test_batches
|
||||
else:
|
||||
# val
|
||||
in_sanity_check = self.trainer.running_sanity_check
|
||||
should_reload_every_epoch = self.trainer.reload_dataloaders_every_epoch
|
||||
if (self.trainer.val_dataloaders is None or should_reload_every_epoch) and not in_sanity_check:
|
||||
self.trainer.reset_val_dataloader(model)
|
||||
|
||||
dataloaders = self.trainer.val_dataloaders
|
||||
new_max_batches = self.trainer.num_val_batches
|
||||
|
||||
if max_batches is None:
|
||||
max_batches = new_max_batches
|
||||
|
||||
return dataloaders, max_batches
|
||||
|
||||
def should_skip_evaluation(self, dataloaders, max_batches):
|
||||
# skip when dataloaders aren't defined
|
||||
if dataloaders is None:
|
||||
return True
|
||||
|
||||
# enable disabling validation step with limit_val_batches = 0
|
||||
should_skip = sum(max_batches) == 0
|
||||
if should_skip:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def on_evaluation_start(self, *args, **kwargs):
|
||||
if self.testing:
|
||||
self.trainer.call_hook('on_test_start', *args, **kwargs)
|
||||
else:
|
||||
self.trainer.call_hook('on_validation_start', *args, **kwargs)
|
||||
|
||||
def on_evaluation_end(self, *args, **kwargs):
|
||||
if self.testing:
|
||||
self.trainer.call_hook('on_test_end', *args, **kwargs)
|
||||
else:
|
||||
self.trainer.call_hook('on_validation_end', *args, **kwargs)
|
||||
|
||||
def reload_evaluation_dataloaders(self):
|
||||
model = self.trainer.get_model()
|
||||
if self.testing:
|
||||
self.trainer.reset_test_dataloader(model)
|
||||
else:
|
||||
self.trainer.reset_val_dataloader(model)
|
||||
|
||||
def is_using_eval_results(self):
|
||||
outputs = self.outputs
|
||||
using_eval_result = len(outputs) > 0 and len(outputs[0]) > 0 and isinstance(outputs[0][0], EvalResult)
|
||||
return using_eval_result
|
||||
|
||||
def setup(self, model, max_batches, dataloaders):
|
||||
# copy properties for forward overrides
|
||||
self.trainer.model_connector.copy_trainer_model_properties(model)
|
||||
|
||||
# bookkeeping
|
||||
eval_loop_results = self.evaluation_loop.log_epoch_metrics(eval_results, test_mode)
|
||||
self.evaluation_loop.predictions.to_disk()
|
||||
self.outputs = []
|
||||
self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size)
|
||||
|
||||
# hook
|
||||
self.evaluation_loop.on_evaluation_epoch_end()
|
||||
# convert max_batches to list
|
||||
if isinstance(max_batches, int):
|
||||
max_batches = [max_batches] * len(dataloaders)
|
||||
|
||||
# enable train mode again
|
||||
model.train()
|
||||
torch.set_grad_enabled(True)
|
||||
self.max_batches = max_batches
|
||||
|
||||
# hook
|
||||
self.evaluation_loop.on_evaluation_end()
|
||||
def on_evaluation_epoch_start(self, *args, **kwargs):
|
||||
if self.testing:
|
||||
self.trainer.call_hook('on_test_epoch_start', *args, **kwargs)
|
||||
else:
|
||||
self.trainer.call_hook('on_validation_epoch_start', *args, **kwargs)
|
||||
|
||||
return eval_loop_results, eval_results
|
||||
def build_args(self, test_mode, batch, batch_idx, dataloader_idx):
|
||||
# make dataloader_idx arg in validation_step optional
|
||||
args = [batch, batch_idx]
|
||||
|
||||
multiple_val_loaders = (not test_mode and len(self.trainer.val_dataloaders) > 1)
|
||||
multiple_test_loaders = (test_mode and len(self.trainer.test_dataloaders) > 1)
|
||||
|
||||
if multiple_test_loaders or multiple_val_loaders:
|
||||
args.append(dataloader_idx)
|
||||
|
||||
return args
|
||||
|
||||
def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx):
|
||||
# configure args
|
||||
args = self.build_args(test_mode, batch, batch_idx, dataloader_idx)
|
||||
|
||||
# run actual test step
|
||||
if self.testing:
|
||||
output = self.trainer.accelerator_backend.test_step(args)
|
||||
else:
|
||||
output = self.trainer.accelerator_backend.validation_step(args)
|
||||
|
||||
# track batch size for weighted average
|
||||
is_result_obj = isinstance(output, Result)
|
||||
if is_result_obj:
|
||||
output.track_batch_size(len(batch))
|
||||
|
||||
# allow only EvalResult when using structured results (from val_step)
|
||||
if is_result_obj and not isinstance(output, EvalResult):
|
||||
m = 'only EvalResults or dicts are allowed from validation_step'
|
||||
raise MisconfigurationException(m)
|
||||
|
||||
return output
|
||||
|
||||
def evaluation_step_end(self, *args, **kwargs):
|
||||
if self.testing:
|
||||
output = self.trainer.call_hook('test_step_end', *args, **kwargs)
|
||||
else:
|
||||
output = self.trainer.call_hook('validation_step_end', *args, **kwargs)
|
||||
return output
|
||||
|
||||
def evaluation_epoch_end(self, num_dataloaders):
|
||||
using_eval_result = self.is_using_eval_results()
|
||||
|
||||
# call the model epoch end
|
||||
eval_results = self.__run_eval_epoch_end(num_dataloaders, using_eval_result)
|
||||
return eval_results
|
||||
|
||||
def log_epoch_metrics(self, eval_results, test_mode):
|
||||
using_eval_result = self.is_using_eval_results()
|
||||
eval_loop_results = self.trainer.logger_connector.on_evaluation_epoch_end(
|
||||
eval_results,
|
||||
using_eval_result,
|
||||
test_mode
|
||||
)
|
||||
return eval_loop_results
|
||||
|
||||
def __run_eval_epoch_end(self, num_dataloaders, using_eval_result):
|
||||
model = self.trainer.get_model()
|
||||
|
||||
# with a single dataloader don't pass an array
|
||||
outputs = self.outputs
|
||||
eval_results = outputs
|
||||
if num_dataloaders == 1:
|
||||
eval_results = outputs[0]
|
||||
|
||||
user_reduced = False
|
||||
|
||||
if self.testing:
|
||||
if is_overridden('test_epoch_end', model=model):
|
||||
if using_eval_result:
|
||||
eval_results = self.__gather_epoch_end_eval_results(outputs)
|
||||
|
||||
eval_results = model.test_epoch_end(eval_results)
|
||||
user_reduced = True
|
||||
|
||||
else:
|
||||
if is_overridden('validation_epoch_end', model=model):
|
||||
if using_eval_result:
|
||||
eval_results = self.__gather_epoch_end_eval_results(outputs)
|
||||
|
||||
eval_results = model.validation_epoch_end(eval_results)
|
||||
user_reduced = True
|
||||
|
||||
if using_eval_result and not user_reduced:
|
||||
eval_results = self.__auto_reduce_result_objs(outputs)
|
||||
|
||||
if not isinstance(eval_results, list):
|
||||
eval_results = [eval_results]
|
||||
|
||||
return eval_results
|
||||
|
||||
def __gather_epoch_end_eval_results(self, outputs):
|
||||
eval_results = []
|
||||
for epoch_output in outputs:
|
||||
result = epoch_output[0].__class__.gather(epoch_output)
|
||||
if 'checkpoint_on' in result:
|
||||
result.checkpoint_on = result.checkpoint_on.mean()
|
||||
if 'early_stop_on' in result:
|
||||
result.early_stop_on = result.early_stop_on.mean()
|
||||
|
||||
eval_results.append(result)
|
||||
|
||||
# with 1 dataloader don't pass in a list
|
||||
if len(eval_results) == 1:
|
||||
eval_results = eval_results[0]
|
||||
return eval_results
|
||||
|
||||
def __auto_reduce_result_objs(self, outputs):
|
||||
# outputs has a list of results per dataloader
|
||||
eval_results = []
|
||||
for dl_output in outputs:
|
||||
result = dl_output[0]
|
||||
result = result.__class__.reduce_on_epoch_end(dl_output)
|
||||
if 'checkpoint_on' in result:
|
||||
result.checkpoint_on = result.checkpoint_on.mean()
|
||||
if 'early_stop_on' in result:
|
||||
result.early_stop_on = result.early_stop_on.mean()
|
||||
eval_results.append(result)
|
||||
|
||||
return eval_results
|
||||
|
||||
def on_evaluation_batch_start(self, *args, **kwargs):
|
||||
if self.testing:
|
||||
self.trainer.call_hook('on_test_batch_start', *args, **kwargs)
|
||||
else:
|
||||
self.trainer.call_hook('on_validation_batch_start', *args, **kwargs)
|
||||
|
||||
def on_evaluation_batch_end(self, *args, **kwargs):
|
||||
if self.testing:
|
||||
self.trainer.call_hook('on_test_batch_end', *args, **kwargs)
|
||||
else:
|
||||
self.trainer.call_hook('on_validation_batch_end', *args, **kwargs)
|
||||
|
||||
def evaluation_batch_end_cleanup(self, output, batch_idx, dataloader_idx):
|
||||
# Add step predictions to prediction collection to write later
|
||||
if output is not None:
|
||||
do_write_predictions = isinstance(output, Result) and self.testing
|
||||
if do_write_predictions:
|
||||
self.predictions.add(output.pop('predictions', None))
|
||||
|
||||
# track debug metrics
|
||||
self.trainer.dev_debugger.track_eval_loss_history(self.testing, batch_idx, dataloader_idx, output)
|
||||
|
||||
def on_evaluation_epoch_end(self, *args, **kwargs):
|
||||
# call the callback hook
|
||||
if self.testing:
|
||||
self.trainer.call_hook('on_test_epoch_end', *args, **kwargs)
|
||||
else:
|
||||
self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs)
|
||||
|
||||
def log_step_metrics(self, output, batch_idx):
|
||||
if self.trainer.running_sanity_check:
|
||||
return
|
||||
|
||||
if isinstance(output, EvalResult):
|
||||
step_log_metrics = output.batch_log_metrics
|
||||
step_pbar_metrics = output.batch_pbar_metrics
|
||||
|
||||
if len(step_log_metrics) > 0:
|
||||
# make the metrics appear as a different line in the same graph
|
||||
metrics_by_epoch = {}
|
||||
for k, v in step_log_metrics.items():
|
||||
metrics_by_epoch[f'{k}/epoch_{self.trainer.current_epoch}'] = v
|
||||
|
||||
self.trainer.logger_connector.log_metrics(metrics_by_epoch, {}, step=batch_idx)
|
||||
|
||||
if len(step_pbar_metrics) > 0:
|
||||
self.trainer.logger_connector.add_progress_bar_metrics(step_pbar_metrics)
|
||||
|
|
|
@ -12,22 +12,25 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC
|
||||
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVALAIBLE, rank_zero_warn, AMPType
|
||||
|
||||
|
||||
class TrainerAMPMixin(ABC):
|
||||
class Initializer:
|
||||
|
||||
# this is just a summary on variables used in this abstract class,
|
||||
# the proper values/initialisation should be done in child class
|
||||
precision: int
|
||||
def __init__(self, trainer):
|
||||
self.trainer = trainer
|
||||
|
||||
def init_amp(self, amp_type: str):
|
||||
assert self.trainer.precision in (16, 32), 'only 32 or 16 bit precision supported'
|
||||
self.trainer.amp_backend = None
|
||||
self._setup_amp_backend(amp_type)
|
||||
|
||||
def _setup_amp_backend(self, amp_type: str):
|
||||
if self.precision != 16:
|
||||
if self.trainer.precision != 16:
|
||||
# no AMP requested, so we can leave now
|
||||
return
|
||||
|
||||
amp_type = amp_type.lower()
|
||||
assert amp_type in ('native', 'apex'), f'Unsupported amp type {amp_type}'
|
||||
if amp_type == 'native':
|
||||
|
@ -38,20 +41,16 @@ class TrainerAMPMixin(ABC):
|
|||
amp_type = 'apex'
|
||||
else:
|
||||
log.info('Using native 16bit precision.')
|
||||
self.amp_backend = AMPType.NATIVE
|
||||
self.trainer.amp_backend = AMPType.NATIVE
|
||||
if amp_type == 'apex':
|
||||
if not APEX_AVAILABLE:
|
||||
rank_zero_warn('You have asked for Apex AMP but you have not installed it yet.'
|
||||
' Install apex first using this guide: https://github.com/NVIDIA/apex#linux')
|
||||
else:
|
||||
log.info('Using APEX 16bit precision.')
|
||||
self.amp_backend = AMPType.APEX
|
||||
if not self.amp_backend:
|
||||
self.trainer.amp_backend = AMPType.APEX
|
||||
if not self.trainer.amp_backend:
|
||||
raise ModuleNotFoundError(
|
||||
f'You have asked for AMP support {amp_type}, but there is no support on your side yet.'
|
||||
f' Consider installing torch >= 1.6 or NVIDIA Apex.'
|
||||
)
|
||||
|
||||
@property
|
||||
def use_amp(self) -> bool:
|
||||
return self.precision == 16
|
|
@ -29,7 +29,6 @@ from pytorch_lightning.core.memory import ModelSummary
|
|||
from pytorch_lightning.core.step_result import EvalResult
|
||||
from pytorch_lightning.loggers import LightningLoggerBase
|
||||
from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler, SimpleProfiler
|
||||
from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin
|
||||
from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin
|
||||
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
|
||||
from pytorch_lightning.trainer.configuration_validator import ConfigValidator
|
||||
|
@ -37,7 +36,6 @@ from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
|
|||
from pytorch_lightning.trainer.deprecated_api import TrainerDeprecatedAPITillVer0_10
|
||||
from pytorch_lightning.trainer.distrib_data_parallel import TrainerDDPMixin
|
||||
from pytorch_lightning.utilities import device_parser
|
||||
from pytorch_lightning.trainer.evaluation_loop import TrainerEvaluationLoopMixin
|
||||
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
|
||||
from pytorch_lightning.trainer.lr_finder import TrainerLRFinderMixin
|
||||
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
|
||||
|
@ -50,15 +48,16 @@ from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only,
|
|||
from pytorch_lightning.utilities.debugging import InternalDebugger
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.cloud_io import get_filesystem
|
||||
from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop
|
||||
from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop
|
||||
from pytorch_lightning.trainer.training_loop import TrainLoop
|
||||
from pytorch_lightning.trainer.data_connector import DataConnector
|
||||
from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector
|
||||
from pytorch_lightning.trainer.logger_connector import LoggerConnector
|
||||
from pytorch_lightning.trainer.lr_scheduler_connector import LRSchedulerConnector
|
||||
from pytorch_lightning.trainer.training_loop import TrainLoop
|
||||
from pytorch_lightning.trainer.model_connector import ModelConnector
|
||||
from pytorch_lightning import _logger as log
|
||||
from pytorch_lightning.tuner.tuning import Tuner
|
||||
from pytorch_lightning.trainer.initializer import Initializer
|
||||
from pytorch_lightning.utilities.model_utils import is_overridden
|
||||
|
||||
# warnings to ignore in trainer
|
||||
|
@ -93,12 +92,10 @@ class Trainer(
|
|||
TrainerCallbackHookMixin,
|
||||
TrainerModelHooksMixin,
|
||||
TrainerOptimizersMixin,
|
||||
TrainerAMPMixin,
|
||||
TrainerDDPMixin,
|
||||
TrainerLoggingMixin,
|
||||
TrainerTrainingTricksMixin,
|
||||
TrainerDataLoadingMixin,
|
||||
TrainerEvaluationLoopMixin,
|
||||
TrainerCallbackConfigMixin,
|
||||
TrainerLRFinderMixin,
|
||||
TrainerDeprecatedAPITillVer0_10,
|
||||
|
@ -380,6 +377,7 @@ class Trainer(
|
|||
self.accelerator_connector = AcceleratorConnector(self)
|
||||
self.logger_connector = LoggerConnector(self)
|
||||
self.model_connector = ModelConnector(self)
|
||||
self.initializer = Initializer(self)
|
||||
self.tuner = Tuner(self)
|
||||
self.accelerator_backend = None
|
||||
|
||||
|
@ -615,13 +613,17 @@ class Trainer(
|
|||
self.scaler = None
|
||||
|
||||
self.amp_level = amp_level
|
||||
self.init_amp(amp_backend)
|
||||
self.initializer.init_amp(amp_backend)
|
||||
|
||||
self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE')
|
||||
|
||||
# Callback system
|
||||
self.on_init_end()
|
||||
|
||||
@property
|
||||
def use_amp(self) -> bool:
|
||||
return self.precision == 16
|
||||
|
||||
@property
|
||||
def callback_metrics(self):
|
||||
return self.logger_connector.callback_metrics
|
||||
|
@ -1209,6 +1211,83 @@ class Trainer(
|
|||
# hook
|
||||
self.train_loop.on_train_end()
|
||||
|
||||
def run_evaluation(self, test_mode: bool = False, max_batches=None):
|
||||
# bookkeeping
|
||||
self.evaluation_loop.testing = test_mode
|
||||
dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders(max_batches)
|
||||
if self.evaluation_loop.should_skip_evaluation(dataloaders, max_batches):
|
||||
return [], []
|
||||
|
||||
# enable eval mode + no grads
|
||||
model = self.get_model()
|
||||
model.zero_grad()
|
||||
model.eval()
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
# hook
|
||||
self.evaluation_loop.on_evaluation_start()
|
||||
|
||||
# set up the eval loop
|
||||
self.evaluation_loop.setup(model, max_batches, dataloaders)
|
||||
|
||||
# hook
|
||||
# TODO: should this be insider the dataloader loop?
|
||||
self.evaluation_loop.on_evaluation_epoch_start()
|
||||
|
||||
# run validation/testing
|
||||
for dataloader_idx, dataloader in enumerate(dataloaders):
|
||||
# bookkeeping
|
||||
dl_outputs = []
|
||||
dataloader = self.accelerator_backend.process_dataloader(dataloader)
|
||||
dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx]
|
||||
|
||||
for batch_idx, batch in enumerate(dataloader):
|
||||
if batch is None:
|
||||
continue
|
||||
|
||||
# stop short when running on limited batches
|
||||
if batch_idx >= dl_max_batches:
|
||||
break
|
||||
|
||||
# hook
|
||||
self.evaluation_loop.on_evaluation_batch_start(batch, batch_idx, dataloader_idx)
|
||||
|
||||
# lightning module methods
|
||||
output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx)
|
||||
output = self.evaluation_loop.evaluation_step_end(output)
|
||||
|
||||
# hook
|
||||
self.evaluation_loop.on_evaluation_batch_end(batch, batch_idx, dataloader_idx)
|
||||
|
||||
# clean up
|
||||
self.evaluation_loop.evaluation_batch_end_cleanup(output, batch_idx, dataloader_idx)
|
||||
self.evaluation_loop.log_step_metrics(output, batch_idx)
|
||||
|
||||
# track epoch level metrics
|
||||
if output is not None:
|
||||
dl_outputs.append(output)
|
||||
|
||||
self.evaluation_loop.outputs.append(dl_outputs)
|
||||
|
||||
# lightning module method
|
||||
eval_results = self.evaluation_loop.evaluation_epoch_end(num_dataloaders=len(dataloaders))
|
||||
|
||||
# bookkeeping
|
||||
eval_loop_results = self.evaluation_loop.log_epoch_metrics(eval_results, test_mode)
|
||||
self.evaluation_loop.predictions.to_disk()
|
||||
|
||||
# hook
|
||||
self.evaluation_loop.on_evaluation_epoch_end()
|
||||
|
||||
# enable train mode again
|
||||
model.train()
|
||||
torch.set_grad_enabled(True)
|
||||
|
||||
# hook
|
||||
self.evaluation_loop.on_evaluation_end()
|
||||
|
||||
return eval_loop_results, eval_results
|
||||
|
||||
def run_test(self):
|
||||
# only load test dataloader for testing
|
||||
# self.reset_test_dataloader(ref_model)
|
||||
|
@ -1420,15 +1499,6 @@ class Trainer(
|
|||
|
||||
return results
|
||||
|
||||
def barrier(self, name):
|
||||
if self.use_ddp or self.use_ddp2:
|
||||
pass
|
||||
# torch_distrib.barrier()
|
||||
|
||||
if self.on_tpu and XLA_AVAILABLE:
|
||||
# wait for all processes to catch up
|
||||
torch_xla.core.xla_model.rendezvous(f'pl.Trainer.{name}')
|
||||
|
||||
def call_setup_hook(self, model):
|
||||
# call setup after the ddp process has connected
|
||||
stage_name = 'test' if self.testing else 'fit'
|
||||
|
@ -1439,11 +1509,6 @@ class Trainer(
|
|||
self.setup(stage_name)
|
||||
model.setup(stage_name)
|
||||
|
||||
def init_amp(self, amp_type: str):
|
||||
assert self.precision in (16, 32), 'only 32 or 16 bit precision supported'
|
||||
self.amp_backend = None
|
||||
self._setup_amp_backend(amp_type)
|
||||
|
||||
def call_hook(self, hook_name, *args, **kwargs):
|
||||
# always profile hooks
|
||||
with self.profiler.profile(hook_name):
|
||||
|
|
Loading…
Reference in New Issue