slow tpu train (#2033)

* use parallel loader

* Revert "use parallel loader"

This reverts commit ed6e7583

* select tpu id for pl

* condition if tpu_id is None

* added info to changelog

* Revert "condition if tpu_id is None"

This reverts commit 1fb6e586

* Apply suggestions from code review

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
Lezwon Castelino 2020-06-03 04:18:05 +05:30 committed by GitHub
parent fa696ce512
commit 943c4b20af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 6 additions and 4 deletions

View File

@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Attribute `best_model_path` to `ModelCheckpoint` for storing and later retrieving the path to the best saved model file ([#1799](https://github.com/PyTorchLightning/pytorch-lightning/pull/1799))
- Speed up single-core TPU training by loading data using `ParallelLoader` ([#2033](https://github.com/PyTorchLightning/pytorch-lightning/pull/2033))
### Changed
- Allow user to select individual TPU core to train on ([#1729](https://github.com/PyTorchLightning/pytorch-lightning/pull/1729))

View File

@ -249,8 +249,8 @@ class TrainerEvaluationLoopMixin(ABC):
dl_outputs = []
# on TPU we have to wrap it under the ParallelLoader
if self.use_tpu and self.tpu_id is None:
device = xm.xla_device()
if self.use_tpu:
device = xm.xla_device(self.tpu_id)
dataloader = xla_pl.ParallelLoader(dataloader, [device])
dataloader = dataloader.per_device_loader(device)

View File

@ -415,8 +415,8 @@ class TrainerTrainLoopMixin(ABC):
train_dataloader = self.train_dataloader
# on TPU we have to wrap it under the ParallelLoader
if self.use_tpu and self.tpu_id is None:
device = xm.xla_device()
if self.use_tpu:
device = xm.xla_device(self.tpu_id)
train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device])
train_dataloader = train_dataloader.per_device_loader(device)