enables plugins (#4041)
* plugin hardware * plugin hardware * plugin hardware
This commit is contained in:
parent
05e0b4e5a1
commit
0c42aa03fd
|
@ -43,10 +43,8 @@ class AcceleratorConnector:
|
|||
benchmark,
|
||||
replace_sampler_ddp,
|
||||
deterministic,
|
||||
cluster_environment
|
||||
):
|
||||
self.trainer.deterministic = deterministic
|
||||
self.cluster_environment = cluster_environment
|
||||
|
||||
torch.backends.cudnn.deterministic = self.trainer.deterministic
|
||||
if self.trainer.deterministic:
|
||||
|
@ -131,9 +129,8 @@ class AcceleratorConnector:
|
|||
def _select_environment(self):
|
||||
env = None
|
||||
|
||||
# in priority: user environment, torchelastic (which is a generic environment), slurm
|
||||
if self.cluster_environment is not None:
|
||||
env = self.cluster_environment
|
||||
if self.trainer.plugin_connector.cloud_environment:
|
||||
return self.trainer.plugin_connector.cloud_environment
|
||||
elif self._is_using_torchelastic():
|
||||
env = TorchElasticEnvironment()
|
||||
elif self.trainer.is_slurm_managing_tasks:
|
||||
|
|
|
@ -0,0 +1,31 @@
|
|||
from pytorch_lightning.cluster_environments import ClusterEnvironment
|
||||
from pytorch_lightning.utilities.exceptions import MisconfigurationException
|
||||
|
||||
|
||||
class PluginConnector:
|
||||
|
||||
def __init__(self, trainer):
|
||||
self.trainer = trainer
|
||||
self.plugins = []
|
||||
self.cloud_environment = None
|
||||
|
||||
def on_trainer_init(self, plugins):
|
||||
self.plugins = plugins
|
||||
if self.plugins is None:
|
||||
self.plugins = []
|
||||
|
||||
self.__attach_cluster()
|
||||
|
||||
def __attach_cluster(self, limit=1):
|
||||
num_clusters = 0
|
||||
for plugin in self.plugins:
|
||||
if isinstance(plugin, ClusterEnvironment):
|
||||
|
||||
# count the clusters
|
||||
num_clusters += 1
|
||||
if num_clusters > limit:
|
||||
m = f'you can only use one cluster environment in plugins. You passed in: {num_clusters}'
|
||||
raise MisconfigurationException(m)
|
||||
|
||||
# set the cluster
|
||||
self.cloud_environment = plugin
|
|
@ -461,41 +461,6 @@ To disable automatic checkpointing, set this to `False`.
|
|||
|
||||
See also :ref:`Saving and Loading Weights <weights_loading>`.
|
||||
|
||||
cluster_environment
|
||||
^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<video width="50%" max-width="400px" controls
|
||||
poster="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/cluster_environment.jpg"
|
||||
src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/cluster_environment.mp4"></video>
|
||||
|
||||
|
|
||||
|
||||
Environment to connect arbitrary cluster backends. Lightning automatically handles:
|
||||
|
||||
- SLURM
|
||||
- TorchElastic
|
||||
|
||||
For any other non-supported cluster environment, define your own class and pass it in.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from pytorch_lightning.cluster_environments import cluster_environment
|
||||
|
||||
class MyCluster(ClusterEnvironment):
|
||||
|
||||
def master_address(self):
|
||||
return your_master_address
|
||||
|
||||
def master_port(self):
|
||||
return your_master_port
|
||||
|
||||
def world_size(self):
|
||||
return the_world_size
|
||||
|
||||
trainer = Trainer(cluster_environment=cluster_environment())
|
||||
|
||||
default_root_dir
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
|
@ -963,6 +928,43 @@ Example::
|
|||
--env=XLA_USE_BF16=1
|
||||
-- python your_trainer_file.py
|
||||
|
||||
plugins
|
||||
^^^^^^^
|
||||
|
||||
.. raw:: html
|
||||
|
||||
<video width="50%" max-width="400px" controls
|
||||
poster="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/cluster_environment.jpg"
|
||||
src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/cluster_environment.mp4"></video>
|
||||
|
||||
|
|
||||
|
||||
Plugins allow you to connect arbitrary backends, precision libraries, SLURM, etc... For example:
|
||||
|
||||
- DDP
|
||||
- SLURM
|
||||
- TorchElastic
|
||||
- Apex
|
||||
|
||||
To define your own behavior, subclass the relevant class and pass it in. Here's an example linking up your own cluster.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from pytorch_lightning.cluster_environments import cluster_environment
|
||||
|
||||
class MyCluster(ClusterEnvironment):
|
||||
|
||||
def master_address(self):
|
||||
return your_master_address
|
||||
|
||||
def master_port(self):
|
||||
return your_master_port
|
||||
|
||||
def world_size(self):
|
||||
return the_world_size
|
||||
|
||||
trainer = Trainer(cluster_environment=cluster_environment())
|
||||
|
||||
prepare_data_per_node
|
||||
^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
|
|
|
@ -57,7 +57,7 @@ from pytorch_lightning.trainer.connectors.data_connector import DataConnector
|
|||
from pytorch_lightning.utilities.cloud_io import load as pl_load
|
||||
from pytorch_lightning.utilities.model_utils import is_overridden
|
||||
from pytorch_lightning.trainer.properties import TrainerProperties
|
||||
from pytorch_lightning.cluster_environments.cluster_environment import ClusterEnvironment
|
||||
from pytorch_lightning.plugins.plugin_connector import PluginConnector
|
||||
|
||||
# warnings to ignore in trainer
|
||||
warnings.filterwarnings(
|
||||
|
@ -114,7 +114,7 @@ class Trainer(
|
|||
distributed_backend: Optional[str] = None,
|
||||
sync_batchnorm: bool = False,
|
||||
precision: int = 32,
|
||||
weights_summary: Optional[str] = ModelSummary.MODE_DEFAULT,
|
||||
weights_summary: Optional[str] = 'top',
|
||||
weights_save_path: Optional[str] = None,
|
||||
num_sanity_val_steps: int = 2,
|
||||
truncated_bptt_steps: Optional[int] = None,
|
||||
|
@ -128,7 +128,7 @@ class Trainer(
|
|||
terminate_on_nan: bool = False,
|
||||
auto_scale_batch_size: Union[str, bool] = False,
|
||||
prepare_data_per_node: bool = True,
|
||||
cluster_environment: ClusterEnvironment = None,
|
||||
plugins: list = None,
|
||||
amp_backend: str = 'native',
|
||||
amp_level: str = 'O2',
|
||||
):
|
||||
|
@ -167,8 +167,6 @@ class Trainer(
|
|||
|
||||
check_val_every_n_epoch: Check val every n train epochs.
|
||||
|
||||
cluster_environment: Environment config to link up arbitrary clusters
|
||||
|
||||
default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed.
|
||||
Default: ``os.getcwd()``.
|
||||
Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/'
|
||||
|
@ -209,6 +207,8 @@ class Trainer(
|
|||
|
||||
overfit_batches: Overfit a percent of training data (float) or a set number of batches (int). Default: 0.0
|
||||
|
||||
plugins: Plugins allow modification of core behavior like ddp and amp.
|
||||
|
||||
precision: Full precision (32), half precision (16). Can be used on CPU, GPU or TPUs.
|
||||
|
||||
max_epochs: Stop training once this number of epochs is reached.
|
||||
|
@ -278,6 +278,7 @@ class Trainer(
|
|||
self.accelerator_backend = None
|
||||
self.evaluation_loop = EvaluationLoop(self)
|
||||
self.train_loop = TrainLoop(self)
|
||||
self.plugin_connector = PluginConnector(self)
|
||||
|
||||
# training state
|
||||
self.weights_summary = weights_summary
|
||||
|
@ -326,7 +327,6 @@ class Trainer(
|
|||
benchmark,
|
||||
replace_sampler_ddp,
|
||||
deterministic,
|
||||
cluster_environment
|
||||
)
|
||||
|
||||
# init train loop related flags
|
||||
|
@ -355,6 +355,9 @@ class Trainer(
|
|||
# set precision
|
||||
self.precision_connector.on_trainer_init(precision, amp_level, amp_backend)
|
||||
|
||||
# last thing are the plugins which override whatever the trainer used by default
|
||||
self.plugin_connector.on_trainer_init(plugins)
|
||||
|
||||
# Callback system
|
||||
self.on_init_end()
|
||||
|
||||
|
|
Loading…
Reference in New Issue