Yielding dict in coroutine

This commit is contained in:
Anton Ryzhov 2013-10-17 19:37:00 +04:00
parent a10c1646dc
commit 71c0a9b8b9
2 changed files with 37 additions and 2 deletions

View File

@ -404,6 +404,10 @@ class Multi(YieldPoint):
a list of ``YieldPoints``. a list of ``YieldPoints``.
""" """
def __init__(self, children): def __init__(self, children):
self.keys = None
if isinstance(children, dict):
self.keys = list(children.keys())
children = children.values()
self.children = [] self.children = []
for i in children: for i in children:
if isinstance(i, Future): if isinstance(i, Future):
@ -423,7 +427,11 @@ class Multi(YieldPoint):
return not self.unfinished_children return not self.unfinished_children
def get_result(self): def get_result(self):
return [i.get_result() for i in self.children] result = (i.get_result() for i in self.children)
if self.keys:
return dict(zip(self.keys, result))
else:
return list(result)
class _NullYieldPoint(YieldPoint): class _NullYieldPoint(YieldPoint):
@ -523,7 +531,7 @@ class Runner(object):
self.finished = True self.finished = True
self.yield_point = _null_yield_point self.yield_point = _null_yield_point
raise raise
if isinstance(yielded, list): if isinstance(yielded, (list, dict)):
yielded = Multi(yielded) yielded = Multi(yielded)
elif isinstance(yielded, Future): elif isinstance(yielded, Future):
yielded = YieldFuture(yielded) yielded = YieldFuture(yielded)

View File

@ -281,6 +281,16 @@ class GenEngineTest(AsyncTestCase):
self.stop() self.stop()
self.run_gen(f) self.run_gen(f)
def test_multi_dict(self):
@gen.engine
def f():
(yield gen.Callback("k1"))("v1")
(yield gen.Callback("k2"))("v2")
results = yield dict(foo=gen.Wait("k1"), bar=gen.Wait("k2"))
self.assertEqual(results, dict(foo="v1", bar="v2"))
self.stop()
self.run_gen(f)
def test_multi_delayed(self): def test_multi_delayed(self):
@gen.engine @gen.engine
def f(): def f():
@ -293,6 +303,18 @@ class GenEngineTest(AsyncTestCase):
self.stop() self.stop()
self.run_gen(f) self.run_gen(f)
def test_multi_dict_delayed(self):
@gen.engine
def f():
# callbacks run at different times
responses = yield dict(
foo=gen.Task(self.delay_callback, 3, arg="v1"),
bar=gen.Task(self.delay_callback, 1, arg="v2"),
)
self.assertEqual(responses, dict(foo="v1", bar="v2"))
self.stop()
self.run_gen(f)
@skipOnTravis @skipOnTravis
@gen_test @gen_test
def test_multi_performance(self): def test_multi_performance(self):
@ -314,6 +336,11 @@ class GenEngineTest(AsyncTestCase):
results = yield [self.async_future(1), self.async_future(2)] results = yield [self.async_future(1), self.async_future(2)]
self.assertEqual(results, [1, 2]) self.assertEqual(results, [1, 2])
@gen_test
def test_multi_dict_future(self):
results = yield dict(foo=self.async_future(1), bar=self.async_future(2))
self.assertEqual(results, dict(foo=1, bar=2))
def test_arguments(self): def test_arguments(self):
@gen.engine @gen.engine
def f(): def f():