refactored model tests
This commit is contained in:
parent
ef843d5f96
commit
ecb68b52f8
|
@ -458,9 +458,10 @@ class Trainer(TrainerIO):
|
||||||
# filter out the weights that were done on gpu so we can load on good old cpus
|
# filter out the weights that were done on gpu so we can load on good old cpus
|
||||||
self.optimizers = model.configure_optimizers()
|
self.optimizers = model.configure_optimizers()
|
||||||
|
|
||||||
|
model.cuda(self.data_parallel_device_ids[0])
|
||||||
|
|
||||||
# run through amp wrapper
|
# run through amp wrapper
|
||||||
if self.use_amp:
|
if self.use_amp:
|
||||||
model.cuda(self.data_parallel_device_ids[0])
|
|
||||||
|
|
||||||
# An example
|
# An example
|
||||||
model, optimizers = amp.initialize(
|
model, optimizers = amp.initialize(
|
||||||
|
|
Loading…
Reference in New Issue