Do not use the base version by default in `_compare_version` (#10051)

This commit is contained in:
Carlos Mocholí 2021-10-25 13:11:32 +02:00 committed by GitHub
parent 225989363b
commit f95ba20012
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 1 deletions

View File

@ -44,7 +44,7 @@ def _module_available(module_path: str) -> bool:
return False
def _compare_version(package: str, op: Callable, version: str, use_base_version: bool = True) -> bool:
def _compare_version(package: str, op: Callable, version: str, use_base_version: bool = False) -> bool:
"""Compare package version with some requirements.
>>> _compare_version("torch", operator.ge, "0.1")

View File

@ -11,8 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import operator
from pytorch_lightning.utilities import _module_available
from pytorch_lightning.utilities.imports import _compare_version
def test_module_exists():
@ -22,3 +24,24 @@ def test_module_exists():
assert not _module_available("torch.nn.asdf")
assert not _module_available("asdf")
assert not _module_available("asdf.bla.asdf")
def test_compare_version(monkeypatch):
from pytorch_lightning.utilities.imports import torch
monkeypatch.setattr(torch, "__version__", "1.8.9")
assert not _compare_version("torch", operator.ge, "1.10.0")
assert _compare_version("torch", operator.lt, "1.10.0")
monkeypatch.setattr(torch, "__version__", "1.10.0.dev123")
assert _compare_version("torch", operator.ge, "1.10.0.dev123")
assert not _compare_version("torch", operator.ge, "1.10.0.dev124")
assert _compare_version("torch", operator.ge, "1.10.0.dev123", use_base_version=True)
assert _compare_version("torch", operator.ge, "1.10.0.dev124", use_base_version=True)
monkeypatch.setattr(torch, "__version__", "1.10.0a0+0aef44c") # dev version before rc
assert _compare_version("torch", operator.ge, "1.10.0.rc0", use_base_version=True)
assert not _compare_version("torch", operator.ge, "1.10.0.rc0")
assert _compare_version("torch", operator.ge, "1.10.0", use_base_version=True)
assert not _compare_version("torch", operator.ge, "1.10.0")