mirror of https://github.com/python/cpython.git
Issue #28613: Fix get_event_loop() to return the current loop
when called from coroutines or callbacks.
This commit is contained in:
parent
1ea023e523
commit
600a349781
|
@ -393,7 +393,10 @@ def run_forever(self):
|
|||
"""Run until stop() is called."""
|
||||
self._check_closed()
|
||||
if self.is_running():
|
||||
raise RuntimeError('Event loop is running.')
|
||||
raise RuntimeError('This event loop is already running')
|
||||
if events._get_running_loop() is not None:
|
||||
raise RuntimeError(
|
||||
'Cannot run the event loop while another loop is running')
|
||||
self._set_coroutine_wrapper(self._debug)
|
||||
self._thread_id = threading.get_ident()
|
||||
if self._asyncgens is not None:
|
||||
|
@ -401,6 +404,7 @@ def run_forever(self):
|
|||
sys.set_asyncgen_hooks(firstiter=self._asyncgen_firstiter_hook,
|
||||
finalizer=self._asyncgen_finalizer_hook)
|
||||
try:
|
||||
events._set_running_loop(self)
|
||||
while True:
|
||||
self._run_once()
|
||||
if self._stopping:
|
||||
|
@ -408,6 +412,7 @@ def run_forever(self):
|
|||
finally:
|
||||
self._stopping = False
|
||||
self._thread_id = None
|
||||
events._set_running_loop(None)
|
||||
self._set_coroutine_wrapper(False)
|
||||
if self._asyncgens is not None:
|
||||
sys.set_asyncgen_hooks(*old_agen_hooks)
|
||||
|
|
|
@ -607,6 +607,30 @@ def new_event_loop(self):
|
|||
_lock = threading.Lock()
|
||||
|
||||
|
||||
# A TLS for the running event loop, used by _get_running_loop.
|
||||
class _RunningLoop(threading.local):
|
||||
_loop = None
|
||||
_running_loop = _RunningLoop()
|
||||
|
||||
|
||||
def _get_running_loop():
|
||||
"""Return the running event loop or None.
|
||||
|
||||
This is a low-level function intended to be used by event loops.
|
||||
This function is thread-specific.
|
||||
"""
|
||||
return _running_loop._loop
|
||||
|
||||
|
||||
def _set_running_loop(loop):
|
||||
"""Set the running event loop.
|
||||
|
||||
This is a low-level function intended to be used by event loops.
|
||||
This function is thread-specific.
|
||||
"""
|
||||
_running_loop._loop = loop
|
||||
|
||||
|
||||
def _init_event_loop_policy():
|
||||
global _event_loop_policy
|
||||
with _lock:
|
||||
|
@ -632,7 +656,17 @@ def set_event_loop_policy(policy):
|
|||
|
||||
|
||||
def get_event_loop():
|
||||
"""Equivalent to calling get_event_loop_policy().get_event_loop()."""
|
||||
"""Return an asyncio event loop.
|
||||
|
||||
When called from a coroutine or a callback (e.g. scheduled with call_soon
|
||||
or similar API), this function will always return the running event loop.
|
||||
|
||||
If there is no running event loop set, the function will return
|
||||
the result of `get_event_loop_policy().get_event_loop()` call.
|
||||
"""
|
||||
current_loop = _get_running_loop()
|
||||
if current_loop is not None:
|
||||
return current_loop
|
||||
return get_event_loop_policy().get_event_loop()
|
||||
|
||||
|
||||
|
|
|
@ -449,7 +449,13 @@ def new_test_loop(self, gen=None):
|
|||
self.set_event_loop(loop)
|
||||
return loop
|
||||
|
||||
def setUp(self):
|
||||
self._get_running_loop = events._get_running_loop
|
||||
events._get_running_loop = lambda: None
|
||||
|
||||
def tearDown(self):
|
||||
events._get_running_loop = self._get_running_loop
|
||||
|
||||
events.set_event_loop(None)
|
||||
|
||||
# Detect CPython bug #23353: ensure that yield/yield-from is not used
|
||||
|
|
|
@ -154,6 +154,7 @@ def test_ipaddr_info_no_inet_pton(self, m_socket):
|
|||
class BaseEventLoopTests(test_utils.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.loop = base_events.BaseEventLoop()
|
||||
self.loop._selector = mock.Mock()
|
||||
self.loop._selector.select.return_value = ()
|
||||
|
@ -976,6 +977,7 @@ def connection_lost(self, exc):
|
|||
class BaseEventLoopWithSelectorTests(test_utils.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.loop = asyncio.new_event_loop()
|
||||
self.set_event_loop(self.loop)
|
||||
|
||||
|
@ -1692,5 +1694,23 @@ def stop_loop_coro(loop):
|
|||
"took .* seconds$")
|
||||
|
||||
|
||||
class RunningLoopTests(unittest.TestCase):
|
||||
|
||||
def test_running_loop_within_a_loop(self):
|
||||
@asyncio.coroutine
|
||||
def runner(loop):
|
||||
loop.run_forever()
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
outer_loop = asyncio.new_event_loop()
|
||||
try:
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
'while another loop is running'):
|
||||
outer_loop.run_until_complete(runner(loop))
|
||||
finally:
|
||||
loop.close()
|
||||
outer_loop.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -2233,6 +2233,7 @@ def noop(*args, **kwargs):
|
|||
class HandleTests(test_utils.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.loop = mock.Mock()
|
||||
self.loop.get_debug.return_value = True
|
||||
|
||||
|
@ -2411,6 +2412,7 @@ def __await__(self):
|
|||
class TimerTests(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.loop = mock.Mock()
|
||||
|
||||
def test_hash(self):
|
||||
|
@ -2719,6 +2721,27 @@ def test_set_event_loop_policy(self):
|
|||
self.assertIs(policy, asyncio.get_event_loop_policy())
|
||||
self.assertIsNot(policy, old_policy)
|
||||
|
||||
def test_get_event_loop_returns_running_loop(self):
|
||||
class Policy(asyncio.DefaultEventLoopPolicy):
|
||||
def get_event_loop(self):
|
||||
raise NotImplementedError
|
||||
|
||||
loop = None
|
||||
|
||||
old_policy = asyncio.get_event_loop_policy()
|
||||
try:
|
||||
asyncio.set_event_loop_policy(Policy())
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
async def func():
|
||||
self.assertIs(asyncio.get_event_loop(), loop)
|
||||
|
||||
loop.run_until_complete(func())
|
||||
finally:
|
||||
asyncio.set_event_loop_policy(old_policy)
|
||||
if loop is not None:
|
||||
loop.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
|
@ -79,6 +79,7 @@ def __iter__(self):
|
|||
class DuckTests(test_utils.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.loop = self.new_test_loop()
|
||||
self.addCleanup(self.loop.close)
|
||||
|
||||
|
@ -96,6 +97,7 @@ def test_ensure_future(self):
|
|||
class FutureTests(test_utils.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.loop = self.new_test_loop()
|
||||
self.addCleanup(self.loop.close)
|
||||
|
||||
|
@ -468,6 +470,7 @@ def test_set_result_unless_cancelled(self):
|
|||
class FutureDoneCallbackTests(test_utils.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.loop = self.new_test_loop()
|
||||
|
||||
def run_briefly(self):
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
class LockTests(test_utils.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.loop = self.new_test_loop()
|
||||
|
||||
def test_ctor_loop(self):
|
||||
|
@ -235,6 +236,7 @@ def test_context_manager_no_yield(self):
|
|||
class EventTests(test_utils.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.loop = self.new_test_loop()
|
||||
|
||||
def test_ctor_loop(self):
|
||||
|
@ -364,6 +366,7 @@ def c1(result):
|
|||
class ConditionTests(test_utils.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.loop = self.new_test_loop()
|
||||
|
||||
def test_ctor_loop(self):
|
||||
|
@ -699,6 +702,7 @@ def test_ambiguous_loops(self):
|
|||
class SemaphoreTests(test_utils.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.loop = self.new_test_loop()
|
||||
|
||||
def test_ctor_loop(self):
|
||||
|
|
|
@ -17,6 +17,7 @@
|
|||
class BaseTest(test_utils.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.loop = asyncio.BaseEventLoop()
|
||||
self.loop._process_events = mock.Mock()
|
||||
self.loop._selector = mock.Mock()
|
||||
|
|
|
@ -24,6 +24,7 @@ def close_transport(transport):
|
|||
class ProactorSocketTransportTests(test_utils.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.loop = self.new_test_loop()
|
||||
self.addCleanup(self.loop.close)
|
||||
self.proactor = mock.Mock()
|
||||
|
@ -436,6 +437,8 @@ def test_dont_pause_writing(self):
|
|||
class BaseProactorEventLoopTests(test_utils.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
self.sock = test_utils.mock_nonblocking_socket()
|
||||
self.proactor = mock.Mock()
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
class _QueueTestBase(test_utils.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.loop = self.new_test_loop()
|
||||
|
||||
|
||||
|
|
|
@ -51,6 +51,7 @@ def close_transport(transport):
|
|||
class BaseSelectorEventLoopTests(test_utils.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.selector = mock.Mock()
|
||||
self.selector.select.return_value = []
|
||||
self.loop = TestBaseSelectorEventLoop(self.selector)
|
||||
|
@ -698,6 +699,7 @@ def test_accept_connection_multiple(self):
|
|||
class SelectorTransportTests(test_utils.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.loop = self.new_test_loop()
|
||||
self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
|
||||
self.sock = mock.Mock(socket.socket)
|
||||
|
@ -793,6 +795,7 @@ def test_connection_lost(self):
|
|||
class SelectorSocketTransportTests(test_utils.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.loop = self.new_test_loop()
|
||||
self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
|
||||
self.sock = mock.Mock(socket.socket)
|
||||
|
@ -1141,6 +1144,7 @@ def test_transport_close_remove_writer(self, m_log):
|
|||
class SelectorSslTransportTests(test_utils.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.loop = self.new_test_loop()
|
||||
self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
|
||||
self.sock = mock.Mock(socket.socket)
|
||||
|
@ -1501,6 +1505,7 @@ def test_ssl_transport_requires_ssl_module(self):
|
|||
class SelectorDatagramTransportTests(test_utils.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.loop = self.new_test_loop()
|
||||
self.protocol = test_utils.make_test_protocol(asyncio.DatagramProtocol)
|
||||
self.sock = mock.Mock(spec_set=socket.socket)
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
class SslProtoHandshakeTests(test_utils.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.loop = asyncio.new_event_loop()
|
||||
self.set_event_loop(self.loop)
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ class StreamReaderTests(test_utils.TestCase):
|
|||
DATA = b'line1\nline2\nline3\n'
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.loop = asyncio.new_event_loop()
|
||||
self.set_event_loop(self.loop)
|
||||
|
||||
|
|
|
@ -35,6 +35,7 @@ def _start(self, *args, **kwargs):
|
|||
|
||||
class SubprocessTransportTests(test_utils.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.loop = self.new_test_loop()
|
||||
self.set_event_loop(self.loop)
|
||||
|
||||
|
@ -466,6 +467,7 @@ class SubprocessWatcherMixin(SubprocessMixin):
|
|||
Watcher = None
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
policy = asyncio.get_event_loop_policy()
|
||||
self.loop = policy.new_event_loop()
|
||||
self.set_event_loop(self.loop)
|
||||
|
@ -490,6 +492,7 @@ class SubprocessFastWatcherTests(SubprocessWatcherMixin,
|
|||
class SubprocessProactorTests(SubprocessMixin, test_utils.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.loop = asyncio.ProactorEventLoop()
|
||||
self.set_event_loop(self.loop)
|
||||
|
||||
|
|
|
@ -75,6 +75,7 @@ def __call__(self, *args):
|
|||
class TaskTests(test_utils.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.loop = self.new_test_loop()
|
||||
|
||||
def test_other_loop_future(self):
|
||||
|
@ -1933,6 +1934,7 @@ def cancelling_callback(_):
|
|||
class GatherTestsBase:
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.one_loop = self.new_test_loop()
|
||||
self.other_loop = self.new_test_loop()
|
||||
self.set_event_loop(self.one_loop, cleanup=False)
|
||||
|
@ -2216,6 +2218,7 @@ class RunCoroutineThreadsafeTests(test_utils.TestCase):
|
|||
"""Test case for asyncio.run_coroutine_threadsafe."""
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.loop = asyncio.new_event_loop()
|
||||
self.set_event_loop(self.loop) # Will cleanup properly
|
||||
|
||||
|
@ -2306,12 +2309,14 @@ def test_run_coroutine_threadsafe_task_factory_exception(self):
|
|||
|
||||
class SleepTests(test_utils.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(None)
|
||||
|
||||
def tearDown(self):
|
||||
self.loop.close()
|
||||
self.loop = None
|
||||
super().tearDown()
|
||||
|
||||
def test_sleep_zero(self):
|
||||
result = 0
|
||||
|
|
|
@ -40,6 +40,7 @@ def close_pipe_transport(transport):
|
|||
class SelectorEventLoopSignalTests(test_utils.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.loop = asyncio.SelectorEventLoop()
|
||||
self.set_event_loop(self.loop)
|
||||
|
||||
|
@ -234,6 +235,7 @@ def test_close(self, m_signal):
|
|||
class SelectorEventLoopUnixSocketTests(test_utils.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.loop = asyncio.SelectorEventLoop()
|
||||
self.set_event_loop(self.loop)
|
||||
|
||||
|
@ -338,6 +340,7 @@ def test_create_unix_connection_ssl_noserverhost(self):
|
|||
class UnixReadPipeTransportTests(test_utils.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.loop = self.new_test_loop()
|
||||
self.protocol = test_utils.make_test_protocol(asyncio.Protocol)
|
||||
self.pipe = mock.Mock(spec_set=io.RawIOBase)
|
||||
|
@ -487,6 +490,7 @@ def test__call_connection_lost_with_err(self):
|
|||
class UnixWritePipeTransportTests(test_utils.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.loop = self.new_test_loop()
|
||||
self.protocol = test_utils.make_test_protocol(asyncio.BaseProtocol)
|
||||
self.pipe = mock.Mock(spec_set=io.RawIOBase)
|
||||
|
@ -805,6 +809,7 @@ class ChildWatcherTestsMixin:
|
|||
ignore_warnings = mock.patch.object(log.logger, "warning")
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.loop = self.new_test_loop()
|
||||
self.running = False
|
||||
self.zombies = {}
|
||||
|
|
|
@ -31,6 +31,7 @@ def data_received(self, data):
|
|||
class ProactorTests(test_utils.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.loop = asyncio.ProactorEventLoop()
|
||||
self.set_event_loop(self.loop)
|
||||
|
||||
|
|
Loading…
Reference in New Issue