Make parallel devices optional across all plugins (#6051)

* Make parallel devices optional across all plugins so that they can be instantiated

* Add any to types to capture vars passed in
This commit is contained in:
Sean Naren 2021-02-18 12:09:53 +00:00 committed by GitHub
parent bcc0004955
commit ffdcb62e8f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 24 additions and 19 deletions

View File

@ -15,7 +15,7 @@ import os
import subprocess
import sys
from time import sleep
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
@ -58,11 +58,11 @@ class DDPPlugin(ParallelPlugin):
def __init__(
self,
parallel_devices,
num_nodes=1,
parallel_devices: Optional[List[torch.device]] = None,
num_nodes: int = 1,
cluster_environment: ClusterEnvironment = None,
sync_batchnorm=False,
**kwargs: Dict[str, Any],
sync_batchnorm: bool = False,
**kwargs: Union[Any, Dict[str, Any]],
) -> None:
super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment)
self.interactive_ddp_procs = []

View File

@ -13,7 +13,7 @@
# limitations under the License.
import os
import re
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, List, Optional, Union
import torch
import torch.distributed as torch_distrib
@ -46,11 +46,11 @@ class DDPSpawnPlugin(ParallelPlugin):
def __init__(
self,
parallel_devices,
num_nodes=1,
parallel_devices: Optional[List[torch.device]] = None,
num_nodes: int = 1,
cluster_environment: ClusterEnvironment = None,
sync_batchnorm: bool = False,
**kwargs: Dict[str, Any],
**kwargs: Union[Any, Dict[str, Any]],
):
super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment)
self.num_nodes = num_nodes

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from typing import List, Optional
import torch
from torch.nn import DataParallel
@ -23,7 +23,7 @@ from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
class DataParallelPlugin(ParallelPlugin):
def __init__(self, parallel_devices: List[torch.device]):
def __init__(self, parallel_devices: Optional[List[torch.device]]):
super().__init__(parallel_devices=parallel_devices, cluster_environment=None)
def setup(self, model):

View File

@ -28,7 +28,7 @@ if _HOROVOD_AVAILABLE:
class HorovodPlugin(ParallelPlugin):
def __init__(self, parallel_devices: List[torch.device]):
def __init__(self, parallel_devices: Optional[List[torch.device]] = None):
super().__init__(parallel_devices=parallel_devices, cluster_environment=None)
@property

View File

@ -32,7 +32,7 @@ class ParallelPlugin(TrainingTypePlugin, ABC):
def __init__(
self,
parallel_devices: List[torch.device],
parallel_devices: Optional[List[torch.device]] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
):
super().__init__()

View File

@ -13,7 +13,7 @@
# limitations under the License.
import os
from contextlib import suppress
from typing import Optional, Sequence
from typing import List, Optional, Sequence
import torch
@ -42,7 +42,7 @@ class RPCPlugin(DDPPlugin):
def __init__(
self,
rpc_timeout_sec: float = DEFAULT_RPC_TIMEOUT_SEC,
parallel_devices: Sequence[int] = (),
parallel_devices: Optional[List[torch.device]] = None,
num_nodes: Optional[int] = None,
cluster_environment: Optional[ClusterEnvironment] = None,
sync_batchnorm: Optional[bool] = None,

View File

@ -8,7 +8,7 @@ from pytorch_lightning.plugins.training_type.training_type_plugin import Trainin
class SingleDevicePlugin(TrainingTypePlugin):
def __init__(self, device: torch.device) -> bool:
def __init__(self, device: torch.device):
super().__init__()
self.device: torch.device = device

View File

@ -1,7 +1,7 @@
import io
import os
import re
from typing import Any, Dict, Iterable, Optional, Sequence, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
import torch
import torch.multiprocessing as mp
@ -26,7 +26,12 @@ else:
class TPUSpawnPlugin(DDPSpawnPlugin):
def __init__(self, parallel_devices: Sequence[int], num_nodes: int = 1, **kwargs: Dict[str, Any]) -> None:
def __init__(
self,
parallel_devices: Optional[List[torch.device]] = None,
num_nodes: int = 1,
**kwargs: Dict[str, Any]
) -> None:
super().__init__(
parallel_devices, num_nodes=num_nodes, cluster_environment=None, sync_batchnorm=False, **kwargs
)

View File

@ -405,7 +405,7 @@ class AcceleratorConnector(object):
return plugin
def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> TrainingTypePlugin:
# necessary for RPC, when user has to provide balance
# necessary for when the user has passed in a plugin
if hasattr(training_type, 'parallel_devices') and not getattr(training_type, 'parallel_devices'):
training_type.parallel_devices = self.parallel_devices
if hasattr(training_type, 'num_processes'):