diff --git a/docs/Trainer/Distributed training.md b/docs/Trainer/Distributed training.md index f40f5cbf61..807bfc2c2b 100644 --- a/docs/Trainer/Distributed training.md +++ b/docs/Trainer/Distributed training.md @@ -58,14 +58,21 @@ Below are the possible configurations we support. You also have the option of specifying which GPUs to use by passing a list: ```python -# DEFAULT (int) +# DEFAULT (int) specifies how many GPUs to use. Trainer(gpus=k) +# Above is equivalent to +Trainer(gpus=list(range(k))) + # You specify which GPUs (don't use if running on cluster) Trainer(gpus=[0, 1]) # can also be a string Trainer(gpus='0, 1') + +# can also be -1 or '-1', this uses all available GPUs +# this is equivalent to list(range(torch.cuda.available_devices())) +Trainer(gpus=-1) ``` --- diff --git a/pytorch_lightning/trainer/dp_mixin.py b/pytorch_lightning/trainer/dp_mixin.py index 96aa5eb31f..0bde8a7b13 100644 --- a/pytorch_lightning/trainer/dp_mixin.py +++ b/pytorch_lightning/trainer/dp_mixin.py @@ -104,3 +104,116 @@ class TrainerDPMixin(object): model = LightningDataParallel(model, device_ids=device_ids) self.run_pretrain_routine(model) + + +def normalize_parse_gpu_string_input(s): + if type(s) is str: + if s == '-1': + return -1 + else: + return [int(x.strip()) for x in s.split(',')] + else: + return s + + +def get_all_available_gpus(): + """ + :return: a list of all available gpus + """ + return list(range(torch.cuda.device_count())) + + +def check_gpus_data_type(gpus): + """ + :param gpus: gpus parameter as passed to the Trainer + Function checks that it is one of: None, Int, String or List + Throws otherwise + :return: return unmodified gpus variable + """ + + if (gpus is not None and + type(gpus) is not int and + type(gpus) is not str and + type(gpus) is not list): # noqa E129 + raise MisconfigurationException("GPUs must be int, string or list of ints or None.") + + +def normalize_parse_gpu_input_to_list(gpus): + assert gpus is not None + if isinstance(gpus, list): + return gpus + else: # must be an int + if not gpus: # gpus==0 + return None + elif gpus == -1: + return get_all_available_gpus() + else: + return list(range(gpus)) + + +def sanitize_gpu_ids(gpus): + """ + :param gpus: list of ints corresponding to GPU indices + Checks that each of the GPUs in the list is actually available. + Throws if any of the GPUs is not available. + :return: unmodified gpus variable + """ + all_available_gpus = get_all_available_gpus() + for gpu in gpus: + if gpu not in all_available_gpus: + message = f""" + Non-available gpu index {gpu} specified: + Available gpu indices are: {all_available_gpus} + """ + raise MisconfigurationException(message) + return gpus + + +def parse_gpu_ids(gpus): + """ + :param gpus: Int, string or list + An int -1 or string '-1' indicate that all available GPUs should be used. + A list of ints or a string containing list of comma separated integers + indicates specific GPUs to use + An int 0 means that no GPUs should be used + Any int N > 0 indicates that GPUs [0..N) should be used. + :return: List of gpus to be used + + If no GPUs are available but the value of gpus variable indicates request for GPUs + then a misconfiguration exception is raised. + """ + + # Check that gpus param is None, Int, String or List + check_gpus_data_type(gpus) + + # Handle the case when no gpus are requested + if gpus is None or type(gpus) is int and gpus == 0: + return None + + # We know user requested GPUs therefore if some of the + # requested GPUs are not available an exception is thrown. + + gpus = normalize_parse_gpu_string_input(gpus) + gpus = normalize_parse_gpu_input_to_list(gpus) + gpus = sanitize_gpu_ids(gpus) + + if not gpus: + raise MisconfigurationException("GPUs requested but non are available.") + return gpus + + +def determine_root_gpu_device(gpus): + """ + :param gpus: non empty list of ints representing which gpus to use + :return: designated root GPU device + """ + if gpus is None: + return None + + assert isinstance(gpus, list), "gpus should be a list" + assert len(gpus), "gpus should be a non empty list" + + # set root gpu + root_gpu = gpus[0] + + return root_gpu diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index a9ae29b7ca..067ffe9db3 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -16,6 +16,10 @@ from pytorch_lightning.trainer.callback_config_mixin import TrainerCallbackConfi from pytorch_lightning.trainer.data_loading_mixin import TrainerDataLoadingMixin from pytorch_lightning.trainer.ddp_mixin import TrainerDDPMixin from pytorch_lightning.trainer.dp_mixin import TrainerDPMixin +from pytorch_lightning.trainer.dp_mixin import ( + parse_gpu_ids, + determine_root_gpu_device +) from pytorch_lightning.trainer.evaluation_loop_mixin import TrainerEvaluationLoopMixin from pytorch_lightning.trainer.logging_mixin import TrainerLoggingMixin from pytorch_lightning.trainer.model_hooks_mixin import TrainerModelHooksMixin @@ -87,7 +91,8 @@ class Trainer(TrainerIOMixin, :param gradient_clip: int. 0 means don't clip. Deprecated. :param process_position: shown in the tqdm bar :param nb_gpu_nodes: number of GPU nodes - :param gpus: int. (ie: 2 gpus) OR list to specify which GPUs [0, 1] or '0,1' + :param gpus: int. (ie: 2 gpus) OR list to specify which GPUs [0, 1] OR '0,1' + OR '-1' / -1 to use all available gpus :param log_gpu_memory: str. None, 'min_max', 'all' :param show_progress_bar: Bool. If true shows tqdm bar :param overfit_pct: float. uses this much of all datasets @@ -183,8 +188,8 @@ class Trainer(TrainerIOMixin, self.configure_accumulated_gradients(accumulate_grad_batches) # allow int, string and gpu list - self.data_parallel_device_ids = self.__parse_gpu_ids(gpus) - self.root_gpu = self.__set_root_gpu(self.data_parallel_device_ids) + self.data_parallel_device_ids = parse_gpu_ids(gpus) + self.root_gpu = determine_root_gpu_device(self.data_parallel_device_ids) # distributed backend choice self.use_ddp = False @@ -272,14 +277,8 @@ class Trainer(TrainerIOMixin, gpus = self.data_parallel_device_ids if gpus is None: return 0 - - if type(gpus) is list: + else: return len(gpus) - if type(gpus) is int: - return gpus - - m = 'gpus must be int, none or list of ints' - raise MisconfigurationException(m) @property def data_parallel(self): diff --git a/tests/test_models.py b/tests/test_models.py index 612668131e..577683d935 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -25,6 +25,10 @@ from pytorch_lightning.testing import ( LightningTestMultipleDataloadersMixin, ) from pytorch_lightning.trainer import trainer_io +from pytorch_lightning.trainer.dp_mixin import ( + parse_gpu_ids, + determine_root_gpu_device, +) from pytorch_lightning.trainer.logging_mixin import TrainerLoggingMixin from pytorch_lightning.utilities.debugging import MisconfigurationException @@ -35,11 +39,14 @@ ROOT_SEED = 1234 torch.manual_seed(ROOT_SEED) np.random.seed(ROOT_SEED) RANDOM_SEEDS = list(np.random.randint(0, 10000, 1000)) +PRETEND_N_OF_GPUS = 16 # ------------------------------------------------------------------------ # TESTS # ------------------------------------------------------------------------ + + def test_multi_gpu_model_ddp2(): """ Make sure DDP2 works @@ -1470,6 +1477,151 @@ def test_multiple_test_dataloader(): trainer.test() +@pytest.fixture +def mocked_device_count(monkeypatch): + def device_count(): + return PRETEND_N_OF_GPUS + + monkeypatch.setattr(torch.cuda, 'device_count', device_count) + + +@pytest.fixture +def mocked_device_count_0(monkeypatch): + def device_count(): + return 0 + + monkeypatch.setattr(torch.cuda, 'device_count', device_count) + + +test_num_gpus_data = [ + pytest.param(None, 0, None, id="None - expect 0 gpu to use."), + pytest.param(0, 0, None, id="Oth gpu, expect 1 gpu to use."), + pytest.param(1, 1, None, id="1st gpu, expect 1 gpu to use."), + pytest.param(-1, PRETEND_N_OF_GPUS, "ddp", id="-1 - use all gpus"), + pytest.param('-1', PRETEND_N_OF_GPUS, "ddp", id="'-1' - use all gpus"), + pytest.param(3, 3, "ddp", id="3rd gpu - 1 gpu to use (backend:ddp)") +] + + +@pytest.mark.gpus_param_tests +@pytest.mark.parametrize(["gpus", "expected_num_gpus", "distributed_backend"], test_num_gpus_data) +def test_trainer_gpu_parse(mocked_device_count, gpus, expected_num_gpus, distributed_backend): + assert Trainer(gpus=gpus, distributed_backend=distributed_backend).num_gpus == expected_num_gpus + + +test_num_gpus_data_0 = [ + pytest.param(None, 0, None, id="None - expect 0 gpu to use."), + pytest.param(None, 0, "ddp", id="None - expect 0 gpu to use."), +] + + +@pytest.mark.gpus_param_tests +@pytest.mark.parametrize(["gpus", "expected_num_gpus", "distributed_backend"], test_num_gpus_data_0) +def test_trainer_num_gpu_0(mocked_device_count_0, gpus, expected_num_gpus, distributed_backend): + assert Trainer(gpus=gpus, distributed_backend=distributed_backend).num_gpus == expected_num_gpus + + +test_root_gpu_data = [ + pytest.param(None, None, "ddp", id="None is None"), + pytest.param(0, None, "ddp", id="O gpus, expect gpu root device to be None."), + pytest.param(1, 0, "ddp", id="1 gpu, expect gpu root device to be 0."), + pytest.param(-1, 0, "ddp", id="-1 - use all gpus, expect gpu root device to be 0."), + pytest.param('-1', 0, "ddp", id="'-1' - use all gpus, expect gpu root device to be 0."), + pytest.param(3, 0, "ddp", id="3 gpus, expect gpu root device to be 0.(backend:ddp)")] + + +@pytest.mark.gpus_param_tests +@pytest.mark.parametrize(['gpus', 'expected_root_gpu', "distributed_backend"], test_root_gpu_data) +def test_root_gpu_property(mocked_device_count, gpus, expected_root_gpu, distributed_backend): + assert Trainer(gpus=gpus, distributed_backend=distributed_backend).root_gpu == expected_root_gpu + + +test_root_gpu_data_for_0_devices_passing = [ + pytest.param(None, None, None, id="None is None"), + pytest.param(None, None, "ddp", id="None is None"), + pytest.param(0, None, "ddp", id="None is None"), +] + + +@pytest.mark.gpus_param_tests +@pytest.mark.parametrize([ + 'gpus', 'expected_root_gpu', "distributed_backend"], test_root_gpu_data_for_0_devices_passing) +def test_root_gpu_property_0_passing( + mocked_device_count_0, gpus, expected_root_gpu, distributed_backend): + assert Trainer(gpus=gpus, distributed_backend=distributed_backend).root_gpu == expected_root_gpu + + +# Asking for a gpu when non are available will result in a MisconfigurationException +test_root_gpu_data_for_0_devices_raising = [ + pytest.param(1, None, "ddp"), + pytest.param(3, None, "ddp"), + pytest.param(3, None, "ddp"), + pytest.param([1, 2], None, "ddp"), + pytest.param([0, 1], None, "ddp"), + pytest.param(-1, None, "ddp"), + pytest.param('-1', None, "ddp") +] + + +@pytest.mark.gpus_param_tests +@pytest.mark.parametrize([ + 'gpus', 'expected_root_gpu', "distributed_backend"], test_root_gpu_data_for_0_devices_raising) +def test_root_gpu_property_0_raising( + mocked_device_count_0, gpus, expected_root_gpu, distributed_backend): + with pytest.raises(MisconfigurationException): + Trainer(gpus=gpus, distributed_backend=distributed_backend).root_gpu + + +test_determine_root_gpu_device_data = [ + pytest.param(None, None, id="No gpus, expect gpu root device to be None"), + pytest.param([0], 0, id="Oth gpu, expect gpu root device to be 0."), + pytest.param([1], 1, id="1st gpu, expect gpu root device to be 1."), + pytest.param([3], 3, id="3rd gpu, expect gpu root device to be 3."), + pytest.param([1, 2], 1, id="[1, 2] gpus, expect gpu root device to be 1."), +] + + +@pytest.mark.gpus_param_tests +@pytest.mark.parametrize(['gpus', 'expected_root_gpu'], test_determine_root_gpu_device_data) +def test_determine_root_gpu_device(gpus, expected_root_gpu): + assert determine_root_gpu_device(gpus) == expected_root_gpu + + +test_parse_gpu_ids_data = [ + pytest.param(None, None), + pytest.param(0, None), + pytest.param(1, [0]), + pytest.param(-1, list(range(PRETEND_N_OF_GPUS)), id="-1 - use all gpus"), + pytest.param('-1', list(range(PRETEND_N_OF_GPUS)), id="'-1' - use all gpus"), + pytest.param(3, [0, 1, 2])] + + +@pytest.mark.gpus_param_tests +@pytest.mark.parametrize(['gpus', 'expected_gpu_ids'], test_parse_gpu_ids_data) +def test_parse_gpu_ids(mocked_device_count, gpus, expected_gpu_ids): + assert parse_gpu_ids(gpus) == expected_gpu_ids + + +@pytest.mark.gpus_param_tests +@pytest.mark.parametrize("gpus", [[1, 2, 19], -1, '-1']) +def test_parse_gpu_fail_on_non_existant_id(mocked_device_count_0, gpus): + with pytest.raises(MisconfigurationException): + parse_gpu_ids(gpus) + + +@pytest.mark.gpus_param_tests +def test_parse_gpu_fail_on_non_existant_id_2(mocked_device_count): + with pytest.raises(MisconfigurationException): + parse_gpu_ids([1, 2, 19]) + + +@pytest.mark.gpus_param_tests +@pytest.mark.parametrize("gpus", [-1, '-1']) +def test_parse_gpu_returns_None_when_no_devices_are_available(mocked_device_count_0, gpus): + with pytest.raises(MisconfigurationException): + parse_gpu_ids(gpus) + + # ------------------------------------------------------------------------ # UTILS # ------------------------------------------------------------------------