Optimizations and one intentional loophole by Jim Fulton.

The optimizations consist mostly of using local variables to cache
methods or instance variables used a lot (e.g. "self.write").

The loopholes allows marshalling extension types as long as they have
a __class__ attribute (in which case they may support the rest of the
class piclking protocol as well).  This allows pickling MESS extension
types.
This commit is contained in:
Guido van Rossum 1996-07-22 22:26:07 +00:00
parent 77c29a1734
commit c7c5e697c3
1 changed files with 82 additions and 58 deletions

View File

@ -126,7 +126,7 @@
I have no answers. Garbage Collection may also become a problem here.) I have no answers. Garbage Collection may also become a problem here.)
""" """
__version__ = "1.5" # Code version __version__ = "1.6" # Code version
from types import * from types import *
import string import string
@ -200,8 +200,11 @@ def save(self, object):
try: try:
f = self.dispatch[t] f = self.dispatch[t]
except KeyError: except KeyError:
raise PicklingError, \ if hasattr(object, '__class__'):
"can't pickle %s objects" % `t.__name__` f = self.dispatch[InstanceType]
else:
raise PicklingError, \
"can't pickle %s objects" % `t.__name__`
f(self, object) f(self, object)
def persistent_id(self, object): def persistent_id(self, object):
@ -234,66 +237,75 @@ def save_string(self, object):
def save_tuple(self, object): def save_tuple(self, object):
d = id(object) d = id(object)
self.write(MARK) write = self.write
save = self.save
has_key = self.memo.has_key
write(MARK)
n = len(object) n = len(object)
for k in range(n): for k in range(n):
self.save(object[k]) save(object[k])
if self.memo.has_key(d): if has_key(d):
# Saving object[k] has saved us! # Saving object[k] has saved us!
while k >= 0: while k >= 0:
self.write(POP) write(POP)
k = k-1 k = k-1
self.write(GET + `d` + '\n') write(GET + `d` + '\n')
break break
else: else:
self.write(TUPLE + PUT + `d` + '\n') write(TUPLE + PUT + `d` + '\n')
self.memo[d] = object self.memo[d] = object
dispatch[TupleType] = save_tuple dispatch[TupleType] = save_tuple
def save_list(self, object): def save_list(self, object):
d = id(object) d = id(object)
self.write(MARK) write = self.write
save = self.save
write(MARK)
n = len(object) n = len(object)
for k in range(n): for k in range(n):
item = object[k] item = object[k]
if not safe(item): if not safe(item):
break break
self.save(item) save(item)
else: else:
k = n k = n
self.write(LIST + PUT + `d` + '\n') write(LIST + PUT + `d` + '\n')
self.memo[d] = object self.memo[d] = object
for k in range(k, n): for k in range(k, n):
item = object[k] item = object[k]
self.save(item) save(item)
self.write(APPEND) write(APPEND)
dispatch[ListType] = save_list dispatch[ListType] = save_list
def save_dict(self, object): def save_dict(self, object):
d = id(object) d = id(object)
self.write(MARK) write = self.write
save = self.save
write(MARK)
items = object.items() items = object.items()
n = len(items) n = len(items)
for k in range(n): for k in range(n):
key, value = items[k] key, value = items[k]
if not safe(key) or not safe(value): if not safe(key) or not safe(value):
break break
self.save(key) save(key)
self.save(value) save(value)
else: else:
k = n k = n
self.write(DICT + PUT + `d` + '\n') self.write(DICT + PUT + `d` + '\n')
self.memo[d] = object self.memo[d] = object
for k in range(k, n): for k in range(k, n):
key, value = items[k] key, value = items[k]
self.save(key) save(key)
self.save(value) save(value)
self.write(SETITEM) write(SETITEM)
dispatch[DictionaryType] = save_dict dispatch[DictionaryType] = save_dict
def save_inst(self, object): def save_inst(self, object):
d = id(object) d = id(object)
cls = object.__class__ cls = object.__class__
write = self.write
save = self.save
module = whichmodule(cls) module = whichmodule(cls)
name = cls.__name__ name = cls.__name__
if hasattr(object, '__getinitargs__'): if hasattr(object, '__getinitargs__'):
@ -301,11 +313,11 @@ def save_inst(self, object):
len(args) # XXX Assert it's a sequence len(args) # XXX Assert it's a sequence
else: else:
args = () args = ()
self.write(MARK) write(MARK)
for arg in args: for arg in args:
self.save(arg) save(arg)
self.write(INST + module + '\n' + name + '\n' + write(INST + module + '\n' + name + '\n' +
PUT + `d` + '\n') PUT + `d` + '\n')
self.memo[d] = object self.memo[d] = object
try: try:
getstate = object.__getstate__ getstate = object.__getstate__
@ -313,8 +325,8 @@ def save_inst(self, object):
stuff = object.__dict__ stuff = object.__dict__
else: else:
stuff = getstate() stuff = getstate()
self.save(stuff) save(stuff)
self.write(BUILD) write(BUILD)
dispatch[InstanceType] = save_inst dispatch[InstanceType] = save_inst
def save_class(self, object): def save_class(self, object):
@ -361,16 +373,21 @@ def __init__(self, file):
def load(self): def load(self):
self.mark = ['spam'] # Any new unique object self.mark = ['spam'] # Any new unique object
self.stack = [] self.stack = []
self.append = self.stack.append
read = self.read
dispatch = self.dispatch
try: try:
while 1: while 1:
key = self.read(1) key = read(1)
self.dispatch[key](self) dispatch[key](self)
except STOP, value: except STOP, value:
return value return value
def marker(self): def marker(self):
k = len(self.stack)-1 stack = self.stack
while self.stack[k] != self.mark: k = k-1 mark = self.mark
k = len(stack)-1
while stack[k] is not mark: k = k-1
return k return k
dispatch = {} dispatch = {}
@ -381,27 +398,28 @@ def load_eof(self):
def load_persid(self): def load_persid(self):
pid = self.readline()[:-1] pid = self.readline()[:-1]
self.stack.append(self.persistent_load(pid)) self.append(self.persistent_load(pid))
dispatch[PERSID] = load_persid dispatch[PERSID] = load_persid
def load_none(self): def load_none(self):
self.stack.append(None) self.append(None)
dispatch[NONE] = load_none dispatch[NONE] = load_none
def load_int(self): def load_int(self):
self.stack.append(string.atoi(self.readline()[:-1], 0)) self.append(string.atoi(self.readline()[:-1], 0))
dispatch[INT] = load_int dispatch[INT] = load_int
def load_long(self): def load_long(self):
self.stack.append(string.atol(self.readline()[:-1], 0)) self.append(string.atol(self.readline()[:-1], 0))
dispatch[LONG] = load_long dispatch[LONG] = load_long
def load_float(self): def load_float(self):
self.stack.append(string.atof(self.readline()[:-1])) self.append(string.atof(self.readline()[:-1]))
dispatch[FLOAT] = load_float dispatch[FLOAT] = load_float
def load_string(self): def load_string(self):
self.stack.append(eval(self.readline()[:-1])) self.append(eval(self.readline()[:-1],
{'__builtins__': {}})) # Let's be careful
dispatch[STRING] = load_string dispatch[STRING] = load_string
def load_tuple(self): def load_tuple(self):
@ -433,14 +451,14 @@ def load_inst(self):
name = self.readline()[:-1] name = self.readline()[:-1]
klass = self.find_class(module, name) klass = self.find_class(module, name)
value = apply(klass, args) value = apply(klass, args)
self.stack.append(value) self.append(value)
dispatch[INST] = load_inst dispatch[INST] = load_inst
def load_class(self): def load_class(self):
module = self.readline()[:-1] module = self.readline()[:-1]
name = self.readline()[:-1] name = self.readline()[:-1]
klass = self.find_class(module, name) klass = self.find_class(module, name)
self.stack.append(klass) self.append(klass)
return klass return klass
dispatch[CLASS] = load_class dispatch[CLASS] = load_class
@ -453,7 +471,9 @@ def find_class(self, module, name):
"Failed to import class %s from module %s" % \ "Failed to import class %s from module %s" % \
(name, module) (name, module)
klass = env[name] klass = env[name]
if type(klass) != ClassType: # if type(klass) != ClassType:
if (type(klass) is FunctionType or
type(klass) is BuiltinFunctionType):
raise SystemError, \ raise SystemError, \
"Imported object %s from module %s is not a class" % \ "Imported object %s from module %s is not a class" % \
(name, module) (name, module)
@ -464,11 +484,11 @@ def load_pop(self):
dispatch[POP] = load_pop dispatch[POP] = load_pop
def load_dup(self): def load_dup(self):
stack.append(stack[-1]) self.append(stack[-1])
dispatch[DUP] = load_dup dispatch[DUP] = load_dup
def load_get(self): def load_get(self):
self.stack.append(self.memo[self.readline()[:-1]]) self.append(self.memo[self.readline()[:-1]])
dispatch[GET] = load_get dispatch[GET] = load_get
def load_put(self): def load_put(self):
@ -476,35 +496,39 @@ def load_put(self):
dispatch[PUT] = load_put dispatch[PUT] = load_put
def load_append(self): def load_append(self):
value = self.stack[-1] stack = self.stack
del self.stack[-1] value = stack[-1]
list = self.stack[-1] del stack[-1]
list = stack[-1]
list.append(value) list.append(value)
dispatch[APPEND] = load_append dispatch[APPEND] = load_append
def load_setitem(self): def load_setitem(self):
value = self.stack[-1] stack = self.stack
key = self.stack[-2] value = stack[-1]
del self.stack[-2:] key = stack[-2]
dict = self.stack[-1] del stack[-2:]
dict = stack[-1]
dict[key] = value dict[key] = value
dispatch[SETITEM] = load_setitem dispatch[SETITEM] = load_setitem
def load_build(self): def load_build(self):
value = self.stack[-1] stack = self.stack
del self.stack[-1] value = stack[-1]
inst = self.stack[-1] del stack[-1]
inst = stack[-1]
try: try:
setstate = inst.__setstate__ setstate = inst.__setstate__
except AttributeError: except AttributeError:
instdict = inst.__dict__
for key in value.keys(): for key in value.keys():
inst.__dict__[key] = value[key] instdict[key] = value[key]
else: else:
setstate(value) setstate(value)
dispatch[BUILD] = load_build dispatch[BUILD] = load_build
def load_mark(self): def load_mark(self):
self.stack.append(self.mark) self.append(self.mark)
dispatch[MARK] = load_mark dispatch[MARK] = load_mark
def load_stop(self): def load_stop(self):
@ -516,12 +540,13 @@ def load_stop(self):
# Shorthands # Shorthands
from StringIO import StringIO
def dump(object, file): def dump(object, file):
Pickler(file).dump(object) Pickler(file).dump(object)
def dumps(object): def dumps(object):
import StringIO file = StringIO()
file = StringIO.StringIO()
Pickler(file).dump(object) Pickler(file).dump(object)
return file.getvalue() return file.getvalue()
@ -529,8 +554,7 @@ def load(file):
return Unpickler(file).load() return Unpickler(file).load()
def loads(str): def loads(str):
import StringIO file = StringIO(str)
file = StringIO.StringIO(str)
return Unpickler(file).load() return Unpickler(file).load()
@ -545,7 +569,7 @@ def test():
c = C() c = C()
c.foo = 1 c.foo = 1
c.bar = 2L c.bar = 2L
x = [0,1,2,3] x = [0, 1, 2, 3]
y = ('abc', 'abc', c, c) y = ('abc', 'abc', c, c)
x.append(y) x.append(y)
x.append(y) x.append(y)