diff --git a/README.md b/README.md index c3e2e8d671..665fcf3618 100644 --- a/README.md +++ b/README.md @@ -80,9 +80,9 @@ class LitClassifier(pl.LightningModule): # train! train_loader = DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32) -mnist_model = LitClassifier() +model = LitClassifier() trainer = pl.Trainer(gpus=8, precision=16) -trainer.fit(mnist_model, train_loader) +trainer.fit(model, train_loader) ``` Other examples: