Refactor utilities/imports.py (#5874)
Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz> Co-authored-by: Nicki Skafte <skaftenicki@gmail.com>
This commit is contained in:
parent
373a31e63e
commit
ecd3678a1f
|
@ -12,50 +12,53 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""General utilities"""
|
||||
import importlib
|
||||
import platform
|
||||
from distutils.version import LooseVersion
|
||||
from importlib.util import find_spec
|
||||
|
||||
import pkg_resources
|
||||
import torch
|
||||
|
||||
|
||||
def _module_available(module_path: str) -> bool:
|
||||
"""Testing if given module is avalaible in your env
|
||||
"""
|
||||
Check if a path is available in your environment
|
||||
|
||||
>>> _module_available('os')
|
||||
True
|
||||
>>> _module_available('bla.bla')
|
||||
False
|
||||
"""
|
||||
# todo: find a better way than try / except
|
||||
try:
|
||||
mods = module_path.split('.')
|
||||
assert mods, 'nothing given to test'
|
||||
# it has to be tested as per partets
|
||||
for i in range(len(mods)):
|
||||
module_path = '.'.join(mods[:i + 1])
|
||||
if importlib.util.find_spec(module_path) is None:
|
||||
return False
|
||||
return True
|
||||
return find_spec(module_path) is not None
|
||||
except AttributeError:
|
||||
# Python 3.6
|
||||
return False
|
||||
except ModuleNotFoundError:
|
||||
# Python 3.7+
|
||||
return False
|
||||
|
||||
|
||||
def _get_version(package: str) -> LooseVersion:
|
||||
return LooseVersion(pkg_resources.get_distribution(package).version)
|
||||
|
||||
|
||||
_IS_WINDOWS = platform.system() == "Windows"
|
||||
_TORCH_GREATER_EQUAL_1_6 = _get_version("torch") >= LooseVersion("1.6.0")
|
||||
|
||||
_APEX_AVAILABLE = _module_available("apex.amp")
|
||||
_NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast")
|
||||
_OMEGACONF_AVAILABLE = _module_available("omegaconf")
|
||||
_BOLTS_AVAILABLE = _module_available('pl_bolts')
|
||||
_FAIRSCALE_AVAILABLE = not _IS_WINDOWS and _module_available('fairscale.nn.data_parallel')
|
||||
_FAIRSCALE_PIPE_AVAILABLE = (
|
||||
_FAIRSCALE_AVAILABLE and _TORCH_GREATER_EQUAL_1_6 and _get_version('fairscale') <= LooseVersion("0.1.3")
|
||||
)
|
||||
_GROUP_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.group')
|
||||
_HOROVOD_AVAILABLE = _module_available("horovod.torch")
|
||||
_HYDRA_AVAILABLE = _module_available("hydra")
|
||||
_HYDRA_EXPERIMENTAL_AVAILABLE = _module_available("hydra.experimental")
|
||||
_HOROVOD_AVAILABLE = _module_available("horovod.torch")
|
||||
_NATIVE_AMP_AVAILABLE = _module_available("torch.cuda.amp") and hasattr(torch.cuda.amp, "autocast")
|
||||
_OMEGACONF_AVAILABLE = _module_available("omegaconf")
|
||||
_RPC_AVAILABLE = not _IS_WINDOWS and _module_available('torch.distributed.rpc')
|
||||
_TORCHTEXT_AVAILABLE = _module_available("torchtext")
|
||||
_XLA_AVAILABLE = _module_available("torch_xla")
|
||||
_FAIRSCALE_AVAILABLE = platform.system() != 'Windows' and _module_available('fairscale.nn.data_parallel')
|
||||
_RPC_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.rpc')
|
||||
_GROUP_AVAILABLE = platform.system() != 'Windows' and _module_available('torch.distributed.group')
|
||||
_FAIRSCALE_PIPE_AVAILABLE = _FAIRSCALE_AVAILABLE and LooseVersion(
|
||||
torch.__version__
|
||||
) >= LooseVersion("1.6.0") and LooseVersion(pkg_resources.get_distribution('fairscale').version
|
||||
) <= LooseVersion("0.1.3")
|
||||
_BOLTS_AVAILABLE = _module_available('pl_bolts')
|
||||
_TORCHVISION_AVAILABLE = _module_available('torchvision')
|
||||
_XLA_AVAILABLE = _module_available("torch_xla")
|
||||
|
|
Loading…
Reference in New Issue