Remove unneeded check

This commit is contained in:
SeanNaren 2020-11-27 14:22:17 +00:00
parent bd4223e951
commit 5598dce1a9
1 changed files with 2 additions and 9 deletions

View File

@ -51,7 +51,7 @@ class PrecisionConnector:
# no AMP requested, so we can leave now
return
using_sharded_plugin = self._check_sharded_plugin(plugins)
using_sharded_plugin = self._check_using_sharded_plugin(plugins)
amp_type = amp_type.lower()
assert amp_type in ('native', 'apex'), f'Unsupported amp type {amp_type}'
if amp_type == 'native':
@ -94,14 +94,7 @@ class PrecisionConnector:
return model
def _check_sharded_plugin(self, plugins):
if plugins and self._sharded_in_plugins(plugins):
if not FAIRSCALE_AVAILABLE:
raise MisconfigurationException('Sharded DDP Plugin requires Fairscale to be installed.')
return True
return False
def _sharded_in_plugins(self, plugins):
def _check_using_sharded_plugin(self, plugins):
for plugin in plugins:
if isinstance(plugin, DDPShardedPlugin):
return True