diff --git a/Lib/contextlib.py b/Lib/contextlib.py index 33d83a6e14c..33c302dd96f 100644 --- a/Lib/contextlib.py +++ b/Lib/contextlib.py @@ -91,7 +91,6 @@ def nested(*contexts): """ exits = [] vars = [] - exc = (None, None, None) try: try: for context in contexts: @@ -103,6 +102,8 @@ def nested(*contexts): yield vars except: exc = sys.exc_info() + else: + exc = (None, None, None) finally: while exits: exit = exits.pop() @@ -110,6 +111,8 @@ def nested(*contexts): exit(*exc) except: exc = sys.exc_info() + else: + exc = (None, None, None) if exc != (None, None, None): raise diff --git a/Lib/test/test_contextlib.py b/Lib/test/test_contextlib.py index 8c8d887a553..f8db88cc58a 100644 --- a/Lib/test/test_contextlib.py +++ b/Lib/test/test_contextlib.py @@ -107,6 +107,60 @@ def b(): else: self.fail("Didn't raise ZeroDivisionError") + def test_nested_b_swallows(self): + @contextmanager + def a(): + yield + @contextmanager + def b(): + try: + yield + except: + # Swallow the exception + pass + try: + with nested(a(), b()): + 1/0 + except ZeroDivisionError: + self.fail("Didn't swallow ZeroDivisionError") + + def test_nested_break(self): + @contextmanager + def a(): + yield + state = 0 + while True: + state += 1 + with nested(a(), a()): + break + state += 10 + self.assertEqual(state, 1) + + def test_nested_continue(self): + @contextmanager + def a(): + yield + state = 0 + while state < 3: + state += 1 + with nested(a(), a()): + continue + state += 10 + self.assertEqual(state, 3) + + def test_nested_return(self): + @contextmanager + def a(): + try: + yield + except: + pass + def foo(): + with nested(a(), a()): + return 1 + return 10 + self.assertEqual(foo(), 1) + class ClosingTestCase(unittest.TestCase): # XXX This needs more work