added amp level option

This commit is contained in:
William Falcon 2019-05-16 15:55:21 -04:00
parent e052a3bc92
commit 60d4b80322
1 changed files with 3 additions and 2 deletions

View File

@ -372,8 +372,9 @@ class Trainer(TrainerIO):
if self.use_amp:
for optimizer in self.optimizers:
with amp.scale_loss(loss, optimizer) as scaled_loss:
optimizer.backward(scaled_loss)
# scaled_loss.backward()
scaled_loss.backward()
for param in self.model.parameters():
print(param.grad.float().sum())
else:
loss.backward()