diff --git a/CHANGELOG.md b/CHANGELOG.md index 602f020e00..9281443203 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -62,6 +62,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `DDPHPCAccelerator` hangs in DDP construction by calling `init_device` ([#5157](https://github.com/PyTorchLightning/pytorch-lightning/pull/5157)) +- Fixed `num_workers` for Windows example ([#5375](https://github.com/PyTorchLightning/pytorch-lightning/pull/5375)) + + ## [1.1.3rc] - 2020-12-29 ### Added diff --git a/pl_examples/basic_examples/mnist_datamodule.py b/pl_examples/basic_examples/mnist_datamodule.py index 17bf84fb29..8a607c2d1a 100644 --- a/pl_examples/basic_examples/mnist_datamodule.py +++ b/pl_examples/basic_examples/mnist_datamodule.py @@ -13,6 +13,7 @@ # limitations under the License. import platform from typing import Optional +from warnings import warn from torch.utils.data import DataLoader, random_split @@ -55,8 +56,10 @@ class MNISTDataModule(LightningDataModule): normalize: If true applies image normalize """ super().__init__(*args, **kwargs) - if platform.system() == "Windows": - # see: https://stackoverflow.com/a/59680818/4521646 + if num_workers and platform.system() == "Windows": + # see: https://stackoverflow.com/a/59680818 + warn(f"You have requested num_workers={num_workers} on Windows," + " but currently recommended is 0, so we set it for you") num_workers = 0 self.dims = (1, 28, 28)