Fix numpy floats in meta.json

This commit is contained in:
Matthew Honnibal 2024-09-14 12:54:08 +02:00
parent 2f1e7ed09a
commit 0576a1ff56
1 changed files with 8 additions and 2 deletions

View File

@ -9,6 +9,7 @@ from contextlib import ExitStack, contextmanager
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from itertools import chain, cycle from itertools import chain, cycle
import numpy
from pathlib import Path from pathlib import Path
from timeit import default_timer as timer from timeit import default_timer as timer
from typing import ( from typing import (
@ -33,6 +34,7 @@ from typing import (
import srsly import srsly
from cymem.cymem import Pool from cymem.cymem import Pool
from thinc.api import Config, CupyOps, Optimizer, get_current_ops from thinc.api import Config, CupyOps, Optimizer, get_current_ops
from thinc.util import convert_recursive
from . import about, ty, util from . import about, ty, util
from .compat import Literal from .compat import Literal
@ -2141,7 +2143,7 @@ class Language:
serializers["tokenizer"] = lambda p: self.tokenizer.to_disk( # type: ignore[union-attr] serializers["tokenizer"] = lambda p: self.tokenizer.to_disk( # type: ignore[union-attr]
p, exclude=["vocab"] p, exclude=["vocab"]
) )
serializers["meta.json"] = lambda p: srsly.write_json(p, self.meta) serializers["meta.json"] = lambda p: srsly.write_json(p, _replace_numpy_floats(self.meta))
serializers["config.cfg"] = lambda p: self.config.to_disk(p) serializers["config.cfg"] = lambda p: self.config.to_disk(p)
for name, proc in self._components: for name, proc in self._components:
if name in exclude: if name in exclude:
@ -2255,7 +2257,7 @@ class Language:
serializers: Dict[str, Callable[[], bytes]] = {} serializers: Dict[str, Callable[[], bytes]] = {}
serializers["vocab"] = lambda: self.vocab.to_bytes(exclude=exclude) serializers["vocab"] = lambda: self.vocab.to_bytes(exclude=exclude)
serializers["tokenizer"] = lambda: self.tokenizer.to_bytes(exclude=["vocab"]) # type: ignore[union-attr] serializers["tokenizer"] = lambda: self.tokenizer.to_bytes(exclude=["vocab"]) # type: ignore[union-attr]
serializers["meta.json"] = lambda: srsly.json_dumps(self.meta) serializers["meta.json"] = lambda: srsly.json_dumps(_replace_numpy_floats(self.meta))
serializers["config.cfg"] = lambda: self.config.to_bytes() serializers["config.cfg"] = lambda: self.config.to_bytes()
for name, proc in self._components: for name, proc in self._components:
if name in exclude: if name in exclude:
@ -2306,6 +2308,10 @@ class Language:
return self return self
def _replace_numpy_floats(meta_dict: dict) -> dict:
return convert_recursive(lambda v: isinstance(v, numpy.floaty), lambda v: float(v), dict(meta_dict))
@dataclass @dataclass
class FactoryMeta: class FactoryMeta:
"""Dataclass containing information about a component and its defaults """Dataclass containing information about a component and its defaults