ref: added data connector (#3285)
* ref: added data connector * ref: added data connector * ref: added data connector * ref: added data connector * ref: added data connector * ref: added data connector * ref: added data connector * ref: added data connector * ref: added data connector * ref: added data connector
This commit is contained in:
parent
b4887d7647
commit
b0f77a74a1
|
@ -1,3 +1,11 @@
|
|||
.. testsetup:: *
|
||||
|
||||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.core.datamodule import LightningDataModule
|
||||
from pytorch_lightning.trainer.trainer import Trainer
|
||||
|
||||
.. _converting:
|
||||
|
||||
**************************************
|
||||
How to organize PyTorch into Lightning
|
||||
**************************************
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.utilities import rank_zero_warn
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.model_utils import is_overridden
|
||||
|
||||
|
||||
class ConfigValidator(object):
|
||||
|
@ -22,13 +23,6 @@ class ConfigValidator(object):
|
|||
def __init__(self, trainer):
|
||||
self.trainer = trainer
|
||||
|
||||
def enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloaders, datamodule):
|
||||
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
|
||||
if (train_dataloader is not None or val_dataloaders is not None) and datamodule is not None:
|
||||
raise MisconfigurationException(
|
||||
'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule'
|
||||
)
|
||||
|
||||
def verify_loop_configurations(self, model: LightningModule):
|
||||
r"""
|
||||
Checks that the model is configured correctly before training or testing is started.
|
||||
|
@ -48,7 +42,7 @@ class ConfigValidator(object):
|
|||
# -----------------------------------
|
||||
# verify model has a training step
|
||||
# -----------------------------------
|
||||
has_training_step = self.trainer.is_overridden('training_step', model)
|
||||
has_training_step = is_overridden('training_step', model)
|
||||
if not has_training_step:
|
||||
raise MisconfigurationException(
|
||||
'No `training_step()` method defined. Lightning `Trainer` expects as minimum a'
|
||||
|
@ -58,7 +52,7 @@ class ConfigValidator(object):
|
|||
# -----------------------------------
|
||||
# verify model has a train dataloader
|
||||
# -----------------------------------
|
||||
has_train_dataloader = self.trainer.is_overridden('train_dataloader', model)
|
||||
has_train_dataloader = is_overridden('train_dataloader', model)
|
||||
if not has_train_dataloader:
|
||||
raise MisconfigurationException(
|
||||
'No `train_dataloader()` method defined. Lightning `Trainer` expects as minimum a'
|
||||
|
@ -68,7 +62,7 @@ class ConfigValidator(object):
|
|||
# -----------------------------------
|
||||
# verify model has optimizer
|
||||
# -----------------------------------
|
||||
has_optimizers = self.trainer.is_overridden('configure_optimizers', model)
|
||||
has_optimizers = is_overridden('configure_optimizers', model)
|
||||
if not has_optimizers:
|
||||
raise MisconfigurationException(
|
||||
'No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a'
|
||||
|
@ -83,8 +77,8 @@ class ConfigValidator(object):
|
|||
if eval_loop_name == 'validation':
|
||||
loader_name = 'val_dataloader'
|
||||
|
||||
has_loader = self.trainer.is_overridden(loader_name, model)
|
||||
has_step = self.trainer.is_overridden(step_name, model)
|
||||
has_loader = is_overridden(loader_name, model)
|
||||
has_step = is_overridden(step_name, model)
|
||||
|
||||
if has_loader and not has_step:
|
||||
rank_zero_warn(
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
# Copyright The PyTorch Lightning team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pytorch_lightning.core.datamodule import LightningDataModule
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from typing import List, Union
|
||||
from torch.utils.data import DataLoader
|
||||
from pytorch_lightning.utilities.model_utils import is_overridden
|
||||
|
||||
|
||||
class DataConnector(object):
|
||||
|
||||
def __init__(self, trainer):
|
||||
self.trainer = trainer
|
||||
|
||||
def attach_data(self, model, train_dataloader, val_dataloaders, datamodule):
|
||||
# if a datamodule comes in as the second arg, then fix it for the user
|
||||
if isinstance(train_dataloader, LightningDataModule):
|
||||
datamodule = train_dataloader
|
||||
train_dataloader = None
|
||||
|
||||
self.__enforce_datamodule_dataloader_override(train_dataloader, val_dataloaders, datamodule)
|
||||
|
||||
# set up the passed in dataloaders (if needed)
|
||||
self.attach_dataloaders(model, train_dataloader, val_dataloaders)
|
||||
self.attach_datamodule(model, datamodule, 'fit')
|
||||
|
||||
def __enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloaders, datamodule):
|
||||
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
|
||||
if (train_dataloader is not None or val_dataloaders is not None) and datamodule is not None:
|
||||
raise MisconfigurationException(
|
||||
'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule'
|
||||
)
|
||||
|
||||
def attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=None, test_dataloaders=None):
|
||||
# when dataloader is passed via fit, patch the train_dataloader
|
||||
# functions to overwrite with these implementations
|
||||
if train_dataloader is not None:
|
||||
model.train_dataloader = _PatchDataLoader(train_dataloader)
|
||||
|
||||
if val_dataloaders is not None:
|
||||
model.val_dataloader = _PatchDataLoader(val_dataloaders)
|
||||
|
||||
if test_dataloaders is not None:
|
||||
model.test_dataloader = _PatchDataLoader(test_dataloaders)
|
||||
|
||||
def attach_datamodule(self, model, datamodule, stage):
|
||||
|
||||
# We use datamodule if it's been provided on .fit or .test, otherwise we check model for it
|
||||
datamodule = datamodule or getattr(model, 'datamodule', None)
|
||||
|
||||
# If we have a datamodule, attach necessary hooks + dataloaders
|
||||
if datamodule:
|
||||
|
||||
# Override loader hooks
|
||||
if is_overridden('train_dataloader', datamodule):
|
||||
model.train_dataloader = datamodule.train_dataloader
|
||||
if is_overridden('val_dataloader', datamodule):
|
||||
model.val_dataloader = datamodule.val_dataloader
|
||||
if is_overridden('test_dataloader', datamodule):
|
||||
model.test_dataloader = datamodule.test_dataloader
|
||||
|
||||
# Override transfer_batch_to_device if dataset-specific to_device logic has been defined in datamodule
|
||||
if is_overridden('transfer_batch_to_device', datamodule):
|
||||
model.transfer_batch_to_device = datamodule.transfer_batch_to_device
|
||||
|
||||
self.trainer.datamodule = datamodule
|
||||
|
||||
|
||||
class _PatchDataLoader(object):
|
||||
r"""
|
||||
Callable object for patching dataloaders passed into trainer.fit().
|
||||
Use this class to override model.*_dataloader() and be pickle-compatible.
|
||||
|
||||
Args:
|
||||
dataloader: Dataloader object to return when called.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader: Union[List[DataLoader], DataLoader]):
|
||||
self.dataloader = dataloader
|
||||
|
||||
# cannot pickle __code__ so cannot verify if PatchDataloader
|
||||
# exists which shows dataloader methods have been overwritten.
|
||||
# so, we hack it by using the string representation
|
||||
self.patch_loader_code = str(self.__call__.__code__)
|
||||
|
||||
def __call__(self) -> Union[List[DataLoader], DataLoader]:
|
||||
return self.dataloader
|
|
@ -55,6 +55,7 @@ from pytorch_lightning.utilities.debugging import InternalDebugger
|
|||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
from pytorch_lightning.utilities.cloud_io import is_remote_path
|
||||
from pytorch_lightning.trainer.evaluate_loop import EvaluationLoop
|
||||
from pytorch_lightning.trainer.data_connector import DataConnector
|
||||
|
||||
# warnings to ignore in trainer
|
||||
warnings.filterwarnings(
|
||||
|
@ -607,6 +608,7 @@ class Trainer(
|
|||
# tracks internal state for debugging
|
||||
self.dev_debugger = InternalDebugger(self)
|
||||
self.config_validator = ConfigValidator(self)
|
||||
self.data_connector = DataConnector(self)
|
||||
self.accelerator_backend = None
|
||||
|
||||
# loops
|
||||
|
@ -974,18 +976,8 @@ class Trainer(
|
|||
"""
|
||||
results = None
|
||||
|
||||
# bind logger and other properties
|
||||
self.copy_trainer_model_properties(model)
|
||||
|
||||
# clean hparams
|
||||
if hasattr(model, 'hparams'):
|
||||
parsing.clean_namespace(model.hparams)
|
||||
|
||||
# links data to the trainer
|
||||
self.attach_data(model, train_dataloader, val_dataloaders, datamodule)
|
||||
|
||||
# check that model is configured correctly
|
||||
self.config_validator.verify_loop_configurations(model)
|
||||
# setup data, etc...
|
||||
self.setup_fit(model, train_dataloader, val_dataloaders, datamodule)
|
||||
|
||||
# hook
|
||||
self.call_hook('on_fit_start', model)
|
||||
|
@ -1031,6 +1023,20 @@ class Trainer(
|
|||
# used for testing or when we need to know that training succeeded
|
||||
return results or 1
|
||||
|
||||
def setup_fit(self, model, train_dataloader, val_dataloaders, datamodule):
|
||||
# bind logger and other properties
|
||||
self.copy_trainer_model_properties(model)
|
||||
|
||||
# clean hparams
|
||||
if hasattr(model, 'hparams'):
|
||||
parsing.clean_namespace(model.hparams)
|
||||
|
||||
# links data to the trainer
|
||||
self.data_connector.attach_data(model, train_dataloader, val_dataloaders, datamodule)
|
||||
|
||||
# check that model is configured correctly
|
||||
self.config_validator.verify_loop_configurations(model)
|
||||
|
||||
def prepare_data(self, model):
|
||||
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
|
||||
# or in the case where each node needs to do its own manipulation in which case just local_rank=0
|
||||
|
@ -1040,18 +1046,6 @@ class Trainer(
|
|||
model.prepare_data()
|
||||
self._is_data_prepared = True
|
||||
|
||||
def attach_data(self, model, train_dataloader, val_dataloaders, datamodule):
|
||||
# if a datamodule comes in as the second arg, then fix it for the user
|
||||
if isinstance(train_dataloader, LightningDataModule):
|
||||
datamodule = train_dataloader
|
||||
train_dataloader = None
|
||||
|
||||
self.config_validator.enforce_datamodule_dataloader_override(train_dataloader, val_dataloaders, datamodule)
|
||||
|
||||
# set up the passed in dataloaders (if needed)
|
||||
self.__attach_dataloaders(model, train_dataloader, val_dataloaders)
|
||||
self.__attach_datamodule(model, datamodule, 'fit')
|
||||
|
||||
def select_accelerator(self):
|
||||
# SLURM ddp
|
||||
use_slurm_ddp = self.use_ddp and self.is_slurm_managing_tasks
|
||||
|
@ -1105,40 +1099,6 @@ class Trainer(
|
|||
else:
|
||||
return self.node_rank == 0 and self.local_rank == 0 and should_call_dm_prepare_data
|
||||
|
||||
def __attach_dataloaders(self, model, train_dataloader=None, val_dataloaders=None, test_dataloaders=None):
|
||||
# when dataloader is passed via fit, patch the train_dataloader
|
||||
# functions to overwrite with these implementations
|
||||
if train_dataloader is not None:
|
||||
model.train_dataloader = _PatchDataLoader(train_dataloader)
|
||||
|
||||
if val_dataloaders is not None:
|
||||
model.val_dataloader = _PatchDataLoader(val_dataloaders)
|
||||
|
||||
if test_dataloaders is not None:
|
||||
model.test_dataloader = _PatchDataLoader(test_dataloaders)
|
||||
|
||||
def __attach_datamodule(self, model, datamodule, stage):
|
||||
|
||||
# We use datamodule if it's been provided on .fit or .test, otherwise we check model for it
|
||||
datamodule = datamodule or getattr(model, 'datamodule', None)
|
||||
|
||||
# If we have a datamodule, attach necessary hooks + dataloaders
|
||||
if datamodule:
|
||||
|
||||
# Override loader hooks
|
||||
if self.is_overridden('train_dataloader', datamodule):
|
||||
model.train_dataloader = datamodule.train_dataloader
|
||||
if self.is_overridden('val_dataloader', datamodule):
|
||||
model.val_dataloader = datamodule.val_dataloader
|
||||
if self.is_overridden('test_dataloader', datamodule):
|
||||
model.test_dataloader = datamodule.test_dataloader
|
||||
|
||||
# Override transfer_batch_to_device if dataset-specific to_device logic has been defined in datamodule
|
||||
if self.is_overridden('transfer_batch_to_device', datamodule):
|
||||
model.transfer_batch_to_device = datamodule.transfer_batch_to_device
|
||||
|
||||
self.datamodule = datamodule
|
||||
|
||||
def run_pretrain_routine(self, model: LightningModule):
|
||||
"""Sanity check a few things before starting actual training.
|
||||
|
||||
|
@ -1348,7 +1308,7 @@ class Trainer(
|
|||
)
|
||||
|
||||
# Attach datamodule to get setup/prepare_data added to model before the call to it below
|
||||
self.__attach_datamodule(model or self.get_model(), datamodule, 'test')
|
||||
self.data_connector.attach_datamodule(model or self.get_model(), datamodule, 'test')
|
||||
|
||||
if model is not None:
|
||||
results = self.__test_given_model(model, test_dataloaders)
|
||||
|
@ -1386,7 +1346,7 @@ class Trainer(
|
|||
|
||||
# attach dataloaders
|
||||
if test_dataloaders is not None:
|
||||
self.__attach_dataloaders(model, test_dataloaders=test_dataloaders)
|
||||
self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders)
|
||||
|
||||
# run tests
|
||||
self.tested_ckpt_path = ckpt_path
|
||||
|
@ -1408,7 +1368,7 @@ class Trainer(
|
|||
|
||||
# attach data
|
||||
if test_dataloaders is not None:
|
||||
self.__attach_dataloaders(model, test_dataloaders=test_dataloaders)
|
||||
self.data_connector.attach_dataloaders(model, test_dataloaders=test_dataloaders)
|
||||
|
||||
# run test
|
||||
# sets up testing so we short circuit to eval
|
||||
|
@ -1472,28 +1432,6 @@ class Trainer(
|
|||
return output
|
||||
|
||||
|
||||
class _PatchDataLoader(object):
|
||||
r"""
|
||||
Callable object for patching dataloaders passed into trainer.fit().
|
||||
Use this class to override model.*_dataloader() and be pickle-compatible.
|
||||
|
||||
Args:
|
||||
dataloader: Dataloader object to return when called.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, dataloader: Union[List[DataLoader], DataLoader]):
|
||||
self.dataloader = dataloader
|
||||
|
||||
# cannot pickle __code__ so cannot verify if PatchDataloader
|
||||
# exists which shows dataloader methods have been overwritten.
|
||||
# so, we hack it by using the string representation
|
||||
self.patch_loader_code = str(self.__call__.__code__)
|
||||
|
||||
def __call__(self) -> Union[List[DataLoader], DataLoader]:
|
||||
return self.dataloader
|
||||
|
||||
|
||||
def _determine_batch_limits(batches: Union[int, float], name: str) -> Union[int, float]:
|
||||
if 0 <= batches <= 1:
|
||||
return batches
|
||||
|
|
|
@ -0,0 +1,29 @@
|
|||
from pytorch_lightning.core.lightning import LightningModule
|
||||
from pytorch_lightning.core.datamodule import LightningDataModule
|
||||
|
||||
|
||||
def is_overridden(method_name: str, model: LightningModule) -> bool:
|
||||
# if you pass DataModule instead of None or a LightningModule, we use LightningDataModule as super
|
||||
# TODO - refector this function to accept model_name, instance, parent so it makes more sense
|
||||
super_object = LightningModule if not isinstance(model, LightningDataModule) else LightningDataModule
|
||||
|
||||
# assert model, 'no model passes'
|
||||
|
||||
if not hasattr(model, method_name):
|
||||
# in case of calling deprecated method
|
||||
return False
|
||||
|
||||
instance_attr = getattr(model, method_name)
|
||||
if not instance_attr:
|
||||
return False
|
||||
super_attr = getattr(super_object, method_name)
|
||||
|
||||
# when code pointers are different, it was implemented
|
||||
if hasattr(instance_attr, 'patch_loader_code'):
|
||||
# cannot pickle __code__ so cannot verify if PatchDataloader
|
||||
# exists which shows dataloader methods have been overwritten.
|
||||
# so, we hack it by using the string representation
|
||||
is_overridden = instance_attr.patch_loader_code != str(super_attr.__code__)
|
||||
else:
|
||||
is_overridden = instance_attr.__code__ is not super_attr.__code__
|
||||
return is_overridden
|
Loading…
Reference in New Issue