diff --git a/docs/source-pytorch/deploy/production_basic.rst b/docs/source-pytorch/deploy/production_basic.rst index e03a2b5a80..4dacb34233 100644 --- a/docs/source-pytorch/deploy/production_basic.rst +++ b/docs/source-pytorch/deploy/production_basic.rst @@ -94,7 +94,7 @@ By using the predict step in Lightning you get free distributed inference using torch.save(batch_indices, os.path.join(self.output_dir, f"batch_indices_{trainer.global_rank}.pt")) - # or you can set `writer_interval="batch"` and override `write_on_batch_end` to save + # or you can set `write_interval="batch"` and override `write_on_batch_end` to save # predictions at batch level pred_writer = CustomWriter(output_dir="pred_path", write_interval="epoch") trainer = Trainer(accelerator="gpu", strategy="ddp", devices=8, callbacks=[pred_writer]) diff --git a/src/lightning/pytorch/callbacks/prediction_writer.py b/src/lightning/pytorch/callbacks/prediction_writer.py index 74ee0b85a7..4279663a86 100644 --- a/src/lightning/pytorch/callbacks/prediction_writer.py +++ b/src/lightning/pytorch/callbacks/prediction_writer.py @@ -93,7 +93,7 @@ class BasePredictionWriter(Callback): torch.save(batch_indices, os.path.join(self.output_dir, f"batch_indices_{trainer.global_rank}.pt")) - # or you can set `writer_interval="batch"` and override `write_on_batch_end` to save + # or you can set `write_interval="batch"` and override `write_on_batch_end` to save # predictions at batch level pred_writer = CustomWriter(output_dir="pred_path", write_interval="epoch") trainer = Trainer(accelerator="gpu", strategy="ddp", devices=8, callbacks=[pred_writer])