diff --git a/docs/source/extensions/datamodules.rst b/docs/source/extensions/datamodules.rst index bc79d7dc3d..443cd5be42 100644 --- a/docs/source/extensions/datamodules.rst +++ b/docs/source/extensions/datamodules.rst @@ -61,8 +61,8 @@ Here's a simple PyTorch example: .. code-block:: python # regular PyTorch - test_data = MNIST(PATH, train=False, download=True) - train_data = MNIST(PATH, train=True, download=True) + test_data = MNIST(my_path, train=False, download=True) + train_data = MNIST(my_path, train=True, download=True) train_data, val_data = random_split(train_data, [55000, 5000]) train_loader = DataLoader(train_data, batch_size=32) @@ -75,8 +75,9 @@ The equivalent DataModule just organizes the same exact code, but makes it reusa class MNISTDataModule(pl.LightningDataModule): - def __init__(self, data_dir: str = PATH, batch_size): + def __init__(self, data_dir: str = "path/to/dir", batch_size: int = 32): super().__init__() + self.data_dir = data_dir self.batch_size = batch_size def setup(self, stage=None): @@ -99,7 +100,7 @@ colleagues or use in different projects. .. code-block:: python - mnist = MNISTDataModule(PATH) + mnist = MNISTDataModule(my_path) model = LitClassifier() trainer = Trainer()