diff --git a/docs/source/common/lightning_module.rst b/docs/source/common/lightning_module.rst index 824fa7a251..08ab0d07da 100644 --- a/docs/source/common/lightning_module.rst +++ b/docs/source/common/lightning_module.rst @@ -489,6 +489,14 @@ For research, LightningModules are best structured as systems. reconstruction_loss = nn.functional.mse_loss(recons, x) self.log('val_reconstruction', reconstruction_loss) + def predict_step(self, batch, batch_idx, dataloader_idx): + x, _ = batch + + # encode + # for predictions, we could return the embedding or the reconstruction or both based on our need. + x = x.view(x.size(0), -1) + return self.encoder(x) + def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=0.0002) @@ -510,6 +518,7 @@ The methods above are part of the lightning interface: - training_step - validation_step - test_step +- predict_step - configure_optimizers Note that in this case, the train loop and val loop are exactly the same. We can of course reuse this code. @@ -554,12 +563,20 @@ Inference in research ^^^^^^^^^^^^^^^^^^^^^ In the case where we want to perform inference with the system we can add a `forward` method to the LightningModule. +.. note:: When using forward, you are responsible to call :func:`~torch.nn.Module.eval` and use the :func:`~torch.no_grad` context manager. + .. code-block:: python class Autoencoder(pl.LightningModule): + def forward(self, x): return self.decoder(x) + model = Autoencoder() + model.eval() + with torch.no_grad(): + reconstruction = model(embedding) + The advantage of adding a forward is that in complex systems, you can do a much more involved inference procedure, such as text generation: @@ -575,6 +592,25 @@ such as text generation: ... return decoded +In the case where you want to scale your inference, you should be using +:meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step`. + +.. code-block:: python + + class Autoencoder(pl.LightningModule): + + def forward(self, x): + return self.decoder(x) + + def predict_step(self, batch, batch_idx, dataloader_idx = None) + # this calls forward + return self(batch) + + data_module = ... + model = Autoencoder() + trainer = Trainer(gpus=2) + trainer.predict(model, data_module) + Inference in production ^^^^^^^^^^^^^^^^^^^^^^^ For cases like production, you might want to iterate different models inside a LightningModule. @@ -615,6 +651,10 @@ For cases like production, you might want to iterate different models inside a L acc = FM.accuracy(y_hat, y) return loss, acc + def predict_step(self, batch, batch_idx, dataloader_idx): + x, y = batch + y_hat = self.model(x) + def configure_optimizers(self): return torch.optim.Adam(self.model.parameters(), lr=0.02) diff --git a/pl_examples/basic_examples/autoencoder.py b/pl_examples/basic_examples/autoencoder.py index 8ea03dabc9..70e46b037e 100644 --- a/pl_examples/basic_examples/autoencoder.py +++ b/pl_examples/basic_examples/autoencoder.py @@ -87,6 +87,12 @@ class LitAutoEncoder(pl.LightningModule): loss = F.mse_loss(x_hat, x) self.log('test_loss', loss, on_step=True) + def predict_step(self, batch, batch_idx, dataloader_idx=None): + x, y = batch + x = x.view(x.size(0), -1) + z = self.encoder(x) + return self.decoder(z) + def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) return optimizer @@ -113,10 +119,15 @@ class MyDataModule(pl.LightningDataModule): def test_dataloader(self): return DataLoader(self.mnist_test, batch_size=self.batch_size) + def predict_dataloader(self): + return DataLoader(self.mnist_test, batch_size=self.batch_size) + def cli_main(): cli = LightningCLI(LitAutoEncoder, MyDataModule, seed_everything_default=1234) cli.trainer.test(cli.model, datamodule=cli.datamodule) + predictions = cli.trainer.predict(cli.model, datamodule=cli.datamodule) + print(predictions[0]) if __name__ == '__main__': diff --git a/pl_examples/basic_examples/backbone_image_classifier.py b/pl_examples/basic_examples/backbone_image_classifier.py index 57cf97be00..0d43748cec 100644 --- a/pl_examples/basic_examples/backbone_image_classifier.py +++ b/pl_examples/basic_examples/backbone_image_classifier.py @@ -100,6 +100,10 @@ class LitClassifier(pl.LightningModule): loss = F.cross_entropy(y_hat, y) self.log('test_loss', loss) + def predict_step(self, batch, batch_idx, dataloader_idx=None): + x, y = batch + return self.backbone(x) + def configure_optimizers(self): # self.hparams available because we called self.save_hyperparameters() return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) @@ -126,10 +130,15 @@ class MyDataModule(pl.LightningDataModule): def test_dataloader(self): return DataLoader(self.mnist_test, batch_size=self.batch_size) + def predict_dataloader(self): + return DataLoader(self.mnist_test, batch_size=self.batch_size) + def cli_main(): cli = LightningCLI(LitClassifier, MyDataModule, seed_everything_default=1234) cli.trainer.test(cli.model, datamodule=cli.datamodule) + predictions = cli.trainer.predict(cli.model, datamodule=cli.datamodule) + print(predictions[0]) if __name__ == '__main__': diff --git a/pl_examples/basic_examples/profiler_example.py b/pl_examples/basic_examples/profiler_example.py index c79214af93..688eb15ef9 100644 --- a/pl_examples/basic_examples/profiler_example.py +++ b/pl_examples/basic_examples/profiler_example.py @@ -62,6 +62,10 @@ class ModelToProfile(LightningModule): loss = self.criterion(outputs, labels) self.log("val_loss", loss) + def predict_step(self, batch, batch_idx, dataloader_idx: int = None): + inputs = batch[0] + return self.model(inputs) + def configure_optimizers(self): return torch.optim.SGD(self.parameters(), lr=0.001, momentum=0.9) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index bc070b25e7..a1b2ce3a5e 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1115,6 +1115,29 @@ class LightningModule( By default, it calls :meth:`~pytorch_lightning.core.lightning.LightningModule.forward`. Override to add any processing logic. + The :meth:`~pytorch_lightning.core.lightning.LightningModule.predict_step` is used + to scale inference on multi-devices. + + To prevent an OOM error, it is possible to use :class:`~pytorch_lightning.callbacks.BasePredictionWriter` + callback to write the predictions to disk or database after each batch or on epoch end. + + The :class:`~pytorch_lightning.callbacks.BasePredictionWriter` should be used while using a spawn + based accelerator. This happens for ``Trainer(accelerator="ddp_spawn")`` + or training on 8 TPU cores with ``Trainer(tpu_cores=8)`` as predictions won't be returned. + + Example :: + + class MyModel(LightningModule): + + def predicts_step(self, batch, batch_idx, dataloader_idx): + return self(batch) + + dm = ... + model = MyModel() + trainer = Trainer(gpus=2) + predictions = trainer.predict(model, dm) + + Args: batch: Current batch batch_idx: Index of current batch