Do not use the base version by default in `_compare_version` (#10051)
This commit is contained in:
parent
225989363b
commit
f95ba20012
|
@ -44,7 +44,7 @@ def _module_available(module_path: str) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def _compare_version(package: str, op: Callable, version: str, use_base_version: bool = True) -> bool:
|
||||
def _compare_version(package: str, op: Callable, version: str, use_base_version: bool = False) -> bool:
|
||||
"""Compare package version with some requirements.
|
||||
|
||||
>>> _compare_version("torch", operator.ge, "0.1")
|
||||
|
|
|
@ -11,8 +11,10 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import operator
|
||||
|
||||
from pytorch_lightning.utilities import _module_available
|
||||
from pytorch_lightning.utilities.imports import _compare_version
|
||||
|
||||
|
||||
def test_module_exists():
|
||||
|
@ -22,3 +24,24 @@ def test_module_exists():
|
|||
assert not _module_available("torch.nn.asdf")
|
||||
assert not _module_available("asdf")
|
||||
assert not _module_available("asdf.bla.asdf")
|
||||
|
||||
|
||||
def test_compare_version(monkeypatch):
|
||||
from pytorch_lightning.utilities.imports import torch
|
||||
|
||||
monkeypatch.setattr(torch, "__version__", "1.8.9")
|
||||
assert not _compare_version("torch", operator.ge, "1.10.0")
|
||||
assert _compare_version("torch", operator.lt, "1.10.0")
|
||||
|
||||
monkeypatch.setattr(torch, "__version__", "1.10.0.dev123")
|
||||
assert _compare_version("torch", operator.ge, "1.10.0.dev123")
|
||||
assert not _compare_version("torch", operator.ge, "1.10.0.dev124")
|
||||
|
||||
assert _compare_version("torch", operator.ge, "1.10.0.dev123", use_base_version=True)
|
||||
assert _compare_version("torch", operator.ge, "1.10.0.dev124", use_base_version=True)
|
||||
|
||||
monkeypatch.setattr(torch, "__version__", "1.10.0a0+0aef44c") # dev version before rc
|
||||
assert _compare_version("torch", operator.ge, "1.10.0.rc0", use_base_version=True)
|
||||
assert not _compare_version("torch", operator.ge, "1.10.0.rc0")
|
||||
assert _compare_version("torch", operator.ge, "1.10.0", use_base_version=True)
|
||||
assert not _compare_version("torch", operator.ge, "1.10.0")
|
||||
|
|
Loading…
Reference in New Issue