From 67dc9bc1350c3aa8f07e35d4c0a6b7bcd2fa630c Mon Sep 17 00:00:00 2001 From: William Falcon Date: Tue, 3 Mar 2020 00:26:59 -0500 Subject: [PATCH] fix 16 bit for TPU (#1020) * tpu 16 bit * tpu 16 bit * tpu 16 bit --- pytorch_lightning/trainer/trainer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index a791d307cf..891ae9764f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -803,7 +803,10 @@ class Trainer(TrainerIOMixin, # 16 bit mixed precision training using apex self.amp_level = amp_level self.precision = precision - if self.precision == 16: + + assert self.precision == 32 or self.precision == 16, 'only 32 or 16 bit precision supported' + + if self.precision == 16 and num_tpu_cores is None: use_amp = True self.init_amp(use_amp)