Replace iter_track

This commit is contained in:
Will McGugan 2020-07-30 18:31:45 +01:00
parent e338ab1457
commit e7ff0dac77
4 changed files with 51 additions and 111 deletions

View File

@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [4.2.2] - Unreleased
### Changed
- Added thread to automatically call updated in progress.track(). Replacing previous adaptive algorithm.
## [4.2.1] - 2020-07-29 ## [4.2.1] - 2020-07-29
### Added ### Added

View File

@ -2,7 +2,7 @@
name = "rich" name = "rich"
homepage = "https://github.com/willmcgugan/rich" homepage = "https://github.com/willmcgugan/rich"
documentation = "https://rich.readthedocs.io/en/latest/" documentation = "https://rich.readthedocs.io/en/latest/"
version = "4.2.1" version = "4.2.2"
description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" description = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal"
authors = ["Will McGugan <willmcgugan@gmail.com>"] authors = ["Will McGugan <willmcgugan@gmail.com>"]
license = "MIT" license = "MIT"

View File

@ -53,50 +53,37 @@ ProgressType = TypeVar("ProgressType")
GetTimeCallable = Callable[[], float] GetTimeCallable = Callable[[], float]
def iter_track( class _TrackThread(Thread):
values: Iterable[ProgressType], """A thread to periodically update progress."""
total: int,
update_period: float = 0.05,
get_time: Callable[[], float] = None,
) -> Iterable[Iterable[ProgressType]]:
"""Break a sequence in to chunks based on time.
Args: def __init__(self, progress: "Progress", task_id: "TaskID", update_period: float):
values (ProgressType): An iterable of values. self.progress = progress
self.task_id = task_id
self.update_period = update_period
self.done = Event()
Returns: self.completed = 0
Iterable[List[ProgressType]]: An iterable containing lists of values from sequence. super().__init__()
""" def run(self) -> None:
task_id = self.task_id
progress = self.progress
update_period = self.update_period
last_completed = 0
while not self.done.wait(update_period):
completed = self.completed
if last_completed != completed:
progress.advance(task_id, completed - last_completed)
last_completed = completed
progress.update(self.task_id, completed=self.completed, refresh=True)
if update_period == 0: def __enter__(self) -> "_TrackThread":
for value in values: self.start()
yield [value] return self
return
get_time = get_time or perf_counter def __exit__(self, exc_type, exc_val, exc_tb) -> None:
period_size = 1.0 self.done.set()
self.join()
def gen_values(
iter_values: Iterator[ProgressType], size: int
) -> Iterable[ProgressType]:
try:
for _ in range(size):
yield next(iter_values)
except StopIteration:
pass
iter_values = iter(values)
value_count = 0
while value_count < total:
_count = max(1, min(int(period_size), total - value_count))
start_time = get_time()
yield gen_values(iter_values, _count)
time_taken = get_time() - start_time
value_count += _count
if abs(time_taken - update_period) > 0.2 * update_period:
period_size = period_size * (1.5 if time_taken < update_period else 0.8)
def track( def track(
@ -112,7 +99,7 @@ def track(
complete_style: StyleType = "bar.complete", complete_style: StyleType = "bar.complete",
finished_style: StyleType = "bar.finished", finished_style: StyleType = "bar.finished",
pulse_style: StyleType = "bar.pulse", pulse_style: StyleType = "bar.pulse",
update_period: float = 0.025, update_period: float = 0.1,
) -> Iterable[ProgressType]: ) -> Iterable[ProgressType]:
"""Track progress by iterating over a sequence. """Track progress by iterating over a sequence.
@ -128,7 +115,7 @@ def track(
complete_style (StyleType, optional): Style for the completed bar. Defaults to "bar.complete". complete_style (StyleType, optional): Style for the completed bar. Defaults to "bar.complete".
finished_style (StyleType, optional): Style for a finished bar. Defaults to "bar.done". finished_style (StyleType, optional): Style for a finished bar. Defaults to "bar.done".
pulse_style (StyleType, optional): Style for pulsing bars. Defaults to "bar.pulse". pulse_style (StyleType, optional): Style for pulsing bars. Defaults to "bar.pulse".
update_period (float, optional): Minimum time (in seconds) between calls to update(). Defaults to 0.05. update_period (float, optional): Minimum time (in seconds) between calls to update(). Defaults to 0.1.
Returns: Returns:
Iterable[ProgressType]: An iterable of the values in the sequence. Iterable[ProgressType]: An iterable of the values in the sequence.
@ -158,30 +145,9 @@ def track(
refresh_per_second=refresh_per_second, refresh_per_second=refresh_per_second,
) )
task_total = total yield from progress.track(
if task_total is None: sequence, total=total, description=description, update_period=update_period
if isinstance(sequence, Sized): )
task_total = len(sequence)
else:
raise ValueError(
f"unable to get size of {sequence!r}, please specify 'total'"
)
task_id = progress.add_task(description, total=task_total)
advance = progress.advance
with progress:
for values in iter_track(
sequence, task_total, update_period=update_period if auto_refresh else 0
):
advance_total = 0
for value in values:
yield value
advance_total += 1
if advance_total == 0:
break
advance(task_id, advance_total)
if not progress.auto_refresh:
progress.refresh()
class ProgressColumn(ABC): class ProgressColumn(ABC):
@ -669,7 +635,7 @@ class Progress(JupyterMixin, RenderHook):
total: int = None, total: int = None,
task_id: Optional[TaskID] = None, task_id: Optional[TaskID] = None,
description="Working...", description="Working...",
update_period: float = 0.025, update_period: float = 0.1,
) -> Iterable[ProgressType]: ) -> Iterable[ProgressType]:
"""Track progress by iterating over a sequence. """Track progress by iterating over a sequence.
@ -678,7 +644,7 @@ class Progress(JupyterMixin, RenderHook):
total: (int, optional): Total number of steps. Default is len(sequence). total: (int, optional): Total number of steps. Default is len(sequence).
task_id: (TaskID): Task to track. Default is new task. task_id: (TaskID): Task to track. Default is new task.
description: (str, optional): Description of task, if new task is created. description: (str, optional): Description of task, if new task is created.
update_period (float, optional): Minimum time (in seconds) between calls to update(). Defaults to 0.05. update_period (float, optional): Minimum time (in seconds) between calls to update(). Defaults to 0.1.
Returns: Returns:
Iterable[ProgressType]: An iterable of values taken from the provided sequence. Iterable[ProgressType]: An iterable of values taken from the provided sequence.
@ -698,19 +664,18 @@ class Progress(JupyterMixin, RenderHook):
else: else:
self.update(task_id, total=task_total) self.update(task_id, total=task_total)
with self: with self:
advance = self.advance if self.auto_refresh:
for values in iter_track( with _TrackThread(self, task_id, update_period) as track_thread:
sequence, for value in sequence:
task_total, yield value
update_period=update_period if self.auto_refresh else 0, track_thread.completed += 1
): else:
advance_total = 0 advance = self.advance
for value in values: refresh = self.refresh
for value in sequence:
yield value yield value
advance_total += 1 advance(task_id, 1)
advance(task_id, advance_total) refresh()
if not self.auto_refresh:
self.refresh()
def start_task(self, task_id: TaskID) -> None: def start_task(self, task_id: TaskID) -> None:
"""Start a task. """Start a task.

View File

@ -11,7 +11,6 @@ from rich.highlighter import NullHighlighter
from rich.progress import ( from rich.progress import (
BarColumn, BarColumn,
FileSizeColumn, FileSizeColumn,
iter_track,
TotalFileSizeColumn, TotalFileSizeColumn,
DownloadColumn, DownloadColumn,
TransferSpeedColumn, TransferSpeedColumn,
@ -44,36 +43,6 @@ class MockClock:
self.time += advance self.time += advance
def test_iter_track():
mock_clock = MockClock(auto=False)
result = []
for values in iter_track(
range(1000), update_period=0.1, total=100, get_time=mock_clock
):
chunk = []
for n in values:
mock_clock.tick(0.01)
chunk.append(n)
result.append(chunk)
expected = [
[0],
[1],
[2, 3],
[4, 5, 6],
[7, 8, 9, 10, 11],
[12, 13, 14, 15, 16, 17, 18],
[19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
[30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40],
[41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51],
[52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62],
[63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73],
[74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84],
[85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95],
[96, 97, 98, 99],
]
assert result == expected
def test_bar_columns(): def test_bar_columns():
bar_column = BarColumn(100) bar_column = BarColumn(100)
assert bar_column.bar_width == 100 assert bar_column.bar_width == 100