Make parallel devices optional across all plugins (#6051)
* Make parallel devices optional across all plugins so that they can be instantiated * Add any to types to capture vars passed in
This commit is contained in:
parent
bcc0004955
commit
ffdcb62e8f
|
@ -15,7 +15,7 @@ import os
|
|||
import subprocess
|
||||
import sys
|
||||
from time import sleep
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -58,11 +58,11 @@ class DDPPlugin(ParallelPlugin):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
parallel_devices,
|
||||
num_nodes=1,
|
||||
parallel_devices: Optional[List[torch.device]] = None,
|
||||
num_nodes: int = 1,
|
||||
cluster_environment: ClusterEnvironment = None,
|
||||
sync_batchnorm=False,
|
||||
**kwargs: Dict[str, Any],
|
||||
sync_batchnorm: bool = False,
|
||||
**kwargs: Union[Any, Dict[str, Any]],
|
||||
) -> None:
|
||||
super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment)
|
||||
self.interactive_ddp_procs = []
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as torch_distrib
|
||||
|
@ -46,11 +46,11 @@ class DDPSpawnPlugin(ParallelPlugin):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
parallel_devices,
|
||||
num_nodes=1,
|
||||
parallel_devices: Optional[List[torch.device]] = None,
|
||||
num_nodes: int = 1,
|
||||
cluster_environment: ClusterEnvironment = None,
|
||||
sync_batchnorm: bool = False,
|
||||
**kwargs: Dict[str, Any],
|
||||
**kwargs: Union[Any, Dict[str, Any]],
|
||||
):
|
||||
super().__init__(parallel_devices=parallel_devices, cluster_environment=cluster_environment)
|
||||
self.num_nodes = num_nodes
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import DataParallel
|
||||
|
@ -23,7 +23,7 @@ from pytorch_lightning.plugins.training_type.parallel import ParallelPlugin
|
|||
|
||||
class DataParallelPlugin(ParallelPlugin):
|
||||
|
||||
def __init__(self, parallel_devices: List[torch.device]):
|
||||
def __init__(self, parallel_devices: Optional[List[torch.device]]):
|
||||
super().__init__(parallel_devices=parallel_devices, cluster_environment=None)
|
||||
|
||||
def setup(self, model):
|
||||
|
|
|
@ -28,7 +28,7 @@ if _HOROVOD_AVAILABLE:
|
|||
|
||||
class HorovodPlugin(ParallelPlugin):
|
||||
|
||||
def __init__(self, parallel_devices: List[torch.device]):
|
||||
def __init__(self, parallel_devices: Optional[List[torch.device]] = None):
|
||||
super().__init__(parallel_devices=parallel_devices, cluster_environment=None)
|
||||
|
||||
@property
|
||||
|
|
|
@ -32,7 +32,7 @@ class ParallelPlugin(TrainingTypePlugin, ABC):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
parallel_devices: List[torch.device],
|
||||
parallel_devices: Optional[List[torch.device]] = None,
|
||||
cluster_environment: Optional[ClusterEnvironment] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
|
|
@ -13,7 +13,7 @@
|
|||
# limitations under the License.
|
||||
import os
|
||||
from contextlib import suppress
|
||||
from typing import Optional, Sequence
|
||||
from typing import List, Optional, Sequence
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -42,7 +42,7 @@ class RPCPlugin(DDPPlugin):
|
|||
def __init__(
|
||||
self,
|
||||
rpc_timeout_sec: float = DEFAULT_RPC_TIMEOUT_SEC,
|
||||
parallel_devices: Sequence[int] = (),
|
||||
parallel_devices: Optional[List[torch.device]] = None,
|
||||
num_nodes: Optional[int] = None,
|
||||
cluster_environment: Optional[ClusterEnvironment] = None,
|
||||
sync_batchnorm: Optional[bool] = None,
|
||||
|
|
|
@ -8,7 +8,7 @@ from pytorch_lightning.plugins.training_type.training_type_plugin import Trainin
|
|||
|
||||
class SingleDevicePlugin(TrainingTypePlugin):
|
||||
|
||||
def __init__(self, device: torch.device) -> bool:
|
||||
def __init__(self, device: torch.device):
|
||||
super().__init__()
|
||||
self.device: torch.device = device
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import io
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict, Iterable, Optional, Sequence, Tuple, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
@ -26,7 +26,12 @@ else:
|
|||
|
||||
class TPUSpawnPlugin(DDPSpawnPlugin):
|
||||
|
||||
def __init__(self, parallel_devices: Sequence[int], num_nodes: int = 1, **kwargs: Dict[str, Any]) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
parallel_devices: Optional[List[torch.device]] = None,
|
||||
num_nodes: int = 1,
|
||||
**kwargs: Dict[str, Any]
|
||||
) -> None:
|
||||
super().__init__(
|
||||
parallel_devices, num_nodes=num_nodes, cluster_environment=None, sync_batchnorm=False, **kwargs
|
||||
)
|
||||
|
|
|
@ -405,7 +405,7 @@ class AcceleratorConnector(object):
|
|||
return plugin
|
||||
|
||||
def resolve_training_type_plugin(self, training_type: TrainingTypePlugin) -> TrainingTypePlugin:
|
||||
# necessary for RPC, when user has to provide balance
|
||||
# necessary for when the user has passed in a plugin
|
||||
if hasattr(training_type, 'parallel_devices') and not getattr(training_type, 'parallel_devices'):
|
||||
training_type.parallel_devices = self.parallel_devices
|
||||
if hasattr(training_type, 'num_processes'):
|
||||
|
|
Loading…
Reference in New Issue