parent
705e576417
commit
67dc9bc135
|
@ -803,7 +803,10 @@ class Trainer(TrainerIOMixin,
|
||||||
# 16 bit mixed precision training using apex
|
# 16 bit mixed precision training using apex
|
||||||
self.amp_level = amp_level
|
self.amp_level = amp_level
|
||||||
self.precision = precision
|
self.precision = precision
|
||||||
if self.precision == 16:
|
|
||||||
|
assert self.precision == 32 or self.precision == 16, 'only 32 or 16 bit precision supported'
|
||||||
|
|
||||||
|
if self.precision == 16 and num_tpu_cores is None:
|
||||||
use_amp = True
|
use_amp = True
|
||||||
self.init_amp(use_amp)
|
self.init_amp(use_amp)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue