diff --git a/tests/test_base.py b/tests/test_base.py index 603720e..00ecbcd 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -541,6 +541,46 @@ class _TestBase: self.assertFalse(isinstance(task, MyTask)) self.loop.run_until_complete(task) + def test_set_task_name(self): + if self.implementation == 'asyncio' and sys.version_info < (3, 8, 0): + raise unittest.SkipTest('unsupported task name') + + self.loop._process_events = mock.Mock() + + result = None + + class MyTask(asyncio.Task): + def set_name(self, name): + nonlocal result + result = name + "!" + + def get_name(self): + return result + + async def coro(): + pass + + factory = lambda loop, coro: MyTask(coro, loop=loop) + + self.assertIsNone(self.loop.get_task_factory()) + task = self.loop.create_task(coro(), name="mytask") + self.assertFalse(isinstance(task, MyTask)) + if sys.version_info >= (3, 8, 0): + self.assertEqual(task.get_name(), "mytask") + self.loop.run_until_complete(task) + + self.loop.set_task_factory(factory) + self.assertIs(self.loop.get_task_factory(), factory) + + task = self.loop.create_task(coro(), name="mytask") + self.assertTrue(isinstance(task, MyTask)) + self.assertEqual(result, "mytask!") + self.assertEqual(task.get_name(), "mytask!") + self.loop.run_until_complete(task) + + self.loop.set_task_factory(None) + self.assertIsNone(self.loop.get_task_factory()) + def _compile_agen(self, src): try: g = {} diff --git a/uvloop/loop.pyx b/uvloop/loop.pyx index c9bb650..54b1194 100644 --- a/uvloop/loop.pyx +++ b/uvloop/loop.pyx @@ -1386,16 +1386,29 @@ cdef class Loop: """Create a Future object attached to the loop.""" return self._new_future() - def create_task(self, coro): + def create_task(self, coro, *, name=None): """Schedule a coroutine object. Return a task object. + + If name is not None, task.set_name(name) will be called if the task + object has the set_name attribute, true for default Task in Python 3.8. """ self._check_closed() if self._task_factory is None: task = aio_Task(coro, loop=self) else: task = self._task_factory(self, coro) + + # copied from asyncio.tasks._set_task_name (bpo-34270) + if name is not None: + try: + set_name = task.set_name + except AttributeError: + pass + else: + set_name(name) + return task def set_task_factory(self, factory):