Added optimizer_idx to backward call (#733)
This commit is contained in:
parent
a804755e6e
commit
946aef6216
|
@ -124,12 +124,13 @@ class ModelHooks(torch.nn.Module):
|
|||
"""
|
||||
pass
|
||||
|
||||
def backward(self, use_amp, loss, optimizer):
|
||||
def backward(self, use_amp, loss, optimizer, optimizer_idx):
|
||||
"""Override backward with your own implementation if you need to
|
||||
|
||||
:param use_amp: Whether amp was requested or not
|
||||
:param loss: Loss is already scaled by accumulated grads
|
||||
:param optimizer: Current optimizer being used
|
||||
:param optimizer_idx: Index of the current optimizer being used
|
||||
:return:
|
||||
|
||||
Called to perform backward step.
|
||||
|
|
|
@ -490,13 +490,14 @@ class TrainerTrainLoopMixin(ABC):
|
|||
|
||||
# backward pass
|
||||
model_ref = self.get_model()
|
||||
model_ref.backward(self.use_amp, closure_loss, optimizer)
|
||||
model_ref.backward(self.use_amp, closure_loss, optimizer, opt_idx)
|
||||
|
||||
# track metrics for callbacks
|
||||
all_callback_metrics.append(callback_metrics)
|
||||
|
||||
# track progress bar metrics
|
||||
self.add_tqdm_metrics(progress_bar_metrics)
|
||||
self.add_tqdm_metrics(progress_bar_metrics)
|
||||
all_log_metrics.append(log_metrics)
|
||||
|
||||
# insert after step hook
|
||||
|
|
Loading…
Reference in New Issue