Remove unneeded check
This commit is contained in:
parent
bd4223e951
commit
5598dce1a9
|
@ -51,7 +51,7 @@ class PrecisionConnector:
|
|||
# no AMP requested, so we can leave now
|
||||
return
|
||||
|
||||
using_sharded_plugin = self._check_sharded_plugin(plugins)
|
||||
using_sharded_plugin = self._check_using_sharded_plugin(plugins)
|
||||
amp_type = amp_type.lower()
|
||||
assert amp_type in ('native', 'apex'), f'Unsupported amp type {amp_type}'
|
||||
if amp_type == 'native':
|
||||
|
@ -94,14 +94,7 @@ class PrecisionConnector:
|
|||
|
||||
return model
|
||||
|
||||
def _check_sharded_plugin(self, plugins):
|
||||
if plugins and self._sharded_in_plugins(plugins):
|
||||
if not FAIRSCALE_AVAILABLE:
|
||||
raise MisconfigurationException('Sharded DDP Plugin requires Fairscale to be installed.')
|
||||
return True
|
||||
return False
|
||||
|
||||
def _sharded_in_plugins(self, plugins):
|
||||
def _check_using_sharded_plugin(self, plugins):
|
||||
for plugin in plugins:
|
||||
if isinstance(plugin, DDPShardedPlugin):
|
||||
return True
|
||||
|
|
Loading…
Reference in New Issue