diff --git a/pytorch_lightning/accelerators/accelerator_connector.py b/pytorch_lightning/accelerators/accelerator_connector.py index a41ced049c..176b67e5b6 100644 --- a/pytorch_lightning/accelerators/accelerator_connector.py +++ b/pytorch_lightning/accelerators/accelerator_connector.py @@ -43,10 +43,8 @@ class AcceleratorConnector: benchmark, replace_sampler_ddp, deterministic, - cluster_environment ): self.trainer.deterministic = deterministic - self.cluster_environment = cluster_environment torch.backends.cudnn.deterministic = self.trainer.deterministic if self.trainer.deterministic: @@ -131,9 +129,8 @@ class AcceleratorConnector: def _select_environment(self): env = None - # in priority: user environment, torchelastic (which is a generic environment), slurm - if self.cluster_environment is not None: - env = self.cluster_environment + if self.trainer.plugin_connector.cloud_environment: + return self.trainer.plugin_connector.cloud_environment elif self._is_using_torchelastic(): env = TorchElasticEnvironment() elif self.trainer.is_slurm_managing_tasks: diff --git a/pytorch_lightning/plugins/plugin_connector.py b/pytorch_lightning/plugins/plugin_connector.py new file mode 100644 index 0000000000..a016b4ec7a --- /dev/null +++ b/pytorch_lightning/plugins/plugin_connector.py @@ -0,0 +1,31 @@ +from pytorch_lightning.cluster_environments import ClusterEnvironment +from pytorch_lightning.utilities.exceptions import MisconfigurationException + + +class PluginConnector: + + def __init__(self, trainer): + self.trainer = trainer + self.plugins = [] + self.cloud_environment = None + + def on_trainer_init(self, plugins): + self.plugins = plugins + if self.plugins is None: + self.plugins = [] + + self.__attach_cluster() + + def __attach_cluster(self, limit=1): + num_clusters = 0 + for plugin in self.plugins: + if isinstance(plugin, ClusterEnvironment): + + # count the clusters + num_clusters += 1 + if num_clusters > limit: + m = f'you can only use one cluster environment in plugins. You passed in: {num_clusters}' + raise MisconfigurationException(m) + + # set the cluster + self.cloud_environment = plugin diff --git a/pytorch_lightning/trainer/__init__.py b/pytorch_lightning/trainer/__init__.py index aa58515247..ec081a4162 100644 --- a/pytorch_lightning/trainer/__init__.py +++ b/pytorch_lightning/trainer/__init__.py @@ -461,41 +461,6 @@ To disable automatic checkpointing, set this to `False`. See also :ref:`Saving and Loading Weights `. -cluster_environment -^^^^^^^^^^^^^^^^^^^ - -.. raw:: html - - - -| - -Environment to connect arbitrary cluster backends. Lightning automatically handles: - -- SLURM -- TorchElastic - -For any other non-supported cluster environment, define your own class and pass it in. - -.. code-block:: python - - from pytorch_lightning.cluster_environments import cluster_environment - - class MyCluster(ClusterEnvironment): - - def master_address(self): - return your_master_address - - def master_port(self): - return your_master_port - - def world_size(self): - return the_world_size - - trainer = Trainer(cluster_environment=cluster_environment()) - default_root_dir ^^^^^^^^^^^^^^^^ @@ -963,6 +928,43 @@ Example:: --env=XLA_USE_BF16=1 -- python your_trainer_file.py +plugins +^^^^^^^ + +.. raw:: html + + + +| + +Plugins allow you to connect arbitrary backends, precision libraries, SLURM, etc... For example: + +- DDP +- SLURM +- TorchElastic +- Apex + +To define your own behavior, subclass the relevant class and pass it in. Here's an example linking up your own cluster. + +.. code-block:: python + + from pytorch_lightning.cluster_environments import cluster_environment + + class MyCluster(ClusterEnvironment): + + def master_address(self): + return your_master_address + + def master_port(self): + return your_master_port + + def world_size(self): + return the_world_size + + trainer = Trainer(cluster_environment=cluster_environment()) + prepare_data_per_node ^^^^^^^^^^^^^^^^^^^^^ diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index c2926ddd13..47e3286fe7 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -57,7 +57,7 @@ from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.model_utils import is_overridden from pytorch_lightning.trainer.properties import TrainerProperties -from pytorch_lightning.cluster_environments.cluster_environment import ClusterEnvironment +from pytorch_lightning.plugins.plugin_connector import PluginConnector # warnings to ignore in trainer warnings.filterwarnings( @@ -114,7 +114,7 @@ class Trainer( distributed_backend: Optional[str] = None, sync_batchnorm: bool = False, precision: int = 32, - weights_summary: Optional[str] = ModelSummary.MODE_DEFAULT, + weights_summary: Optional[str] = 'top', weights_save_path: Optional[str] = None, num_sanity_val_steps: int = 2, truncated_bptt_steps: Optional[int] = None, @@ -128,7 +128,7 @@ class Trainer( terminate_on_nan: bool = False, auto_scale_batch_size: Union[str, bool] = False, prepare_data_per_node: bool = True, - cluster_environment: ClusterEnvironment = None, + plugins: list = None, amp_backend: str = 'native', amp_level: str = 'O2', ): @@ -167,8 +167,6 @@ class Trainer( check_val_every_n_epoch: Check val every n train epochs. - cluster_environment: Environment config to link up arbitrary clusters - default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed. Default: ``os.getcwd()``. Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/' @@ -209,6 +207,8 @@ class Trainer( overfit_batches: Overfit a percent of training data (float) or a set number of batches (int). Default: 0.0 + plugins: Plugins allow modification of core behavior like ddp and amp. + precision: Full precision (32), half precision (16). Can be used on CPU, GPU or TPUs. max_epochs: Stop training once this number of epochs is reached. @@ -278,6 +278,7 @@ class Trainer( self.accelerator_backend = None self.evaluation_loop = EvaluationLoop(self) self.train_loop = TrainLoop(self) + self.plugin_connector = PluginConnector(self) # training state self.weights_summary = weights_summary @@ -326,7 +327,6 @@ class Trainer( benchmark, replace_sampler_ddp, deterministic, - cluster_environment ) # init train loop related flags @@ -355,6 +355,9 @@ class Trainer( # set precision self.precision_connector.on_trainer_init(precision, amp_level, amp_backend) + # last thing are the plugins which override whatever the trainer used by default + self.plugin_connector.on_trainer_init(plugins) + # Callback system self.on_init_end()