[doc] Add more reference around predict_step (#7997)
* add predict examples * update on comments
This commit is contained in:
parent
d2983c7c51
commit
917cf83638
|
@ -489,6 +489,14 @@ For research, LightningModules are best structured as systems.
|
|||
reconstruction_loss = nn.functional.mse_loss(recons, x)
|
||||
self.log('val_reconstruction', reconstruction_loss)
|
||||
|
||||
def predict_step(self, batch, batch_idx, dataloader_idx):
|
||||
x, _ = batch
|
||||
|
||||
# encode
|
||||
# for predictions, we could return the embedding or the reconstruction or both based on our need.
|
||||
x = x.view(x.size(0), -1)
|
||||
return self.encoder(x)
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.parameters(), lr=0.0002)
|
||||
|
||||
|
@ -510,6 +518,7 @@ The methods above are part of the lightning interface:
|
|||
- training_step
|
||||
- validation_step
|
||||
- test_step
|
||||
- predict_step
|
||||
- configure_optimizers
|
||||
|
||||
Note that in this case, the train loop and val loop are exactly the same. We can of course reuse this code.
|
||||
|
@ -554,12 +563,20 @@ Inference in research
|
|||
^^^^^^^^^^^^^^^^^^^^^
|
||||
In the case where we want to perform inference with the system we can add a `forward` method to the LightningModule.
|
||||
|
||||
.. note:: When using forward, you are responsible to call :func:`~torch.nn.Module.eval` and use the :func:`~torch.no_grad` context manager.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class Autoencoder(pl.LightningModule):
|
||||
|
||||
def forward(self, x):
|
||||
return self.decoder(x)
|
||||
|
||||
model = Autoencoder()
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
reconstruction = model(embedding)
|
||||
|
||||
The advantage of adding a forward is that in complex systems, you can do a much more involved inference procedure,
|
||||
such as text generation:
|
||||
|
||||
|
@ -575,6 +592,25 @@ such as text generation:
|
|||
...
|
||||
return decoded
|
||||
|
||||
In the case where you want to scale your inference, you should be using
|
||||
:meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class Autoencoder(pl.LightningModule):
|
||||
|
||||
def forward(self, x):
|
||||
return self.decoder(x)
|
||||
|
||||
def predict_step(self, batch, batch_idx, dataloader_idx = None)
|
||||
# this calls forward
|
||||
return self(batch)
|
||||
|
||||
data_module = ...
|
||||
model = Autoencoder()
|
||||
trainer = Trainer(gpus=2)
|
||||
trainer.predict(model, data_module)
|
||||
|
||||
Inference in production
|
||||
^^^^^^^^^^^^^^^^^^^^^^^
|
||||
For cases like production, you might want to iterate different models inside a LightningModule.
|
||||
|
@ -615,6 +651,10 @@ For cases like production, you might want to iterate different models inside a L
|
|||
acc = FM.accuracy(y_hat, y)
|
||||
return loss, acc
|
||||
|
||||
def predict_step(self, batch, batch_idx, dataloader_idx):
|
||||
x, y = batch
|
||||
y_hat = self.model(x)
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.Adam(self.model.parameters(), lr=0.02)
|
||||
|
||||
|
|
|
@ -87,6 +87,12 @@ class LitAutoEncoder(pl.LightningModule):
|
|||
loss = F.mse_loss(x_hat, x)
|
||||
self.log('test_loss', loss, on_step=True)
|
||||
|
||||
def predict_step(self, batch, batch_idx, dataloader_idx=None):
|
||||
x, y = batch
|
||||
x = x.view(x.size(0), -1)
|
||||
z = self.encoder(x)
|
||||
return self.decoder(z)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
|
||||
return optimizer
|
||||
|
@ -113,10 +119,15 @@ class MyDataModule(pl.LightningDataModule):
|
|||
def test_dataloader(self):
|
||||
return DataLoader(self.mnist_test, batch_size=self.batch_size)
|
||||
|
||||
def predict_dataloader(self):
|
||||
return DataLoader(self.mnist_test, batch_size=self.batch_size)
|
||||
|
||||
|
||||
def cli_main():
|
||||
cli = LightningCLI(LitAutoEncoder, MyDataModule, seed_everything_default=1234)
|
||||
cli.trainer.test(cli.model, datamodule=cli.datamodule)
|
||||
predictions = cli.trainer.predict(cli.model, datamodule=cli.datamodule)
|
||||
print(predictions[0])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -100,6 +100,10 @@ class LitClassifier(pl.LightningModule):
|
|||
loss = F.cross_entropy(y_hat, y)
|
||||
self.log('test_loss', loss)
|
||||
|
||||
def predict_step(self, batch, batch_idx, dataloader_idx=None):
|
||||
x, y = batch
|
||||
return self.backbone(x)
|
||||
|
||||
def configure_optimizers(self):
|
||||
# self.hparams available because we called self.save_hyperparameters()
|
||||
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
|
||||
|
@ -126,10 +130,15 @@ class MyDataModule(pl.LightningDataModule):
|
|||
def test_dataloader(self):
|
||||
return DataLoader(self.mnist_test, batch_size=self.batch_size)
|
||||
|
||||
def predict_dataloader(self):
|
||||
return DataLoader(self.mnist_test, batch_size=self.batch_size)
|
||||
|
||||
|
||||
def cli_main():
|
||||
cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234)
|
||||
cli.trainer.test(cli.model, datamodule=cli.datamodule)
|
||||
predictions = cli.trainer.predict(cli.model, datamodule=cli.datamodule)
|
||||
print(predictions[0])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -62,6 +62,10 @@ class ModelToProfile(LightningModule):
|
|||
loss = self.criterion(outputs, labels)
|
||||
self.log("val_loss", loss)
|
||||
|
||||
def predict_step(self, batch, batch_idx, dataloader_idx: int = None):
|
||||
inputs = batch[0]
|
||||
return self.model(inputs)
|
||||
|
||||
def configure_optimizers(self):
|
||||
return torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9)
|
||||
|
||||
|
|
|
@ -1115,6 +1115,29 @@ class LightningModule(
|
|||
By default, it calls :meth:`~pytorch_lightning.core.lightning.LightningModule.forward`.
|
||||
Override to add any processing logic.
|
||||
|
||||
The :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` is used
|
||||
to scale inference on multi-devices.
|
||||
|
||||
To prevent an OOM error, it is possible to use :class:`~pytorch_lightning.callbacks.BasePredictionWriter`
|
||||
callback to write the predictions to disk or database after each batch or on epoch end.
|
||||
|
||||
The :class:`~pytorch_lightning.callbacks.BasePredictionWriter` should be used while using a spawn
|
||||
based accelerator. This happens for ``Trainer(accelerator="ddp_spawn")``
|
||||
or training on 8 TPU cores with ``Trainer(tpu_cores=8)`` as predictions won't be returned.
|
||||
|
||||
Example ::
|
||||
|
||||
class MyModel(LightningModule):
|
||||
|
||||
def predicts_step(self, batch, batch_idx, dataloader_idx):
|
||||
return self(batch)
|
||||
|
||||
dm = ...
|
||||
model = MyModel()
|
||||
trainer = Trainer(gpus=2)
|
||||
predictions = trainer.predict(model, dm)
|
||||
|
||||
|
||||
Args:
|
||||
batch: Current batch
|
||||
batch_idx: Index of current batch
|
||||
|
|
Loading…
Reference in New Issue