ref: added data connector (#3285)

* ref: added data connector

* ref: added data connector

* ref: added data connector

* ref: added data connector

* ref: added data connector

* ref: added data connector

* ref: added data connector

* ref: added data connector

* ref: added data connector

* ref: added data connector
This commit is contained in:
William Falcon 2020-08-31 11:08:22 -04:00 committed by GitHub
parent b4887d7647
commit b0f77a74a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 164 additions and 95 deletions

View File

@ -1,3 +1,11 @@
.. testsetup:: *
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.trainer.trainer import Trainer
.. _converting:
**************************************
How to organize PyTorch into Lightning
**************************************

View File

@ -15,6 +15,7 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_utils import is_overridden
class ConfigValidator(object):
@ -22,13 +23,6 @@ class ConfigValidator(object):
def __init__(self, trainer):
self.trainer = trainer
def enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloaders, datamodule):
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
if (train_dataloader is not None or val_dataloaders is not None) and datamodule is not None:
raise MisconfigurationException(
'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule'
)
def verify_loop_configurations(self, model: LightningModule):
r"""
Checks that the model is configured correctly before training or testing is started.
@ -48,7 +42,7 @@ class ConfigValidator(object):
# -----------------------------------
# verify model has a training step
# -----------------------------------
has_training_step = self.trainer.is_overridden('training_step', model)
has_training_step = is_overridden('training_step', model)
if not has_training_step:
raise MisconfigurationException(
'No `training_step()` method defined. Lightning `Trainer` expects as minimum a'
@ -58,7 +52,7 @@ class ConfigValidator(object):
# -----------------------------------
# verify model has a train dataloader
# -----------------------------------
has_train_dataloader = self.trainer.is_overridden('train_dataloader', model)
has_train_dataloader = is_overridden('train_dataloader', model)
if not has_train_dataloader:
raise MisconfigurationException(
'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a'
@ -68,7 +62,7 @@ class ConfigValidator(object):
# -----------------------------------
# verify model has optimizer
# -----------------------------------
has_optimizers = self.trainer.is_overridden('configure_optimizers', model)
has_optimizers = is_overridden('configure_optimizers', model)
if not has_optimizers:
raise MisconfigurationException(
'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a'
@ -83,8 +77,8 @@ class ConfigValidator(object):
if eval_loop_name == 'validation':
loader_name = 'val_dataloader'
has_loader = self.trainer.is_overridden(loader_name, model)
has_step = self.trainer.is_overridden(step_name, model)
has_loader = is_overridden(loader_name, model)
has_step = is_overridden(step_name, model)
if has_loader and not has_step:
rank_zero_warn(

View File

@ -0,0 +1,100 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from typing import List, Union
from torch.utils.data import DataLoader
from pytorch_lightning.utilities.model_utils import is_overridden
class DataConnector(object):
def __init__(self, trainer):
self.trainer = trainer
def attach_data(self, model, train_dataloader, val_dataloaders, datamodule):
# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(train_dataloader, LightningDataModule):
datamodule = train_dataloader
train_dataloader = None
self.__enforce_datamodule_dataloader_override(train_dataloader, val_dataloaders, datamodule)
# set up the passed in dataloaders (if needed)
self.attach_dataloaders(model, train_dataloader, val_dataloaders)
self.attach_datamodule(model, datamodule, 'fit')
def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloaders, datamodule):
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
if (train_dataloader is not None or val_dataloaders is not None) and datamodule is not None:
raise MisconfigurationException(
'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule'
)
def attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=None, test_dataloaders=None):
# when dataloader is passed via fit, patch the train_dataloader
# functions to overwrite with these implementations
if train_dataloader is not None:
model.train_dataloader = _PatchDataLoader(train_dataloader)
if val_dataloaders is not None:
model.val_dataloader = _PatchDataLoader(val_dataloaders)
if test_dataloaders is not None:
model.test_dataloader = _PatchDataLoader(test_dataloaders)
def attach_datamodule(self, model, datamodule, stage):
# We use datamodule if it's been provided on .fit or .test, otherwise we check model for it
datamodule = datamodule or getattr(model, 'datamodule', None)
# If we have a datamodule, attach necessary hooks + dataloaders
if datamodule:
# Override loader hooks
if is_overridden('train_dataloader', datamodule):
model.train_dataloader = datamodule.train_dataloader
if is_overridden('val_dataloader', datamodule):
model.val_dataloader = datamodule.val_dataloader
if is_overridden('test_dataloader', datamodule):
model.test_dataloader = datamodule.test_dataloader
# Override transfer_batch_to_device if dataset-specific to_device logic has been defined in datamodule
if is_overridden('transfer_batch_to_device', datamodule):
model.transfer_batch_to_device = datamodule.transfer_batch_to_device
self.trainer.datamodule = datamodule
class _PatchDataLoader(object):
r"""
Callable object for patching dataloaders passed into trainer.fit().
Use this class to override model.*_dataloader() and be pickle-compatible.
Args:
dataloader: Dataloader object to return when called.
"""
def __init__(self, dataloader: Union[List[DataLoader], DataLoader]):
self.dataloader = dataloader
# cannot pickle __code__ so cannot verify if PatchDataloader
# exists which shows dataloader methods have been overwritten.
# so, we hack it by using the string representation
self.patch_loader_code = str(self.__call__.__code__)
def __call__(self) -> Union[List[DataLoader], DataLoader]:
return self.dataloader

View File

@ -55,6 +55,7 @@ from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.cloud_io import is_remote_path
from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop
from pytorch_lightning.trainer.data_connector import DataConnector
# warnings to ignore in trainer
warnings.filterwarnings(
@ -607,6 +608,7 @@ class Trainer(
# tracks internal state for debugging
self.dev_debugger = InternalDebugger(self)
self.config_validator = ConfigValidator(self)
self.data_connector = DataConnector(self)
self.accelerator_backend = None
# loops
@ -974,18 +976,8 @@ class Trainer(
"""
results = None
# bind logger and other properties
self.copy_trainer_model_properties(model)
# clean hparams
if hasattr(model, 'hparams'):
parsing.clean_namespace(model.hparams)
# links data to the trainer
self.attach_data(model, train_dataloader, val_dataloaders, datamodule)
# check that model is configured correctly
self.config_validator.verify_loop_configurations(model)
# setup data, etc...
self.setup_fit(model, train_dataloader, val_dataloaders, datamodule)
# hook
self.call_hook('on_fit_start', model)
@ -1031,6 +1023,20 @@ class Trainer(
# used for testing or when we need to know that training succeeded
return results or 1
def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule):
# bind logger and other properties
self.copy_trainer_model_properties(model)
# clean hparams
if hasattr(model, 'hparams'):
parsing.clean_namespace(model.hparams)
# links data to the trainer
self.data_connector.attach_data(model, train_dataloader, val_dataloaders, datamodule)
# check that model is configured correctly
self.config_validator.verify_loop_configurations(model)
def prepare_data(self, model):
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
# or in the case where each node needs to do its own manipulation in which case just local_rank=0
@ -1040,18 +1046,6 @@ class Trainer(
model.prepare_data()
self._is_data_prepared = True
def attach_data(self, model, train_dataloader, val_dataloaders, datamodule):
# if a datamodule comes in as the second arg, then fix it for the user
if isinstance(train_dataloader, LightningDataModule):
datamodule = train_dataloader
train_dataloader = None
self.config_validator.enforce_datamodule_dataloader_override(train_dataloader, val_dataloaders, datamodule)
# set up the passed in dataloaders (if needed)
self.__attach_dataloaders(model, train_dataloader, val_dataloaders)
self.__attach_datamodule(model, datamodule, 'fit')
def select_accelerator(self):
# SLURM ddp
use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks
@ -1105,40 +1099,6 @@ class Trainer(
else:
return self.node_rank == 0 and self.local_rank == 0 and should_call_dm_prepare_data
def __attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=None, test_dataloaders=None):
# when dataloader is passed via fit, patch the train_dataloader
# functions to overwrite with these implementations
if train_dataloader is not None:
model.train_dataloader = _PatchDataLoader(train_dataloader)
if val_dataloaders is not None:
model.val_dataloader = _PatchDataLoader(val_dataloaders)
if test_dataloaders is not None:
model.test_dataloader = _PatchDataLoader(test_dataloaders)
def __attach_datamodule(self, model, datamodule, stage):
# We use datamodule if it's been provided on .fit or .test, otherwise we check model for it
datamodule = datamodule or getattr(model, 'datamodule', None)
# If we have a datamodule, attach necessary hooks + dataloaders
if datamodule:
# Override loader hooks
if self.is_overridden('train_dataloader', datamodule):
model.train_dataloader = datamodule.train_dataloader
if self.is_overridden('val_dataloader', datamodule):
model.val_dataloader = datamodule.val_dataloader
if self.is_overridden('test_dataloader', datamodule):
model.test_dataloader = datamodule.test_dataloader
# Override transfer_batch_to_device if dataset-specific to_device logic has been defined in datamodule
if self.is_overridden('transfer_batch_to_device', datamodule):
model.transfer_batch_to_device = datamodule.transfer_batch_to_device
self.datamodule = datamodule
def run_pretrain_routine(self, model: LightningModule):
"""Sanity check a few things before starting actual training.
@ -1348,7 +1308,7 @@ class Trainer(
)
# Attach datamodule to get setup/prepare_data added to model before the call to it below
self.__attach_datamodule(model or self.get_model(), datamodule, 'test')
self.data_connector.attach_datamodule(model or self.get_model(), datamodule, 'test')
if model is not None:
results = self.__test_given_model(model, test_dataloaders)
@ -1386,7 +1346,7 @@ class Trainer(
# attach dataloaders
if test_dataloaders is not None:
self.__attach_dataloaders(model, test_dataloaders=test_dataloaders)
self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders)
# run tests
self.tested_ckpt_path = ckpt_path
@ -1408,7 +1368,7 @@ class Trainer(
# attach data
if test_dataloaders is not None:
self.__attach_dataloaders(model, test_dataloaders=test_dataloaders)
self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders)
# run test
# sets up testing so we short circuit to eval
@ -1472,28 +1432,6 @@ class Trainer(
return output
class _PatchDataLoader(object):
r"""
Callable object for patching dataloaders passed into trainer.fit().
Use this class to override model.*_dataloader() and be pickle-compatible.
Args:
dataloader: Dataloader object to return when called.
"""
def __init__(self, dataloader: Union[List[DataLoader], DataLoader]):
self.dataloader = dataloader
# cannot pickle __code__ so cannot verify if PatchDataloader
# exists which shows dataloader methods have been overwritten.
# so, we hack it by using the string representation
self.patch_loader_code = str(self.__call__.__code__)
def __call__(self) -> Union[List[DataLoader], DataLoader]:
return self.dataloader
def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]:
if 0 <= batches <= 1:
return batches

View File

@ -0,0 +1,29 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.datamodule import LightningDataModule
def is_overridden(method_name: str, model: LightningModule) -> bool:
# if you pass DataModule instead of None or a LightningModule, we use LightningDataModule as super
# TODO - refector this function to accept model_name, instance, parent so it makes more sense
super_object = LightningModule if not isinstance(model, LightningDataModule) else LightningDataModule
# assert model, 'no model passes'
if not hasattr(model, method_name):
# in case of calling deprecated method
return False
instance_attr = getattr(model, method_name)
if not instance_attr:
return False
super_attr = getattr(super_object, method_name)
# when code pointers are different, it was implemented
if hasattr(instance_attr, 'patch_loader_code'):
# cannot pickle __code__ so cannot verify if PatchDataloader
# exists which shows dataloader methods have been overwritten.
# so, we hack it by using the string representation
is_overridden = instance_attr.patch_loader_code != str(super_attr.__code__)
else:
is_overridden = instance_attr.__code__ is not super_attr.__code__
return is_overridden