Fix enforce_datamodule_dataloader_override() for iterable datasets (#2957)

This function has the if statement `if (train_dataloader or val_dataloaders) and datamodule:`.


The issue is similar to that in https://github.com/PyTorchLightning/pytorch-lightning/pull/1560. The problem is that the `if(dl)` translates to `if(bool(dl))`, but there's no dataloader.__bool__ so bool() uses dataloader.__len__ > 0. But... dataloader.__len__ uses IterableDataset.__len__ for IterableDatasets for which __len__ is undefined.

The fix is also the same, the `if dl` should be replaced by `if dl is not None`.

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
This commit is contained in:
SiddhantRanade 2020-08-13 15:06:17 -06:00 committed by GitHub
parent 53f855cdbf
commit 88bfed371e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -10,7 +10,7 @@ class ConfigValidator(object):
def enforce_datamodule_dataloader_override(self, train_dataloader, val_dataloaders, datamodule):
# If you supply a datamodule you can't supply train_dataloader or val_dataloaders
if (train_dataloader or val_dataloaders) and datamodule:
if (train_dataloader is not None or val_dataloaders is not None) and datamodule is not None:
raise MisconfigurationException(
'You cannot pass train_dataloader or val_dataloaders to trainer.fit if you supply a datamodule'
)