From b0f77a74a1a5cfa8d67fab5730ecd37d1915a184 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 31 Aug 2020 11:08:22 -0400 Subject: [PATCH] ref: added data connector (#3285) * ref: added data connector * ref: added data connector * ref: added data connector * ref: added data connector * ref: added data connector * ref: added data connector * ref: added data connector * ref: added data connector * ref: added data connector * ref: added data connector --- docs/source/converting.rst | 8 ++ .../trainer/configuration_validator.py | 18 +-- pytorch_lightning/trainer/data_connector.py | 100 +++++++++++++++++ pytorch_lightning/trainer/trainer.py | 104 ++++-------------- pytorch_lightning/utilities/model_utils.py | 29 +++++ 5 files changed, 164 insertions(+), 95 deletions(-) create mode 100644 pytorch_lightning/trainer/data_connector.py create mode 100644 pytorch_lightning/utilities/model_utils.py diff --git a/docs/source/converting.rst b/docs/source/converting.rst index 3dcf261f2c..ccdef25755 100644 --- a/docs/source/converting.rst +++ b/docs/source/converting.rst @@ -1,3 +1,11 @@ +.. testsetup:: * + + from pytorch_lightning.core.lightning import LightningModule + from pytorch_lightning.core.datamodule import LightningDataModule + from pytorch_lightning.trainer.trainer import Trainer + +.. _converting: + ************************************** How to organize PyTorch into Lightning ************************************** diff --git a/pytorch_lightning/trainer/configuration_validator.py b/pytorch_lightning/trainer/configuration_validator.py index adcb5fffe4..01c0119e85 100644 --- a/pytorch_lightning/trainer/configuration_validator.py +++ b/pytorch_lightning/trainer/configuration_validator.py @@ -15,6 +15,7 @@ from pytorch_lightning.core.lightning import LightningModule from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.model_utils import is_overridden class ConfigValidator(object): @@ -22,13 +23,6 @@ class ConfigValidator(object): def __init__(self, trainer): self.trainer = trainer - def enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloaders, datamodule): - # If you supply a datamodule you can't supply train_dataloader or val_dataloaders - if (train_dataloader is not None or val_dataloaders is not None) and datamodule is not None: - raise MisconfigurationException( - 'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule' - ) - def verify_loop_configurations(self, model: LightningModule): r""" Checks that the model is configured correctly before training or testing is started. @@ -48,7 +42,7 @@ class ConfigValidator(object): # ----------------------------------- # verify model has a training step # ----------------------------------- - has_training_step = self.trainer.is_overridden('training_step', model) + has_training_step = is_overridden('training_step', model) if not has_training_step: raise MisconfigurationException( 'No `training_step()` method defined. Lightning `Trainer` expects as minimum a' @@ -58,7 +52,7 @@ class ConfigValidator(object): # ----------------------------------- # verify model has a train dataloader # ----------------------------------- - has_train_dataloader = self.trainer.is_overridden('train_dataloader', model) + has_train_dataloader = is_overridden('train_dataloader', model) if not has_train_dataloader: raise MisconfigurationException( 'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a' @@ -68,7 +62,7 @@ class ConfigValidator(object): # ----------------------------------- # verify model has optimizer # ----------------------------------- - has_optimizers = self.trainer.is_overridden('configure_optimizers', model) + has_optimizers = is_overridden('configure_optimizers', model) if not has_optimizers: raise MisconfigurationException( 'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a' @@ -83,8 +77,8 @@ class ConfigValidator(object): if eval_loop_name == 'validation': loader_name = 'val_dataloader' - has_loader = self.trainer.is_overridden(loader_name, model) - has_step = self.trainer.is_overridden(step_name, model) + has_loader = is_overridden(loader_name, model) + has_step = is_overridden(step_name, model) if has_loader and not has_step: rank_zero_warn( diff --git a/pytorch_lightning/trainer/data_connector.py b/pytorch_lightning/trainer/data_connector.py new file mode 100644 index 0000000000..3e03a91cff --- /dev/null +++ b/pytorch_lightning/trainer/data_connector.py @@ -0,0 +1,100 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pytorch_lightning.core.datamodule import LightningDataModule +from pytorch_lightning.utilities.exceptions import MisconfigurationException +from typing import List, Union +from torch.utils.data import DataLoader +from pytorch_lightning.utilities.model_utils import is_overridden + + +class DataConnector(object): + + def __init__(self, trainer): + self.trainer = trainer + + def attach_data(self, model, train_dataloader, val_dataloaders, datamodule): + # if a datamodule comes in as the second arg, then fix it for the user + if isinstance(train_dataloader, LightningDataModule): + datamodule = train_dataloader + train_dataloader = None + + self.__enforce_datamodule_dataloader_override(train_dataloader, val_dataloaders, datamodule) + + # set up the passed in dataloaders (if needed) + self.attach_dataloaders(model, train_dataloader, val_dataloaders) + self.attach_datamodule(model, datamodule, 'fit') + + def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloaders, datamodule): + # If you supply a datamodule you can't supply train_dataloader or val_dataloaders + if (train_dataloader is not None or val_dataloaders is not None) and datamodule is not None: + raise MisconfigurationException( + 'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule' + ) + + def attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=None, test_dataloaders=None): + # when dataloader is passed via fit, patch the train_dataloader + # functions to overwrite with these implementations + if train_dataloader is not None: + model.train_dataloader = _PatchDataLoader(train_dataloader) + + if val_dataloaders is not None: + model.val_dataloader = _PatchDataLoader(val_dataloaders) + + if test_dataloaders is not None: + model.test_dataloader = _PatchDataLoader(test_dataloaders) + + def attach_datamodule(self, model, datamodule, stage): + + # We use datamodule if it's been provided on .fit or .test, otherwise we check model for it + datamodule = datamodule or getattr(model, 'datamodule', None) + + # If we have a datamodule, attach necessary hooks + dataloaders + if datamodule: + + # Override loader hooks + if is_overridden('train_dataloader', datamodule): + model.train_dataloader = datamodule.train_dataloader + if is_overridden('val_dataloader', datamodule): + model.val_dataloader = datamodule.val_dataloader + if is_overridden('test_dataloader', datamodule): + model.test_dataloader = datamodule.test_dataloader + + # Override transfer_batch_to_device if dataset-specific to_device logic has been defined in datamodule + if is_overridden('transfer_batch_to_device', datamodule): + model.transfer_batch_to_device = datamodule.transfer_batch_to_device + + self.trainer.datamodule = datamodule + + +class _PatchDataLoader(object): + r""" + Callable object for patching dataloaders passed into trainer.fit(). + Use this class to override model.*_dataloader() and be pickle-compatible. + + Args: + dataloader: Dataloader object to return when called. + + """ + + def __init__(self, dataloader: Union[List[DataLoader], DataLoader]): + self.dataloader = dataloader + + # cannot pickle __code__ so cannot verify if PatchDataloader + # exists which shows dataloader methods have been overwritten. + # so, we hack it by using the string representation + self.patch_loader_code = str(self.__call__.__code__) + + def __call__(self) -> Union[List[DataLoader], DataLoader]: + return self.dataloader diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index cf1940105a..78398dbcb0 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -55,6 +55,7 @@ from pytorch_lightning.utilities.debugging import InternalDebugger from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.cloud_io import is_remote_path from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop +from pytorch_lightning.trainer.data_connector import DataConnector # warnings to ignore in trainer warnings.filterwarnings( @@ -607,6 +608,7 @@ class Trainer( # tracks internal state for debugging self.dev_debugger = InternalDebugger(self) self.config_validator = ConfigValidator(self) + self.data_connector = DataConnector(self) self.accelerator_backend = None # loops @@ -974,18 +976,8 @@ class Trainer( """ results = None - # bind logger and other properties - self.copy_trainer_model_properties(model) - - # clean hparams - if hasattr(model, 'hparams'): - parsing.clean_namespace(model.hparams) - - # links data to the trainer - self.attach_data(model, train_dataloader, val_dataloaders, datamodule) - - # check that model is configured correctly - self.config_validator.verify_loop_configurations(model) + # setup data, etc... + self.setup_fit(model, train_dataloader, val_dataloaders, datamodule) # hook self.call_hook('on_fit_start', model) @@ -1031,6 +1023,20 @@ class Trainer( # used for testing or when we need to know that training succeeded return results or 1 + def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule): + # bind logger and other properties + self.copy_trainer_model_properties(model) + + # clean hparams + if hasattr(model, 'hparams'): + parsing.clean_namespace(model.hparams) + + # links data to the trainer + self.data_connector.attach_data(model, train_dataloader, val_dataloaders, datamodule) + + # check that model is configured correctly + self.config_validator.verify_loop_configurations(model) + def prepare_data(self, model): # on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0 # or in the case where each node needs to do its own manipulation in which case just local_rank=0 @@ -1040,18 +1046,6 @@ class Trainer( model.prepare_data() self._is_data_prepared = True - def attach_data(self, model, train_dataloader, val_dataloaders, datamodule): - # if a datamodule comes in as the second arg, then fix it for the user - if isinstance(train_dataloader, LightningDataModule): - datamodule = train_dataloader - train_dataloader = None - - self.config_validator.enforce_datamodule_dataloader_override(train_dataloader, val_dataloaders, datamodule) - - # set up the passed in dataloaders (if needed) - self.__attach_dataloaders(model, train_dataloader, val_dataloaders) - self.__attach_datamodule(model, datamodule, 'fit') - def select_accelerator(self): # SLURM ddp use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks @@ -1105,40 +1099,6 @@ class Trainer( else: return self.node_rank == 0 and self.local_rank == 0 and should_call_dm_prepare_data - def __attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=None, test_dataloaders=None): - # when dataloader is passed via fit, patch the train_dataloader - # functions to overwrite with these implementations - if train_dataloader is not None: - model.train_dataloader = _PatchDataLoader(train_dataloader) - - if val_dataloaders is not None: - model.val_dataloader = _PatchDataLoader(val_dataloaders) - - if test_dataloaders is not None: - model.test_dataloader = _PatchDataLoader(test_dataloaders) - - def __attach_datamodule(self, model, datamodule, stage): - - # We use datamodule if it's been provided on .fit or .test, otherwise we check model for it - datamodule = datamodule or getattr(model, 'datamodule', None) - - # If we have a datamodule, attach necessary hooks + dataloaders - if datamodule: - - # Override loader hooks - if self.is_overridden('train_dataloader', datamodule): - model.train_dataloader = datamodule.train_dataloader - if self.is_overridden('val_dataloader', datamodule): - model.val_dataloader = datamodule.val_dataloader - if self.is_overridden('test_dataloader', datamodule): - model.test_dataloader = datamodule.test_dataloader - - # Override transfer_batch_to_device if dataset-specific to_device logic has been defined in datamodule - if self.is_overridden('transfer_batch_to_device', datamodule): - model.transfer_batch_to_device = datamodule.transfer_batch_to_device - - self.datamodule = datamodule - def run_pretrain_routine(self, model: LightningModule): """Sanity check a few things before starting actual training. @@ -1348,7 +1308,7 @@ class Trainer( ) # Attach datamodule to get setup/prepare_data added to model before the call to it below - self.__attach_datamodule(model or self.get_model(), datamodule, 'test') + self.data_connector.attach_datamodule(model or self.get_model(), datamodule, 'test') if model is not None: results = self.__test_given_model(model, test_dataloaders) @@ -1386,7 +1346,7 @@ class Trainer( # attach dataloaders if test_dataloaders is not None: - self.__attach_dataloaders(model, test_dataloaders=test_dataloaders) + self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders) # run tests self.tested_ckpt_path = ckpt_path @@ -1408,7 +1368,7 @@ class Trainer( # attach data if test_dataloaders is not None: - self.__attach_dataloaders(model, test_dataloaders=test_dataloaders) + self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders) # run test # sets up testing so we short circuit to eval @@ -1472,28 +1432,6 @@ class Trainer( return output -class _PatchDataLoader(object): - r""" - Callable object for patching dataloaders passed into trainer.fit(). - Use this class to override model.*_dataloader() and be pickle-compatible. - - Args: - dataloader: Dataloader object to return when called. - - """ - - def __init__(self, dataloader: Union[List[DataLoader], DataLoader]): - self.dataloader = dataloader - - # cannot pickle __code__ so cannot verify if PatchDataloader - # exists which shows dataloader methods have been overwritten. - # so, we hack it by using the string representation - self.patch_loader_code = str(self.__call__.__code__) - - def __call__(self) -> Union[List[DataLoader], DataLoader]: - return self.dataloader - - def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]: if 0 <= batches <= 1: return batches diff --git a/pytorch_lightning/utilities/model_utils.py b/pytorch_lightning/utilities/model_utils.py new file mode 100644 index 0000000000..d71e1fd5c0 --- /dev/null +++ b/pytorch_lightning/utilities/model_utils.py @@ -0,0 +1,29 @@ +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.core.datamodule import LightningDataModule + + +def is_overridden(method_name: str, model: LightningModule) -> bool: + # if you pass DataModule instead of None or a LightningModule, we use LightningDataModule as super + # TODO - refector this function to accept model_name, instance, parent so it makes more sense + super_object = LightningModule if not isinstance(model, LightningDataModule) else LightningDataModule + + # assert model, 'no model passes' + + if not hasattr(model, method_name): + # in case of calling deprecated method + return False + + instance_attr = getattr(model, method_name) + if not instance_attr: + return False + super_attr = getattr(super_object, method_name) + + # when code pointers are different, it was implemented + if hasattr(instance_attr, 'patch_loader_code'): + # cannot pickle __code__ so cannot verify if PatchDataloader + # exists which shows dataloader methods have been overwritten. + # so, we hack it by using the string representation + is_overridden = instance_attr.patch_loader_code != str(super_attr.__code__) + else: + is_overridden = instance_attr.__code__ is not super_attr.__code__ + return is_overridden