Fix typo in `BasePredictionWriter` documentation (#18381)
This commit is contained in:
parent
9496d9aef1
commit
9d7a2848b7
|
@ -94,7 +94,7 @@ By using the predict step in Lightning you get free distributed inference using
|
|||
torch.save(batch_indices, os.path.join(self.output_dir, f"batch_indices_{trainer.global_rank}.pt"))
|
||||
|
||||
|
||||
# or you can set `writer_interval="batch"` and override `write_on_batch_end` to save
|
||||
# or you can set `write_interval="batch"` and override `write_on_batch_end` to save
|
||||
# predictions at batch level
|
||||
pred_writer = CustomWriter(output_dir="pred_path", write_interval="epoch")
|
||||
trainer = Trainer(accelerator="gpu", strategy="ddp", devices=8, callbacks=[pred_writer])
|
||||
|
|
|
@ -93,7 +93,7 @@ class BasePredictionWriter(Callback):
|
|||
torch.save(batch_indices, os.path.join(self.output_dir, f"batch_indices_{trainer.global_rank}.pt"))
|
||||
|
||||
|
||||
# or you can set `writer_interval="batch"` and override `write_on_batch_end` to save
|
||||
# or you can set `write_interval="batch"` and override `write_on_batch_end` to save
|
||||
# predictions at batch level
|
||||
pred_writer = CustomWriter(output_dir="pred_path", write_interval="epoch")
|
||||
trainer = Trainer(accelerator="gpu", strategy="ddp", devices=8, callbacks=[pred_writer])
|
||||
|
|
Loading…
Reference in New Issue