From e7ff0dac77d404907eed7c2a3c772739a52b52ae Mon Sep 17 00:00:00 2001 From: Will McGugan Date: Thu, 30 Jul 2020 18:31:45 +0100 Subject: [PATCH] Replace iter_track --- CHANGELOG.md | 6 ++ pyproject.toml | 2 +- rich/progress.py | 123 +++++++++++++++-------------------------- tests/test_progress.py | 31 ----------- 4 files changed, 51 insertions(+), 111 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f59f176d..136edef9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [4.2.2] - Unreleased + +### Changed + +- Added thread to automatically call updated in progress.track(). Replacing previous adaptive algorithm. + ## [4.2.1] - 2020-07-29 ### Added diff --git a/pyproject.toml b/pyproject.toml index 92a297ab..14498ef6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ name = "rich" homepage = "https://github.com/willmcgugan/rich" documentation = "https://rich.readthedocs.io/en/latest/" -version = "4.2.1" +version = "4.2.2" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" authors = ["Will McGugan "] license = "MIT" diff --git a/rich/progress.py b/rich/progress.py index 38b40d6e..51747ad2 100644 --- a/rich/progress.py +++ b/rich/progress.py @@ -53,50 +53,37 @@ ProgressType = TypeVar("ProgressType") GetTimeCallable = Callable[[], float] -def iter_track( - values: Iterable[ProgressType], - total: int, - update_period: float = 0.05, - get_time: Callable[[], float] = None, -) -> Iterable[Iterable[ProgressType]]: - """Break a sequence in to chunks based on time. +class _TrackThread(Thread): + """A thread to periodically update progress.""" - Args: - values (ProgressType): An iterable of values. + def __init__(self, progress: "Progress", task_id: "TaskID", update_period: float): + self.progress = progress + self.task_id = task_id + self.update_period = update_period + self.done = Event() - Returns: - Iterable[List[ProgressType]]: An iterable containing lists of values from sequence. + self.completed = 0 + super().__init__() - """ + def run(self) -> None: + task_id = self.task_id + progress = self.progress + update_period = self.update_period + last_completed = 0 + while not self.done.wait(update_period): + completed = self.completed + if last_completed != completed: + progress.advance(task_id, completed - last_completed) + last_completed = completed + progress.update(self.task_id, completed=self.completed, refresh=True) - if update_period == 0: - for value in values: - yield [value] - return + def __enter__(self) -> "_TrackThread": + self.start() + return self - get_time = get_time or perf_counter - period_size = 1.0 - - def gen_values( - iter_values: Iterator[ProgressType], size: int - ) -> Iterable[ProgressType]: - try: - for _ in range(size): - yield next(iter_values) - except StopIteration: - pass - - iter_values = iter(values) - value_count = 0 - - while value_count < total: - _count = max(1, min(int(period_size), total - value_count)) - start_time = get_time() - yield gen_values(iter_values, _count) - time_taken = get_time() - start_time - value_count += _count - if abs(time_taken - update_period) > 0.2 * update_period: - period_size = period_size * (1.5 if time_taken < update_period else 0.8) + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.done.set() + self.join() def track( @@ -112,7 +99,7 @@ def track( complete_style: StyleType = "bar.complete", finished_style: StyleType = "bar.finished", pulse_style: StyleType = "bar.pulse", - update_period: float = 0.025, + update_period: float = 0.1, ) -> Iterable[ProgressType]: """Track progress by iterating over a sequence. @@ -128,7 +115,7 @@ def track( complete_style (StyleType, optional): Style for the completed bar. Defaults to "bar.complete". finished_style (StyleType, optional): Style for a finished bar. Defaults to "bar.done". pulse_style (StyleType, optional): Style for pulsing bars. Defaults to "bar.pulse". - update_period (float, optional): Minimum time (in seconds) between calls to update(). Defaults to 0.05. + update_period (float, optional): Minimum time (in seconds) between calls to update(). Defaults to 0.1. Returns: Iterable[ProgressType]: An iterable of the values in the sequence. @@ -158,30 +145,9 @@ def track( refresh_per_second=refresh_per_second, ) - task_total = total - if task_total is None: - if isinstance(sequence, Sized): - task_total = len(sequence) - else: - raise ValueError( - f"unable to get size of {sequence!r}, please specify 'total'" - ) - - task_id = progress.add_task(description, total=task_total) - advance = progress.advance - with progress: - for values in iter_track( - sequence, task_total, update_period=update_period if auto_refresh else 0 - ): - advance_total = 0 - for value in values: - yield value - advance_total += 1 - if advance_total == 0: - break - advance(task_id, advance_total) - if not progress.auto_refresh: - progress.refresh() + yield from progress.track( + sequence, total=total, description=description, update_period=update_period + ) class ProgressColumn(ABC): @@ -669,7 +635,7 @@ class Progress(JupyterMixin, RenderHook): total: int = None, task_id: Optional[TaskID] = None, description="Working...", - update_period: float = 0.025, + update_period: float = 0.1, ) -> Iterable[ProgressType]: """Track progress by iterating over a sequence. @@ -678,7 +644,7 @@ class Progress(JupyterMixin, RenderHook): total: (int, optional): Total number of steps. Default is len(sequence). task_id: (TaskID): Task to track. Default is new task. description: (str, optional): Description of task, if new task is created. - update_period (float, optional): Minimum time (in seconds) between calls to update(). Defaults to 0.05. + update_period (float, optional): Minimum time (in seconds) between calls to update(). Defaults to 0.1. Returns: Iterable[ProgressType]: An iterable of values taken from the provided sequence. @@ -698,19 +664,18 @@ class Progress(JupyterMixin, RenderHook): else: self.update(task_id, total=task_total) with self: - advance = self.advance - for values in iter_track( - sequence, - task_total, - update_period=update_period if self.auto_refresh else 0, - ): - advance_total = 0 - for value in values: + if self.auto_refresh: + with _TrackThread(self, task_id, update_period) as track_thread: + for value in sequence: + yield value + track_thread.completed += 1 + else: + advance = self.advance + refresh = self.refresh + for value in sequence: yield value - advance_total += 1 - advance(task_id, advance_total) - if not self.auto_refresh: - self.refresh() + advance(task_id, 1) + refresh() def start_task(self, task_id: TaskID) -> None: """Start a task. diff --git a/tests/test_progress.py b/tests/test_progress.py index 962dcf7f..c7bed14e 100644 --- a/tests/test_progress.py +++ b/tests/test_progress.py @@ -11,7 +11,6 @@ from rich.highlighter import NullHighlighter from rich.progress import ( BarColumn, FileSizeColumn, - iter_track, TotalFileSizeColumn, DownloadColumn, TransferSpeedColumn, @@ -44,36 +43,6 @@ class MockClock: self.time += advance -def test_iter_track(): - mock_clock = MockClock(auto=False) - result = [] - for values in iter_track( - range(1000), update_period=0.1, total=100, get_time=mock_clock - ): - chunk = [] - for n in values: - mock_clock.tick(0.01) - chunk.append(n) - result.append(chunk) - expected = [ - [0], - [1], - [2, 3], - [4, 5, 6], - [7, 8, 9, 10, 11], - [12, 13, 14, 15, 16, 17, 18], - [19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29], - [30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40], - [41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51], - [52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62], - [63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73], - [74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84], - [85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95], - [96, 97, 98, 99], - ] - assert result == expected - - def test_bar_columns(): bar_column = BarColumn(100) assert bar_column.bar_width == 100