Fix typo in `BasePredictionWriter` documentation (#18381)

This commit is contained in:
Maxim Borodin 2023-08-24 13:34:50 +03:00 committed by GitHub
parent 9496d9aef1
commit 9d7a2848b7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 2 deletions

View File

@ -94,7 +94,7 @@ By using the predict step in Lightning you get free distributed inference using
torch.save(batch_indices, os.path.join(self.output_dir, f"batch_indices_{trainer.global_rank}.pt"))
# or you can set `writer_interval="batch"` and override `write_on_batch_end` to save
# or you can set `write_interval="batch"` and override `write_on_batch_end` to save
# predictions at batch level
pred_writer = CustomWriter(output_dir="pred_path", write_interval="epoch")
trainer = Trainer(accelerator="gpu", strategy="ddp", devices=8, callbacks=[pred_writer])

View File

@ -93,7 +93,7 @@ class BasePredictionWriter(Callback):
torch.save(batch_indices, os.path.join(self.output_dir, f"batch_indices_{trainer.global_rank}.pt"))
# or you can set `writer_interval="batch"` and override `write_on_batch_end` to save
# or you can set `write_interval="batch"` and override `write_on_batch_end` to save
# predictions at batch level
pred_writer = CustomWriter(output_dir="pred_path", write_interval="epoch")
trainer = Trainer(accelerator="gpu", strategy="ddp", devices=8, callbacks=[pred_writer])