fix 16 bit for TPU (#1020)

* tpu 16 bit

* tpu 16 bit

* tpu 16 bit
This commit is contained in:
William Falcon 2020-03-03 00:26:59 -05:00 committed by GitHub
parent 705e576417
commit 67dc9bc135
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 1 deletions

View File

@ -803,7 +803,10 @@ class Trainer(TrainerIOMixin,
# 16 bit mixed precision training using apex
self.amp_level = amp_level
self.precision = precision
if self.precision == 16:
assert self.precision == 32 or self.precision == 16, 'only 32 or 16 bit precision supported'
if self.precision == 16 and num_tpu_cores is None:
use_amp = True
self.init_amp(use_amp)