Address code review for deepspeed (#6042)

This commit is contained in:
Sean Naren 2021-02-17 22:53:20 +00:00 committed by GitHub
parent b7c2e0a80e
commit 8d7ac8f0f8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 8 additions and 8 deletions

View File

@ -158,13 +158,12 @@ class DeepSpeedPlugin(DDPPlugin):
rank_zero_info(f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable")
config = os.environ[self.DEEPSPEED_ENV_VAR]
if isinstance(config, str) or isinstance(config, Path):
if os.path.exists(config):
with open(config) as f:
config = json.load(f)
else:
if not os.path.isfile(config):
raise MisconfigurationException(
f"You passed in a path to a DeepSpeed config but the path does not exist: {config}"
)
with open(config) as f:
config = json.load(f)
return config
def pre_dispatch(self):
@ -198,7 +197,7 @@ class DeepSpeedPlugin(DDPPlugin):
optimizers, schedulers, optimizer_frequencies = self.lightning_module.trainer.init_optimizers(
self.lightning_module
)
if (len(optimizers) != 1) or len(schedulers) > 1:
if len(optimizers) > 1 or len(schedulers) > 1:
raise MisconfigurationException(
"DeepSpeed currently only supports single optimizer, single optional scheduler."
)
@ -234,7 +233,7 @@ class DeepSpeedPlugin(DDPPlugin):
self.model_to_device()
self.pre_configure_ddp()
self._model = DistributedDataParallel(
self.model = DistributedDataParallel(
model,
device_ids=self.determine_ddp_device_ids(),
**self._ddp_kwargs,

View File

@ -34,7 +34,6 @@ from pytorch_lightning.plugins import (
DeepSpeedPrecisionPlugin,
HorovodPlugin,
NativeMixedPrecisionPlugin,
Plugin,
PrecisionPlugin,
ShardedNativeMixedPrecisionPlugin,
SingleDevicePlugin,
@ -147,7 +146,9 @@ class AcceleratorConnector(object):
self.replace_sampler_ddp = replace_sampler_ddp
def handle_given_plugins(self, plugins: Optional[Union[Plugin, Sequence]]):
def handle_given_plugins(
self, plugins: Optional[Union[ClusterEnvironment, TrainingTypePlugin, PrecisionPlugin, Sequence]]
):
plugins = plugins if plugins is not None else []
if isinstance(plugins, str):