diff --git a/docs/source-pytorch/fabric/guide/lightning_module.rst b/docs/source-pytorch/fabric/guide/lightning_module.rst index e3fb72b399..2b2f659b7e 100644 --- a/docs/source-pytorch/fabric/guide/lightning_module.rst +++ b/docs/source-pytorch/fabric/guide/lightning_module.rst @@ -100,10 +100,10 @@ It is up to you to call everything at the right place. model.train() for epoch in range(num_epochs): for i, batch in enumerate(dataloader): + optimizer.zero_grad() loss = model.training_step(batch, i) fabric.backward(loss) optimizer.step() - optimizer.zero_grad() # Control when hooks are called if condition: diff --git a/examples/fabric/image_classifier/train_fabric.py b/examples/fabric/image_classifier/train_fabric.py index 157972aaea..bac892d0ba 100644 --- a/examples/fabric/image_classifier/train_fabric.py +++ b/examples/fabric/image_classifier/train_fabric.py @@ -77,7 +77,6 @@ def run(hparams): # by the command line. See all options: `lightning run model --help` fabric = Fabric() - fabric.hparams = hparams seed_everything(hparams.seed) # instead of torch.manual_seed(...) transform = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))])