DP device fix (#3196)
This commit is contained in:
parent
4d98419bb8
commit
0aee137ba7
|
@ -47,6 +47,9 @@ class DataParallelBackend(Accelerator):
|
|||
self.trainer.lr_schedulers = lr_schedulers
|
||||
self.trainer.optimizer_frequencies = optimizer_frequencies
|
||||
|
||||
# init torch data parallel
|
||||
model = self.__init_torch_data_parallel(model)
|
||||
|
||||
# hack forward to do autocast for the user
|
||||
self.model_autocast_original_forward = model.forward
|
||||
|
||||
|
@ -54,9 +57,6 @@ class DataParallelBackend(Accelerator):
|
|||
if self.trainer.amp_backend:
|
||||
model = self.__init_half_precision(model)
|
||||
|
||||
# init torch data parallel
|
||||
model = self.__init_torch_data_parallel(model)
|
||||
|
||||
self.trainer.model = model
|
||||
|
||||
def __init_torch_data_parallel(self, model):
|
||||
|
|
Loading…
Reference in New Issue