lightning/pytorch_lightning/trainer/trainer.py

653 lines
24 KiB
Python
Raw Normal View History

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
2019-07-09 00:11:20 +00:00
import os
import warnings
from typing import Dict, Iterable, List, Optional, Union
2019-07-09 00:11:20 +00:00
2019-03-31 01:45:16 +00:00
import torch
from torch.utils.data import DataLoader
2019-07-09 00:11:20 +00:00
from pytorch_lightning.callbacks import Callback, EarlyStopping, ModelCheckpoint
from pytorch_lightning.core.datamodule import LightningDataModule
Custom argparser extension with Trainer arguments (argument types added) (#1147) * `add_argparse_args` method fixed (argument types added) * CHANGELOG.md upd * autopep8 fixes * --gpus=0 removed from test (for ci tests) * typo fixed * reduce on plateau scheduler fixed * Trainer cli related tests moved to test_trainer_cli.py * refactored: get_init_arguments_and_types is a public classmethod of the Trainer now * test_get_init_arguments_and_types added * autopep8 fixes * Trainer cli related tests moved to test_trainer_cli.py * refactored: get_init_arguments_and_types is a public classmethod of the Trainer now * test_get_init_arguments_and_types added * autopep8 fixes * Trainer cli related tests moved to test_trainer_cli.py * refactored: get_init_arguments_and_types is a public classmethod of the Trainer now * test_get_init_arguments_and_types added * autopep8 fixes * Trainer cli related tests moved to test_trainer_cli.py * test_get_init_arguments_and_types added * autopep8 fixes * Apply suggestions from code review * cosmetics * cosmetics * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * `Trainer.get_init_arguments_and_types` now returns arg types wrapped in tuples (not in sets) * deprecated args are now ignored in argparser * get_deprecated_arg_names small refactor * get_deprecated_arg_names bug fixed * Trainer cli related tests moved to test_trainer_cli.py * refactored: get_init_arguments_and_types is a public classmethod of the Trainer now * test_get_init_arguments_and_types added * autopep8 fixes * Trainer cli related tests moved to test_trainer_cli.py * autopep8 fixes * Trainer cli related tests moved to test_trainer_cli.py * Trainer cli related tests moved to test_trainer_cli.py * test_get_init_arguments_and_types added * autopep8 fixes * autopep8 fixes * Apply suggestions from code review * cosmetics * cosmetics * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * `Trainer.get_init_arguments_and_types` now returns arg types wrapped in tuples (not in sets) * deprecated args are now ignored in argparser * get_deprecated_arg_names small refactor * get_deprecated_arg_names bug fixed * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Joe Davison <joe@huggingface.co> * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Joe Davison <joe@huggingface.co> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Joe Davison <joe@huggingface.co> Co-authored-by: William Falcon <waf2107@columbia.edu>
2020-03-24 18:55:27 +00:00
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.step_result import EvalResult
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.profiler import BaseProfiler
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
from pytorch_lightning.trainer.configuration_validator import ConfigValidator
Custom argparser extension with Trainer arguments (argument types added) (#1147) * `add_argparse_args` method fixed (argument types added) * CHANGELOG.md upd * autopep8 fixes * --gpus=0 removed from test (for ci tests) * typo fixed * reduce on plateau scheduler fixed * Trainer cli related tests moved to test_trainer_cli.py * refactored: get_init_arguments_and_types is a public classmethod of the Trainer now * test_get_init_arguments_and_types added * autopep8 fixes * Trainer cli related tests moved to test_trainer_cli.py * refactored: get_init_arguments_and_types is a public classmethod of the Trainer now * test_get_init_arguments_and_types added * autopep8 fixes * Trainer cli related tests moved to test_trainer_cli.py * refactored: get_init_arguments_and_types is a public classmethod of the Trainer now * test_get_init_arguments_and_types added * autopep8 fixes * Trainer cli related tests moved to test_trainer_cli.py * test_get_init_arguments_and_types added * autopep8 fixes * Apply suggestions from code review * cosmetics * cosmetics * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * `Trainer.get_init_arguments_and_types` now returns arg types wrapped in tuples (not in sets) * deprecated args are now ignored in argparser * get_deprecated_arg_names small refactor * get_deprecated_arg_names bug fixed * Trainer cli related tests moved to test_trainer_cli.py * refactored: get_init_arguments_and_types is a public classmethod of the Trainer now * test_get_init_arguments_and_types added * autopep8 fixes * Trainer cli related tests moved to test_trainer_cli.py * autopep8 fixes * Trainer cli related tests moved to test_trainer_cli.py * Trainer cli related tests moved to test_trainer_cli.py * test_get_init_arguments_and_types added * autopep8 fixes * autopep8 fixes * Apply suggestions from code review * cosmetics * cosmetics * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * `Trainer.get_init_arguments_and_types` now returns arg types wrapped in tuples (not in sets) * deprecated args are now ignored in argparser * get_deprecated_arg_names small refactor * get_deprecated_arg_names bug fixed * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Joe Davison <joe@huggingface.co> * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Joe Davison <joe@huggingface.co> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Joe Davison <joe@huggingface.co> Co-authored-by: William Falcon <waf2107@columbia.edu>
2020-03-24 18:55:27 +00:00
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
from pytorch_lightning.trainer.states import TrainerState, trainer_state
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.utilities import rank_zero_warn
Structured results (train loop only. val loop separate PR) (PR 2/5) (#2615) * r * r * r * patched optimizer closure with sr * patched optimizer closure with sr * patched optimizer closure with sr * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added train step structured result * added autoreduce for train step * added auto reduce on train * added auto reduce on train * added auto reduce on train * added auto reduce on train * added auto reduce on train * added auto reduce on train * added hooks * added hooks * added hooks * added hooks * added hooks * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * cache * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * Update pytorch_lightning/callbacks/early_stopping.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/callbacks/early_stopping.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/callbacks/early_stopping.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/callbacks/model_checkpoint.py * Update pytorch_lightning/core/step_result.py * finished tests for structured results on train epoch * finished tests for structured results on train epoch * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> * simple * finished tests for structured results on train epoch * simple * simple * revert * finished tests for structured results on train epoch * finished tests for structured results on train epoch * Update tests/base/deterministic_model.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * finished tests for structured results on train epoch * docstring typos * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * finished tests for structured results on train epoch * Update pytorch_lightning/core/step_result.py Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Update pytorch_lightning/overrides/data_parallel.py Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
2020-07-20 23:00:20 +00:00
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.trainer.evaluation_loop import EvaluationLoop
from pytorch_lightning.trainer.training_loop import TrainLoop
from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector
from pytorch_lightning.trainer.connectors.training_trick_connector import TrainingTricksConnector
from pytorch_lightning.trainer.connectors.callback_connector import CallbackConnector
from pytorch_lightning.trainer.connectors.model_connector import ModelConnector
from pytorch_lightning.trainer.connectors.debugging_connector import DebuggingConnector
2020-09-12 11:05:21 +00:00
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector
from pytorch_lightning.trainer.connectors.slurm_connector import SLURMConnector
from pytorch_lightning import _logger as log
from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.trainer.connectors.precision_connector import PrecisionConnector
from pytorch_lightning.trainer.connectors.profiler_connector import ProfilerConnector
from pytorch_lightning.trainer.connectors.data_connector import DataConnector
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.trainer import docstrings
from pytorch_lightning.trainer.properties import TrainerProperties
# warnings to ignore in trainer
warnings.filterwarnings(
'ignore', message='torch.distributed.reduce_op is deprecated, ' 'please use torch.distributed.ReduceOp instead'
)
os.environ['PYTHONWARNINGS'] = 'ignore:semaphore_tracker:UserWarning'
2019-05-14 00:40:07 +00:00
try:
from apex import amp
2019-08-05 21:28:04 +00:00
except ImportError:
amp = None
2019-03-31 01:45:16 +00:00
2019-07-09 00:12:27 +00:00
class Trainer(
TrainerProperties,
TrainerCallbackHookMixin,
TrainerModelHooksMixin,
TrainerOptimizersMixin,
TrainerLoggingMixin,
TrainerTrainingTricksMixin,
TrainerDataLoadingMixin,
):
def __init__(
self,
logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True,
checkpoint_callback: Union[ModelCheckpoint, bool] = True,
early_stop_callback: Optional[Union[EarlyStopping, bool]] = False,
callbacks: Optional[List[Callback]] = None,
default_root_dir: Optional[str] = None,
gradient_clip_val: float = 0,
process_position: int = 0,
num_nodes: int = 1,
num_processes: int = 1,
gpus: Optional[Union[List[int], str, int]] = None,
auto_select_gpus: bool = False,
tpu_cores: Optional[Union[List[int], str, int]] = None,
log_gpu_memory: Optional[str] = None,
progress_bar_refresh_rate: int = 1,
[WIP] Rename overfit_pct to overfit_batches (and fix) and val_percent_check and test_percent_check (and fix) (#2213) * fixed percent check for val/test * fixed percent check for val/test * fixed percent check for val/test * fixed percent check for val/test * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * add on fit_start on fit_end hooks * add on fit_start on fit_end hooks * add on fit_start on fit_end hooks Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-06-17 12:03:28 +00:00
overfit_batches: Union[int, float] = 0.0,
track_grad_norm: Union[int, float, str] = -1,
check_val_every_n_epoch: int = 1,
fast_dev_run: bool = False,
accumulate_grad_batches: Union[int, Dict[int, int], List[list]] = 1,
max_epochs: int = 1000,
min_epochs: int = 1,
max_steps: Optional[int] = None,
min_steps: Optional[int] = None,
limit_train_batches: Union[int, float] = 1.0,
[WIP] Rename overfit_pct to overfit_batches (and fix) and val_percent_check and test_percent_check (and fix) (#2213) * fixed percent check for val/test * fixed percent check for val/test * fixed percent check for val/test * fixed percent check for val/test * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * overfit_pct now uses train loaders for val and test and does not shuffle * add on fit_start on fit_end hooks * add on fit_start on fit_end hooks * add on fit_start on fit_end hooks Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-06-17 12:03:28 +00:00
limit_val_batches: Union[int, float] = 1.0,
limit_test_batches: Union[int, float] = 1.0,
val_check_interval: Union[int, float] = 1.0,
log_save_interval: int = 100,
row_log_interval: int = 50,
distributed_backend: Optional[str] = None,
sync_batchnorm: bool = False,
precision: int = 32,
weights_summary: Optional[str] = ModelSummary.MODE_DEFAULT,
weights_save_path: Optional[str] = None,
num_sanity_val_steps: int = 2,
truncated_bptt_steps: Optional[int] = None,
resume_from_checkpoint: Optional[str] = None,
profiler: Optional[Union[BaseProfiler, bool]] = None,
benchmark: bool = False,
deterministic: bool = False,
reload_dataloaders_every_epoch: bool = False,
auto_lr_find: Union[bool, str] = False,
replace_sampler_ddp: bool = True,
terminate_on_nan: bool = False,
auto_scale_batch_size: Union[str, bool] = False,
prepare_data_per_node: bool = True,
amp_backend: str = 'native',
2020-06-25 22:54:32 +00:00
amp_level: str = 'O2', # backward compatible, todo: remove in v1.0.0
overfit_pct: float = None, # backward compatible, todo: remove in v1.0.0
):
super().__init__()
2019-07-18 16:04:19 +00:00
# init connectors
self.dev_debugger = InternalDebugger(self)
self.config_validator = ConfigValidator(self)
self.data_connector = DataConnector(self)
self.optimizer_connector = OptimizerConnector(self)
self.accelerator_connector = AcceleratorConnector(self)
self.logger_connector = LoggerConnector(self)
self.model_connector = ModelConnector(self)
self.precision_connector = PrecisionConnector(self)
self.callback_connector = CallbackConnector(self)
self.debugging_connector = DebuggingConnector(self)
self.training_tricks_connector = TrainingTricksConnector(self)
self.profile_connector = ProfilerConnector(self)
2020-09-12 11:05:21 +00:00
self.checkpoint_connector = CheckpointConnector(self)
self.slurm_connector = SLURMConnector(self)
self.tuner = Tuner(self)
self.accelerator_backend = None
self.evaluation_loop = EvaluationLoop(self)
self.train_loop = TrainLoop(self)
Continue Jeremy's early stopping PR #1504 (#2391) * add state_dict for early stopping * move best attr after monitor_op defined * improve early stopping and model checkpoint callbacks * fix formatting * fix attr init order * clean up setting of default_root_dir attr * logger needs default root dir set first * reorg trainer init * remove direct references to checkpoint callback * more fixes * more bugfixes * run callbacks at epoch end * update tests to use on epoch end * PR cleanup * address failing tests * refactor for homogeneity * fix merge conflict * separate tests * tests for early stopping bug regressions * small fixes * revert model checkpoint change * typo fix * fix tests * update train loop * cannot pass an int as default_save_path * refactor log message * fix test case * appease the linter * fix some doctests * move config to callback * fixes from rebase * fixes from rebase * chlog * docs * reformat * formatting * fix * fix * fixes from rebase * add new test for patience * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/callbacks/test_early_stopping.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * fix formatting * remove enable_early_stop attribute * add state_dict for early stopping * move best attr after monitor_op defined * improve early stopping and model checkpoint callbacks * fix formatting * fix attr init order * clean up setting of default_root_dir attr * logger needs default root dir set first * reorg trainer init * remove direct references to checkpoint callback * more fixes * more bugfixes * run callbacks at epoch end * update tests to use on epoch end * PR cleanup * address failing tests * refactor for homogeneity * fix merge conflict * separate tests * tests for early stopping bug regressions * small fixes * revert model checkpoint change * typo fix * fix tests * update train loop * fix test case * appease the linter * fix some doctests * move config to callback * fixes from rebase * fixes from rebase * chlog * docs * reformat * formatting * fix * fix * fixes from rebase * add new test for patience * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/callbacks/test_early_stopping.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * fix formatting * remove enable_early_stop attribute * fix test with new epoch indexing * fix progress bar totals * fix off by one error (see #2289) epoch starts at 0 now * added missing imports * fix hpc_save folderpath * fix formatting * fix tests * small fixes from a rebase * fix * tmpdir * tmpdir * tmpdir * wandb * fix merge conflict * add back evaluation after training * test_resume_early_stopping_from_checkpoint TODO * undo the horovod check * update changelog * remove a duplicate test from merge error * try fix dp_resume test * add the logger fix from master * try remove default_root_dir * try mocking numpy * try import numpy in docs test * fix wandb test * pep 8 fix * skip if no amp * dont mock when doctesting * install extra * fix the resume ES test * undo conf.py changes * revert remove comet pickle from test * Update CHANGELOG.md Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update weights_loading.rst * Update weights_loading.rst * Update weights_loading.rst * renamed flag * renamed flag * revert the None check in logger experiment name/version * add the old comments * _experiment * test chckpointing on DDP * skip the ddp test on windows * cloudpickle * renamed flag * renamed flag * parentheses for clarity * apply suggestion max epochs Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Jeremy Jordan <jtjordan@ncsu.edu> Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu>
2020-06-29 01:36:46 +00:00
# training state
self.weights_summary = weights_summary
Continue Jeremy's early stopping PR #1504 (#2391) * add state_dict for early stopping * move best attr after monitor_op defined * improve early stopping and model checkpoint callbacks * fix formatting * fix attr init order * clean up setting of default_root_dir attr * logger needs default root dir set first * reorg trainer init * remove direct references to checkpoint callback * more fixes * more bugfixes * run callbacks at epoch end * update tests to use on epoch end * PR cleanup * address failing tests * refactor for homogeneity * fix merge conflict * separate tests * tests for early stopping bug regressions * small fixes * revert model checkpoint change * typo fix * fix tests * update train loop * cannot pass an int as default_save_path * refactor log message * fix test case * appease the linter * fix some doctests * move config to callback * fixes from rebase * fixes from rebase * chlog * docs * reformat * formatting * fix * fix * fixes from rebase * add new test for patience * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/callbacks/test_early_stopping.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * fix formatting * remove enable_early_stop attribute * add state_dict for early stopping * move best attr after monitor_op defined * improve early stopping and model checkpoint callbacks * fix formatting * fix attr init order * clean up setting of default_root_dir attr * logger needs default root dir set first * reorg trainer init * remove direct references to checkpoint callback * more fixes * more bugfixes * run callbacks at epoch end * update tests to use on epoch end * PR cleanup * address failing tests * refactor for homogeneity * fix merge conflict * separate tests * tests for early stopping bug regressions * small fixes * revert model checkpoint change * typo fix * fix tests * update train loop * fix test case * appease the linter * fix some doctests * move config to callback * fixes from rebase * fixes from rebase * chlog * docs * reformat * formatting * fix * fix * fixes from rebase * add new test for patience * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/callbacks/test_early_stopping.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * fix formatting * remove enable_early_stop attribute * fix test with new epoch indexing * fix progress bar totals * fix off by one error (see #2289) epoch starts at 0 now * added missing imports * fix hpc_save folderpath * fix formatting * fix tests * small fixes from a rebase * fix * tmpdir * tmpdir * tmpdir * wandb * fix merge conflict * add back evaluation after training * test_resume_early_stopping_from_checkpoint TODO * undo the horovod check * update changelog * remove a duplicate test from merge error * try fix dp_resume test * add the logger fix from master * try remove default_root_dir * try mocking numpy * try import numpy in docs test * fix wandb test * pep 8 fix * skip if no amp * dont mock when doctesting * install extra * fix the resume ES test * undo conf.py changes * revert remove comet pickle from test * Update CHANGELOG.md Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update weights_loading.rst * Update weights_loading.rst * Update weights_loading.rst * renamed flag * renamed flag * revert the None check in logger experiment name/version * add the old comments * _experiment * test chckpointing on DDP * skip the ddp test on windows * cloudpickle * renamed flag * renamed flag * parentheses for clarity * apply suggestion max epochs Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Jeremy Jordan <jtjordan@ncsu.edu> Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu>
2020-06-29 01:36:46 +00:00
self.model = None
self.shown_warnings = set()
Continue Jeremy's early stopping PR #1504 (#2391) * add state_dict for early stopping * move best attr after monitor_op defined * improve early stopping and model checkpoint callbacks * fix formatting * fix attr init order * clean up setting of default_root_dir attr * logger needs default root dir set first * reorg trainer init * remove direct references to checkpoint callback * more fixes * more bugfixes * run callbacks at epoch end * update tests to use on epoch end * PR cleanup * address failing tests * refactor for homogeneity * fix merge conflict * separate tests * tests for early stopping bug regressions * small fixes * revert model checkpoint change * typo fix * fix tests * update train loop * cannot pass an int as default_save_path * refactor log message * fix test case * appease the linter * fix some doctests * move config to callback * fixes from rebase * fixes from rebase * chlog * docs * reformat * formatting * fix * fix * fixes from rebase * add new test for patience * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/callbacks/test_early_stopping.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * fix formatting * remove enable_early_stop attribute * add state_dict for early stopping * move best attr after monitor_op defined * improve early stopping and model checkpoint callbacks * fix formatting * fix attr init order * clean up setting of default_root_dir attr * logger needs default root dir set first * reorg trainer init * remove direct references to checkpoint callback * more fixes * more bugfixes * run callbacks at epoch end * update tests to use on epoch end * PR cleanup * address failing tests * refactor for homogeneity * fix merge conflict * separate tests * tests for early stopping bug regressions * small fixes * revert model checkpoint change * typo fix * fix tests * update train loop * fix test case * appease the linter * fix some doctests * move config to callback * fixes from rebase * fixes from rebase * chlog * docs * reformat * formatting * fix * fix * fixes from rebase * add new test for patience * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/callbacks/test_early_stopping.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * fix formatting * remove enable_early_stop attribute * fix test with new epoch indexing * fix progress bar totals * fix off by one error (see #2289) epoch starts at 0 now * added missing imports * fix hpc_save folderpath * fix formatting * fix tests * small fixes from a rebase * fix * tmpdir * tmpdir * tmpdir * wandb * fix merge conflict * add back evaluation after training * test_resume_early_stopping_from_checkpoint TODO * undo the horovod check * update changelog * remove a duplicate test from merge error * try fix dp_resume test * add the logger fix from master * try remove default_root_dir * try mocking numpy * try import numpy in docs test * fix wandb test * pep 8 fix * skip if no amp * dont mock when doctesting * install extra * fix the resume ES test * undo conf.py changes * revert remove comet pickle from test * Update CHANGELOG.md Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update weights_loading.rst * Update weights_loading.rst * Update weights_loading.rst * renamed flag * renamed flag * revert the None check in logger experiment name/version * add the old comments * _experiment * test chckpointing on DDP * skip the ddp test on windows * cloudpickle * renamed flag * renamed flag * parentheses for clarity * apply suggestion max epochs Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Jeremy Jordan <jtjordan@ncsu.edu> Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu>
2020-06-29 01:36:46 +00:00
# init callbacks
self.callback_connector.on_trainer_init(
callbacks,
early_stop_callback,
checkpoint_callback,
progress_bar_refresh_rate,
process_position,
default_root_dir,
weights_save_path,
resume_from_checkpoint
)
Continue Jeremy's early stopping PR #1504 (#2391) * add state_dict for early stopping * move best attr after monitor_op defined * improve early stopping and model checkpoint callbacks * fix formatting * fix attr init order * clean up setting of default_root_dir attr * logger needs default root dir set first * reorg trainer init * remove direct references to checkpoint callback * more fixes * more bugfixes * run callbacks at epoch end * update tests to use on epoch end * PR cleanup * address failing tests * refactor for homogeneity * fix merge conflict * separate tests * tests for early stopping bug regressions * small fixes * revert model checkpoint change * typo fix * fix tests * update train loop * cannot pass an int as default_save_path * refactor log message * fix test case * appease the linter * fix some doctests * move config to callback * fixes from rebase * fixes from rebase * chlog * docs * reformat * formatting * fix * fix * fixes from rebase * add new test for patience * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/callbacks/test_early_stopping.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * fix formatting * remove enable_early_stop attribute * add state_dict for early stopping * move best attr after monitor_op defined * improve early stopping and model checkpoint callbacks * fix formatting * fix attr init order * clean up setting of default_root_dir attr * logger needs default root dir set first * reorg trainer init * remove direct references to checkpoint callback * more fixes * more bugfixes * run callbacks at epoch end * update tests to use on epoch end * PR cleanup * address failing tests * refactor for homogeneity * fix merge conflict * separate tests * tests for early stopping bug regressions * small fixes * revert model checkpoint change * typo fix * fix tests * update train loop * fix test case * appease the linter * fix some doctests * move config to callback * fixes from rebase * fixes from rebase * chlog * docs * reformat * formatting * fix * fix * fixes from rebase * add new test for patience * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update tests/callbacks/test_early_stopping.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * fix formatting * remove enable_early_stop attribute * fix test with new epoch indexing * fix progress bar totals * fix off by one error (see #2289) epoch starts at 0 now * added missing imports * fix hpc_save folderpath * fix formatting * fix tests * small fixes from a rebase * fix * tmpdir * tmpdir * tmpdir * wandb * fix merge conflict * add back evaluation after training * test_resume_early_stopping_from_checkpoint TODO * undo the horovod check * update changelog * remove a duplicate test from merge error * try fix dp_resume test * add the logger fix from master * try remove default_root_dir * try mocking numpy * try import numpy in docs test * fix wandb test * pep 8 fix * skip if no amp * dont mock when doctesting * install extra * fix the resume ES test * undo conf.py changes * revert remove comet pickle from test * Update CHANGELOG.md Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Update weights_loading.rst * Update weights_loading.rst * Update weights_loading.rst * renamed flag * renamed flag * revert the None check in logger experiment name/version * add the old comments * _experiment * test chckpointing on DDP * skip the ddp test on windows * cloudpickle * renamed flag * renamed flag * parentheses for clarity * apply suggestion max epochs Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Jeremy Jordan <jtjordan@ncsu.edu> Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: William Falcon <waf2107@columbia.edu>
2020-06-29 01:36:46 +00:00
# hook
self.on_init_start()
# init optimizer + lr scheduler related flags
self.optimizer_connector.on_trainer_init()
# init data flags
self.data_connector.on_trainer_init(
check_val_every_n_epoch,
reload_dataloaders_every_epoch,
prepare_data_per_node
)
# init training tricks
self.training_tricks_connector.on_trainer_init(
gradient_clip_val,
track_grad_norm,
accumulate_grad_batches,
truncated_bptt_steps,
terminate_on_nan
)
# init accelerator related flags
self.accelerator_connector.on_trainer_init(
num_processes,
tpu_cores,
distributed_backend,
auto_select_gpus,
gpus,
num_nodes,
log_gpu_memory,
sync_batchnorm,
benchmark,
replace_sampler_ddp,
deterministic
)
# init train loop related flags
self.train_loop.on_trainer_init(max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps)
self.evaluation_loop.on_trainer_init()
# configure tuner
self.tuner.on_trainer_init(auto_lr_find, auto_scale_batch_size)
2019-07-08 13:42:13 +00:00
# configure profiler
self.profile_connector.on_trainer_init(profiler)
# init logger flags
self.logger_connector.on_trainer_init(logger, log_save_interval, row_log_interval)
2019-09-06 04:29:38 +00:00
# init debugging flags
self.debugging_connector.on_init_start(
overfit_pct,
limit_train_batches,
limit_val_batches,
limit_test_batches,
val_check_interval,
overfit_batches,
fast_dev_run
)
2019-09-06 04:29:38 +00:00
# set precision
self.precision_connector.on_trainer_init(precision, amp_level, amp_backend)
# Callback system
self.on_init_end()
def tune(
self,
model: LightningModule,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: Optional[LightningDataModule] = None,
):
# TODO: temporary, need to decide if tune or separate object
# setup data, etc...
2020-09-11 01:58:47 +00:00
self.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule)
# hook
self.data_connector.prepare_data(model)
# Run auto batch size scaling
if self.auto_scale_batch_size:
if isinstance(self.auto_scale_batch_size, bool):
self.auto_scale_batch_size = 'power'
self.tuner.scale_batch_size(
model,
mode=self.auto_scale_batch_size,
train_dataloader=train_dataloader,
val_dataloaders=val_dataloaders,
datamodule=datamodule,
)
model.logger = self.logger # reset logger binding
# Run learning rate finder:
if self.auto_lr_find:
self.tuner.internal_find_lr(self, model)
model.logger = self.logger # reset logger binding
2019-03-31 01:45:16 +00:00
# -----------------------------
# MODEL TRAINING
# -----------------------------
def fit(
self,
model: LightningModule,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
datamodule: Optional[LightningDataModule] = None,
):
ref: result 1/n (make monitor default to checkpoint_on to simplify re… (#3571) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * force crash when max_epochs < epochs in a checkpoint Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
2020-09-21 02:58:43 +00:00
self._state = TrainerState.RUNNING
# setup data, etc...
2020-09-11 01:58:47 +00:00
self.train_loop.setup_fit(model, train_dataloader, val_dataloaders, datamodule)
# hook
self.call_hook('on_fit_start', model)
# hook
self.data_connector.prepare_data(model)
2020-08-02 12:13:31 +00:00
# set testing if set in environ
self.testing = os.environ.get('PL_TESTING_MODE', self.testing)
# -------------------------
# TRAIN
# -------------------------
self.accelerator_backend = self.accelerator_connector.select_accelerator()
2020-08-27 01:29:10 +00:00
self.accelerator_backend.setup(model)
results = self.accelerator_backend.train()
self.accelerator_backend.teardown()
# -------------------------
# POST-Training
# -------------------------
2020-08-27 01:29:10 +00:00
# hook
self.call_hook('on_fit_end')
# hook
self.teardown('fit')
if self.is_function_implemented('teardown'):
model.teardown('fit')
# return 1 when finished
# used for testing or when we need to know that training succeeded
ref: result 1/n (make monitor default to checkpoint_on to simplify re… (#3571) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * force crash when max_epochs < epochs in a checkpoint Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
2020-09-21 02:58:43 +00:00
if self._state != TrainerState.INTERRUPTED:
self._state = TrainerState.FINISHED
2020-08-27 01:29:10 +00:00
return results or 1
def train(self):
self.run_sanity_check(self.get_model())
# enable train mode
model = self.get_model()
model.train()
torch.set_grad_enabled(True)
# reload data when needed
self.train_loop.reset_train_val_dataloaders(model)
# hook
self.train_loop.on_train_start()
try:
# run all epochs
for epoch in range(self.current_epoch, self.max_epochs):
# reset train dataloader
if self.reload_dataloaders_every_epoch:
self.reset_train_dataloader(model)
# hook
self.train_loop.on_train_epoch_start(epoch)
# run train epoch
self.train_loop.run_training_epoch()
if self.max_steps and self.max_steps <= self.global_step:
# hook
self.train_loop.on_train_end()
return
# update LR schedulers
self.optimizer_connector.update_learning_rates(interval='epoch')
# early stopping
met_min_epochs = epoch >= self.min_epochs - 1
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
if self.should_stop:
if (met_min_epochs and met_min_steps):
self.train_loop.on_train_end()
return
else:
log.info('Trainer was signaled to stop but required minimum epochs'
f' ({self.min_epochs}) or minimum steps ({self.min_steps}) has'
' not been met. Training will continue...')
# hook
self.train_loop.on_train_end()
except KeyboardInterrupt:
rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...')
# user could press ctrl+c many times... only shutdown once
if not self.interrupted:
self.interrupted = True
self._state = TrainerState.INTERRUPTED
self.on_keyboard_interrupt()
# hook
self.train_loop.on_train_end()
def run_evaluation(self, test_mode: bool = False, max_batches=None):
# bookkeeping
self.evaluation_loop.testing = test_mode
dataloaders, max_batches = self.evaluation_loop.get_evaluation_dataloaders(max_batches)
if self.evaluation_loop.should_skip_evaluation(dataloaders, max_batches):
return [], []
# enable eval mode + no grads
model = self.get_model()
model.zero_grad()
model.eval()
torch.set_grad_enabled(False)
# hook
self.evaluation_loop.on_evaluation_start()
# set up the eval loop
self.evaluation_loop.setup(model, max_batches, dataloaders)
# hook
# TODO: should this be insider the dataloader loop?
self.evaluation_loop.on_evaluation_epoch_start()
# run validation/testing
for dataloader_idx, dataloader in enumerate(dataloaders):
# bookkeeping
dl_outputs = []
dataloader = self.accelerator_backend.process_dataloader(dataloader)
dl_max_batches = self.evaluation_loop.max_batches[dataloader_idx]
for batch_idx, batch in enumerate(dataloader):
if batch is None:
continue
# stop short when running on limited batches
if batch_idx >= dl_max_batches:
break
# hook
self.evaluation_loop.on_evaluation_batch_start(batch, batch_idx, dataloader_idx)
# lightning module methods
output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx)
output = self.evaluation_loop.evaluation_step_end(output)
# hook
self.evaluation_loop.on_evaluation_batch_end(batch, batch_idx, dataloader_idx)
# clean up
self.evaluation_loop.evaluation_batch_end_cleanup(output, batch_idx, dataloader_idx)
self.evaluation_loop.log_step_metrics(output, batch_idx)
# track epoch level metrics
if output is not None:
dl_outputs.append(output)
self.evaluation_loop.outputs.append(dl_outputs)
# lightning module method
eval_results = self.evaluation_loop.evaluation_epoch_end(num_dataloaders=len(dataloaders))
# bookkeeping
eval_loop_results = self.evaluation_loop.log_epoch_metrics(eval_results, test_mode)
self.evaluation_loop.predictions.to_disk()
# hook
self.evaluation_loop.on_evaluation_epoch_end()
# enable train mode again
model.train()
torch.set_grad_enabled(True)
# hook
self.evaluation_loop.on_evaluation_end()
return eval_loop_results, eval_results
def run_test(self):
# only load test dataloader for testing
# self.reset_test_dataloader(ref_model)
eval_loop_results, _ = self.run_evaluation(test_mode=True)
Fix ddp tests + .test() (#2512) * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * fix deprecation warnings * added base tests for tpu * added base tests for tpu * Update pytorch_lightning/trainer/trainer.py Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com>
2020-07-07 16:24:56 +00:00
if len(eval_loop_results) == 0:
return 1
Expectopatronum implement #89 (#182) * rename validate -> evaluate; implement test logic; allow multiple test_loaders * add test_step and test_end to LightningModule * add in_test_mode to pretraining to implement case 2 (test pretrained model) * fix code style issues * LightningTestModel: add optional second test set, implement test_step and test_end * implemented test for multiple test_dataloaders; fixed typo * add two test cases for #89 * add documentation for test_step, test_end; fix computation of loss in validation_step example * Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py * Added proper dp ddp routing calls for test mode * Update trainer.py * Update test_models.py * Update trainer.py * Update trainer.py * Update override_data_parallel.py * Update test_models.py * Update test_models.py * Update trainer.py * Update trainer.py * Update trainer.py * Update test_models.py * Update test_models.py * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * Update trainer.py * Update override_data_parallel.py * Update debug.py * Update lm_test_module.py * Update test_models.py
2019-08-30 22:56:09 +00:00
# remove the tensors from the eval results
for i, result in enumerate(eval_loop_results):
if isinstance(result, dict):
for k, v in result.items():
if isinstance(v, torch.Tensor):
result[k] = v.cpu().item()
return eval_loop_results
def run_sanity_check(self, ref_model):
using_val_step = ref_model.val_dataloader is not None and is_overridden('validation_step', ref_model)
should_sanity_check = using_val_step and self.num_sanity_val_steps > 0 and self.limit_val_batches > 0
Expectopatronum implement #89 (#182) * rename validate -> evaluate; implement test logic; allow multiple test_loaders * add test_step and test_end to LightningModule * add in_test_mode to pretraining to implement case 2 (test pretrained model) * fix code style issues * LightningTestModel: add optional second test set, implement test_step and test_end * implemented test for multiple test_dataloaders; fixed typo * add two test cases for #89 * add documentation for test_step, test_end; fix computation of loss in validation_step example * Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py * Added proper dp ddp routing calls for test mode * Update trainer.py * Update test_models.py * Update trainer.py * Update trainer.py * Update override_data_parallel.py * Update test_models.py * Update test_models.py * Update trainer.py * Update trainer.py * Update trainer.py * Update test_models.py * Update test_models.py * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * debug * Update trainer.py * Update override_data_parallel.py * Update debug.py * Update lm_test_module.py * Update test_models.py
2019-08-30 22:56:09 +00:00
# run tiny validation (if validation defined)
# to make sure program won't crash during val
if should_sanity_check:
self.reset_val_dataloader(ref_model)
self.num_sanity_val_batches = [
min(self.num_sanity_val_steps, val_batches) for val_batches in self.num_val_batches
]
Progress bar callback (#1450) * squash and rebase sanity check hooks sanity check callback hook finish moved core progress bar functionality into callback wip remove duplicate merge clean up imports docs sanity check progress bar main sanity move callback calls init progrss bar callback configuration and docs changelog rate decorator pass process_position disable on rank > 0 position index is_enabled remove decorator refactor init tqdm bars callback method ordering cannot reset when disabled sequence -> list default values fix has no attr _time() move on_val_end to proper place fix the pickle issue update warning properties check for None remove old comment switch order pull out non-tqdm functionality into base class documentation for the base class docs fix refresh rate issue in validation restrict type hint of trainer arg more docs update trainer docs rst docs fix lines too long fix test add missing type hints fix typo move docstring to __init__ solves doctest failures remove doctest :(( can't fix the pickle error fix example simplify by saving trainer reference fix docs errors move docstring initial value multiple val checks per epoch simpler handling of inf dataset sizes update inf docs renamed training_tqdm_dict rename get_tqdm_dict rename occurences of tqdm update changelog fix doctest fix formatting errors added callback tests progress bar on off test more tests for progress bar weird test fix? add ignored property disable default progress bar in LR finder change enable/disable behavior trying doctest in CI again undo doctest pickle error undo doctest pickle error :(( remove progress_bar_callback Trainer arg and fix tests restore progress bar after auto lr find update docs fix rebase fix wrong negation * fix fast dev run total * more thorough testing * remove old args * fix merge * fix merge * separate tests * type hint total batches * reduce if Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * is_disabled Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * is_enabled Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * rename enabled/disabled * move deprecated api * remove duplicated test from merge * fix rename is_disabled * newline * test also testprogress for fast dev run Co-authored-by: J. Borovec <jirka.borovec@seznam.cz> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-04-24 00:46:18 +00:00
# hook and callback
self.running_sanity_check = True
Progress bar callback (#1450) * squash and rebase sanity check hooks sanity check callback hook finish moved core progress bar functionality into callback wip remove duplicate merge clean up imports docs sanity check progress bar main sanity move callback calls init progrss bar callback configuration and docs changelog rate decorator pass process_position disable on rank > 0 position index is_enabled remove decorator refactor init tqdm bars callback method ordering cannot reset when disabled sequence -> list default values fix has no attr _time() move on_val_end to proper place fix the pickle issue update warning properties check for None remove old comment switch order pull out non-tqdm functionality into base class documentation for the base class docs fix refresh rate issue in validation restrict type hint of trainer arg more docs update trainer docs rst docs fix lines too long fix test add missing type hints fix typo move docstring to __init__ solves doctest failures remove doctest :(( can't fix the pickle error fix example simplify by saving trainer reference fix docs errors move docstring initial value multiple val checks per epoch simpler handling of inf dataset sizes update inf docs renamed training_tqdm_dict rename get_tqdm_dict rename occurences of tqdm update changelog fix doctest fix formatting errors added callback tests progress bar on off test more tests for progress bar weird test fix? add ignored property disable default progress bar in LR finder change enable/disable behavior trying doctest in CI again undo doctest pickle error undo doctest pickle error :(( remove progress_bar_callback Trainer arg and fix tests restore progress bar after auto lr find update docs fix rebase fix wrong negation * fix fast dev run total * more thorough testing * remove old args * fix merge * fix merge * separate tests * type hint total batches * reduce if Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * is_disabled Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * is_enabled Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * rename enabled/disabled * move deprecated api * remove duplicated test from merge * fix rename is_disabled * newline * test also testprogress for fast dev run Co-authored-by: J. Borovec <jirka.borovec@seznam.cz> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-04-24 00:46:18 +00:00
self.on_sanity_check_start()
# run eval step
_, eval_results = self.run_evaluation(test_mode=False, max_batches=self.num_sanity_val_batches)
# allow no returns from eval
if eval_results is not None and len(eval_results) > 0:
# when we get a list back, used only the last item
if isinstance(eval_results, list):
eval_results = eval_results[-1]
if isinstance(eval_results, EvalResult):
callback_metrics = eval_results.callback_metrics
else:
ref: result 1/n (make monitor default to checkpoint_on to simplify re… (#3571) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * Update pytorch_lightning/callbacks/model_checkpoint.py Co-authored-by: ananthsub <ananth.subramaniam@gmail.com> * ref: result 1/n (make monitor default to checkpoint_on to simplify result syntax) * force crash when max_epochs < epochs in a checkpoint Co-authored-by: ananthsub <ananth.subramaniam@gmail.com>
2020-09-21 02:58:43 +00:00
_, _, _, callback_metrics, _ = self.process_dict_result(eval_results)
self.logger_connector.callback_metrics = callback_metrics
2019-08-07 11:51:55 +00:00
Progress bar callback (#1450) * squash and rebase sanity check hooks sanity check callback hook finish moved core progress bar functionality into callback wip remove duplicate merge clean up imports docs sanity check progress bar main sanity move callback calls init progrss bar callback configuration and docs changelog rate decorator pass process_position disable on rank > 0 position index is_enabled remove decorator refactor init tqdm bars callback method ordering cannot reset when disabled sequence -> list default values fix has no attr _time() move on_val_end to proper place fix the pickle issue update warning properties check for None remove old comment switch order pull out non-tqdm functionality into base class documentation for the base class docs fix refresh rate issue in validation restrict type hint of trainer arg more docs update trainer docs rst docs fix lines too long fix test add missing type hints fix typo move docstring to __init__ solves doctest failures remove doctest :(( can't fix the pickle error fix example simplify by saving trainer reference fix docs errors move docstring initial value multiple val checks per epoch simpler handling of inf dataset sizes update inf docs renamed training_tqdm_dict rename get_tqdm_dict rename occurences of tqdm update changelog fix doctest fix formatting errors added callback tests progress bar on off test more tests for progress bar weird test fix? add ignored property disable default progress bar in LR finder change enable/disable behavior trying doctest in CI again undo doctest pickle error undo doctest pickle error :(( remove progress_bar_callback Trainer arg and fix tests restore progress bar after auto lr find update docs fix rebase fix wrong negation * fix fast dev run total * more thorough testing * remove old args * fix merge * fix merge * separate tests * type hint total batches * reduce if Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * is_disabled Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * is_enabled Co-Authored-By: Jirka Borovec <Borda@users.noreply.github.com> * rename enabled/disabled * move deprecated api * remove duplicated test from merge * fix rename is_disabled * newline * test also testprogress for fast dev run Co-authored-by: J. Borovec <jirka.borovec@seznam.cz> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
2020-04-24 00:46:18 +00:00
self.on_sanity_check_end()
self.running_sanity_check = False
@trainer_state(entering=TrainerState.RUNNING, exiting=TrainerState.FINISHED)
def test(
self,
model: Optional[LightningModule] = None,
test_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
ckpt_path: Optional[str] = 'best',
verbose: bool = True,
datamodule: Optional[LightningDataModule] = None,
):
Fix ddp tests + .test() (#2512) * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * fix deprecation warnings * added base tests for tpu * added base tests for tpu * Update pytorch_lightning/trainer/trainer.py Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com>
2020-07-07 16:24:56 +00:00
# --------------------
# SETUP HOOK
# --------------------
self.verbose_test = verbose
Fixes .test() for ddp (#2570) * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint
2020-07-09 22:36:36 +00:00
if self.global_rank != 0:
return
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
if test_dataloaders and datamodule:
raise MisconfigurationException(
'You cannot pass test_dataloaders to trainer.test if you supply a datamodule'
)
# Attach datamodule to get setup/prepare_data added to model before the call to it below
self.data_connector.attach_datamodule(model or self.get_model(), datamodule, 'test')
Fixes .test() for ddp (#2570) * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint
2020-07-09 22:36:36 +00:00
if model is not None:
results = self.__test_given_model(model, test_dataloaders)
else:
results = self.__test_using_best_weights(ckpt_path, test_dataloaders)
self.teardown('test')
return results
def __test_using_best_weights(self, ckpt_path, test_dataloaders):
model = self.get_model()
Fix ddp tests + .test() (#2512) * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * fix deprecation warnings * added base tests for tpu * added base tests for tpu * Update pytorch_lightning/trainer/trainer.py Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com>
2020-07-07 16:24:56 +00:00
# if user requests the best checkpoint but we don't have it, error
Fixes .test() for ddp (#2570) * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint
2020-07-09 22:36:36 +00:00
if ckpt_path == 'best' and self.checkpoint_callback.save_top_k <= 0:
raise MisconfigurationException(
'ckpt_path is "best", but ModelCheckpoint is not configured to save the best model.'
)
Fixes .test() for ddp (#2570) * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint
2020-07-09 22:36:36 +00:00
# load best weights
if ckpt_path is not None:
# ckpt_path is 'best' so load the best model
if ckpt_path == 'best':
ckpt_path = self.checkpoint_callback.best_model_path
if len(ckpt_path) == 0:
rank_zero_warn(
f'.test() found no path for the best weights, {ckpt_path}. Please '
f'specify a path for a checkpoint .test(ckpt_path=PATH)'
)
return {}
Fixes .test() for ddp (#2570) * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint
2020-07-09 22:36:36 +00:00
ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
model.load_state_dict(ckpt['state_dict'])
Fixes .test() for ddp (#2570) * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint
2020-07-09 22:36:36 +00:00
# attach dataloaders
if test_dataloaders is not None:
self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders)
Fixes .test() for ddp (#2570) * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint
2020-07-09 22:36:36 +00:00
# run tests
self.tested_ckpt_path = ckpt_path
Fix ddp tests + .test() (#2512) * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * fix deprecation warnings * added base tests for tpu * added base tests for tpu * Update pytorch_lightning/trainer/trainer.py Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com>
2020-07-07 16:24:56 +00:00
self.testing = True
Fixes .test() for ddp (#2570) * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint
2020-07-09 22:36:36 +00:00
os.environ['PL_TESTING_MODE'] = '1'
Fix ddp tests + .test() (#2512) * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * fix deprecation warnings * added base tests for tpu * added base tests for tpu * Update pytorch_lightning/trainer/trainer.py Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com>
2020-07-07 16:24:56 +00:00
self.model = model
results = self.fit(model)
2020-03-06 11:57:14 +00:00
self.testing = False
Fixes .test() for ddp (#2570) * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint
2020-07-09 22:36:36 +00:00
del os.environ['PL_TESTING_MODE']
2020-03-06 11:57:14 +00:00
Fixes .test() for ddp (#2570) * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint
2020-07-09 22:36:36 +00:00
# teardown
if self.is_function_implemented('teardown'):
2020-06-25 15:10:17 +00:00
model_ref = self.get_model()
model_ref.teardown('test')
Fix ddp tests + .test() (#2512) * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * fix deprecation warnings * added base tests for tpu * added base tests for tpu * Update pytorch_lightning/trainer/trainer.py Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com>
2020-07-07 16:24:56 +00:00
return results
Fixes .test() for ddp (#2570) * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint
2020-07-09 22:36:36 +00:00
def __test_given_model(self, model, test_dataloaders):
# attach data
if test_dataloaders is not None:
self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders)
Fixes .test() for ddp (#2570) * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint * enable none checkpoint
2020-07-09 22:36:36 +00:00
# run test
# sets up testing so we short circuit to eval
self.testing = True
self.model = model
results = self.fit(model)
self.testing = False
# teardown
if self.is_function_implemented('teardown'):
model.teardown('test')
return results
def call_setup_hook(self, model):
# call setup after the ddp process has connected
stage_name = 'test' if self.testing else 'fit'
if self.datamodule is not None:
called = self.datamodule.has_setup_test if self.testing else self.datamodule.has_setup_fit
if not called:
self.datamodule.setup(stage_name)
self.setup(stage_name)
model.setup(stage_name)
def call_hook(self, hook_name, *args, **kwargs):
# always profile hooks
with self.profiler.profile(hook_name):
# first call trainer hook
if hasattr(self, hook_name):
trainer_hook = getattr(self, hook_name)
trainer_hook(*args, **kwargs)
# next call hook in lightningModule
output = None
model_ref = self.get_model()
if is_overridden(hook_name, model_ref):
hook_fx = getattr(model_ref, hook_name)
output = hook_fx(*args, **kwargs)
# if the PL module doesn't have the hook then call the accelator
# used to auto-reduce things for the user with Results obj
elif hasattr(self.accelerator_backend, hook_name):
accelerator_hook = getattr(self.accelerator_backend, hook_name)
output = accelerator_hook(*args, **kwargs)
return output
2020-09-11 01:58:47 +00:00
# add docstrings
Trainer.__init__.__doc__ = docstrings.trainer.init
Trainer.fit.__doc__ = docstrings.trainer.fit
Trainer.test.__doc__ = docstrings.trainer.test