refactored model tests
This commit is contained in:
parent
ef843d5f96
commit
ecb68b52f8
|
@ -458,9 +458,10 @@ class Trainer(TrainerIO):
|
|||
# filter out the weights that were done on gpu so we can load on good old cpus
|
||||
self.optimizers = model.configure_optimizers()
|
||||
|
||||
model.cuda(self.data_parallel_device_ids[0])
|
||||
|
||||
# run through amp wrapper
|
||||
if self.use_amp:
|
||||
model.cuda(self.data_parallel_device_ids[0])
|
||||
|
||||
# An example
|
||||
model, optimizers = amp.initialize(
|
||||
|
|
Loading…
Reference in New Issue