enables plugins (#4041)

* plugin hardware

* plugin hardware

* plugin hardware
This commit is contained in:
William Falcon 2020-10-09 22:03:46 -04:00 committed by GitHub
parent 05e0b4e5a1
commit 0c42aa03fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 79 additions and 46 deletions

View File

@ -43,10 +43,8 @@ class AcceleratorConnector:
benchmark,
replace_sampler_ddp,
deterministic,
cluster_environment
):
self.trainer.deterministic = deterministic
self.cluster_environment = cluster_environment
torch.backends.cudnn.deterministic = self.trainer.deterministic
if self.trainer.deterministic:
@ -131,9 +129,8 @@ class AcceleratorConnector:
def _select_environment(self):
env = None
# in priority: user environment, torchelastic (which is a generic environment), slurm
if self.cluster_environment is not None:
env = self.cluster_environment
if self.trainer.plugin_connector.cloud_environment:
return self.trainer.plugin_connector.cloud_environment
elif self._is_using_torchelastic():
env = TorchElasticEnvironment()
elif self.trainer.is_slurm_managing_tasks:

View File

@ -0,0 +1,31 @@
from pytorch_lightning.cluster_environments import ClusterEnvironment
from pytorch_lightning.utilities.exceptions import MisconfigurationException
class PluginConnector:
def __init__(self, trainer):
self.trainer = trainer
self.plugins = []
self.cloud_environment = None
def on_trainer_init(self, plugins):
self.plugins = plugins
if self.plugins is None:
self.plugins = []
self.__attach_cluster()
def __attach_cluster(self, limit=1):
num_clusters = 0
for plugin in self.plugins:
if isinstance(plugin, ClusterEnvironment):
# count the clusters
num_clusters += 1
if num_clusters > limit:
m = f'you can only use one cluster environment in plugins. You passed in: {num_clusters}'
raise MisconfigurationException(m)
# set the cluster
self.cloud_environment = plugin

View File

@ -461,41 +461,6 @@ To disable automatic checkpointing, set this to `False`.
See also :ref:`Saving and Loading Weights <weights_loading>`.
cluster_environment
^^^^^^^^^^^^^^^^^^^
.. raw:: html
<video width="50%" max-width="400px" controls
poster="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/cluster_environment.jpg"
src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/cluster_environment.mp4"></video>
|
Environment to connect arbitrary cluster backends. Lightning automatically handles:
- SLURM
- TorchElastic
For any other non-supported cluster environment, define your own class and pass it in.
.. code-block:: python
from pytorch_lightning.cluster_environments import cluster_environment
class MyCluster(ClusterEnvironment):
def master_address(self):
return your_master_address
def master_port(self):
return your_master_port
def world_size(self):
return the_world_size
trainer = Trainer(cluster_environment=cluster_environment())
default_root_dir
^^^^^^^^^^^^^^^^
@ -963,6 +928,43 @@ Example::
--env=XLA_USE_BF16=1
-- python your_trainer_file.py
plugins
^^^^^^^
.. raw:: html
<video width="50%" max-width="400px" controls
poster="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/cluster_environment.jpg"
src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/cluster_environment.mp4"></video>
|
Plugins allow you to connect arbitrary backends, precision libraries, SLURM, etc... For example:
- DDP
- SLURM
- TorchElastic
- Apex
To define your own behavior, subclass the relevant class and pass it in. Here's an example linking up your own cluster.
.. code-block:: python
from pytorch_lightning.cluster_environments import cluster_environment
class MyCluster(ClusterEnvironment):
def master_address(self):
return your_master_address
def master_port(self):
return your_master_port
def world_size(self):
return the_world_size
trainer = Trainer(cluster_environment=cluster_environment())
prepare_data_per_node
^^^^^^^^^^^^^^^^^^^^^

View File

@ -57,7 +57,7 @@ from pytorch_lightning.trainer.connectors.data_connector import DataConnector
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.trainer.properties import TrainerProperties
from pytorch_lightning.cluster_environments.cluster_environment import ClusterEnvironment
from pytorch_lightning.plugins.plugin_connector import PluginConnector
# warnings to ignore in trainer
warnings.filterwarnings(
@ -114,7 +114,7 @@ class Trainer(
distributed_backend: Optional[str] = None,
sync_batchnorm: bool = False,
precision: int = 32,
weights_summary: Optional[str] = ModelSummary.MODE_DEFAULT,
weights_summary: Optional[str] = 'top',
weights_save_path: Optional[str] = None,
num_sanity_val_steps: int = 2,
truncated_bptt_steps: Optional[int] = None,
@ -128,7 +128,7 @@ class Trainer(
terminate_on_nan: bool = False,
auto_scale_batch_size: Union[str, bool] = False,
prepare_data_per_node: bool = True,
cluster_environment: ClusterEnvironment = None,
plugins: list = None,
amp_backend: str = 'native',
amp_level: str = 'O2',
):
@ -167,8 +167,6 @@ class Trainer(
check_val_every_n_epoch: Check val every n train epochs.
cluster_environment: Environment config to link up arbitrary clusters
default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed.
Default: ``os.getcwd()``.
Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'
@ -209,6 +207,8 @@ class Trainer(
overfit_batches: Overfit a percent of training data (float) or a set number of batches (int). Default: 0.0
plugins: Plugins allow modification of core behavior like ddp and amp.
precision: Full precision (32), half precision (16). Can be used on CPU, GPU or TPUs.
max_epochs: Stop training once this number of epochs is reached.
@ -278,6 +278,7 @@ class Trainer(
self.accelerator_backend = None
self.evaluation_loop = EvaluationLoop(self)
self.train_loop = TrainLoop(self)
self.plugin_connector = PluginConnector(self)
# training state
self.weights_summary = weights_summary
@ -326,7 +327,6 @@ class Trainer(
benchmark,
replace_sampler_ddp,
deterministic,
cluster_environment
)
# init train loop related flags
@ -355,6 +355,9 @@ class Trainer(
# set precision
self.precision_connector.on_trainer_init(precision, amp_level, amp_backend)
# last thing are the plugins which override whatever the trainer used by default
self.plugin_connector.on_trainer_init(plugins)
# Callback system
self.on_init_end()