Address code review for deepspeed (#6042)
This commit is contained in:
parent
b7c2e0a80e
commit
8d7ac8f0f8
|
@ -158,13 +158,12 @@ class DeepSpeedPlugin(DDPPlugin):
|
|||
rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable")
|
||||
config = os.environ[self.DEEPSPEED_ENV_VAR]
|
||||
if isinstance(config, str) or isinstance(config, Path):
|
||||
if os.path.exists(config):
|
||||
with open(config) as f:
|
||||
config = json.load(f)
|
||||
else:
|
||||
if not os.path.isfile(config):
|
||||
raise MisconfigurationException(
|
||||
f"You passed in a path to a DeepSpeed config but the path does not exist: {config}"
|
||||
)
|
||||
with open(config) as f:
|
||||
config = json.load(f)
|
||||
return config
|
||||
|
||||
def pre_dispatch(self):
|
||||
|
@ -198,7 +197,7 @@ class DeepSpeedPlugin(DDPPlugin):
|
|||
optimizers, schedulers, optimizer_frequencies = self.lightning_module.trainer.init_optimizers(
|
||||
self.lightning_module
|
||||
)
|
||||
if (len(optimizers) != 1) or len(schedulers) > 1:
|
||||
if len(optimizers) > 1 or len(schedulers) > 1:
|
||||
raise MisconfigurationException(
|
||||
"DeepSpeed currently only supports single optimizer, single optional scheduler."
|
||||
)
|
||||
|
@ -234,7 +233,7 @@ class DeepSpeedPlugin(DDPPlugin):
|
|||
self.model_to_device()
|
||||
|
||||
self.pre_configure_ddp()
|
||||
self._model = DistributedDataParallel(
|
||||
self.model = DistributedDataParallel(
|
||||
model,
|
||||
device_ids=self.determine_ddp_device_ids(),
|
||||
**self._ddp_kwargs,
|
||||
|
|
|
@ -34,7 +34,6 @@ from pytorch_lightning.plugins import (
|
|||
DeepSpeedPrecisionPlugin,
|
||||
HorovodPlugin,
|
||||
NativeMixedPrecisionPlugin,
|
||||
Plugin,
|
||||
PrecisionPlugin,
|
||||
ShardedNativeMixedPrecisionPlugin,
|
||||
SingleDevicePlugin,
|
||||
|
@ -147,7 +146,9 @@ class AcceleratorConnector(object):
|
|||
|
||||
self.replace_sampler_ddp = replace_sampler_ddp
|
||||
|
||||
def handle_given_plugins(self, plugins: Optional[Union[Plugin, Sequence]]):
|
||||
def handle_given_plugins(
|
||||
self, plugins: Optional[Union[ClusterEnvironment, TrainingTypePlugin, PrecisionPlugin, Sequence]]
|
||||
):
|
||||
plugins = plugins if plugins is not None else []
|
||||
|
||||
if isinstance(plugins, str):
|
||||
|
|
Loading…
Reference in New Issue