diff --git a/tests/tests_pytorch/callbacks/test_prediction_writer.py b/tests/tests_pytorch/callbacks/test_prediction_writer.py index ca6548a397..3a3ea69fa9 100644 --- a/tests/tests_pytorch/callbacks/test_prediction_writer.py +++ b/tests/tests_pytorch/callbacks/test_prediction_writer.py @@ -83,7 +83,9 @@ def test_prediction_writer_batch_indices(num_workers, tmp_path): DummyPredictionWriter.write_on_batch_end = Mock() DummyPredictionWriter.write_on_epoch_end = Mock() - dataloader = DataLoader(RandomDataset(32, 64), batch_size=4, num_workers=num_workers, persistent_workers=num_workers > 0) + dataloader = DataLoader( + RandomDataset(32, 64), batch_size=4, num_workers=num_workers, persistent_workers=num_workers > 0 + ) model = BoringModel() writer = DummyPredictionWriter("batch_and_epoch") trainer = Trainer(default_root_dir=tmp_path, logger=False, limit_predict_batches=4, callbacks=writer)