diff --git a/Makefile b/Makefile index 1c3259c..d09df6d 100644 --- a/Makefile +++ b/Makefile @@ -101,6 +101,22 @@ src = re.sub( src, flags=re.X) +src = re.sub( + r''' + \s* __Pyx_Coroutine_get_name\(__pyx_CoroutineObject\s+\*self\) + \s* { + \s* Py_INCREF\(self->gi_name\); + ''', + + r''' + __Pyx_Coroutine_get_name(__pyx_CoroutineObject *self) + { + if (self->gi_name == NULL) { return __pyx_empty_unicode; } + Py_INCREF(self->gi_name); + ''', + + src, flags=re.X) + with open('uvloop/loop.c', 'wt') as f: f.write(src) endef diff --git a/tests/test_base.py b/tests/test_base.py index b8dc46b..27defff 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -398,7 +398,13 @@ class _TestBase: class TestBaseUV(_TestBase, UVTestCase): def test_cython_coro_is_coroutine(self): + from asyncio.coroutines import _format_coroutine + coro = self.loop.create_server(object) + + self.assertEqual(_format_coroutine(coro), + 'Loop.create_server()') + self.assertEqual(self.loop.create_server.__qualname__, 'Loop.create_server') self.assertEqual(self.loop.create_server.__name__, @@ -411,6 +417,8 @@ class TestBaseUV(_TestBase, UVTestCase): self.loop.run_until_complete(fut) except asyncio.CancelledError: pass + + _format_coroutine(coro) # This line checks against Cython segfault coro.close() diff --git a/uvloop/__init__.py b/uvloop/__init__.py index 8d1c135..b57dfe1 100644 --- a/uvloop/__init__.py +++ b/uvloop/__init__.py @@ -3,6 +3,7 @@ import asyncio from asyncio.events import BaseDefaultEventLoopPolicy as __BasePolicy from . import includes as __includes +from . import _patch from .loop import Loop as __BaseLoop diff --git a/uvloop/_patch.py b/uvloop/_patch.py new file mode 100644 index 0000000..da3347d --- /dev/null +++ b/uvloop/_patch.py @@ -0,0 +1,19 @@ +import asyncio + +from asyncio import coroutines + + +def _format_coroutine(coro): + if asyncio.iscoroutine(coro) and not hasattr(coro, 'cr_code'): + # Most likely a Cython coroutine + coro_name = '{}()'.format(coro.__qualname__ or coro.__name__) + if coro.cr_running: + return '{} running'.format(coro_name) + else: + return coro_name + + return _old_format_coroutine(coro) + + +_old_format_coroutine = coroutines._format_coroutine +coroutines._format_coroutine = _format_coroutine