[doc] Add more reference around predict_step (#7997)

* add predict examples

* update on comments
This commit is contained in:
thomas chaton 2021-06-16 12:23:27 +01:00 committed by GitHub
parent d2983c7c51
commit 917cf83638
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 87 additions and 0 deletions

View File

@ -489,6 +489,14 @@ For research, LightningModules are best structured as systems.
reconstruction_loss = nn.functional.mse_loss(recons, x)
self.log('val_reconstruction', reconstruction_loss)
def predict_step(self, batch, batch_idx, dataloader_idx):
x, _ = batch
# encode
# for predictions, we could return the embedding or the reconstruction or both based on our need.
x = x.view(x.size(0), -1)
return self.encoder(x)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.0002)
@ -510,6 +518,7 @@ The methods above are part of the lightning interface:
- training_step
- validation_step
- test_step
- predict_step
- configure_optimizers
Note that in this case, the train loop and val loop are exactly the same. We can of course reuse this code.
@ -554,12 +563,20 @@ Inference in research
^^^^^^^^^^^^^^^^^^^^^
In the case where we want to perform inference with the system we can add a `forward` method to the LightningModule.
.. note:: When using forward, you are responsible to call :func:`~torch.nn.Module.eval` and use the :func:`~torch.no_grad` context manager.
.. code-block:: python
class Autoencoder(pl.LightningModule):
def forward(self, x):
return self.decoder(x)
model = Autoencoder()
model.eval()
with torch.no_grad():
reconstruction = model(embedding)
The advantage of adding a forward is that in complex systems, you can do a much more involved inference procedure,
such as text generation:
@ -575,6 +592,25 @@ such as text generation:
...
return decoded
In the case where you want to scale your inference, you should be using
:meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step`.
.. code-block:: python
class Autoencoder(pl.LightningModule):
def forward(self, x):
return self.decoder(x)
def predict_step(self, batch, batch_idx, dataloader_idx = None)
# this calls forward
return self(batch)
data_module = ...
model = Autoencoder()
trainer = Trainer(gpus=2)
trainer.predict(model, data_module)
Inference in production
^^^^^^^^^^^^^^^^^^^^^^^
For cases like production, you might want to iterate different models inside a LightningModule.
@ -615,6 +651,10 @@ For cases like production, you might want to iterate different models inside a L
acc = FM.accuracy(y_hat, y)
return loss, acc
def predict_step(self, batch, batch_idx, dataloader_idx):
x, y = batch
y_hat = self.model(x)
def configure_optimizers(self):
return torch.optim.Adam(self.model.parameters(), lr=0.02)

View File

@ -87,6 +87,12 @@ class LitAutoEncoder(pl.LightningModule):
loss = F.mse_loss(x_hat, x)
self.log('test_loss', loss, on_step=True)
def predict_step(self, batch, batch_idx, dataloader_idx=None):
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
return self.decoder(z)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
return optimizer
@ -113,10 +119,15 @@ class MyDataModule(pl.LightningDataModule):
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)
def predict_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)
def cli_main():
cli = LightningCLI(LitAutoEncoder, MyDataModule, seed_everything_default=1234)
cli.trainer.test(cli.model, datamodule=cli.datamodule)
predictions = cli.trainer.predict(cli.model, datamodule=cli.datamodule)
print(predictions[0])
if __name__ == '__main__':

View File

@ -100,6 +100,10 @@ class LitClassifier(pl.LightningModule):
loss = F.cross_entropy(y_hat, y)
self.log('test_loss', loss)
def predict_step(self, batch, batch_idx, dataloader_idx=None):
x, y = batch
return self.backbone(x)
def configure_optimizers(self):
# self.hparams available because we called self.save_hyperparameters()
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
@ -126,10 +130,15 @@ class MyDataModule(pl.LightningDataModule):
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)
def predict_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)
def cli_main():
cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234)
cli.trainer.test(cli.model, datamodule=cli.datamodule)
predictions = cli.trainer.predict(cli.model, datamodule=cli.datamodule)
print(predictions[0])
if __name__ == '__main__':

View File

@ -62,6 +62,10 @@ class ModelToProfile(LightningModule):
loss = self.criterion(outputs, labels)
self.log("val_loss", loss)
def predict_step(self, batch, batch_idx, dataloader_idx: int = None):
inputs = batch[0]
return self.model(inputs)
def configure_optimizers(self):
return torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9)

View File

@ -1115,6 +1115,29 @@ class LightningModule(
By default, it calls :meth:`~pytorch_lightning.core.lightning.LightningModule.forward`.
Override to add any processing logic.
The :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` is used
to scale inference on multi-devices.
To prevent an OOM error, it is possible to use :class:`~pytorch_lightning.callbacks.BasePredictionWriter`
callback to write the predictions to disk or database after each batch or on epoch end.
The :class:`~pytorch_lightning.callbacks.BasePredictionWriter` should be used while using a spawn
based accelerator. This happens for ``Trainer(accelerator="ddp_spawn")``
or training on 8 TPU cores with ``Trainer(tpu_cores=8)`` as predictions won't be returned.
Example ::
class MyModel(LightningModule):
def predicts_step(self, batch, batch_idx, dataloader_idx):
return self(batch)
dm = ...
model = MyModel()
trainer = Trainer(gpus=2)
predictions = trainer.predict(model, dm)
Args:
batch: Current batch
batch_idx: Index of current batch