ref: decouple apex second attemp part 1/n (#4052)

This commit is contained in:
William Falcon 2020-10-10 09:53:02 -04:00 committed by GitHub
parent 5b261a230e
commit e854d3744c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 14 deletions

View File

@ -76,12 +76,10 @@ class Accelerator(object):
# scale loss for 16 bit
if self.trainer.precision == 16:
closure_loss = model_ref.amp_scale_loss(
closure_loss,
optimizer,
opt_idx,
amp_backend=self.trainer.amp_backend
)
if self.trainer.amp_backend == AMPType.NATIVE:
closure_loss = self.trainer.scaler.scale(closure_loss)
else:
closure_loss = amp.scale_loss(closure_loss, optimizer)
# enter amp context
if self.trainer.amp_backend == AMPType.APEX:

View File

@ -318,14 +318,6 @@ class ModelHooks:
"""
loss.backward()
def amp_scale_loss(self, unscaled_loss: Tensor, optimizer: Optimizer, optimizer_idx: int, amp_backend: AMPType):
if amp_backend == AMPType.NATIVE:
scaled_loss = self.trainer.scaler.scale(unscaled_loss)
else:
scaled_loss = amp.scale_loss(unscaled_loss, optimizer)
return scaled_loss
class DataHooks:
def prepare_data(self) -> None: