* sampler

* sampler

* sampler

* check for dataloader type

* check for dataloader type
This commit is contained in:
William Falcon 2020-03-31 18:22:45 -04:00 committed by GitHub
parent aca8c7e6f3
commit 7de51f78ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 15 additions and 2 deletions

View File

@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Changed ### Changed
- Does not interfere with a default sampler ([#1318](https://github.com/PyTorchLightning/pytorch-lightning/pull/1318))
- Enhanced load_from_checkpoint to also forward params to the model ([#1307](https://github.com/PyTorchLightning/pytorch-lightning/pull/1307)) - Enhanced load_from_checkpoint to also forward params to the model ([#1307](https://github.com/PyTorchLightning/pytorch-lightning/pull/1307))
- Made `evalaute` method private >> `Trainer._evaluate(...)`. ([#1260](https://github.com/PyTorchLightning/pytorch-lightning/pull/1260)) - Made `evalaute` method private >> `Trainer._evaluate(...)`. ([#1260](https://github.com/PyTorchLightning/pytorch-lightning/pull/1260))

View File

@ -74,6 +74,15 @@ class TrainerDataLoadingMixin(ABC):
raise ValueError(msg) raise ValueError(msg)
def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader: def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
# don't do anything if it's not a dataloader
if not isinstance(dataloader, DataLoader):
return dataloader
# don't add sampler when user gives one
if dataloader.sampler is not None:
return dataloader
if self.use_ddp or self.use_ddp2 or self.use_tpu: if self.use_ddp or self.use_ddp2 or self.use_tpu:
dl_args = { dl_args = {
'dataset': dataloader.dataset, 'dataset': dataloader.dataset,

View File

@ -250,8 +250,11 @@ You must configure your job submission script correctly for the trainer to work.
.. note:: When running in DDP mode, any errors in your code will show up as an NCCL issue. .. note:: When running in DDP mode, any errors in your code will show up as an NCCL issue.
Set the `NCCL_DEBUG=INFO` flag to see the ACTUAL error. Set the `NCCL_DEBUG=INFO` flag to see the ACTUAL error.
Finally, make sure to add a distributed sampler to your dataset. The distributed sampler copies a Normally now you would need to add a distributed sampler to your dataset, however
portion of your dataset onto each GPU. (World_size = gpus_per_node * nb_nodes). Lightning automates this for you. But if you still need to set a sampler Lightning will
not interfere nor automate it.
Here's an example of how to add your own sampler (again no need with Lightning).
.. code-block:: python .. code-block:: python