DP device fix (#3196)
This commit is contained in:
parent
4d98419bb8
commit
0aee137ba7
|
@ -47,6 +47,9 @@ class DataParallelBackend(Accelerator):
|
||||||
self.trainer.lr_schedulers = lr_schedulers
|
self.trainer.lr_schedulers = lr_schedulers
|
||||||
self.trainer.optimizer_frequencies = optimizer_frequencies
|
self.trainer.optimizer_frequencies = optimizer_frequencies
|
||||||
|
|
||||||
|
# init torch data parallel
|
||||||
|
model = self.__init_torch_data_parallel(model)
|
||||||
|
|
||||||
# hack forward to do autocast for the user
|
# hack forward to do autocast for the user
|
||||||
self.model_autocast_original_forward = model.forward
|
self.model_autocast_original_forward = model.forward
|
||||||
|
|
||||||
|
@ -54,9 +57,6 @@ class DataParallelBackend(Accelerator):
|
||||||
if self.trainer.amp_backend:
|
if self.trainer.amp_backend:
|
||||||
model = self.__init_half_precision(model)
|
model = self.__init_half_precision(model)
|
||||||
|
|
||||||
# init torch data parallel
|
|
||||||
model = self.__init_torch_data_parallel(model)
|
|
||||||
|
|
||||||
self.trainer.model = model
|
self.trainer.model = model
|
||||||
|
|
||||||
def __init_torch_data_parallel(self, model):
|
def __init_torch_data_parallel(self, model):
|
||||||
|
|
Loading…
Reference in New Issue