From ed59a76f64e052c2003c13086951d496cf1b8ee3 Mon Sep 17 00:00:00 2001 From: Will McGugan Date: Sun, 24 May 2020 12:37:40 +0100 Subject: [PATCH 1/2] deadlock fix --- rich/progress.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/rich/progress.py b/rich/progress.py index cc822d2a..448d5a32 100644 --- a/rich/progress.py +++ b/rich/progress.py @@ -352,7 +352,6 @@ class _RefreshThread(Thread): def stop(self) -> None: self.done.set() - self.join() def run(self) -> None: while not self.done.wait(1.0 / self.refresh_per_second): @@ -432,19 +431,20 @@ class Progress: def stop(self) -> None: """Stop the progress display.""" + assert self._refresh_thread is not None with self._lock: if not self._started: return self._started = False try: - if self.auto_refresh and self._refresh_thread is not None: + if self.auto_refresh: self._refresh_thread.stop() - self._refresh_thread = None self.refresh() if self.console.is_terminal: self.console.line() finally: self.console.show_cursor(True) + self._refresh_thread.join() def __enter__(self) -> "Progress": with self._lock: @@ -456,10 +456,13 @@ class Progress: return self def __exit__(self, exc_type, exc_val, exc_tb) -> None: + stopping = False with self._lock: self._enter_count -= 1 if not self._enter_count: - self.stop() + stopping = True + if stopping: + self.stop() def track( self, From 033aaa1c6c7d33905cfa23b74f1082f1c244f717 Mon Sep 17 00:00:00 2001 From: Will McGugan Date: Sun, 24 May 2020 12:47:29 +0100 Subject: [PATCH 2/2] simplify context manager --- rich/progress.py | 23 ++++++----------------- 1 file changed, 6 insertions(+), 17 deletions(-) diff --git a/rich/progress.py b/rich/progress.py index 448d5a32..c7437791 100644 --- a/rich/progress.py +++ b/rich/progress.py @@ -378,6 +378,7 @@ class Progress: get_time: GetTimeCallable = monotonic, ) -> None: assert refresh_per_second > 0, "refresh_per_second must be > 0" + self._lock = RLock() self.columns = columns or ( TextColumn("[progress.description]{task.description}"), BarColumn(), @@ -392,16 +393,15 @@ class Progress: self._tasks: Dict[TaskID, Task] = {} self._live_render = LiveRender(self.get_renderable()) self._task_index: TaskID = TaskID(0) - self._lock = RLock() self._refresh_thread: Optional[_RefreshThread] = None self._refresh_count = 0 - self._enter_count = 0 self._started = False @property def tasks(self) -> List[Task]: """Get a list of Task instances.""" - return list(self._tasks.values()) + with self._lock: + return list(self._tasks.values()) @property def task_ids(self) -> List[TaskID]: @@ -447,22 +447,11 @@ class Progress: self._refresh_thread.join() def __enter__(self) -> "Progress": - with self._lock: - if self._enter_count: - self._enter_count += 1 - return self - self.start() - self._enter_count += 1 - return self + self.start() + return self def __exit__(self, exc_type, exc_val, exc_tb) -> None: - stopping = False - with self._lock: - self._enter_count -= 1 - if not self._enter_count: - stopping = True - if stopping: - self.stop() + self.stop() def track( self,