set dp as default backend
This commit is contained in:
parent
2096a0aa84
commit
c163caf8cb
|
@ -43,6 +43,7 @@ class LightningDataParallel(DataParallel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def parallel_apply(self, replicas, inputs, kwargs):
|
def parallel_apply(self, replicas, inputs, kwargs):
|
||||||
|
print('LDP')
|
||||||
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
|
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
|
||||||
|
|
||||||
|
|
||||||
|
@ -55,6 +56,7 @@ class LightningDistributedDataParallel(DistributedDataParallel):
|
||||||
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
|
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
|
||||||
|
|
||||||
def forward(self, *inputs, **kwargs):
|
def forward(self, *inputs, **kwargs):
|
||||||
|
print('LDDP')
|
||||||
self._sync_params()
|
self._sync_params()
|
||||||
if self.device_ids:
|
if self.device_ids:
|
||||||
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
|
||||||
|
|
Loading…
Reference in New Issue