spaCy/spacy/tests/test_misc.py

40 lines
992 B
Python
Raw Normal View History

2017-04-23 19:06:46 +00:00
# coding: utf-8
from __future__ import unicode_literals
from ..util import ensure_path
2017-05-28 23:37:57 +00:00
from .._ml import model_to_bytes, model_from_bytes
2017-04-23 19:06:46 +00:00
from pathlib import Path
import pytest
2017-05-28 23:37:57 +00:00
from thinc.neural import Maxout, Softmax
from thinc.api import chain
2017-04-23 19:06:46 +00:00
@pytest.mark.parametrize('text', ['hello/world', 'hello world'])
def test_util_ensure_path_succeeds(text):
path = ensure_path(text)
assert isinstance(path, Path)
2017-05-28 23:37:57 +00:00
def test_simple_model_roundtrip_bytes():
model = Maxout(5, 10, pieces=2)
model.b += 1
data = model_to_bytes(model)
model.b -= 1
model_from_bytes(model, data)
assert model.b[0, 0] == 1
def test_multi_model_roundtrip_bytes():
model = chain(Maxout(5, 10, pieces=2), Maxout(2, 3))
model._layers[0].b += 1
model._layers[1].b += 2
data = model_to_bytes(model)
model._layers[0].b -= 1
model._layers[1].b -= 2
model_from_bytes(model, data)
assert model._layers[0].b[0, 0] == 1
assert model._layers[1].b[0, 0] == 2