From 0576a1ff56e7584a6e6b198e59a6ded5944bd0b6 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sat, 14 Sep 2024 12:54:08 +0200 Subject: [PATCH] Fix numpy floats in meta.json --- spacy/language.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/spacy/language.py b/spacy/language.py index 0d9aab9e3..f2303c10c 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -9,6 +9,7 @@ from contextlib import ExitStack, contextmanager from copy import deepcopy from dataclasses import dataclass from itertools import chain, cycle +import numpy from pathlib import Path from timeit import default_timer as timer from typing import ( @@ -33,6 +34,7 @@ from typing import ( import srsly from cymem.cymem import Pool from thinc.api import Config, CupyOps, Optimizer, get_current_ops +from thinc.util import convert_recursive from . import about, ty, util from .compat import Literal @@ -2141,7 +2143,7 @@ class Language: serializers["tokenizer"] = lambda p: self.tokenizer.to_disk( # type: ignore[union-attr] p, exclude=["vocab"] ) - serializers["meta.json"] = lambda p: srsly.write_json(p, self.meta) + serializers["meta.json"] = lambda p: srsly.write_json(p, _replace_numpy_floats(self.meta)) serializers["config.cfg"] = lambda p: self.config.to_disk(p) for name, proc in self._components: if name in exclude: @@ -2255,7 +2257,7 @@ class Language: serializers: Dict[str, Callable[[], bytes]] = {} serializers["vocab"] = lambda: self.vocab.to_bytes(exclude=exclude) serializers["tokenizer"] = lambda: self.tokenizer.to_bytes(exclude=["vocab"]) # type: ignore[union-attr] - serializers["meta.json"] = lambda: srsly.json_dumps(self.meta) + serializers["meta.json"] = lambda: srsly.json_dumps(_replace_numpy_floats(self.meta)) serializers["config.cfg"] = lambda: self.config.to_bytes() for name, proc in self._components: if name in exclude: @@ -2306,6 +2308,10 @@ class Language: return self +def _replace_numpy_floats(meta_dict: dict) -> dict: + return convert_recursive(lambda v: isinstance(v, numpy.floaty), lambda v: float(v), dict(meta_dict)) + + @dataclass class FactoryMeta: """Dataclass containing information about a component and its defaults