refactored model tests

This commit is contained in:
William Falcon 2019-07-24 13:50:02 -04:00
parent 7d1e1eb7f9
commit 53f1f18442
1 changed files with 2 additions and 0 deletions

View File

@ -460,6 +460,8 @@ class Trainer(TrainerIO):
# run through amp wrapper
if self.use_amp:
model.cuda(self.data_parallel_device_ids[0])
# An example
model, optimizers = amp.initialize(
model, self.optimizers, opt_level=self.amp_level,