diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index a791d307cf..891ae9764f 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -803,7 +803,10 @@ class Trainer(TrainerIOMixin, # 16 bit mixed precision training using apex self.amp_level = amp_level self.precision = precision - if self.precision == 16: + + assert self.precision == 32 or self.precision == 16, 'only 32 or 16 bit precision supported' + + if self.precision == 16 and num_tpu_cores is None: use_amp = True self.init_amp(use_amp)