From 8d0f4009c30377bcee1d382beb4237f392d88415 Mon Sep 17 00:00:00 2001 From: Miroslav Date: Tue, 31 Jul 2018 18:21:51 +0300 Subject: [PATCH 1/5] Commands rework. Asyncio support for commands --- mitmproxy/command.py | 81 +++++++++++++++++----- mitmproxy/language/parser.py | 61 +++++++++++----- mitmproxy/language/traversal.py | 30 ++++++++ mitmproxy/optmanager.py | 4 +- mitmproxy/tools/console/commandexecutor.py | 43 +++++++----- mitmproxy/tools/console/consoleaddons.py | 71 ++++--------------- mitmproxy/tools/console/defaultkeys.py | 61 ++++++++-------- mitmproxy/tools/console/overlay.py | 14 ++-- test/mitmproxy/test_command.py | 9 +-- test/mitmproxy/test_optmanager.py | 5 +- 10 files changed, 231 insertions(+), 148 deletions(-) create mode 100644 mitmproxy/language/traversal.py diff --git a/mitmproxy/command.py b/mitmproxy/command.py index 03ae454fa..36787f4af 100644 --- a/mitmproxy/command.py +++ b/mitmproxy/command.py @@ -1,6 +1,7 @@ """ This module manages and invokes typed commands. """ +import asyncio import inspect import types import typing @@ -10,7 +11,7 @@ import sys import mitmproxy.types from mitmproxy import exceptions -from mitmproxy.language import lexer, parser +from mitmproxy.language import lexer, parser, traversal def verify_arg_signature(f: typing.Callable, args: list, kwargs: dict) -> None: @@ -33,6 +34,25 @@ def typename(t: type) -> str: return to.display +class AsyncExectuionManager: + def __init__(self) -> None: + self.counter: int = 0 + self.running_commands: typing.Dict[int, asyncio.Task] = {} + + def add_command(self, command_task: asyncio.Task) -> None: + self.counter += 1 + self.running_commands[self.counter] = command_task + + def stop_command(self, cid: int) -> None: + try: + command_task = self.running_commands[cid] + except KeyError: + raise ValueError(f"There is not the command with id={cid}") + else: + command_task.cancel() + del self.running_commands[cid] + + class Command: def __init__(self, manager, path, func) -> None: self.path = path @@ -83,7 +103,10 @@ class Command: pargs = [] for arg, paramtype in zip(args, self.paramtypes): - pargs.append(parsearg(self.manager, arg, paramtype)) + if isinstance(arg, tuple) and arg[0] is mitmproxy.types.CommandTypes.get(paramtype, None): + pargs.append(arg[1]) + else: + pargs.append(parsearg(self.manager, arg, paramtype)) pargs.extend(remainder) return pargs @@ -95,6 +118,8 @@ class Command: ret = self.func(*self.prepare_args(args)) if ret is None and self.returntype is None: return + elif asyncio.iscoroutine(ret): + return ret typ = mitmproxy.types.CommandTypes.get(self.returntype) if not typ.is_valid(self.manager, typ, ret): raise exceptions.CommandError( @@ -102,7 +127,7 @@ class Command: self.path, typ.display ) ) - return ret + return typ, ret ParseResult = typing.NamedTuple( @@ -118,6 +143,7 @@ ParseResult = typing.NamedTuple( class CommandManager(mitmproxy.types._CommandBase): def __init__(self, master): self.master = master + self.async_manager = AsyncExectuionManager() self.command_parser = parser.create_parser(self) self.commands: typing.Dict[str, Command] = {} self.oneword_commands: typing.List[str] = [] @@ -199,29 +225,46 @@ class CommandManager(mitmproxy.types._CommandBase): return parse, remhelp + def get_command_by_path(self, path: str) -> Command: + """ + Returns command by its path. May raise CommandError. + """ + if path not in self.commands: + raise exceptions.CommandError(f"Unknown command: {path}") + return self.commands[path] + def call(self, path: str, *args: typing.Sequence[typing.Any]) -> typing.Any: """ Call a command with native arguments. May raise CommandError. """ - if path not in self.commands: - raise exceptions.CommandError("Unknown command: %s" % path) - return self.commands[path].func(*args) + return self.get_command_by_path(path).func(*args) def call_strings(self, path: str, args: typing.Sequence[str]) -> typing.Any: """ Call a command using a list of string arguments. May raise CommandError. """ - if path not in self.commands: - raise exceptions.CommandError("Unknown command: %s" % path) - return self.commands[path].call(args) + return self.get_command_by_path(path).call(args) - def execute(self, cmdstr: str): + def async_execute(self, cmdstr: str) -> asyncio.Task: + """ + Schedule a command to be executed. May raise CommandError. + """ + lex = lexer.create_lexer(cmdstr, self.oneword_commands) + self.command_parser.asynchoronous = True + parsed_cmd = self.command_parser.parse(lexer=lex) + + execution_coro = traversal.execute_parsed_line(parsed_cmd) + command_task = asyncio.ensure_future(execution_coro) + self.async_manager.add_command(command_task) + return command_task + + def execute(self, cmdstr: str) -> typing.Any: """ Execute a command string. May raise CommandError. """ lex = lexer.create_lexer(cmdstr, self.oneword_commands) - parser_return = self.command_parser.parse(lexer=lex) - return parser_return + parsed_cmd = self.command_parser.parse(lexer=lex) + return parsed_cmd def dump(self, out=sys.stdout) -> None: cmds = list(self.commands.values()) @@ -248,10 +291,16 @@ def parsearg(manager: CommandManager, spec: str, argtype: type) -> typing.Any: def command(path): def decorator(function): - @functools.wraps(function) - def wrapper(*args, **kwargs): - verify_arg_signature(function, args, kwargs) - return function(*args, **kwargs) + if asyncio.iscoroutinefunction(function): + @functools.wraps(function) + async def wrapper(*args, **kwargs): + verify_arg_signature(function, args, kwargs) + return await function(*args, **kwargs) + else: + @functools.wraps(function) + def wrapper(*args, **kwargs): + verify_arg_signature(function, args, kwargs) + return function(*args, **kwargs) wrapper.__dict__["command_path"] = path return wrapper return decorator diff --git a/mitmproxy/language/parser.py b/mitmproxy/language/parser.py index fd9dcd1c3..55d2ecdb1 100644 --- a/mitmproxy/language/parser.py +++ b/mitmproxy/language/parser.py @@ -1,4 +1,5 @@ import typing +import asyncio import ply.lex as lex import ply.yacc as yacc @@ -8,21 +9,34 @@ from mitmproxy import exceptions from mitmproxy.language.lexer import CommandLanguageLexer +ParsedEntity = typing.Union[str, "ParsedCommand"] + + +ParsedCommand = typing.NamedTuple( + "ParsedCommand", + [ + ("command", "mitmproxy.command.Command"), + ("args", typing.List[ParsedEntity]) + ] +) + + class CommandLanguageParser: # the list of possible tokens is always required tokens = CommandLanguageLexer.tokens def __init__(self, command_manager: "mitmproxy.command.CommandManager") -> None: - self.return_value: typing.Any = None - self._pipe_value: typing.Any = None + self.parsed_line: ParsedEntity = None + self._parsed_pipe_elem: ParsedCommand = None + self.asynchoronous: bool = False self.command_manager = command_manager # Grammar rules def p_command_line(self, p): """command_line : starting_expression pipes_chain""" - self.return_value = self._pipe_value + self.parsed_line = self._parsed_pipe_elem def p_starting_expression(self, p): """starting_expression : PLAIN_STR @@ -31,7 +45,7 @@ class CommandLanguageParser: | command_call_no_parentheses | command_call_with_parentheses""" p[0] = p[1] - self._pipe_value = p[0] + self._parsed_pipe_elem = p[0] def p_pipes_chain(self, p): """pipes_chain : empty @@ -43,19 +57,19 @@ class CommandLanguageParser: """pipe_expression : PIPE COMMAND argument_list pipe_expression : PIPE COMMAND LPAREN argument_list RPAREN""" if len(p) == 4: - new_args = [self._pipe_value, *p[3]] + new_args = [self._parsed_pipe_elem, *p[3]] else: - new_args = [self._pipe_value, *p[4]] - p[0] = self.command_manager.call_strings(p[2], new_args) - self._pipe_value = p[0] + new_args = [self._parsed_pipe_elem, *p[4]] + p[0] = self._call_command(p[2], new_args) + self._parsed_pipe_elem = p[0] - def p_call_command_no_parentheses(self, p): + def p_command_call_no_parentheses(self, p): """command_call_no_parentheses : COMMAND argument_list""" - p[0] = self.command_manager.call_strings(p[1], p[2]) + p[0] = self._call_command(p[1], p[2]) - def p_call_command_with_parentheses(self, p): + def p_command_call_with_parentheses(self, p): """command_call_with_parentheses : COMMAND LPAREN argument_list RPAREN""" - p[0] = self.command_manager.call_strings(p[1], p[3]) + p[0] = self._call_command(p[1], p[3]) def p_argument_list(self, p): """argument_list : empty @@ -73,7 +87,7 @@ class CommandLanguageParser: def p_array(self, p): """array : LBRACE argument_list RBRACE""" - p[0] = ",".join(p[2]) if p[2] else "" + p[0] = p[2] def p_quoted_str(self, p): """quoted_str : QUOTED_STR""" @@ -88,8 +102,22 @@ class CommandLanguageParser: else: raise exceptions.CommandError(f"Syntax error at '{p.value}'") + # Supporting methods + + def _call_command(self, command: str, + args: typing.List[ParsedEntity]) -> ParsedCommand: + if self.asynchoronous: + c = self.command_manager.get_command_by_path(command) + ret = ParsedCommand(c, args) + else: + ret = self.command_manager.call_strings(command, args) + if asyncio.iscoroutine(ret): + raise ValueError(f"You are trying to run async " + f"command {command} through sync executor.") + return ret + @staticmethod - def _create_list(p: yacc.YaccProduction) -> typing.List[typing.Any]: + def _create_list(p: yacc.YaccProduction) -> typing.List[ParsedEntity]: if len(p) == 2: p[0] = [] if p[1] is None else [p[1]] else: @@ -103,8 +131,9 @@ class CommandLanguageParser: def parse(self, lexer: lex.Lexer, **kwargs) -> typing.Any: self.parser.parse(lexer=lexer, **kwargs) - self._pipe_value = None - return self.return_value + self._parsed_pipe_elem = None + self.parser.asynchoronous = False + return self.parsed_line def create_parser( diff --git a/mitmproxy/language/traversal.py b/mitmproxy/language/traversal.py new file mode 100644 index 000000000..63c316ee9 --- /dev/null +++ b/mitmproxy/language/traversal.py @@ -0,0 +1,30 @@ +import typing +import asyncio + +import mitmproxy.command # noqa +from mitmproxy.language.parser import ParsedCommand, ParsedEntity + + +async def execute_parsed_line(line: ParsedEntity): + if isinstance(line, ParsedCommand): + return await traverse_entity(line.command, line.args) + else: + return line + + +async def traverse_entity(command: typing.Optional["mitmproxy.command.Command"], + args: typing.List[ParsedEntity]): + for i, arg in enumerate(args): + if isinstance(arg, ParsedCommand): + args[i] = await traverse_entity(arg.command, arg.args) + elif isinstance(arg, list): + args[i] = await traverse_entity(args=arg, command=command) + + if command is not None: + ret = command.call(args) + if asyncio.iscoroutine(ret): + return await ret + else: + return ret + else: + return args diff --git a/mitmproxy/optmanager.py b/mitmproxy/optmanager.py index 06e696c01..9d31eccd5 100644 --- a/mitmproxy/optmanager.py +++ b/mitmproxy/optmanager.py @@ -234,8 +234,8 @@ class OptManager: if attr not in self._options: raise KeyError("No such option: %s" % attr) - def setter(x): - setattr(self, attr, x) + def setter(future_x): + setattr(self, attr, future_x.result()) return setter def toggler(self, attr): diff --git a/mitmproxy/tools/console/commandexecutor.py b/mitmproxy/tools/console/commandexecutor.py index 3db03d3e6..bb728be4c 100644 --- a/mitmproxy/tools/console/commandexecutor.py +++ b/mitmproxy/tools/console/commandexecutor.py @@ -17,21 +17,30 @@ class CommandExecutor: ret = self.master.commands.execute(cmd) except exceptions.CommandError as v: signals.status_message.send(message=str(v)) + except ValueError: + # Asynchronous launch + command_task = self.master.commands.async_execute(cmd) + command_task.add_done_callback(self.check_return) else: - if ret: - if type(ret) == typing.Sequence[flow.Flow]: - signals.status_message.send( - message="Command returned %s flows" % len(ret) - ) - elif type(ret) == flow.Flow: - signals.status_message.send( - message="Command returned 1 flow" - ) - else: - self.master.overlay( - overlay.DataViewerOverlay( - self.master, - ret, - ), - valign="top" - ) \ No newline at end of file + self.check_return(ret=ret) + + def check_return(self, task=None, ret=None): + if task is not None: + ret = task.result() + if ret: + if type(ret) == typing.Sequence[flow.Flow]: + signals.status_message.send( + message=f"Command returned {len(ret)} flows" + ) + elif type(ret) == flow.Flow: + signals.status_message.send( + message="Command returned 1 flow" + ) + else: + self.master.overlay( + overlay.DataViewerOverlay( + self.master, + ret, + ), + valign="top" + ) \ No newline at end of file diff --git a/mitmproxy/tools/console/consoleaddons.py b/mitmproxy/tools/console/consoleaddons.py index 54fe11c49..1ced557c1 100644 --- a/mitmproxy/tools/console/consoleaddons.py +++ b/mitmproxy/tools/console/consoleaddons.py @@ -1,5 +1,4 @@ import csv -import shlex import typing from mitmproxy import ctx @@ -147,7 +146,7 @@ class ConsoleAddon: fv = self.master.window.current("options") if not fv: raise exceptions.CommandError("Not viewing options.") - self.master.commands.execute("options.reset.one %s" % fv.current_name()) + self.master.commands.execute(f"options.reset.one {fv.current_name()}") @command.command("console.nav.start") def nav_start(self) -> None: @@ -220,65 +219,25 @@ class ConsoleAddon: self.master.inject_key("right") @command.command("console.choose") - def console_choose( + async def console_choose( self, prompt: str, - choices: typing.Sequence[str], - cmd: mitmproxy.types.Cmd, - *args: mitmproxy.types.Arg - ) -> None: + choices: typing.Sequence[str] + ) -> str: """ - Prompt the user to choose from a specified list of strings, then - invoke another command with all occurrences of {choice} replaced by - the choice the user made. + Prompt the user to choose from a list of strings. + Wait until user makes choice. Returns it. """ - def callback(opt): - # We're now outside of the call context... - repl = cmd + " " + " ".join(args) - repl = repl.replace("{choice}", opt) - try: - self.master.commands.execute(repl) - except exceptions.CommandError as e: - signals.status_message.send(message=str(e)) - - self.master.overlay( - overlay.Chooser(self.master, prompt, choices, "", callback) - ) - - @command.command("console.choose.cmd") - def console_choose_cmd( - self, - prompt: str, - choicecmd: mitmproxy.types.Cmd, - subcmd: mitmproxy.types.Cmd, - *args: mitmproxy.types.Arg - ) -> None: - """ - Prompt the user to choose from a list of strings returned by a - command, then invoke another command with all occurrences of {choice} - replaced by the choice the user made. - """ - choices = ctx.master.commands.call_strings(choicecmd, []) - - def callback(opt): - # We're now outside of the call context... - repl = shlex.quote(" ".join(args)) - repl = repl.replace("{choice}", opt) - try: - self.master.commands.execute(subcmd + " " + repl) - except exceptions.CommandError as e: - signals.status_message.send(message=str(e)) - - self.master.overlay( - overlay.Chooser(self.master, prompt, choices, "", callback) - ) + chooser = overlay.Chooser(self.master, prompt, choices, "") + self.master.overlay(chooser) + return await chooser.get_choice() @command.command("console.command") - def console_command(self, *partial: str) -> None: + def console_command(self, command_parts: typing.Sequence[str]) -> None: """ Prompt the user to edit a command with a (possibly empty) starting value. """ - signals.status_prompt_command.send(partial=" ".join(partial)) # type: ignore + signals.status_prompt_command.send(partial=" ".join(command_parts)) # type: ignore @command.command("console.command.set") def console_command_set(self, option: str) -> None: @@ -288,7 +247,7 @@ class ConsoleAddon: option_value = getattr(self.master.options, option, None) current_value = option_value if option_value else "" self.master.commands.execute( - "console.command set %s=%s" % (option, current_value) + f"console.command [set {option}={current_value}]" ) @command.command("console.view.keybindings") @@ -430,7 +389,7 @@ class ConsoleAddon: self.master.switch_view("edit_focus_setcookies") elif part in ["url", "method", "status_code", "reason"]: self.master.commands.execute( - "console.command flow.set @focus %s " % part + f"console.command [flow.set @focus {part}]" ) def _grideditor(self): @@ -516,7 +475,7 @@ class ConsoleAddon: try: self.master.commands.call_strings( "view.setval", - ["@focus", "flowview_mode_%s" % idx, mode] + ["@focus", f"flowview_mode_{idx}", mode] ) except exceptions.CommandError as e: signals.status_message.send(message=str(e)) @@ -541,7 +500,7 @@ class ConsoleAddon: "view.getval", [ "@focus", - "flowview_mode_%s" % idx, + f"flowview_mode_{idx}", self.master.options.console_default_contentview, ] ) diff --git a/mitmproxy/tools/console/defaultkeys.py b/mitmproxy/tools/console/defaultkeys.py index 7f65c1f7d..ce87572a9 100644 --- a/mitmproxy/tools/console/defaultkeys.py +++ b/mitmproxy/tools/console/defaultkeys.py @@ -1,6 +1,6 @@ def map(km): - km.add(":", "console.command ", ["global"], "Command prompt") + km.add(":", "console.command []", ["global"], "Command prompt") km.add("?", "console.view.help", ["global"], "View help") km.add("B", "browser.start", ["global"], "Start an attached browser") km.add("C", "console.view.commands", ["global"], "View commands") @@ -32,7 +32,7 @@ def map(km): km.add("A", "flow.resume @all", ["flowlist", "flowview"], "Resume all intercepted flows") km.add("a", "flow.resume @focus", ["flowlist", "flowview"], "Resume this intercepted flow") km.add( - "b", "console.command cut.save @focus response.content ", + "b", "console.command [cut.save @focus response.content]", ["flowlist", "flowview"], "Save response body to file" ) @@ -41,8 +41,10 @@ def map(km): km.add( "e", """ - console.choose.cmd Format export.formats - console.command export.file {choice} @focus + console.command [ + export.file + console.choose(Format export.formats()) @focus + ] """, ["flowlist", "flowview"], "Export this flow to file" @@ -51,40 +53,37 @@ def map(km): km.add("F", "set console_focus_follow=toggle", ["flowlist"], "Set focus follow") km.add( "ctrl l", - "console.command cut.clip ", + "console.command [cut.clip]", ["flowlist", "flowview"], "Send cuts to clipboard" ) - km.add("L", "console.command view.load ", ["flowlist"], "Load flows from file") + km.add("L", "console.command [view.load]", ["flowlist"], "Load flows from file") km.add("m", "flow.mark.toggle @focus", ["flowlist"], "Toggle mark on this flow") km.add("M", "view.marked.toggle", ["flowlist"], "Toggle viewing marked flows") km.add( "n", - "console.command view.create get https://example.com/", + "console.command [view.create get https://example.com/]", ["flowlist"], "Create a new flow" ) km.add( "o", - """ - console.choose.cmd Order view.order.options - set view_order={choice} - """, + "set view_order=console.choose(Order view.order.options())", ["flowlist"], "Set flow list order" ) km.add("r", "replay.client @focus", ["flowlist", "flowview"], "Replay this flow") - km.add("S", "console.command replay.server ", ["flowlist"], "Start server replay") + km.add("S", "console.command [replay.server]", ["flowlist"], "Start server replay") km.add("v", "set view_order_reversed=toggle", ["flowlist"], "Reverse flow list order") km.add("U", "flow.mark @all false", ["flowlist"], "Un-set all marks") - km.add("w", "console.command save.file @shown ", ["flowlist"], "Save listed flows to file") + km.add("w", "console.command [save.file @shown]", ["flowlist"], "Save listed flows to file") km.add("V", "flow.revert @focus", ["flowlist", "flowview"], "Revert changes to this flow") km.add("X", "flow.kill @focus", ["flowlist"], "Kill this flow") km.add("z", "view.remove @all", ["flowlist"], "Clear flow list") km.add("Z", "view.remove @hidden", ["flowlist"], "Purge all flows not showing") km.add( "|", - "console.command script.run @focus ", + "console.command [script.run @focus]", ["flowlist", "flowview"], "Run a script on this flow" ) @@ -92,8 +91,8 @@ def map(km): km.add( "e", """ - console.choose.cmd Part console.edit.focus.options - console.edit.focus {choice} + console.edit.focus + console.choose(Part console.edit.focus.options()) """, ["flowview"], "Edit a flow component" @@ -104,14 +103,14 @@ def map(km): ["flowview"], "Toggle viewing full contents on this flow", ) - km.add("w", "console.command save.file @focus ", ["flowview"], "Save flow to file") + km.add("w", "console.command [save.file @focus]", ["flowview"], "Save flow to file") km.add("space", "view.focus.next", ["flowview"], "Go to next flow") km.add( "v", """ - console.choose "View Part" request,response - console.bodyview @focus {choice} + @focus | + console.bodyview console.choose("View Part" [request response]) """, ["flowview"], "View flow body in an external viewer" @@ -120,8 +119,8 @@ def map(km): km.add( "m", """ - console.choose.cmd Mode console.flowview.mode.options - console.flowview.mode.set {choice} + console.flowview.mode.set + console.choose(Mode console.flowview.mode.options()) """, ["flowview"], "Set flow view mode" @@ -129,15 +128,15 @@ def map(km): km.add( "z", """ - console.choose "Part" request,response - flow.encode.toggle @focus {choice} + @focus | + flow.encode.toggle console.choose(Part [request response]) """, ["flowview"], "Encode/decode flow body" ) - km.add("L", "console.command options.load ", ["options"], "Load from file") - km.add("S", "console.command options.save ", ["options"], "Save to file") + km.add("L", "console.command [options.load]", ["options"], "Load from file") + km.add("S", "console.command [options.save]", ["options"], "Save to file") km.add("D", "options.reset", ["options"], "Reset all options") km.add("d", "console.options.reset.focus", ["options"], "Reset this option") @@ -146,20 +145,20 @@ def map(km): km.add("d", "console.grideditor.delete", ["grideditor"], "Delete this row") km.add( "r", - "console.command console.grideditor.load", + "console.command [console.grideditor.load]", ["grideditor"], "Read unescaped data into the current cell from file" ) km.add( "R", - "console.command console.grideditor.load_escaped", + "console.command [console.grideditor.load_escaped]", ["grideditor"], "Load a Python-style escaped string into the current cell from file" ) km.add("e", "console.grideditor.editor", ["grideditor"], "Edit in external editor") km.add( "w", - "console.command console.grideditor.save ", + "console.command [console.grideditor.save]", ["grideditor"], "Save data to file as CSV" ) @@ -169,8 +168,10 @@ def map(km): km.add( "a", """ - console.choose.cmd "Context" console.key.contexts - console.command console.key.bind {choice} + console.command [ + console.key.bind + console.choose(Context console.key.contexts()) + ] """, ["keybindings"], "Add a key binding" diff --git a/mitmproxy/tools/console/overlay.py b/mitmproxy/tools/console/overlay.py index 8b195703c..46a008917 100644 --- a/mitmproxy/tools/console/overlay.py +++ b/mitmproxy/tools/console/overlay.py @@ -1,4 +1,5 @@ import math +import asyncio import urwid @@ -104,10 +105,12 @@ class ChooserListWalker(urwid.ListWalker): class Chooser(urwid.WidgetWrap, layoutwidget.LayoutWidget): keyctx = "chooser" - def __init__(self, master, title, choices, current, callback): + def __init__(self, master, title, choices, current, callback=None): self.master = master self.choices = choices - self.callback = callback + self._future_choice = asyncio.get_event_loop().create_future() + if callback: + self._future_choice.add_done_callback(callback) choicewidth = max([len(i) for i in choices]) self.width = max(choicewidth, len(title)) + 7 @@ -125,6 +128,9 @@ class Chooser(urwid.WidgetWrap, layoutwidget.LayoutWidget): ) ) + async def get_choice(self): + return await self._future_choice + def selectable(self): return True @@ -132,11 +138,11 @@ class Chooser(urwid.WidgetWrap, layoutwidget.LayoutWidget): key = self.master.keymap.handle_only("chooser", key) choice = self.walker.choice_by_shortcut(key) if choice: - self.callback(choice) + self._future_choice.set_result(choice) signals.pop_view_state.send(self) return if key == "m_select": - self.callback(self.choices[self.walker.index]) + self._future_choice.set_result(self.choices[self.walker.index]) signals.pop_view_state.send(self) return elif key in ["q", "esc"]: diff --git a/test/mitmproxy/test_command.py b/test/mitmproxy/test_command.py index 802ebd1f9..b26ddac96 100644 --- a/test/mitmproxy/test_command.py +++ b/test/mitmproxy/test_command.py @@ -33,8 +33,8 @@ class TAddon: return choices @command.command("cmd6") - def cmd6(self, one: str, two: str) -> str: - return f"{one} {two}" + def cmd6(self, pipe_value: str) -> str: + return pipe_value @command.command("subcommand") def subcommand(self, cmd: mitmproxy.types.Cmd, *args: mitmproxy.types.Arg) -> str: @@ -306,10 +306,7 @@ def test_simple(): assert(c.execute("one.two foo") == "ret foo") assert (c.execute("one.two(foo)") == "ret foo") assert (c.execute("array.command [1 2 3]") == ["1", "2", "3"]) - assert (c.execute("foo | one.two | one.two") == "ret ret foo") - assert (c.execute("one | pipe.command(two) |" - " pipe.command(three)") == "one two three") - + assert (c.execute("foo | pipe.command") == "foo") assert(c.execute("one.two \"foo\"") == "ret foo") assert(c.execute("one.two 'foo'") == "ret foo") assert(c.call("one.two", "foo") == "ret foo") diff --git a/test/mitmproxy/test_optmanager.py b/test/mitmproxy/test_optmanager.py index 1e4f09d47..6b0c9f0a7 100644 --- a/test/mitmproxy/test_optmanager.py +++ b/test/mitmproxy/test_optmanager.py @@ -1,6 +1,7 @@ import copy import pytest import typing +import asyncio import argparse from mitmproxy import options @@ -114,7 +115,9 @@ def test_options(): def test_setter(): o = TO() f = o.setter("two") - f(99) + opt_future = asyncio.Future() + opt_future.set_result(99) + f(opt_future) assert o.two == 99 with pytest.raises(Exception, match="No such option"): o.setter("nonexistent") From 1c3c39cadfd05c55c2d7510731aa19be67d7081c Mon Sep 17 00:00:00 2001 From: Miroslav Date: Sat, 4 Aug 2018 02:18:10 +0300 Subject: [PATCH 2/5] Asyncio bugs fixed. --- mitmproxy/command.py | 79 +++++++++++++++++----- mitmproxy/exceptions.py | 4 ++ mitmproxy/language/lexer.py | 9 ++- mitmproxy/language/parser.py | 33 +++++---- mitmproxy/language/traversal.py | 10 ++- mitmproxy/tools/console/commandexecutor.py | 14 ++-- mitmproxy/tools/console/commands.py | 1 + mitmproxy/tools/console/consoleaddons.py | 18 +++-- mitmproxy/tools/console/keymap.py | 7 ++ mitmproxy/tools/console/overlay.py | 3 +- 10 files changed, 125 insertions(+), 53 deletions(-) diff --git a/mitmproxy/command.py b/mitmproxy/command.py index 36787f4af..ab6063f99 100644 --- a/mitmproxy/command.py +++ b/mitmproxy/command.py @@ -34,23 +34,43 @@ def typename(t: type) -> str: return to.display +RunningCommand = typing.NamedTuple( + "RunningCommand", + [ + ("cmdstr", str), + ("task", asyncio.Task) + ], +) + + class AsyncExectuionManager: def __init__(self) -> None: self.counter: int = 0 - self.running_commands: typing.Dict[int, asyncio.Task] = {} + self.running_cmds: typing.Dict[int, RunningCommand] = {} - def add_command(self, command_task: asyncio.Task) -> None: + def add_command(self, cmd: RunningCommand) -> None: self.counter += 1 - self.running_commands[self.counter] = command_task + cmd.task.add_done_callback(functools.partial(self._delete_callback, + cid=self.counter)) + self.running_cmds[self.counter] = cmd def stop_command(self, cid: int) -> None: try: - command_task = self.running_commands[cid] + cmd = self.running_cmds[cid] except KeyError: raise ValueError(f"There is not the command with id={cid}") else: - command_task.cancel() - del self.running_commands[cid] + cmd.task.cancel() + del self.running_cmds[cid] + + def get_running(self) -> typing.List[typing.Tuple[int, str]]: + running = [] + for cid in sorted(self.running_cmds): + running.append((cid, self.running_cmds[cid].cmdstr)) + return running + + def _delete_callback(self, task: asyncio.Task, cid: int) -> None: + del self.running_cmds[cid] class Command: @@ -58,6 +78,7 @@ class Command: self.path = path self.manager = manager self.func = func + self.asyncf = True if asyncio.iscoroutinefunction(func) else False sig = inspect.signature(self.func) self.help = None if func.__doc__: @@ -102,24 +123,29 @@ class Command: args = args[:len(self.paramtypes) - 1] pargs = [] + for arg, paramtype in zip(args, self.paramtypes): - if isinstance(arg, tuple) and arg[0] is mitmproxy.types.CommandTypes.get(paramtype, None): - pargs.append(arg[1]) + if not isinstance(arg, str): + t = mitmproxy.types.CommandTypes.get(paramtype, None) + if t.is_valid(self.manager, t, arg): + pargs.append(arg) + else: + raise exceptions.CommandError( + f"{arg} is unexpected data for {paramtype.display} type" + ) else: pargs.append(parsearg(self.manager, arg, paramtype)) pargs.extend(remainder) return pargs - def call(self, args: typing.Sequence[str]) -> typing.Any: + def call(self, args: typing.Sequence[typing.Any]) -> typing.Any: """ - Call the command with a list of arguments. At this point, all - arguments are strings. + Call the command with a list of arguments. """ ret = self.func(*self.prepare_args(args)) + if ret is None and self.returntype is None: return - elif asyncio.iscoroutine(ret): - return ret typ = mitmproxy.types.CommandTypes.get(self.returntype) if not typ.is_valid(self.manager, typ, ret): raise exceptions.CommandError( @@ -127,7 +153,24 @@ class Command: self.path, typ.display ) ) - return typ, ret + return ret + + async def async_call(self, args: typing.Sequence[typing.Any]) -> typing.Any: + """ + Call the command with a list of arguments asynchronously. + """ + ret = await self.func(*self.prepare_args(args)) + + if ret is None and self.returntype is None: + return + typ = mitmproxy.types.CommandTypes.get(self.returntype) + if not typ.is_valid(self.manager, typ, ret): + raise exceptions.CommandError( + "%s returned unexpected data - expected %s" % ( + self.path, typ.display + ) + ) + return ret ParseResult = typing.NamedTuple( @@ -178,7 +221,7 @@ class CommandManager(mitmproxy.types._CommandBase): """ Parse a possibly partial command. Return a sequence of ParseResults and a sequence of remainder type help items. """ - parts: typing.List[str] = lexer.get_tokens(cmdstr) + parts: typing.List[str] = [t.value for t in lexer.get_tokens(cmdstr)] if not parts: parts = [""] elif parts[-1].isspace(): @@ -250,12 +293,11 @@ class CommandManager(mitmproxy.types._CommandBase): Schedule a command to be executed. May raise CommandError. """ lex = lexer.create_lexer(cmdstr, self.oneword_commands) - self.command_parser.asynchoronous = True - parsed_cmd = self.command_parser.parse(lexer=lex) + parsed_cmd = self.command_parser.parse(lexer=lex, async_exec=True) execution_coro = traversal.execute_parsed_line(parsed_cmd) command_task = asyncio.ensure_future(execution_coro) - self.async_manager.add_command(command_task) + self.async_manager.add_command(RunningCommand(cmdstr, command_task)) return command_task def execute(self, cmdstr: str) -> typing.Any: @@ -263,6 +305,7 @@ class CommandManager(mitmproxy.types._CommandBase): Execute a command string. May raise CommandError. """ lex = lexer.create_lexer(cmdstr, self.oneword_commands) + self.command_parser.async_exec = False parsed_cmd = self.command_parser.parse(lexer=lex) return parsed_cmd diff --git a/mitmproxy/exceptions.py b/mitmproxy/exceptions.py index d568898be..c182dc299 100644 --- a/mitmproxy/exceptions.py +++ b/mitmproxy/exceptions.py @@ -97,6 +97,10 @@ class CommandError(Exception): pass +class ExecutionError(CommandError): + pass + + class OptionsError(MitmproxyException): pass diff --git a/mitmproxy/language/lexer.py b/mitmproxy/language/lexer.py index 9318ba753..96bca1ea1 100644 --- a/mitmproxy/language/lexer.py +++ b/mitmproxy/language/lexer.py @@ -62,8 +62,11 @@ def create_lexer(cmdstr: str, oneword_commands: typing.Sequence[str]) -> lex.Lex return command_lexer.lexer -def get_tokens(cmdstr: str, state="interactive") -> typing.List[str]: - lexer = create_lexer(cmdstr, []) +def get_tokens(cmdstr: str, state="interactive", + oneword_commands=None) -> typing.List[lex.LexToken]: + if oneword_commands is None: + oneword_commands = [] + lexer = create_lexer(cmdstr, oneword_commands) # Switching to the other state lexer.begin(state) - return [token.value for token in lexer] + return list(lexer) diff --git a/mitmproxy/language/parser.py b/mitmproxy/language/parser.py index 55d2ecdb1..f85e01ba5 100644 --- a/mitmproxy/language/parser.py +++ b/mitmproxy/language/parser.py @@ -1,5 +1,4 @@ import typing -import asyncio import ply.lex as lex import ply.yacc as yacc @@ -9,7 +8,7 @@ from mitmproxy import exceptions from mitmproxy.language.lexer import CommandLanguageLexer -ParsedEntity = typing.Union[str, "ParsedCommand"] +ParsedEntity = typing.Union[str, list, "ParsedCommand"] ParsedCommand = typing.NamedTuple( @@ -29,7 +28,7 @@ class CommandLanguageParser: command_manager: "mitmproxy.command.CommandManager") -> None: self.parsed_line: ParsedEntity = None self._parsed_pipe_elem: ParsedCommand = None - self.asynchoronous: bool = False + self.async_exec: bool = False self.command_manager = command_manager # Grammar rules @@ -97,6 +96,7 @@ class CommandLanguageParser: """empty :""" def p_error(self, p): + self._reset_internals() if p is None: raise exceptions.CommandError("Syntax error at EOF") else: @@ -106,16 +106,22 @@ class CommandLanguageParser: def _call_command(self, command: str, args: typing.List[ParsedEntity]) -> ParsedCommand: - if self.asynchoronous: - c = self.command_manager.get_command_by_path(command) - ret = ParsedCommand(c, args) + cmd = self.command_manager.get_command_by_path(command) + if self.async_exec: + ret = ParsedCommand(cmd, args) else: - ret = self.command_manager.call_strings(command, args) - if asyncio.iscoroutine(ret): - raise ValueError(f"You are trying to run async " - f"command {command} through sync executor.") + if cmd.asyncf: + self._reset_internals() + raise exceptions.ExecutionError(f"You are trying to run async " + f"command '{command}' through sync executor.") + else: + ret = cmd.call(args) return ret + def _reset_internals(self): + self._parsed_pipe_elem = None + self.async_exec = False + @staticmethod def _create_list(p: yacc.YaccProduction) -> typing.List[ParsedEntity]: if len(p) == 2: @@ -129,10 +135,11 @@ class CommandLanguageParser: self.parser = yacc.yacc(module=self, errorlog=yacc.NullLogger(), **kwargs) - def parse(self, lexer: lex.Lexer, **kwargs) -> typing.Any: + def parse(self, lexer: lex.Lexer, + async_exec=False, **kwargs) -> typing.Any: + self.async_exec = async_exec self.parser.parse(lexer=lexer, **kwargs) - self._parsed_pipe_elem = None - self.parser.asynchoronous = False + self._reset_internals() return self.parsed_line diff --git a/mitmproxy/language/traversal.py b/mitmproxy/language/traversal.py index 63c316ee9..cc50bfa3d 100644 --- a/mitmproxy/language/traversal.py +++ b/mitmproxy/language/traversal.py @@ -1,5 +1,4 @@ import typing -import asyncio import mitmproxy.command # noqa from mitmproxy.language.parser import ParsedCommand, ParsedEntity @@ -18,13 +17,12 @@ async def traverse_entity(command: typing.Optional["mitmproxy.command.Command"], if isinstance(arg, ParsedCommand): args[i] = await traverse_entity(arg.command, arg.args) elif isinstance(arg, list): - args[i] = await traverse_entity(args=arg, command=command) + args[i] = await traverse_entity(command=None, args=arg) if command is not None: - ret = command.call(args) - if asyncio.iscoroutine(ret): - return await ret + if command.asyncf: + return await command.async_call(args) else: - return ret + return command.call(args) else: return args diff --git a/mitmproxy/tools/console/commandexecutor.py b/mitmproxy/tools/console/commandexecutor.py index bb728be4c..d11f672ae 100644 --- a/mitmproxy/tools/console/commandexecutor.py +++ b/mitmproxy/tools/console/commandexecutor.py @@ -1,4 +1,5 @@ import typing +import asyncio from mitmproxy import exceptions from mitmproxy import flow @@ -15,18 +16,21 @@ class CommandExecutor: if cmd.strip(): try: ret = self.master.commands.execute(cmd) - except exceptions.CommandError as v: - signals.status_message.send(message=str(v)) - except ValueError: + except exceptions.ExecutionError: # Asynchronous launch command_task = self.master.commands.async_execute(cmd) command_task.add_done_callback(self.check_return) + except exceptions.CommandError as v: + signals.status_message.send(message=str(v)) else: self.check_return(ret=ret) def check_return(self, task=None, ret=None): if task is not None: - ret = task.result() + try: + ret = task.result() + except asyncio.CancelledError: + return if ret: if type(ret) == typing.Sequence[flow.Flow]: signals.status_message.send( @@ -43,4 +47,4 @@ class CommandExecutor: ret, ), valign="top" - ) \ No newline at end of file + ) diff --git a/mitmproxy/tools/console/commands.py b/mitmproxy/tools/console/commands.py index 0f35742b3..9d8215b3c 100644 --- a/mitmproxy/tools/console/commands.py +++ b/mitmproxy/tools/console/commands.py @@ -18,6 +18,7 @@ class CommandItem(urwid.WidgetWrap): def get_widget(self): parts = [ ("focus", ">> " if self.focused else " "), + ("text", "async " if self.cmd.asyncf else " "), ("title", self.cmd.path), ("text", " "), ("text", " ".join(self.cmd.paramnames())), diff --git a/mitmproxy/tools/console/consoleaddons.py b/mitmproxy/tools/console/consoleaddons.py index 1ced557c1..54682bc88 100644 --- a/mitmproxy/tools/console/consoleaddons.py +++ b/mitmproxy/tools/console/consoleaddons.py @@ -574,15 +574,19 @@ class ConsoleAddon: @command.command("console.key.edit.focus") def key_edit_focus(self) -> None: """ - Execute the currently focused key binding. + Edit the currently focused key binding. """ b = self._keyfocus() - self.console_command( - "console.key.bind", - ",".join(b.contexts), - b.key, - b.command, - ) + print(f"""console.command [ console.key.bind + [{" ".join(b.contexts)}] + {b.key} + "{b.command}" ]""") + # self.console_command( + # "console.key.bind", + # ",".join(b.contexts), + # b.key, + # b.command, + # ) def running(self): self.started = True diff --git a/mitmproxy/tools/console/keymap.py b/mitmproxy/tools/console/keymap.py index d22420bfd..2fac8601b 100644 --- a/mitmproxy/tools/console/keymap.py +++ b/mitmproxy/tools/console/keymap.py @@ -4,6 +4,7 @@ import os import ruamel.yaml from mitmproxy import command +from mitmproxy.language import lexer from mitmproxy.tools.console import commandexecutor from mitmproxy.tools.console import signals from mitmproxy import ctx @@ -55,6 +56,7 @@ class Binding: class Keymap: def __init__(self, master): + self.oneword_commands = master.commands.oneword_commands self.executor = commandexecutor.CommandExecutor(master) self.keys = {} for c in Contexts: @@ -153,6 +155,11 @@ class Keymap: return self.executor(b.command) return key + def _get_braced_command(self, command): + tokens = lexer.get_tokens(command, self.oneword_commands) + + + keyAttrs = { "key": lambda x: isinstance(x, str), diff --git a/mitmproxy/tools/console/overlay.py b/mitmproxy/tools/console/overlay.py index 46a008917..86da20642 100644 --- a/mitmproxy/tools/console/overlay.py +++ b/mitmproxy/tools/console/overlay.py @@ -108,7 +108,7 @@ class Chooser(urwid.WidgetWrap, layoutwidget.LayoutWidget): def __init__(self, master, title, choices, current, callback=None): self.master = master self.choices = choices - self._future_choice = asyncio.get_event_loop().create_future() + self._future_choice = asyncio.Future() if callback: self._future_choice.add_done_callback(callback) choicewidth = max([len(i) for i in choices]) @@ -146,6 +146,7 @@ class Chooser(urwid.WidgetWrap, layoutwidget.LayoutWidget): signals.pop_view_state.send(self) return elif key in ["q", "esc"]: + self._future_choice.cancel() signals.pop_view_state.send(self) return From c0ed80e02b7d624bda922914e1aaa61377a89b24 Mon Sep 17 00:00:00 2001 From: Miroslav Date: Sat, 4 Aug 2018 11:56:53 +0300 Subject: [PATCH 3/5] Excess changes were canceled --- mitmproxy/command.py | 3 +-- mitmproxy/language/lexer.py | 7 ++----- mitmproxy/language/parser.py | 2 +- mitmproxy/tools/console/consoleaddons.py | 16 ++++++---------- mitmproxy/tools/console/keymap.py | 5 ----- 5 files changed, 10 insertions(+), 23 deletions(-) diff --git a/mitmproxy/command.py b/mitmproxy/command.py index ab6063f99..df32f9fea 100644 --- a/mitmproxy/command.py +++ b/mitmproxy/command.py @@ -221,7 +221,7 @@ class CommandManager(mitmproxy.types._CommandBase): """ Parse a possibly partial command. Return a sequence of ParseResults and a sequence of remainder type help items. """ - parts: typing.List[str] = [t.value for t in lexer.get_tokens(cmdstr)] + parts: typing.List[str] = lexer.get_tokens(cmdstr) if not parts: parts = [""] elif parts[-1].isspace(): @@ -305,7 +305,6 @@ class CommandManager(mitmproxy.types._CommandBase): Execute a command string. May raise CommandError. """ lex = lexer.create_lexer(cmdstr, self.oneword_commands) - self.command_parser.async_exec = False parsed_cmd = self.command_parser.parse(lexer=lex) return parsed_cmd diff --git a/mitmproxy/language/lexer.py b/mitmproxy/language/lexer.py index 96bca1ea1..c204b82f3 100644 --- a/mitmproxy/language/lexer.py +++ b/mitmproxy/language/lexer.py @@ -62,11 +62,8 @@ def create_lexer(cmdstr: str, oneword_commands: typing.Sequence[str]) -> lex.Lex return command_lexer.lexer -def get_tokens(cmdstr: str, state="interactive", - oneword_commands=None) -> typing.List[lex.LexToken]: - if oneword_commands is None: - oneword_commands = [] - lexer = create_lexer(cmdstr, oneword_commands) +def get_tokens(cmdstr: str, state="interactive") -> typing.List[lex.LexToken]: + lexer = create_lexer(cmdstr, []) # Switching to the other state lexer.begin(state) return list(lexer) diff --git a/mitmproxy/language/parser.py b/mitmproxy/language/parser.py index f85e01ba5..c3923ccfa 100644 --- a/mitmproxy/language/parser.py +++ b/mitmproxy/language/parser.py @@ -136,7 +136,7 @@ class CommandLanguageParser: errorlog=yacc.NullLogger(), **kwargs) def parse(self, lexer: lex.Lexer, - async_exec=False, **kwargs) -> typing.Any: + async_exec: bool=False, **kwargs) -> typing.Any: self.async_exec = async_exec self.parser.parse(lexer=lexer, **kwargs) self._reset_internals() diff --git a/mitmproxy/tools/console/consoleaddons.py b/mitmproxy/tools/console/consoleaddons.py index 54682bc88..be170f567 100644 --- a/mitmproxy/tools/console/consoleaddons.py +++ b/mitmproxy/tools/console/consoleaddons.py @@ -577,16 +577,12 @@ class ConsoleAddon: Edit the currently focused key binding. """ b = self._keyfocus() - print(f"""console.command [ console.key.bind - [{" ".join(b.contexts)}] - {b.key} - "{b.command}" ]""") - # self.console_command( - # "console.key.bind", - # ",".join(b.contexts), - # b.key, - # b.command, - # ) + self.console_command( + "console.key.bind", + ",".join(b.contexts), + b.key, + b.command, + ) def running(self): self.started = True diff --git a/mitmproxy/tools/console/keymap.py b/mitmproxy/tools/console/keymap.py index 2fac8601b..d85d7d68b 100644 --- a/mitmproxy/tools/console/keymap.py +++ b/mitmproxy/tools/console/keymap.py @@ -155,11 +155,6 @@ class Keymap: return self.executor(b.command) return key - def _get_braced_command(self, command): - tokens = lexer.get_tokens(command, self.oneword_commands) - - - keyAttrs = { "key": lambda x: isinstance(x, str), From c78f10f8596114c37a663c3c46e42e7d1416b721 Mon Sep 17 00:00:00 2001 From: Miroslav Date: Sat, 4 Aug 2018 12:50:49 +0300 Subject: [PATCH 4/5] get_tokens changes go back --- mitmproxy/language/lexer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mitmproxy/language/lexer.py b/mitmproxy/language/lexer.py index c204b82f3..2e44d7df2 100644 --- a/mitmproxy/language/lexer.py +++ b/mitmproxy/language/lexer.py @@ -66,4 +66,4 @@ def get_tokens(cmdstr: str, state="interactive") -> typing.List[lex.LexToken]: lexer = create_lexer(cmdstr, []) # Switching to the other state lexer.begin(state) - return list(lexer) + return [token.value for token in lexer] From a69eea67fbb7a58aad0a3fd513bd2a70a5d581ce Mon Sep 17 00:00:00 2001 From: Miroslav Date: Sun, 12 Aug 2018 16:00:39 +0300 Subject: [PATCH 5/5] Going into mergeable state. Some typing fixes since typing module doesn't fully support recursive types. a few typos were fixed test coverage, new tests for the new code new syntax feature - assignment for 'set' command. --- mitmproxy/command.py | 22 +++-- mitmproxy/language/lexer.py | 4 +- mitmproxy/language/parser.py | 23 ++++-- mitmproxy/tools/console/keymap.py | 2 - test/mitmproxy/language/test_traversal.py | 46 +++++++++++ test/mitmproxy/test_command.py | 80 ++++++++++++++++++- .../tools/console/test_defaultkeys.py | 13 ++- 7 files changed, 158 insertions(+), 32 deletions(-) create mode 100644 test/mitmproxy/language/test_traversal.py diff --git a/mitmproxy/command.py b/mitmproxy/command.py index df32f9fea..5c07f0ad5 100644 --- a/mitmproxy/command.py +++ b/mitmproxy/command.py @@ -38,12 +38,12 @@ RunningCommand = typing.NamedTuple( "RunningCommand", [ ("cmdstr", str), - ("task", asyncio.Task) + ("task", asyncio.Future) ], ) -class AsyncExectuionManager: +class AsyncExecutionManager: def __init__(self) -> None: self.counter: int = 0 self.running_cmds: typing.Dict[int, RunningCommand] = {} @@ -114,7 +114,7 @@ class Command: ret = " -> " + ret return "%s %s%s" % (self.path, params, ret) - def prepare_args(self, args: typing.Sequence[str]) -> typing.List[typing.Any]: + def prepare_args(self, args: typing.Sequence[typing.Any]) -> typing.List[typing.Any]: verify_arg_signature(self.func, list(args), {}) remainder: typing.Sequence[str] = [] @@ -131,7 +131,7 @@ class Command: pargs.append(arg) else: raise exceptions.CommandError( - f"{arg} is unexpected data for {paramtype.display} type" + f"{arg} is unexpected data for {t.display} type" ) else: pargs.append(parsearg(self.manager, arg, paramtype)) @@ -149,9 +149,7 @@ class Command: typ = mitmproxy.types.CommandTypes.get(self.returntype) if not typ.is_valid(self.manager, typ, ret): raise exceptions.CommandError( - "%s returned unexpected data - expected %s" % ( - self.path, typ.display - ) + f"{self.path} returned unexpected data - expected {typ.display}" ) return ret @@ -166,9 +164,7 @@ class Command: typ = mitmproxy.types.CommandTypes.get(self.returntype) if not typ.is_valid(self.manager, typ, ret): raise exceptions.CommandError( - "%s returned unexpected data - expected %s" % ( - self.path, typ.display - ) + f"{self.path} returned unexpected data - expected {typ.display}" ) return ret @@ -186,7 +182,7 @@ ParseResult = typing.NamedTuple( class CommandManager(mitmproxy.types._CommandBase): def __init__(self, master): self.master = master - self.async_manager = AsyncExectuionManager() + self.async_manager = AsyncExecutionManager() self.command_parser = parser.create_parser(self) self.commands: typing.Dict[str, Command] = {} self.oneword_commands: typing.List[str] = [] @@ -288,7 +284,7 @@ class CommandManager(mitmproxy.types._CommandBase): """ return self.get_command_by_path(path).call(args) - def async_execute(self, cmdstr: str) -> asyncio.Task: + def async_execute(self, cmdstr: str) -> asyncio.Future: """ Schedule a command to be executed. May raise CommandError. """ @@ -324,7 +320,7 @@ def parsearg(manager: CommandManager, spec: str, argtype: type) -> typing.Any: """ t = mitmproxy.types.CommandTypes.get(argtype, None) if not t: - raise exceptions.CommandError("Unsupported argument type: %s" % argtype) + raise exceptions.CommandError(f"Unsupported argument type: {argtype}") try: return t.parse(manager, argtype, spec) # type: ignore except exceptions.TypeError as e: diff --git a/mitmproxy/language/lexer.py b/mitmproxy/language/lexer.py index 2e44d7df2..bc76e6f05 100644 --- a/mitmproxy/language/lexer.py +++ b/mitmproxy/language/lexer.py @@ -8,6 +8,7 @@ class CommandLanguageLexer: tokens = ( "WHITESPACE", "PIPE", + "EQUAL_SIGN", "LPAREN", "RPAREN", "LBRACE", "RBRACE", "PLAIN_STR", "QUOTED_STR", @@ -23,12 +24,13 @@ class CommandLanguageLexer: # Main(INITIAL) state t_ignore_WHITESPACE = r"\s+" t_PIPE = r"\|" + t_EQUAL_SIGN = r"\=" t_LPAREN = r"\(" t_RPAREN = r"\)" t_LBRACE = r"\[" t_RBRACE = r"\]" - special_symbols = re.escape("()[]|") + special_symbols = re.escape("()[]|=") plain_str = rf"[^{special_symbols}\s]+" def t_COMMAND(self, t): diff --git a/mitmproxy/language/parser.py b/mitmproxy/language/parser.py index c3923ccfa..cc7824095 100644 --- a/mitmproxy/language/parser.py +++ b/mitmproxy/language/parser.py @@ -1,4 +1,5 @@ import typing +import collections import ply.lex as lex import ply.yacc as yacc @@ -11,12 +12,8 @@ from mitmproxy.language.lexer import CommandLanguageLexer ParsedEntity = typing.Union[str, list, "ParsedCommand"] -ParsedCommand = typing.NamedTuple( - "ParsedCommand", - [ - ("command", "mitmproxy.command.Command"), - ("args", typing.List[ParsedEntity]) - ] +ParsedCommand = collections.namedtuple( + "ParsedCommand", ["command", "args"] ) @@ -41,8 +38,7 @@ class CommandLanguageParser: """starting_expression : PLAIN_STR | quoted_str | array - | command_call_no_parentheses - | command_call_with_parentheses""" + | command_call""" p[0] = p[1] self._parsed_pipe_elem = p[0] @@ -62,6 +58,11 @@ class CommandLanguageParser: p[0] = self._call_command(p[2], new_args) self._parsed_pipe_elem = p[0] + def p_command_call(self, p): + """command_call : command_call_no_parentheses + | command_call_with_parentheses""" + p[0] = p[1] + def p_command_call_no_parentheses(self, p): """command_call_no_parentheses : COMMAND argument_list""" p[0] = self._call_command(p[1], p[2]) @@ -76,11 +77,17 @@ class CommandLanguageParser: | argument_list argument""" p[0] = self._create_list(p) + def p_assignment(self, p): + """assignment : PLAIN_STR EQUAL_SIGN starting_expression + | QUOTED_STR EQUAL_SIGN starting_expression""" + p[0] = f"{p[1]}{p[2]}{p[3]}" + def p_argument(self, p): """argument : PLAIN_STR | quoted_str | array | COMMAND + | assignment | command_call_with_parentheses""" p[0] = p[1] diff --git a/mitmproxy/tools/console/keymap.py b/mitmproxy/tools/console/keymap.py index d85d7d68b..d22420bfd 100644 --- a/mitmproxy/tools/console/keymap.py +++ b/mitmproxy/tools/console/keymap.py @@ -4,7 +4,6 @@ import os import ruamel.yaml from mitmproxy import command -from mitmproxy.language import lexer from mitmproxy.tools.console import commandexecutor from mitmproxy.tools.console import signals from mitmproxy import ctx @@ -56,7 +55,6 @@ class Binding: class Keymap: def __init__(self, master): - self.oneword_commands = master.commands.oneword_commands self.executor = commandexecutor.CommandExecutor(master) self.keys = {} for c in Contexts: diff --git a/test/mitmproxy/language/test_traversal.py b/test/mitmproxy/language/test_traversal.py new file mode 100644 index 000000000..5f9e1c9b3 --- /dev/null +++ b/test/mitmproxy/language/test_traversal.py @@ -0,0 +1,46 @@ +import typing +import asyncio + +from mitmproxy import command +from mitmproxy.test import taddons +from mitmproxy.language import lexer, parser, traversal + +import pytest + + +class TAddon: + @command.command("cmd1") + def cmd1(self, foo: typing.Sequence[str]) -> str: + return " ".join(foo) + + @command.command("cmd2") + def cmd2(self, foo: str) -> str: + return foo + + @command.command("cmd3") + async def cmd3(self, foo: str) -> str: + await asyncio.sleep(0.01) + return foo + + +@pytest.mark.asyncio +async def test_execute_parsed_line(): + test_commands = ["""join.cmd1 [str.cmd2(abc) + str.cmd2(strasync.cmd3("def"))]""", + "[1 2 3]", "str.cmd2 abc | strasync.cmd3()"] + results = ["abc def", ['1', '2', '3'], "abc"] + + with taddons.context() as tctx: + cm = command.CommandManager(tctx.master) + a = TAddon() + cm.add("join.cmd1", a.cmd1) + cm.add("str.cmd2", a.cmd2) + cm.add("strasync.cmd3", a.cmd3) + + command_parser = parser.create_parser(cm) + for cmd, exp_res in zip(test_commands, results): + lxr = lexer.create_lexer(cmd, cm.oneword_commands) + parsed = command_parser.parse(lxr, async_exec=True) + + result = await traversal.execute_parsed_line(parsed) + assert result == exp_res diff --git a/test/mitmproxy/test_command.py b/test/mitmproxy/test_command.py index b26ddac96..44fb1a06e 100644 --- a/test/mitmproxy/test_command.py +++ b/test/mitmproxy/test_command.py @@ -1,11 +1,15 @@ import typing import inspect +import asyncio +from unittest import mock + +import mitmproxy.types from mitmproxy import command from mitmproxy import flow from mitmproxy import exceptions from mitmproxy.test import tflow from mitmproxy.test import taddons -import mitmproxy.types + import io import pytest @@ -36,6 +40,15 @@ class TAddon: def cmd6(self, pipe_value: str) -> str: return pipe_value + @command.command("cmd7") + async def cmd7(self, foo: str) -> str: + await asyncio.sleep(0.01) + return foo + + @command.command("cmd8") + async def cmd8(self, foo: str) -> str: + return 99 + @command.command("subcommand") def subcommand(self, cmd: mitmproxy.types.Cmd, *args: mitmproxy.types.Arg) -> str: return "ok" @@ -44,6 +57,10 @@ class TAddon: def empty(self) -> None: pass + @command.command("empty") + async def asyncempty(self) -> None: + pass + @command.command("varargs") def varargs(self, one: str, *var: str) -> typing.Sequence[str]: return list(var) @@ -82,6 +99,35 @@ class TypeErrAddon: pass +class TestAsyncExecutionManager: + def test_add_command(self): + aem = command.AsyncExecutionManager() + dummy_command = command.RunningCommand("addon.command", mock.Mock()) + aem.add_command(dummy_command) + assert aem.running_cmds == {1: dummy_command} + + def test_stop_command(self): + aem = command.AsyncExecutionManager() + dummy_command = command.RunningCommand("addon.command", mock.Mock()) + aem.add_command(dummy_command) + with pytest.raises(ValueError, match="There is not the command"): + aem.stop_command(100) + + aem.stop_command(1) + assert aem.running_cmds == {} + + def test_get_runnings(self): + aem = command.AsyncExecutionManager() + expected_res = [] + for i in range(3): + cmd = f"addon.command{i}" + dummy = command.RunningCommand(cmd, mock.Mock()) + aem.add_command(dummy) + expected_res.append((i + 1, cmd)) + + assert aem.get_running() == expected_res + + class TestCommand: def test_typecheck(self): with taddons.context(loadcore=False) as tctx: @@ -115,9 +161,25 @@ class TestCommand: with pytest.raises(exceptions.CommandError): c.call(["foo"]) + with pytest.raises(exceptions.CommandError, match="unexpected data"): + c.call([123]) + c = command.Command(cm, "cmd.three", a.cmd3) assert c.call(["1"]) == 1 + @pytest.mark.asyncio + async def test_async_call(self): + with taddons.context() as tctx: + cm = command.CommandManager(tctx.master) + a = TAddon() + + c = command.Command(cm, "async.empty", a.asyncempty) + await c.async_call([]) + + c = command.Command(cm, "asynccmd.two", a.cmd8) + with pytest.raises(exceptions.CommandError, match="unexpected data"): + await c.async_call(["foo"]) + def test_parse_partial(self): tests = [ [ @@ -301,6 +363,7 @@ def test_simple(): c.add("one.two", a.cmd1) c.add("array.command", a.cmd5) c.add("pipe.command", a.cmd6) + c.add("strasync.command", a.cmd7) assert c.commands["one.two"].help == "cmd1 help" assert(c.execute("one.two foo") == "ret foo") @@ -320,6 +383,8 @@ def test_simple(): c.execute("") with pytest.raises(exceptions.CommandError, match="argument mismatch"): c.execute("one.two too many args") + with pytest.raises(exceptions.ExecutionError, match="sync executor"): + c.execute("strasync.command abc") with pytest.raises(exceptions.CommandError, match="Unknown"): c.call("nonexistent") @@ -331,6 +396,19 @@ def test_simple(): assert fp.getvalue() +@pytest.mark.asyncio +async def test_async_execute(): + with taddons.context() as tctx: + c = command.CommandManager(tctx.master) + a = TAddon() + c.add("strasync.command", a.cmd7) + + c.async_execute("strasync.command abc") + assert c.async_manager.get_running() == [(1, "strasync.command abc")] + assert "abc" == await c.async_manager.running_cmds[1].task + assert c.async_manager.get_running() == [] + + def test_typename(): assert command.typename(str) == "str" assert command.typename(typing.Sequence[flow.Flow]) == "[flow]" diff --git a/test/mitmproxy/tools/console/test_defaultkeys.py b/test/mitmproxy/tools/console/test_defaultkeys.py index f60174aa7..fe32a1646 100644 --- a/test/mitmproxy/tools/console/test_defaultkeys.py +++ b/test/mitmproxy/tools/console/test_defaultkeys.py @@ -2,7 +2,7 @@ from mitmproxy.test.tflow import tflow from mitmproxy.tools.console import defaultkeys from mitmproxy.tools.console import keymap from mitmproxy.tools.console import master -from mitmproxy.language import lexer +from mitmproxy.language import lexer, parser import pytest @@ -15,12 +15,11 @@ async def test_commands_exist(): m = master.ConsoleMaster(None) await m.load_flow(tflow()) - for binding in km.bindings: - cmd, *args = lexer.get_tokens(binding.command, state="INITIAL") - assert cmd in m.commands.commands + command_parser = parser.create_parser(m.commands) - cmd_obj = m.commands.commands[cmd] + for binding in km.bindings: + lxr = lexer.create_lexer(binding.command, m.commands.oneword_commands) try: - cmd_obj.prepare_args(args) + command_parser.parse(lxr, async_exec=True) except Exception as e: - raise ValueError("Invalid command: {}".format(binding.command)) from e + raise ValueError(f"Invalid command: '{binding.command}'") from e