Fix Python2/3 load bug

This commit is contained in:
Matthew Honnibal 2017-05-31 15:21:44 -05:00
parent 5385a06dd2
commit c8a58cfcf8
2 changed files with 7 additions and 2 deletions

View File

@ -59,6 +59,11 @@ elif is_python3:
json_dumps = lambda data: ujson.dumps(data, indent=2) json_dumps = lambda data: ujson.dumps(data, indent=2)
path2str = lambda path: str(path) path2str = lambda path: str(path)
def getattr_(obj, name, *default):
if is_python3 and isinstance(name, bytes):
name = name.decode('utf8')
return getattr(obj, name, *default)
def symlink_to(orig, dest): def symlink_to(orig, dest):
if is_python2 and is_windows: if is_python2 and is_windows:

View File

@ -22,7 +22,7 @@ import ujson
from .symbols import ORTH from .symbols import ORTH
from .compat import cupy, CudaStream, path2str, basestring_, input_, unicode_ from .compat import cupy, CudaStream, path2str, basestring_, input_, unicode_
from .compat import copy_array, normalize_string_keys from .compat import copy_array, normalize_string_keys, getattr_
LANGUAGES = {} LANGUAGES = {}
@ -499,7 +499,7 @@ def model_from_bytes(model, bytes_data):
for dim, value in weights[i]['dims'].items(): for dim, value in weights[i]['dims'].items():
setattr(layer, dim, value) setattr(layer, dim, value)
for param in weights[i]['params']: for param in weights[i]['params']:
dest = getattr(layer, param['name']) dest = getattr_(layer, param['name'])
copy_array(dest, param['value']) copy_array(dest, param['value'])
i += 1 i += 1
if hasattr(layer, '_layers'): if hasattr(layer, '_layers'):