diff --git a/pytorch_lightning/trainer/evaluate_loop.py b/pytorch_lightning/trainer/evaluate_loop.py deleted file mode 100644 index 824cd36512..0000000000 --- a/pytorch_lightning/trainer/evaluate_loop.py +++ /dev/null @@ -1,264 +0,0 @@ -from pytorch_lightning.trainer.supporters import PredictionCollection -from pytorch_lightning.core.step_result import Result, EvalResult -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.model_utils import is_overridden - - -class EvaluationLoop(object): - def __init__(self, trainer): - self.trainer = trainer - self.testing = False - self.outputs = [] - self.predictions = None - self.max_batches = None - - def get_evaluation_dataloaders(self, max_batches): - # select dataloaders - model = self.trainer.get_model() - - # select dataloaders - if self.testing: - self.trainer.reset_test_dataloader(model) - - dataloaders = self.trainer.test_dataloaders - new_max_batches = self.trainer.num_test_batches - else: - # val - in_sanity_check = self.trainer.running_sanity_check - should_reload_every_epoch = self.trainer.reload_dataloaders_every_epoch - if (self.trainer.val_dataloaders is None or should_reload_every_epoch) and not in_sanity_check: - self.trainer.reset_val_dataloader(model) - - dataloaders = self.trainer.val_dataloaders - new_max_batches = self.trainer.num_val_batches - - if max_batches is None: - max_batches = new_max_batches - - return dataloaders, max_batches - - def should_skip_evaluation(self, dataloaders, max_batches): - # skip when dataloaders aren't defined - if dataloaders is None: - return True - - # enable disabling validation step with limit_val_batches = 0 - should_skip = sum(max_batches) == 0 - if should_skip: - return True - - return False - - def on_evaluation_start(self, *args, **kwargs): - if self.testing: - self.trainer.call_hook('on_test_start', *args, **kwargs) - else: - self.trainer.call_hook('on_validation_start', *args, **kwargs) - - def on_evaluation_end(self, *args, **kwargs): - if self.testing: - self.trainer.call_hook('on_test_end', *args, **kwargs) - else: - self.trainer.call_hook('on_validation_end', *args, **kwargs) - - def reload_evaluation_dataloaders(self): - model = self.trainer.get_model() - if self.testing: - self.trainer.reset_test_dataloader(model) - else: - self.trainer.reset_val_dataloader(model) - - def is_using_eval_results(self): - outputs = self.outputs - using_eval_result = len(outputs) > 0 and len(outputs[0]) > 0 and isinstance(outputs[0][0], EvalResult) - return using_eval_result - - def setup(self, model, max_batches, dataloaders): - # copy properties for forward overrides - self.trainer.model_connector.copy_trainer_model_properties(model) - - # bookkeeping - self.outputs = [] - self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size) - - # convert max_batches to list - if isinstance(max_batches, int): - max_batches = [max_batches] * len(dataloaders) - - self.max_batches = max_batches - - def on_evaluation_epoch_start(self, *args, **kwargs): - if self.testing: - self.trainer.call_hook('on_test_epoch_start', *args, **kwargs) - else: - self.trainer.call_hook('on_validation_epoch_start', *args, **kwargs) - - def build_args(self, test_mode, batch, batch_idx, dataloader_idx): - # make dataloader_idx arg in validation_step optional - args = [batch, batch_idx] - - multiple_val_loaders = (not test_mode and len(self.trainer.val_dataloaders) > 1) - multiple_test_loaders = (test_mode and len(self.trainer.test_dataloaders) > 1) - - if multiple_test_loaders or multiple_val_loaders: - args.append(dataloader_idx) - - return args - - def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx): - # configure args - args = self.build_args(test_mode, batch, batch_idx, dataloader_idx) - - # run actual test step - if self.testing: - output = self.trainer.accelerator_backend.test_step(args) - else: - output = self.trainer.accelerator_backend.validation_step(args) - - # track batch size for weighted average - is_result_obj = isinstance(output, Result) - if is_result_obj: - output.track_batch_size(len(batch)) - - # allow only EvalResult when using structured results (from val_step) - if is_result_obj and not isinstance(output, EvalResult): - m = 'only EvalResults or dicts are allowed from validation_step' - raise MisconfigurationException(m) - - return output - - def evaluation_step_end(self, *args, **kwargs): - if self.testing: - output = self.trainer.call_hook('test_step_end', *args, **kwargs) - else: - output = self.trainer.call_hook('validation_step_end', *args, **kwargs) - return output - - def evaluation_epoch_end(self, num_dataloaders): - using_eval_result = self.is_using_eval_results() - - # call the model epoch end - eval_results = self.__run_eval_epoch_end(num_dataloaders, using_eval_result) - return eval_results - - def log_epoch_metrics(self, eval_results, test_mode): - using_eval_result = self.is_using_eval_results() - eval_loop_results = self.trainer.logger_connector.on_evaluation_epoch_end( - eval_results, - using_eval_result, - test_mode - ) - return eval_loop_results - - def __run_eval_epoch_end(self, num_dataloaders, using_eval_result): - model = self.trainer.get_model() - - # with a single dataloader don't pass an array - outputs = self.outputs - eval_results = outputs - if num_dataloaders == 1: - eval_results = outputs[0] - - user_reduced = False - - if self.testing: - if is_overridden('test_epoch_end', model=model): - if using_eval_result: - eval_results = self.__gather_epoch_end_eval_results(outputs) - - eval_results = model.test_epoch_end(eval_results) - user_reduced = True - - else: - if is_overridden('validation_epoch_end', model=model): - if using_eval_result: - eval_results = self.__gather_epoch_end_eval_results(outputs) - - eval_results = model.validation_epoch_end(eval_results) - user_reduced = True - - if using_eval_result and not user_reduced: - eval_results = self.__auto_reduce_result_objs(outputs) - - if not isinstance(eval_results, list): - eval_results = [eval_results] - - return eval_results - - def __gather_epoch_end_eval_results(self, outputs): - eval_results = [] - for epoch_output in outputs: - result = epoch_output[0].__class__.gather(epoch_output) - if 'checkpoint_on' in result: - result.checkpoint_on = result.checkpoint_on.mean() - if 'early_stop_on' in result: - result.early_stop_on = result.early_stop_on.mean() - - eval_results.append(result) - - # with 1 dataloader don't pass in a list - if len(eval_results) == 1: - eval_results = eval_results[0] - return eval_results - - def __auto_reduce_result_objs(self, outputs): - # outputs has a list of results per dataloader - eval_results = [] - for dl_output in outputs: - result = dl_output[0] - result = result.__class__.reduce_on_epoch_end(dl_output) - if 'checkpoint_on' in result: - result.checkpoint_on = result.checkpoint_on.mean() - if 'early_stop_on' in result: - result.early_stop_on = result.early_stop_on.mean() - eval_results.append(result) - - return eval_results - - def on_evaluation_batch_start(self, *args, **kwargs): - if self.testing: - self.trainer.call_hook('on_test_batch_start', *args, **kwargs) - else: - self.trainer.call_hook('on_validation_batch_start', *args, **kwargs) - - def on_evaluation_batch_end(self, *args, **kwargs): - if self.testing: - self.trainer.call_hook('on_test_batch_end', *args, **kwargs) - else: - self.trainer.call_hook('on_validation_batch_end', *args, **kwargs) - - def evaluation_batch_end_cleanup(self, output, batch_idx, dataloader_idx): - # Add step predictions to prediction collection to write later - if output is not None: - do_write_predictions = isinstance(output, Result) and self.testing - if do_write_predictions: - self.predictions.add(output.pop('predictions', None)) - - # track debug metrics - self.trainer.dev_debugger.track_eval_loss_history(self.testing, batch_idx, dataloader_idx, output) - - def on_evaluation_epoch_end(self, *args, **kwargs): - # call the callback hook - if self.testing: - self.trainer.call_hook('on_test_epoch_end', *args, **kwargs) - else: - self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs) - - def log_step_metrics(self, output, batch_idx): - if self.trainer.running_sanity_check: - return - - if isinstance(output, EvalResult): - step_log_metrics = output.batch_log_metrics - step_pbar_metrics = output.batch_pbar_metrics - - if len(step_log_metrics) > 0: - # make the metrics appear as a different line in the same graph - metrics_by_epoch = {} - for k, v in step_log_metrics.items(): - metrics_by_epoch[f'{k}/epoch_{self.trainer.current_epoch}'] = v - - self.trainer.logger_connector.log_metrics(metrics_by_epoch, {}, step=batch_idx) - - if len(step_pbar_metrics) > 0: - self.trainer.logger_connector.add_progress_bar_metrics(step_pbar_metrics) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index c8b5ea3312..161c3ad599 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -1,264 +1,278 @@ -""" -Validation loop -=============== - -The lightning validation loop handles everything except the actual computations of your model. -To decide what will happen in your validation loop, define the `validation_step` function. -Below are all the things lightning automates for you in the validation loop. - -.. note:: Lightning will run 5 steps of validation in the beginning of training as a sanity - check so you don't have to wait until a full epoch to catch possible validation issues. - -Check validation every n epochs -------------------------------- - -If you have a small dataset you might want to check validation every n epochs - -.. code-block:: python - - # DEFAULT - trainer = Trainer(check_val_every_n_epoch=1) - -Set how much of the validation set to check -------------------------------------------- - -If you don't want to check 100% of the validation set (for debugging or if it's huge), set this flag. - -limit_val_batches will be overwritten by overfit_batches if `overfit_batches > 0` - -.. code-block:: python - - # DEFAULT - trainer = Trainer(limit_val_batches=1.0) - - # check 10% only - trainer = Trainer(limit_val_batches=0.1) - -Set how much of the test set to check -------------------------------------- - -If you don't want to check 100% of the test set (for debugging or if it's huge), set this flag. - -limit_test_batches will be overwritten by overfit_batches if `overfit_batches > 0` - -.. code-block:: python - - # DEFAULT - trainer = Trainer(limit_test_batches=1.0) - - # check 10% only - trainer = Trainer(limit_test_batches=0.1) - -Set validation check frequency within 1 training epoch ------------------------------------------------------- - -For large datasets it's often desirable to check validation multiple times within a training loop. - Pass in a float to check that often within 1 training epoch. - Pass in an int k to check every k training batches. Must use an int if using an IterableDataset. - -.. code-block:: python - - # DEFAULT - trainer = Trainer(val_check_interval=0.95) - - # check every .25 of an epoch - trainer = Trainer(val_check_interval=0.25) - - # check every 100 train batches (ie: for IterableDatasets or fixed frequency) - trainer = Trainer(val_check_interval=100) - - -Set the number of validation sanity steps ------------------------------------------ - -Lightning runs a few steps of validation in the beginning of training. -This avoids crashing in the validation loop sometime deep into a lengthy training loop. - -.. code-block:: python - - # DEFAULT - trainer = Trainer(num_sanity_val_steps=2) - - -You can use `Trainer(num_sanity_val_steps=0)` to skip the sanity check or `Trainer(num_sanity_val_steps=-1)` -to check all the validation data. - -# Testing loop - -To ensure you don't accidentally use test data to guide training decisions Lightning - makes running the test set deliberate. - -**test** - -You have two options to run the test set. -First case is where you test right after a full training routine. - -.. code-block:: python - - # run full training - trainer.fit(model) - - # run test set - trainer.test() - - -Second case is where you load a model and run the test set - -.. code-block:: python - - model = MyLightningModule.load_from_checkpoint( - checkpoint_path='/path/to/pytorch_checkpoint.ckpt', - hparams_file='/path/to/test_tube/experiment/version/hparams.yaml', - map_location=None - ) - - # init trainer with whatever options - trainer = Trainer(...) - - # test (pass in the model) - trainer.test(model) - -In this second case, the options you pass to trainer will be used when running - the test set (ie: 16-bit, dp, ddp, etc...) - -""" - -from abc import ABC, abstractmethod -from typing import Callable, List - -import torch -from torch.utils.data import DataLoader - -from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.utilities import AMPType -from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop -from pytorch_lightning.trainer.logger_connector import LoggerConnector - - -class TrainerEvaluationLoopMixin(ABC): - - # this is just a summary on variables used in this abstract class, - # the proper values/initialisation should be done in child class - on_gpu: bool - use_ddp: bool - use_dp: bool - use_ddp2: bool - use_horovod: bool - use_single_gpu: bool - data_parallel_device_ids: ... - model: LightningModule - num_test_batches: List[int] - num_val_batches: int - world_size: int - fast_dev_run: ... - process_output: ... - progress_bar_dict: ... - global_rank: int - current_epoch: int - callback_metrics: ... - test_dataloaders: DataLoader - val_dataloaders: DataLoader - use_tpu: bool - reload_dataloaders_every_epoch: ... - tpu_id: int - verbose_test: bool - running_sanity_check: bool - amp_backend: AMPType - logger_connector: LoggerConnector - - # Callback system - on_validation_batch_start: Callable - on_validation_batch_end: Callable - on_test_batch_start: Callable - on_test_batch_end: Callable - on_validation_start: Callable - on_validation_end: Callable - on_test_start: Callable - on_test_end: Callable - accelerator_backend: ... - evaluation_loop: EvaluationLoop - - @abstractmethod - def get_model(self) -> LightningModule: - """Warning: this is just empty shell for code implemented in other class.""" - - @abstractmethod - def call_hook(self, hook_name, *args, **kwargs): - """Warning: this is just empty shell for code implemented in other class.""" - - def run_evaluation(self, test_mode: bool = False, max_batches=None): - # bookkeeping - self.evaluation_loop.testing = test_mode - dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders(max_batches) - if self.evaluation_loop.should_skip_evaluation(dataloaders, max_batches): - return [], [] - - # enable eval mode + no grads - model = self.get_model() - model.zero_grad() - model.eval() - torch.set_grad_enabled(False) - - # hook - self.evaluation_loop.on_evaluation_start() - - # set up the eval loop - self.evaluation_loop.setup(model, max_batches, dataloaders) - - # hook - # TODO: should this be insider the dataloader loop? - self.evaluation_loop.on_evaluation_epoch_start() - - # run validation/testing - for dataloader_idx, dataloader in enumerate(dataloaders): - # bookkeeping - dl_outputs = [] - dataloader = self.accelerator_backend.process_dataloader(dataloader) - dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx] - - for batch_idx, batch in enumerate(dataloader): - if batch is None: - continue - - # stop short when running on limited batches - if batch_idx >= dl_max_batches: - break - - # hook - self.evaluation_loop.on_evaluation_batch_start(batch, batch_idx, dataloader_idx) - - # lightning module methods - output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx) - output = self.evaluation_loop.evaluation_step_end(output) - - # hook - self.evaluation_loop.on_evaluation_batch_end(batch, batch_idx, dataloader_idx) - - # clean up - self.evaluation_loop.evaluation_batch_end_cleanup(output, batch_idx, dataloader_idx) - self.evaluation_loop.log_step_metrics(output, batch_idx) - - # track epoch level metrics - if output is not None: - dl_outputs.append(output) - - self.evaluation_loop.outputs.append(dl_outputs) - - # lightning module method - eval_results = self.evaluation_loop.evaluation_epoch_end(num_dataloaders=len(dataloaders)) +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pytorch_lightning.trainer.supporters import PredictionCollection +from pytorch_lightning.core.step_result import Result, EvalResult +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.model_utils import is_overridden + + +class EvaluationLoop(object): + def __init__(self, trainer): + self.trainer = trainer + self.testing = False + self.outputs = [] + self.predictions = None + self.max_batches = None + + def get_evaluation_dataloaders(self, max_batches): + # select dataloaders + model = self.trainer.get_model() + + # select dataloaders + if self.testing: + self.trainer.reset_test_dataloader(model) + + dataloaders = self.trainer.test_dataloaders + new_max_batches = self.trainer.num_test_batches + else: + # val + in_sanity_check = self.trainer.running_sanity_check + should_reload_every_epoch = self.trainer.reload_dataloaders_every_epoch + if (self.trainer.val_dataloaders is None or should_reload_every_epoch) and not in_sanity_check: + self.trainer.reset_val_dataloader(model) + + dataloaders = self.trainer.val_dataloaders + new_max_batches = self.trainer.num_val_batches + + if max_batches is None: + max_batches = new_max_batches + + return dataloaders, max_batches + + def should_skip_evaluation(self, dataloaders, max_batches): + # skip when dataloaders aren't defined + if dataloaders is None: + return True + + # enable disabling validation step with limit_val_batches = 0 + should_skip = sum(max_batches) == 0 + if should_skip: + return True + + return False + + def on_evaluation_start(self, *args, **kwargs): + if self.testing: + self.trainer.call_hook('on_test_start', *args, **kwargs) + else: + self.trainer.call_hook('on_validation_start', *args, **kwargs) + + def on_evaluation_end(self, *args, **kwargs): + if self.testing: + self.trainer.call_hook('on_test_end', *args, **kwargs) + else: + self.trainer.call_hook('on_validation_end', *args, **kwargs) + + def reload_evaluation_dataloaders(self): + model = self.trainer.get_model() + if self.testing: + self.trainer.reset_test_dataloader(model) + else: + self.trainer.reset_val_dataloader(model) + + def is_using_eval_results(self): + outputs = self.outputs + using_eval_result = len(outputs) > 0 and len(outputs[0]) > 0 and isinstance(outputs[0][0], EvalResult) + return using_eval_result + + def setup(self, model, max_batches, dataloaders): + # copy properties for forward overrides + self.trainer.model_connector.copy_trainer_model_properties(model) # bookkeeping - eval_loop_results = self.evaluation_loop.log_epoch_metrics(eval_results, test_mode) - self.evaluation_loop.predictions.to_disk() + self.outputs = [] + self.predictions = PredictionCollection(self.trainer.global_rank, self.trainer.world_size) - # hook - self.evaluation_loop.on_evaluation_epoch_end() + # convert max_batches to list + if isinstance(max_batches, int): + max_batches = [max_batches] * len(dataloaders) - # enable train mode again - model.train() - torch.set_grad_enabled(True) + self.max_batches = max_batches - # hook - self.evaluation_loop.on_evaluation_end() + def on_evaluation_epoch_start(self, *args, **kwargs): + if self.testing: + self.trainer.call_hook('on_test_epoch_start', *args, **kwargs) + else: + self.trainer.call_hook('on_validation_epoch_start', *args, **kwargs) - return eval_loop_results, eval_results + def build_args(self, test_mode, batch, batch_idx, dataloader_idx): + # make dataloader_idx arg in validation_step optional + args = [batch, batch_idx] + + multiple_val_loaders = (not test_mode and len(self.trainer.val_dataloaders) > 1) + multiple_test_loaders = (test_mode and len(self.trainer.test_dataloaders) > 1) + + if multiple_test_loaders or multiple_val_loaders: + args.append(dataloader_idx) + + return args + + def evaluation_step(self, test_mode, batch, batch_idx, dataloader_idx): + # configure args + args = self.build_args(test_mode, batch, batch_idx, dataloader_idx) + + # run actual test step + if self.testing: + output = self.trainer.accelerator_backend.test_step(args) + else: + output = self.trainer.accelerator_backend.validation_step(args) + + # track batch size for weighted average + is_result_obj = isinstance(output, Result) + if is_result_obj: + output.track_batch_size(len(batch)) + + # allow only EvalResult when using structured results (from val_step) + if is_result_obj and not isinstance(output, EvalResult): + m = 'only EvalResults or dicts are allowed from validation_step' + raise MisconfigurationException(m) + + return output + + def evaluation_step_end(self, *args, **kwargs): + if self.testing: + output = self.trainer.call_hook('test_step_end', *args, **kwargs) + else: + output = self.trainer.call_hook('validation_step_end', *args, **kwargs) + return output + + def evaluation_epoch_end(self, num_dataloaders): + using_eval_result = self.is_using_eval_results() + + # call the model epoch end + eval_results = self.__run_eval_epoch_end(num_dataloaders, using_eval_result) + return eval_results + + def log_epoch_metrics(self, eval_results, test_mode): + using_eval_result = self.is_using_eval_results() + eval_loop_results = self.trainer.logger_connector.on_evaluation_epoch_end( + eval_results, + using_eval_result, + test_mode + ) + return eval_loop_results + + def __run_eval_epoch_end(self, num_dataloaders, using_eval_result): + model = self.trainer.get_model() + + # with a single dataloader don't pass an array + outputs = self.outputs + eval_results = outputs + if num_dataloaders == 1: + eval_results = outputs[0] + + user_reduced = False + + if self.testing: + if is_overridden('test_epoch_end', model=model): + if using_eval_result: + eval_results = self.__gather_epoch_end_eval_results(outputs) + + eval_results = model.test_epoch_end(eval_results) + user_reduced = True + + else: + if is_overridden('validation_epoch_end', model=model): + if using_eval_result: + eval_results = self.__gather_epoch_end_eval_results(outputs) + + eval_results = model.validation_epoch_end(eval_results) + user_reduced = True + + if using_eval_result and not user_reduced: + eval_results = self.__auto_reduce_result_objs(outputs) + + if not isinstance(eval_results, list): + eval_results = [eval_results] + + return eval_results + + def __gather_epoch_end_eval_results(self, outputs): + eval_results = [] + for epoch_output in outputs: + result = epoch_output[0].__class__.gather(epoch_output) + if 'checkpoint_on' in result: + result.checkpoint_on = result.checkpoint_on.mean() + if 'early_stop_on' in result: + result.early_stop_on = result.early_stop_on.mean() + + eval_results.append(result) + + # with 1 dataloader don't pass in a list + if len(eval_results) == 1: + eval_results = eval_results[0] + return eval_results + + def __auto_reduce_result_objs(self, outputs): + # outputs has a list of results per dataloader + eval_results = [] + for dl_output in outputs: + result = dl_output[0] + result = result.__class__.reduce_on_epoch_end(dl_output) + if 'checkpoint_on' in result: + result.checkpoint_on = result.checkpoint_on.mean() + if 'early_stop_on' in result: + result.early_stop_on = result.early_stop_on.mean() + eval_results.append(result) + + return eval_results + + def on_evaluation_batch_start(self, *args, **kwargs): + if self.testing: + self.trainer.call_hook('on_test_batch_start', *args, **kwargs) + else: + self.trainer.call_hook('on_validation_batch_start', *args, **kwargs) + + def on_evaluation_batch_end(self, *args, **kwargs): + if self.testing: + self.trainer.call_hook('on_test_batch_end', *args, **kwargs) + else: + self.trainer.call_hook('on_validation_batch_end', *args, **kwargs) + + def evaluation_batch_end_cleanup(self, output, batch_idx, dataloader_idx): + # Add step predictions to prediction collection to write later + if output is not None: + do_write_predictions = isinstance(output, Result) and self.testing + if do_write_predictions: + self.predictions.add(output.pop('predictions', None)) + + # track debug metrics + self.trainer.dev_debugger.track_eval_loss_history(self.testing, batch_idx, dataloader_idx, output) + + def on_evaluation_epoch_end(self, *args, **kwargs): + # call the callback hook + if self.testing: + self.trainer.call_hook('on_test_epoch_end', *args, **kwargs) + else: + self.trainer.call_hook('on_validation_epoch_end', *args, **kwargs) + + def log_step_metrics(self, output, batch_idx): + if self.trainer.running_sanity_check: + return + + if isinstance(output, EvalResult): + step_log_metrics = output.batch_log_metrics + step_pbar_metrics = output.batch_pbar_metrics + + if len(step_log_metrics) > 0: + # make the metrics appear as a different line in the same graph + metrics_by_epoch = {} + for k, v in step_log_metrics.items(): + metrics_by_epoch[f'{k}/epoch_{self.trainer.current_epoch}'] = v + + self.trainer.logger_connector.log_metrics(metrics_by_epoch, {}, step=batch_idx) + + if len(step_pbar_metrics) > 0: + self.trainer.logger_connector.add_progress_bar_metrics(step_pbar_metrics) diff --git a/pytorch_lightning/trainer/auto_mix_precision.py b/pytorch_lightning/trainer/initializer.py similarity index 80% rename from pytorch_lightning/trainer/auto_mix_precision.py rename to pytorch_lightning/trainer/initializer.py index d93acff0c4..b2a39056e1 100644 --- a/pytorch_lightning/trainer/auto_mix_precision.py +++ b/pytorch_lightning/trainer/initializer.py @@ -12,22 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -from abc import ABC - from pytorch_lightning import _logger as log from pytorch_lightning.utilities import APEX_AVAILABLE, NATIVE_AMP_AVALAIBLE, rank_zero_warn, AMPType -class TrainerAMPMixin(ABC): +class Initializer: - # this is just a summary on variables used in this abstract class, - # the proper values/initialisation should be done in child class - precision: int + def __init__(self, trainer): + self.trainer = trainer + + def init_amp(self, amp_type: str): + assert self.trainer.precision in (16, 32), 'only 32 or 16 bit precision supported' + self.trainer.amp_backend = None + self._setup_amp_backend(amp_type) def _setup_amp_backend(self, amp_type: str): - if self.precision != 16: + if self.trainer.precision != 16: # no AMP requested, so we can leave now return + amp_type = amp_type.lower() assert amp_type in ('native', 'apex'), f'Unsupported amp type {amp_type}' if amp_type == 'native': @@ -38,20 +41,16 @@ class TrainerAMPMixin(ABC): amp_type = 'apex' else: log.info('Using native 16bit precision.') - self.amp_backend = AMPType.NATIVE + self.trainer.amp_backend = AMPType.NATIVE if amp_type == 'apex': if not APEX_AVAILABLE: rank_zero_warn('You have asked for Apex AMP but you have not installed it yet.' ' Install apex first using this guide: https://github.com/NVIDIA/apex#linux') else: log.info('Using APEX 16bit precision.') - self.amp_backend = AMPType.APEX - if not self.amp_backend: + self.trainer.amp_backend = AMPType.APEX + if not self.trainer.amp_backend: raise ModuleNotFoundError( f'You have asked for AMP support {amp_type}, but there is no support on your side yet.' f' Consider installing torch >= 1.6 or NVIDIA Apex.' ) - - @property - def use_amp(self) -> bool: - return self.precision == 16 diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 3f2ade1d30..6e5b3da085 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -29,7 +29,6 @@ from pytorch_lightning.core.memory import ModelSummary from pytorch_lightning.core.step_result import EvalResult from pytorch_lightning.loggers import LightningLoggerBase from pytorch_lightning.profiler import BaseProfiler, PassThroughProfiler, SimpleProfiler -from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin from pytorch_lightning.trainer.configuration_validator import ConfigValidator @@ -37,7 +36,6 @@ from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin from pytorch_lightning.trainer.deprecated_api import TrainerDeprecatedAPITillVer0_10 from pytorch_lightning.trainer.distrib_data_parallel import TrainerDDPMixin from pytorch_lightning.utilities import device_parser -from pytorch_lightning.trainer.evaluation_loop import TrainerEvaluationLoopMixin from pytorch_lightning.trainer.logging import TrainerLoggingMixin from pytorch_lightning.trainer.lr_finder import TrainerLRFinderMixin from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin @@ -50,15 +48,16 @@ from pytorch_lightning.utilities import parsing, rank_zero_info, rank_zero_only, from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.cloud_io import get_filesystem -from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop +from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop +from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.trainer.data_connector import DataConnector from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector from pytorch_lightning.trainer.logger_connector import LoggerConnector from pytorch_lightning.trainer.lr_scheduler_connector import LRSchedulerConnector -from pytorch_lightning.trainer.training_loop import TrainLoop from pytorch_lightning.trainer.model_connector import ModelConnector from pytorch_lightning import _logger as log from pytorch_lightning.tuner.tuning import Tuner +from pytorch_lightning.trainer.initializer import Initializer from pytorch_lightning.utilities.model_utils import is_overridden # warnings to ignore in trainer @@ -93,12 +92,10 @@ class Trainer( TrainerCallbackHookMixin, TrainerModelHooksMixin, TrainerOptimizersMixin, - TrainerAMPMixin, TrainerDDPMixin, TrainerLoggingMixin, TrainerTrainingTricksMixin, TrainerDataLoadingMixin, - TrainerEvaluationLoopMixin, TrainerCallbackConfigMixin, TrainerLRFinderMixin, TrainerDeprecatedAPITillVer0_10, @@ -380,6 +377,7 @@ class Trainer( self.accelerator_connector = AcceleratorConnector(self) self.logger_connector = LoggerConnector(self) self.model_connector = ModelConnector(self) + self.initializer = Initializer(self) self.tuner = Tuner(self) self.accelerator_backend = None @@ -615,13 +613,17 @@ class Trainer( self.scaler = None self.amp_level = amp_level - self.init_amp(amp_backend) + self.initializer.init_amp(amp_backend) self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE') # Callback system self.on_init_end() + @property + def use_amp(self) -> bool: + return self.precision == 16 + @property def callback_metrics(self): return self.logger_connector.callback_metrics @@ -1209,6 +1211,83 @@ class Trainer( # hook self.train_loop.on_train_end() + def run_evaluation(self, test_mode: bool = False, max_batches=None): + # bookkeeping + self.evaluation_loop.testing = test_mode + dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders(max_batches) + if self.evaluation_loop.should_skip_evaluation(dataloaders, max_batches): + return [], [] + + # enable eval mode + no grads + model = self.get_model() + model.zero_grad() + model.eval() + torch.set_grad_enabled(False) + + # hook + self.evaluation_loop.on_evaluation_start() + + # set up the eval loop + self.evaluation_loop.setup(model, max_batches, dataloaders) + + # hook + # TODO: should this be insider the dataloader loop? + self.evaluation_loop.on_evaluation_epoch_start() + + # run validation/testing + for dataloader_idx, dataloader in enumerate(dataloaders): + # bookkeeping + dl_outputs = [] + dataloader = self.accelerator_backend.process_dataloader(dataloader) + dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx] + + for batch_idx, batch in enumerate(dataloader): + if batch is None: + continue + + # stop short when running on limited batches + if batch_idx >= dl_max_batches: + break + + # hook + self.evaluation_loop.on_evaluation_batch_start(batch, batch_idx, dataloader_idx) + + # lightning module methods + output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx) + output = self.evaluation_loop.evaluation_step_end(output) + + # hook + self.evaluation_loop.on_evaluation_batch_end(batch, batch_idx, dataloader_idx) + + # clean up + self.evaluation_loop.evaluation_batch_end_cleanup(output, batch_idx, dataloader_idx) + self.evaluation_loop.log_step_metrics(output, batch_idx) + + # track epoch level metrics + if output is not None: + dl_outputs.append(output) + + self.evaluation_loop.outputs.append(dl_outputs) + + # lightning module method + eval_results = self.evaluation_loop.evaluation_epoch_end(num_dataloaders=len(dataloaders)) + + # bookkeeping + eval_loop_results = self.evaluation_loop.log_epoch_metrics(eval_results, test_mode) + self.evaluation_loop.predictions.to_disk() + + # hook + self.evaluation_loop.on_evaluation_epoch_end() + + # enable train mode again + model.train() + torch.set_grad_enabled(True) + + # hook + self.evaluation_loop.on_evaluation_end() + + return eval_loop_results, eval_results + def run_test(self): # only load test dataloader for testing # self.reset_test_dataloader(ref_model) @@ -1420,15 +1499,6 @@ class Trainer( return results - def barrier(self, name): - if self.use_ddp or self.use_ddp2: - pass - # torch_distrib.barrier() - - if self.on_tpu and XLA_AVAILABLE: - # wait for all processes to catch up - torch_xla.core.xla_model.rendezvous(f'pl.Trainer.{name}') - def call_setup_hook(self, model): # call setup after the ddp process has connected stage_name = 'test' if self.testing else 'fit' @@ -1439,11 +1509,6 @@ class Trainer( self.setup(stage_name) model.setup(stage_name) - def init_amp(self, amp_type: str): - assert self.precision in (16, 32), 'only 32 or 16 bit precision supported' - self.amp_backend = None - self._setup_amp_backend(amp_type) - def call_hook(self, hook_name, *args, **kwargs): # always profile hooks with self.profiler.profile(hook_name):