From 943c4b20af232bd197c7a94972dd57b77b2a090d Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Wed, 3 Jun 2020 04:18:05 +0530 Subject: [PATCH] slow tpu train (#2033) * use parallel loader * Revert "use parallel loader" This reverts commit ed6e7583 * select tpu id for pl * condition if tpu_id is None * added info to changelog * Revert "condition if tpu_id is None" This reverts commit 1fb6e586 * Apply suggestions from code review Co-authored-by: Jirka Borovec --- CHANGELOG.md | 2 ++ pytorch_lightning/trainer/evaluation_loop.py | 4 ++-- pytorch_lightning/trainer/training_loop.py | 4 ++-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 47d460d1b0..0ee273b87f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Attribute `best_model_path` to `ModelCheckpoint` for storing and later retrieving the path to the best saved model file ([#1799](https://github.com/PyTorchLightning/pytorch-lightning/pull/1799)) +- Speed up single-core TPU training by loading data using `ParallelLoader` ([#2033](https://github.com/PyTorchLightning/pytorch-lightning/pull/2033)) + ### Changed - Allow user to select individual TPU core to train on ([#1729](https://github.com/PyTorchLightning/pytorch-lightning/pull/1729)) diff --git a/pytorch_lightning/trainer/evaluation_loop.py b/pytorch_lightning/trainer/evaluation_loop.py index 58d16632b3..a8c866f990 100644 --- a/pytorch_lightning/trainer/evaluation_loop.py +++ b/pytorch_lightning/trainer/evaluation_loop.py @@ -249,8 +249,8 @@ class TrainerEvaluationLoopMixin(ABC): dl_outputs = [] # on TPU we have to wrap it under the ParallelLoader - if self.use_tpu and self.tpu_id is None: - device = xm.xla_device() + if self.use_tpu: + device = xm.xla_device(self.tpu_id) dataloader = xla_pl.ParallelLoader(dataloader, [device]) dataloader = dataloader.per_device_loader(device) diff --git a/pytorch_lightning/trainer/training_loop.py b/pytorch_lightning/trainer/training_loop.py index a1a3e35a6e..c1bb08fb7a 100644 --- a/pytorch_lightning/trainer/training_loop.py +++ b/pytorch_lightning/trainer/training_loop.py @@ -415,8 +415,8 @@ class TrainerTrainLoopMixin(ABC): train_dataloader = self.train_dataloader # on TPU we have to wrap it under the ParallelLoader - if self.use_tpu and self.tpu_id is None: - device = xm.xla_device() + if self.use_tpu: + device = xm.xla_device(self.tpu_id) train_dataloader = xla_pl.ParallelLoader(train_dataloader, [device]) train_dataloader = train_dataloader.per_device_loader(device)