Bugfix for deepcopy + small unrelated refactor (issue #938)

This commit is contained in:
Erez Sh 2021-07-09 22:44:31 +03:00
parent be4a7af62b
commit 7cb8acbe54
2 changed files with 19 additions and 7 deletions

View File

@ -1,4 +1,5 @@
from warnings import warn from warnings import warn
from copy import deepcopy
from .utils import Serialize from .utils import Serialize
from .lexer import TerminalDef from .lexer import TerminalDef
@ -31,6 +32,17 @@ class LexerConf(Serialize):
def _deserialize(self): def _deserialize(self):
self.terminals_by_name = {t.name: t for t in self.terminals} self.terminals_by_name = {t.name: t for t in self.terminals}
def __deepcopy__(self, memo=None):
return type(self)(
deepcopy(self.terminals, memo),
self.re_module,
deepcopy(self.ignore, memo),
deepcopy(self.postlex, memo),
deepcopy(self.callbacks, memo),
deepcopy(self.g_regex_flags, memo),
deepcopy(self.skip_validation, memo),
deepcopy(self.use_bytes, memo),
)
class ParserConf(Serialize): class ParserConf(Serialize):

View File

@ -73,14 +73,13 @@ class Serialize(object):
fields = getattr(self, '__serialize_fields__') fields = getattr(self, '__serialize_fields__')
res = {f: _serialize(getattr(self, f), memo) for f in fields} res = {f: _serialize(getattr(self, f), memo) for f in fields}
res['__type__'] = type(self).__name__ res['__type__'] = type(self).__name__
postprocess = getattr(self, '_serialize', None) if hasattr(self, '_serialize'):
if postprocess: self._serialize(res, memo)
postprocess(res, memo)
return res return res
@classmethod @classmethod
def deserialize(cls, data, memo): def deserialize(cls, data, memo):
namespace = getattr(cls, '__serialize_namespace__', {}) namespace = getattr(cls, '__serialize_namespace__', [])
namespace = {c.__name__:c for c in namespace} namespace = {c.__name__:c for c in namespace}
fields = getattr(cls, '__serialize_fields__') fields = getattr(cls, '__serialize_fields__')
@ -94,9 +93,10 @@ class Serialize(object):
setattr(inst, f, _deserialize(data[f], namespace, memo)) setattr(inst, f, _deserialize(data[f], namespace, memo))
except KeyError as e: except KeyError as e:
raise KeyError("Cannot find key for class", cls, e) raise KeyError("Cannot find key for class", cls, e)
postprocess = getattr(inst, '_deserialize', None)
if postprocess: if hasattr(inst, '_deserialize'):
postprocess() inst._deserialize()
return inst return inst