slow tpu train (#2033)
* use parallel loader
* Revert "use parallel loader"
This reverts commit ed6e7583
* select tpu id for pl
* condition if tpu_id is None
* added info to changelog
* Revert "condition if tpu_id is None"
This reverts commit 1fb6e586
* Apply suggestions from code review
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
This commit is contained in:
parent
fa696ce512
commit
943c4b20af
|
@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
|
|||
|
||||
- Attribute `best_model_path` to `ModelCheckpoint` for storing and later retrieving the path to the best saved model file ([#1799](https://github.com/PyTorchLightning/pytorch-lightning/pull/1799))
|
||||
|
||||
- Speed up single-core TPU training by loading data using `ParallelLoader` ([#2033](https://github.com/PyTorchLightning/pytorch-lightning/pull/2033))
|
||||
|
||||
### Changed
|
||||
|
||||
- Allow user to select individual TPU core to train on ([#1729](https://github.com/PyTorchLightning/pytorch-lightning/pull/1729))
|
||||
|
|
|
@ -249,8 +249,8 @@ class TrainerEvaluationLoopMixin(ABC):
|
|||
dl_outputs = []
|
||||
|
||||
# on TPU we have to wrap it under the ParallelLoader
|
||||
if self.use_tpu and self.tpu_id is None:
|
||||
device = xm.xla_device()
|
||||
if self.use_tpu:
|
||||
device = xm.xla_device(self.tpu_id)
|
||||
dataloader = xla_pl.ParallelLoader(dataloader, [device])
|
||||
dataloader = dataloader.per_device_loader(device)
|
||||
|
||||
|
|
|
@ -415,8 +415,8 @@ class TrainerTrainLoopMixin(ABC):
|
|||
train_dataloader = self.train_dataloader
|
||||
|
||||
# on TPU we have to wrap it under the ParallelLoader
|
||||
if self.use_tpu and self.tpu_id is None:
|
||||
device = xm.xla_device()
|
||||
if self.use_tpu:
|
||||
device = xm.xla_device(self.tpu_id)
|
||||
train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device])
|
||||
train_dataloader = train_dataloader.per_device_loader(device)
|
||||
|
||||
|
|
Loading…
Reference in New Issue