refactored model tests

This commit is contained in:
William Falcon 2019-07-24 13:56:49 -04:00
parent ef843d5f96
commit ecb68b52f8
1 changed files with 2 additions and 1 deletions

View File

@ -458,9 +458,10 @@ class Trainer(TrainerIO):
# filter out the weights that were done on gpu so we can load on good old cpus
self.optimizers = model.configure_optimizers()
model.cuda(self.data_parallel_device_ids[0])
# run through amp wrapper
if self.use_amp:
model.cuda(self.data_parallel_device_ids[0])
# An example
model, optimizers = amp.initialize(