diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 69cf3ce1d4..c7ad708956 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -44,7 +44,7 @@ def _module_available(module_path: str) -> bool: return False -def _compare_version(package: str, op: Callable, version: str, use_base_version: bool = True) -> bool: +def _compare_version(package: str, op: Callable, version: str, use_base_version: bool = False) -> bool: """Compare package version with some requirements. >>> _compare_version("torch", operator.ge, "0.1") diff --git a/tests/utilities/test_imports.py b/tests/utilities/test_imports.py index bf2c2c4f70..75bcb51ffb 100644 --- a/tests/utilities/test_imports.py +++ b/tests/utilities/test_imports.py @@ -11,8 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import operator from pytorch_lightning.utilities import _module_available +from pytorch_lightning.utilities.imports import _compare_version def test_module_exists(): @@ -22,3 +24,24 @@ def test_module_exists(): assert not _module_available("torch.nn.asdf") assert not _module_available("asdf") assert not _module_available("asdf.bla.asdf") + + +def test_compare_version(monkeypatch): + from pytorch_lightning.utilities.imports import torch + + monkeypatch.setattr(torch, "__version__", "1.8.9") + assert not _compare_version("torch", operator.ge, "1.10.0") + assert _compare_version("torch", operator.lt, "1.10.0") + + monkeypatch.setattr(torch, "__version__", "1.10.0.dev123") + assert _compare_version("torch", operator.ge, "1.10.0.dev123") + assert not _compare_version("torch", operator.ge, "1.10.0.dev124") + + assert _compare_version("torch", operator.ge, "1.10.0.dev123", use_base_version=True) + assert _compare_version("torch", operator.ge, "1.10.0.dev124", use_base_version=True) + + monkeypatch.setattr(torch, "__version__", "1.10.0a0+0aef44c") # dev version before rc + assert _compare_version("torch", operator.ge, "1.10.0.rc0", use_base_version=True) + assert not _compare_version("torch", operator.ge, "1.10.0.rc0") + assert _compare_version("torch", operator.ge, "1.10.0", use_base_version=True) + assert not _compare_version("torch", operator.ge, "1.10.0")