From de781fdc26c4ca8f652775e7751515c0ea2a68d1 Mon Sep 17 00:00:00 2001 From: Hood Chatham Date: Tue, 7 Sep 2021 09:42:51 -0700 Subject: [PATCH] Handle syntax error in find_imports (#1819) --- docs/project/changelog.md | 4 ++++ src/py/_pyodide/_base.py | 8 ++++++-- src/tests/test_pyodide.py | 24 +++++++++++++++++------- 3 files changed, 27 insertions(+), 9 deletions(-) diff --git a/docs/project/changelog.md b/docs/project/changelog.md index 7215eedad..ee00ad45c 100644 --- a/docs/project/changelog.md +++ b/docs/project/changelog.md @@ -57,6 +57,10 @@ substitutions: - {{Fix}} Fixed a use after free bug in the error handling code. {pr}`1816` +- {{Enhancement}} If `find_imports` is used on code that contains a syntax + error, it will return an empty list instead of raising a `SyntaxError`. + {pr}`1819` + ## Version 0.18.0 _August 3rd, 2021_ diff --git a/src/py/_pyodide/_base.py b/src/py/_pyodide/_base.py index 7fb361a7f..b857d5c9c 100644 --- a/src/py/_pyodide/_base.py +++ b/src/py/_pyodide/_base.py @@ -515,7 +515,8 @@ def find_imports(source: str) -> List[str]: Returns ------- ``List[str]`` - A list of module names that are imported in ``source``. + A list of module names that are imported in ``source``. If ``source`` is not + syntactically correct Python code (after dedenting), returns an empty list. Examples -------- @@ -527,7 +528,10 @@ def find_imports(source: str) -> List[str]: # handle mis-indented input from multi-line strings source = dedent(source) - mod = ast.parse(source) + try: + mod = ast.parse(source) + except SyntaxError: + return [] imports = set() for node in ast.walk(mod): if isinstance(node, ast.Import): diff --git a/src/tests/test_pyodide.py b/src/tests/test_pyodide.py index ca1af4c2b..db8098cd2 100644 --- a/src/tests/test_pyodide.py +++ b/src/tests/test_pyodide.py @@ -11,16 +11,26 @@ from pyodide import find_imports, eval_code, CodeRunner, should_quiet # noqa: E def test_find_imports(): res = find_imports( - dedent( - """ - import numpy as np - from scipy import sparse - import matplotlib.pyplot as plt - """ - ) + """ + import numpy as np + from scipy import sparse + import matplotlib.pyplot as plt + """ ) assert set(res) == {"numpy", "scipy", "matplotlib"} + # If there is a syntax error in the code, find_imports should return empty + # list. + res = find_imports( + """ + import numpy as np + from scipy import sparse + import matplotlib.pyplot as plt + for x in [1,2,3] + """ + ) + assert res == [] + def test_code_runner(): assert should_quiet("1+1;")