* Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py * Update trainer.py
This commit is contained in:
parent
09d4475cc7
commit
73d08557ba
|
@ -899,6 +899,9 @@ We recommend you switch to ddp if you want to use amp
|
||||||
|
|
||||||
self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic)
|
self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic)
|
||||||
|
|
||||||
|
# accumulate loss (if accumulate_grad_batches = 1 no effect)
|
||||||
|
loss = loss / self.accumulate_grad_batches
|
||||||
|
|
||||||
# backward pass
|
# backward pass
|
||||||
if self.use_amp:
|
if self.use_amp:
|
||||||
# scale loss when using amp
|
# scale loss when using amp
|
||||||
|
@ -918,12 +921,11 @@ We recommend you switch to ddp if you want to use amp
|
||||||
for param in model.parameters():
|
for param in model.parameters():
|
||||||
print(param.grad.float().sum())
|
print(param.grad.float().sum())
|
||||||
|
|
||||||
# avoid memory leaks
|
# track total loss for logging (avoid mem leaks)
|
||||||
self.batch_loss_value += loss.item()
|
self.batch_loss_value += loss.item()
|
||||||
|
|
||||||
# gradient update with accumulated gradients
|
# gradient update with accumulated gradients
|
||||||
if (self.batch_nb + 1) % self.accumulate_grad_batches == 0:
|
if (self.batch_nb + 1) % self.accumulate_grad_batches == 0:
|
||||||
|
|
||||||
# clip gradients
|
# clip gradients
|
||||||
if self.gradient_clip > 0:
|
if self.gradient_clip > 0:
|
||||||
model = self.__get_model()
|
model = self.__get_model()
|
||||||
|
@ -941,11 +943,7 @@ We recommend you switch to ddp if you want to use amp
|
||||||
# clear gradients
|
# clear gradients
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
# queuing loss across batches blows it up proportionally...
|
# calculate running loss for display
|
||||||
# divide out the number accumulated
|
|
||||||
self.batch_loss_value = self.batch_loss_value / self.accumulate_grad_batches
|
|
||||||
|
|
||||||
# track loss
|
|
||||||
self.running_loss.append(self.batch_loss_value)
|
self.running_loss.append(self.batch_loss_value)
|
||||||
self.batch_loss_value = 0
|
self.batch_loss_value = 0
|
||||||
self.avg_loss = np.mean(self.running_loss[-100:])
|
self.avg_loss = np.mean(self.running_loss[-100:])
|
||||||
|
|
Loading…
Reference in New Issue