diff --git a/tests/base/datamodules.py b/tests/base/datamodules.py index d1f7fabf8d..234d27b721 100644 --- a/tests/base/datamodules.py +++ b/tests/base/datamodules.py @@ -25,7 +25,7 @@ class TrialMNISTDataModule(LightningDataModule): self.dims = self.mnist_train[0][0].shape if stage == 'test' or stage is None: - self.mnist_test = TrialMNIST(root=self.data_dir, train=False, num_samples=32, download=True) + self.mnist_test = TrialMNIST(root=self.data_dir, train=False, num_samples=64, download=True) self.dims = getattr(self, 'dims', self.mnist_test[0][0].shape) self.non_picklable = lambda x: x**2 diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index ca78ba35bf..1bd1cf15de 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -15,7 +15,9 @@ from tests.base import EvalModelTemplate @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine") def test_multi_gpu_wandb_ddp_spawn(tmpdir): - """Make sure DP/DDP + AMP work.""" + """ + Test ddp + wb + """ from pytorch_lightning.loggers import WandbLogger tutils.set_random_master_port()