Refactor utilities/imports.py (#5874)

Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
This commit is contained in:
Carlos Mocholí 2021-02-11 00:32:47 +01:00 committed by GitHub
parent 373a31e63e
commit ecd3678a1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 26 additions and 23 deletions

View File

@ -12,50 +12,53 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""General utilities"""
import importlib
import platform
from distutils.version import LooseVersion
from importlib.util import find_spec
import pkg_resources
import torch
def _module_available(module_path: str) -> bool:
"""Testing if given module is avalaible in your env
"""
Check if a path is available in your environment
>>> _module_available('os')
True
>>> _module_available('bla.bla')
False
"""
# todo: find a better way than try / except
try:
mods = module_path.split('.')
assert mods, 'nothing given to test'
# it has to be tested as per partets
for i in range(len(mods)):
module_path = '.'.join(mods[:i + 1])
if importlib.util.find_spec(module_path) is None:
return False
return True
return find_spec(module_path) is not None
except AttributeError:
# Python 3.6
return False
except ModuleNotFoundError:
# Python 3.7+
return False
def _get_version(package: str) -> LooseVersion:
return LooseVersion(pkg_resources.get_distribution(package).version)
_IS_WINDOWS = platform.system() == "Windows"
_TORCH_GREATER_EQUAL_1_6 = _get_version("torch") >= LooseVersion("1.6.0")
_APEX_AVAILABLE = _module_available("apex.amp")
_NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast")
_OMEGACONF_AVAILABLE = _module_available("omegaconf")
_BOLTS_AVAILABLE = _module_available('pl_bolts')
_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available('fairscale.nn.data_parallel')
_FAIRSCALE_PIPE_AVAILABLE = (
_FAIRSCALE_AVAILABLE and _TORCH_GREATER_EQUAL_1_6 and _get_version('fairscale') <= LooseVersion("0.1.3")
)
_GROUP_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.group')
_HOROVOD_AVAILABLE = _module_available("horovod.torch")
_HYDRA_AVAILABLE = _module_available("hydra")
_HYDRA_EXPERIMENTAL_AVAILABLE = _module_available("hydra.experimental")
_HOROVOD_AVAILABLE = _module_available("horovod.torch")
_NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast")
_OMEGACONF_AVAILABLE = _module_available("omegaconf")
_RPC_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.rpc')
_TORCHTEXT_AVAILABLE = _module_available("torchtext")
_XLA_AVAILABLE = _module_available("torch_xla")
_FAIRSCALE_AVAILABLE = platform.system() != 'Windows' and _module_available('fairscale.nn.data_parallel')
_RPC_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.rpc')
_GROUP_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.group')
_FAIRSCALE_PIPE_AVAILABLE = _FAIRSCALE_AVAILABLE and LooseVersion(
torch.__version__
) >= LooseVersion("1.6.0") and LooseVersion(pkg_resources.get_distribution('fairscale').version
) <= LooseVersion("0.1.3")
_BOLTS_AVAILABLE = _module_available('pl_bolts')
_TORCHVISION_AVAILABLE = _module_available('torchvision')
_XLA_AVAILABLE = _module_available("torch_xla")