diff --git a/spacy/tests/test_misc.py b/spacy/tests/test_misc.py index 0d09999a9..b38a50f71 100644 --- a/spacy/tests/test_misc.py +++ b/spacy/tests/test_misc.py @@ -8,6 +8,7 @@ from spacy import prefer_gpu, require_gpu, require_cpu from spacy.ml._precomputable_affine import PrecomputableAffine from spacy.ml._precomputable_affine import _backprop_precomputable_affine_padding from spacy.util import dot_to_object, SimpleFrozenList, import_file +from spacy.util import to_ternary_int from thinc.api import Config, Optimizer, ConfigValidationError, get_current_ops from thinc.api import set_current_ops from spacy.training.batchers import minibatch_by_words @@ -386,3 +387,18 @@ def make_dummy_component( nlp = English.from_config(config) nlp.add_pipe("dummy_component") nlp.initialize() + + +def test_to_ternary_int(): + assert to_ternary_int(True) == 1 + assert to_ternary_int(None) == 0 + assert to_ternary_int(False) == -1 + assert to_ternary_int(1) == 1 + assert to_ternary_int(1.0) == 1 + assert to_ternary_int(0) == 0 + assert to_ternary_int(0.0) == 0 + assert to_ternary_int(-1) == -1 + assert to_ternary_int(5) == -1 + assert to_ternary_int(-10) == -1 + assert to_ternary_int("string") == -1 + assert to_ternary_int([0, "string"]) == -1 diff --git a/spacy/util.py b/spacy/util.py index 512c6b742..84142d5d8 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -1533,11 +1533,15 @@ def to_ternary_int(val) -> int: attributes such as SENT_START: True/1/1.0 is 1 (True), None/0/0.0 is 0 (None), any other values are -1 (False). """ - if isinstance(val, float): - val = int(val) - if val is True or val is 1: + if val is True: return 1 - elif val is None or val is 0: + elif val is None: + return 0 + elif val is False: + return -1 + elif val == 1: + return 1 + elif val == 0: return 0 else: return -1