diff --git a/mitmproxy/master.py b/mitmproxy/master.py index 0fcf312ef..31849a887 100644 --- a/mitmproxy/master.py +++ b/mitmproxy/master.py @@ -2,7 +2,6 @@ import threading import contextlib import asyncio import signal -import time from mitmproxy import addonmanager from mitmproxy import options @@ -37,11 +36,10 @@ class Master: The master handles mitmproxy's main event loop. """ def __init__(self, opts): - self.loop = asyncio.new_event_loop() - asyncio.set_event_loop(self.loop) + loop = asyncio.get_event_loop() for signame in ('SIGINT', 'SIGTERM'): - self.loop.add_signal_handler(getattr(signal, signame), self.shutdown) - self.event_queue = asyncio.Queue(loop=self.loop) + loop.add_signal_handler(getattr(signal, signame), self.shutdown) + self.event_queue = asyncio.Queue() self.options = opts or options.Options() # type: options.Options self.commands = command.CommandManager(self) @@ -57,9 +55,7 @@ class Master: @server.setter def server(self, server): - server.set_channel( - controller.Channel(self.loop, self.event_queue) - ) + server.set_channel(controller.Channel(asyncio.get_event_loop(), self.event_queue)) self._server = server @contextlib.contextmanager @@ -111,18 +107,16 @@ class Master: self.addons.trigger("running") while True: if self.should_exit.is_set(): - self.loop.stop() + asyncio.get_event_loop().stop() return self.addons.trigger("tick") - await asyncio.sleep(0.1, loop=self.loop) + await asyncio.sleep(0.1) - def run(self, inject=None): + def run(self): self.start() - asyncio.ensure_future(self.main(), loop=self.loop) - asyncio.ensure_future(self.tick(), loop=self.loop) - if inject: - asyncio.ensure_future(inject(), loop=self.loop) - self.loop.run_forever() + asyncio.ensure_future(self.main()) + asyncio.ensure_future(self.tick()) + asyncio.get_event_loop().run_forever() self.shutdown() self.addons.trigger("done") @@ -214,7 +208,7 @@ class Master: rt = http_replay.RequestReplayThread( self.options, f, - self.loop, + asyncio.get_event_loop(), self.event_queue, self.should_exit ) diff --git a/test/mitmproxy/test_controller.py b/test/mitmproxy/test_controller.py index e27f6baf5..f7c64ed99 100644 --- a/test/mitmproxy/test_controller.py +++ b/test/mitmproxy/test_controller.py @@ -1,102 +1,31 @@ import asyncio -from threading import Thread, Event -from unittest.mock import Mock import queue import pytest -import sys from mitmproxy.exceptions import Kill, ControlException from mitmproxy import controller -from mitmproxy import master -from mitmproxy import proxy from mitmproxy.test import taddons -class TMsg: - pass +@pytest.mark.asyncio +async def test_master(): + class TMsg: + pass - -def test_master(): class tAddon: def log(self, _): ctx.master.should_exit.set() - with taddons.context() as ctx: - ctx.master.addons.add(tAddon()) + with taddons.context(tAddon()) as ctx: assert not ctx.master.should_exit.is_set() async def test(): msg = TMsg() msg.reply = controller.DummyReply() - await ctx.master.event_queue.put(("log", msg)) + await ctx.master.channel.tell("log", msg) - ctx.master.run(inject=test) - - -# class TestMaster: -# # def test_simple(self): -# # class tAddon: -# # def log(self, _): -# # ctx.master.should_exit.set() - -# # with taddons.context() as ctx: -# # ctx.master.addons.add(tAddon()) -# # assert not ctx.master.should_exit.is_set() -# # msg = TMsg() -# # msg.reply = controller.DummyReply() -# # ctx.master.event_queue.put(("log", msg)) -# # ctx.master.run() -# # assert ctx.master.should_exit.is_set() - -# # def test_server_simple(self): -# # m = master.Master(None) -# # m.server = proxy.DummyServer() -# # m.start() -# # m.shutdown() -# # m.start() -# # m.shutdown() -# pass - - -# class TestServerThread: -# def test_simple(self): -# m = Mock() -# t = master.ServerThread(m) -# t.run() -# assert m.serve_forever.called - - -# class TestChannel: -# def test_tell(self): -# q = queue.Queue() -# channel = controller.Channel(q, Event()) -# m = Mock(name="test_tell") -# channel.tell("test", m) -# assert q.get() == ("test", m) -# assert m.reply - -# def test_ask_simple(self): -# q = queue.Queue() - -# def reply(): -# m, obj = q.get() -# assert m == "test" -# obj.reply.send(42) -# obj.reply.take() -# obj.reply.commit() - -# Thread(target=reply).start() - -# channel = controller.Channel(q, Event()) -# assert channel.ask("test", Mock(name="test_ask_simple")) == 42 - -# def test_ask_shutdown(self): -# q = queue.Queue() -# done = Event() -# done.set() -# channel = controller.Channel(q, done) -# with pytest.raises(Kill): -# channel.ask("test", Mock(name="test_ask_shutdown")) + asyncio.ensure_future(test()) + assert not ctx.master.should_exit.is_set() class TestReply: