From 7cb8acbe54eb108b6e99859adfd41717df43e032 Mon Sep 17 00:00:00 2001 From: Erez Sh Date: Fri, 9 Jul 2021 22:44:31 +0300 Subject: [PATCH] Bugfix for deepcopy + small unrelated refactor (issue #938) --- lark/common.py | 12 ++++++++++++ lark/utils.py | 14 +++++++------- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/lark/common.py b/lark/common.py index 467acf8..cb408d9 100644 --- a/lark/common.py +++ b/lark/common.py @@ -1,4 +1,5 @@ from warnings import warn +from copy import deepcopy from .utils import Serialize from .lexer import TerminalDef @@ -31,6 +32,17 @@ class LexerConf(Serialize): def _deserialize(self): self.terminals_by_name = {t.name: t for t in self.terminals} + def __deepcopy__(self, memo=None): + return type(self)( + deepcopy(self.terminals, memo), + self.re_module, + deepcopy(self.ignore, memo), + deepcopy(self.postlex, memo), + deepcopy(self.callbacks, memo), + deepcopy(self.g_regex_flags, memo), + deepcopy(self.skip_validation, memo), + deepcopy(self.use_bytes, memo), + ) class ParserConf(Serialize): diff --git a/lark/utils.py b/lark/utils.py index b9d7ac3..ea78801 100644 --- a/lark/utils.py +++ b/lark/utils.py @@ -73,14 +73,13 @@ class Serialize(object): fields = getattr(self, '__serialize_fields__') res = {f: _serialize(getattr(self, f), memo) for f in fields} res['__type__'] = type(self).__name__ - postprocess = getattr(self, '_serialize', None) - if postprocess: - postprocess(res, memo) + if hasattr(self, '_serialize'): + self._serialize(res, memo) return res @classmethod def deserialize(cls, data, memo): - namespace = getattr(cls, '__serialize_namespace__', {}) + namespace = getattr(cls, '__serialize_namespace__', []) namespace = {c.__name__:c for c in namespace} fields = getattr(cls, '__serialize_fields__') @@ -94,9 +93,10 @@ class Serialize(object): setattr(inst, f, _deserialize(data[f], namespace, memo)) except KeyError as e: raise KeyError("Cannot find key for class", cls, e) - postprocess = getattr(inst, '_deserialize', None) - if postprocess: - postprocess() + + if hasattr(inst, '_deserialize'): + inst._deserialize() + return inst