fix accumulated grad norm fixes #87 (#88)

* Update trainer.py

* Update trainer.py

* Update trainer.py

* Update trainer.py

* Update trainer.py

* Update trainer.py

* Update trainer.py
This commit is contained in:
William Falcon 2019-08-10 08:32:45 -04:00 committed by GitHub
parent 09d4475cc7
commit 73d08557ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 7 deletions

View File

@ -899,6 +899,9 @@ We recommend you switch to ddp if you want to use amp
self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic) self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic)
# accumulate loss (if accumulate_grad_batches = 1 no effect)
loss = loss / self.accumulate_grad_batches
# backward pass # backward pass
if self.use_amp: if self.use_amp:
# scale loss when using amp # scale loss when using amp
@ -918,12 +921,11 @@ We recommend you switch to ddp if you want to use amp
for param in model.parameters(): for param in model.parameters():
print(param.grad.float().sum()) print(param.grad.float().sum())
# avoid memory leaks # track total loss for logging (avoid mem leaks)
self.batch_loss_value += loss.item() self.batch_loss_value += loss.item()
# gradient update with accumulated gradients # gradient update with accumulated gradients
if (self.batch_nb + 1) % self.accumulate_grad_batches == 0: if (self.batch_nb + 1) % self.accumulate_grad_batches == 0:
# clip gradients # clip gradients
if self.gradient_clip > 0: if self.gradient_clip > 0:
model = self.__get_model() model = self.__get_model()
@ -941,11 +943,7 @@ We recommend you switch to ddp if you want to use amp
# clear gradients # clear gradients
optimizer.zero_grad() optimizer.zero_grad()
# queuing loss across batches blows it up proportionally... # calculate running loss for display
# divide out the number accumulated
self.batch_loss_value = self.batch_loss_value / self.accumulate_grad_batches
# track loss
self.running_loss.append(self.batch_loss_value) self.running_loss.append(self.batch_loss_value)
self.batch_loss_value = 0 self.batch_loss_value = 0
self.avg_loss = np.mean(self.running_loss[-100:]) self.avg_loss = np.mean(self.running_loss[-100:])