ref: decouple apex second attemp part 1/n (#4052)
This commit is contained in:
parent
5b261a230e
commit
e854d3744c
|
@ -76,12 +76,10 @@ class Accelerator(object):
|
|||
|
||||
# scale loss for 16 bit
|
||||
if self.trainer.precision == 16:
|
||||
closure_loss = model_ref.amp_scale_loss(
|
||||
closure_loss,
|
||||
optimizer,
|
||||
opt_idx,
|
||||
amp_backend=self.trainer.amp_backend
|
||||
)
|
||||
if self.trainer.amp_backend == AMPType.NATIVE:
|
||||
closure_loss = self.trainer.scaler.scale(closure_loss)
|
||||
else:
|
||||
closure_loss = amp.scale_loss(closure_loss, optimizer)
|
||||
|
||||
# enter amp context
|
||||
if self.trainer.amp_backend == AMPType.APEX:
|
||||
|
|
|
@ -318,14 +318,6 @@ class ModelHooks:
|
|||
"""
|
||||
loss.backward()
|
||||
|
||||
def amp_scale_loss(self, unscaled_loss: Tensor, optimizer: Optimizer, optimizer_idx: int, amp_backend: AMPType):
|
||||
if amp_backend == AMPType.NATIVE:
|
||||
scaled_loss = self.trainer.scaler.scale(unscaled_loss)
|
||||
else:
|
||||
scaled_loss = amp.scale_loss(unscaled_loss, optimizer)
|
||||
|
||||
return scaled_loss
|
||||
|
||||
|
||||
class DataHooks:
|
||||
def prepare_data(self) -> None:
|
||||
|
|
Loading…
Reference in New Issue