parent
705e576417
commit
67dc9bc135
|
@ -803,7 +803,10 @@ class Trainer(TrainerIOMixin,
|
|||
# 16 bit mixed precision training using apex
|
||||
self.amp_level = amp_level
|
||||
self.precision = precision
|
||||
if self.precision == 16:
|
||||
|
||||
assert self.precision == 32 or self.precision == 16, 'only 32 or 16 bit precision supported'
|
||||
|
||||
if self.precision == 16 and num_tpu_cores is None:
|
||||
use_amp = True
|
||||
self.init_amp(use_amp)
|
||||
|
||||
|
|
Loading…
Reference in New Issue