Refactor utilities/imports.py (#5874)
Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
This commit is contained in:
parent
373a31e63e
commit
ecd3678a1f
|
@ -12,50 +12,53 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""General utilities"""
|
"""General utilities"""
|
||||||
import importlib
|
|
||||||
import platform
|
import platform
|
||||||
from distutils.version import LooseVersion
|
from distutils.version import LooseVersion
|
||||||
|
from importlib.util import find_spec
|
||||||
|
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def _module_available(module_path: str) -> bool:
|
def _module_available(module_path: str) -> bool:
|
||||||
"""Testing if given module is avalaible in your env
|
"""
|
||||||
|
Check if a path is available in your environment
|
||||||
|
|
||||||
>>> _module_available('os')
|
>>> _module_available('os')
|
||||||
True
|
True
|
||||||
>>> _module_available('bla.bla')
|
>>> _module_available('bla.bla')
|
||||||
False
|
False
|
||||||
"""
|
"""
|
||||||
# todo: find a better way than try / except
|
|
||||||
try:
|
try:
|
||||||
mods = module_path.split('.')
|
return find_spec(module_path) is not None
|
||||||
assert mods, 'nothing given to test'
|
|
||||||
# it has to be tested as per partets
|
|
||||||
for i in range(len(mods)):
|
|
||||||
module_path = '.'.join(mods[:i + 1])
|
|
||||||
if importlib.util.find_spec(module_path) is None:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
|
# Python 3.6
|
||||||
|
return False
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
# Python 3.7+
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _get_version(package: str) -> LooseVersion:
|
||||||
|
return LooseVersion(pkg_resources.get_distribution(package).version)
|
||||||
|
|
||||||
|
|
||||||
|
_IS_WINDOWS = platform.system() == "Windows"
|
||||||
|
_TORCH_GREATER_EQUAL_1_6 = _get_version("torch") >= LooseVersion("1.6.0")
|
||||||
|
|
||||||
_APEX_AVAILABLE = _module_available("apex.amp")
|
_APEX_AVAILABLE = _module_available("apex.amp")
|
||||||
_NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast")
|
_BOLTS_AVAILABLE = _module_available('pl_bolts')
|
||||||
_OMEGACONF_AVAILABLE = _module_available("omegaconf")
|
_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available('fairscale.nn.data_parallel')
|
||||||
|
_FAIRSCALE_PIPE_AVAILABLE = (
|
||||||
|
_FAIRSCALE_AVAILABLE and _TORCH_GREATER_EQUAL_1_6 and _get_version('fairscale') <= LooseVersion("0.1.3")
|
||||||
|
)
|
||||||
|
_GROUP_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.group')
|
||||||
|
_HOROVOD_AVAILABLE = _module_available("horovod.torch")
|
||||||
_HYDRA_AVAILABLE = _module_available("hydra")
|
_HYDRA_AVAILABLE = _module_available("hydra")
|
||||||
_HYDRA_EXPERIMENTAL_AVAILABLE = _module_available("hydra.experimental")
|
_HYDRA_EXPERIMENTAL_AVAILABLE = _module_available("hydra.experimental")
|
||||||
_HOROVOD_AVAILABLE = _module_available("horovod.torch")
|
_NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast")
|
||||||
|
_OMEGACONF_AVAILABLE = _module_available("omegaconf")
|
||||||
|
_RPC_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.rpc')
|
||||||
_TORCHTEXT_AVAILABLE = _module_available("torchtext")
|
_TORCHTEXT_AVAILABLE = _module_available("torchtext")
|
||||||
_XLA_AVAILABLE = _module_available("torch_xla")
|
|
||||||
_FAIRSCALE_AVAILABLE = platform.system() != 'Windows' and _module_available('fairscale.nn.data_parallel')
|
|
||||||
_RPC_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.rpc')
|
|
||||||
_GROUP_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.group')
|
|
||||||
_FAIRSCALE_PIPE_AVAILABLE = _FAIRSCALE_AVAILABLE and LooseVersion(
|
|
||||||
torch.__version__
|
|
||||||
) >= LooseVersion("1.6.0") and LooseVersion(pkg_resources.get_distribution('fairscale').version
|
|
||||||
) <= LooseVersion("0.1.3")
|
|
||||||
_BOLTS_AVAILABLE = _module_available('pl_bolts')
|
|
||||||
_TORCHVISION_AVAILABLE = _module_available('torchvision')
|
_TORCHVISION_AVAILABLE = _module_available('torchvision')
|
||||||
|
_XLA_AVAILABLE = _module_available("torch_xla")
|
||||||
|
|
Loading…
Reference in New Issue