refactored model tests
This commit is contained in:
parent
7d1e1eb7f9
commit
53f1f18442
|
@ -460,6 +460,8 @@ class Trainer(TrainerIO):
|
|||
|
||||
# run through amp wrapper
|
||||
if self.use_amp:
|
||||
model.cuda(self.data_parallel_device_ids[0])
|
||||
|
||||
# An example
|
||||
model, optimizers = amp.initialize(
|
||||
model, self.optimizers, opt_level=self.amp_level,
|
||||
|
|
Loading…
Reference in New Issue