11from abc import ABC , abstractmethod
22from enum import auto
3- from typing import TYPE_CHECKING , Dict , NamedTuple , Optional
3+ from typing import Dict , NamedTuple , Optional
44
5- import click
5+ from rich . progress import BarColumn , Progress , SpinnerColumn , TaskProgressColumn , TextColumn , TimeElapsedColumn
66
77from cycode .cli .utils .enum_utils import AutoCountEnum
88from cycode .logger import get_logger
99
10- if TYPE_CHECKING :
11- from click ._termui_impl import ProgressBar
12- from click .termui import V as ProgressBarValue
13-
1410# use LOGGING_LEVEL=DEBUG env var to see debug logs of this module
1511logger = get_logger ('Progress Bar' , control_level_in_runtime = False )
1612
@@ -32,6 +28,14 @@ class ProgressBarSectionInfo(NamedTuple):
3228
3329
3430_PROGRESS_BAR_LENGTH = 100
31+ _PROGRESS_BAR_COLUMNS = (
32+ SpinnerColumn (),
33+ TextColumn ('[progress.description]{task.description}' ),
34+ TextColumn ('{task.fields[right_side_label]}' ),
35+ BarColumn (bar_width = None ),
36+ TaskProgressColumn (),
37+ TimeElapsedColumn (),
38+ )
3539
3640ProgressBarSections = Dict [ProgressBarSection , ProgressBarSectionInfo ]
3741
@@ -91,12 +95,6 @@ class BaseProgressBar(ABC):
9195 def __init__ (self , * args , ** kwargs ) -> None :
9296 pass
9397
94- @abstractmethod
95- def __enter__ (self ) -> 'BaseProgressBar' : ...
96-
97- @abstractmethod
98- def __exit__ (self , * args , ** kwargs ) -> None : ...
99-
10098 @abstractmethod
10199 def start (self ) -> None : ...
102100
@@ -110,19 +108,13 @@ def set_section_length(self, section: 'ProgressBarSection', length: int = 0) ->
110108 def update (self , section : 'ProgressBarSection' ) -> None : ...
111109
112110 @abstractmethod
113- def update_label (self , label : Optional [str ] = None ) -> None : ...
111+ def update_right_side_label (self , label : Optional [str ] = None ) -> None : ...
114112
115113
116114class DummyProgressBar (BaseProgressBar ):
117115 def __init__ (self , * args , ** kwargs ) -> None :
118116 super ().__init__ (* args , ** kwargs )
119117
120- def __enter__ (self ) -> 'DummyProgressBar' :
121- return self
122-
123- def __exit__ (self , * args , ** kwargs ) -> None :
124- pass
125-
126118 def start (self ) -> None :
127119 pass
128120
@@ -135,46 +127,49 @@ def set_section_length(self, section: 'ProgressBarSection', length: int = 0) ->
135127 def update (self , section : 'ProgressBarSection' ) -> None :
136128 pass
137129
138- def update_label (self , label : Optional [str ] = None ) -> None :
130+ def update_right_side_label (self , label : Optional [str ] = None ) -> None :
139131 pass
140132
141133
142134class CompositeProgressBar (BaseProgressBar ):
143135 def __init__ (self , progress_bar_sections : ProgressBarSections ) -> None :
144136 super ().__init__ ()
145137
146- self ._progress_bar_sections = progress_bar_sections
147-
148- self ._progress_bar_context_manager = click .progressbar (
149- length = _PROGRESS_BAR_LENGTH ,
150- item_show_func = self ._progress_bar_item_show_func ,
151- update_min_steps = 0 ,
152- )
153- self ._progress_bar : Optional ['ProgressBar' ] = None
154138 self ._run = False
139+ self ._progress_bar_sections = progress_bar_sections
155140
156141 self ._section_lengths : Dict [ProgressBarSection , int ] = {}
157142 self ._section_values : Dict [ProgressBarSection , int ] = {}
158143
159144 self ._current_section_value = 0
160145 self ._current_section : ProgressBarSectionInfo = _get_initial_section (self ._progress_bar_sections )
146+ self ._current_right_side_label = ''
161147
162- def __enter__ (self ) -> 'CompositeProgressBar' :
163- self ._progress_bar = self ._progress_bar_context_manager .__enter__ ()
164- self ._run = True
165- return self
148+ self ._progress_bar = Progress (
149+ * _PROGRESS_BAR_COLUMNS ,
150+ transient = True ,
151+ )
152+ self ._progress_bar_task_id = self ._progress_bar .add_task (
153+ description = self ._current_section .label ,
154+ total = _PROGRESS_BAR_LENGTH ,
155+ right_side_label = self ._current_right_side_label ,
156+ )
166157
167- def __exit__ (self , * args , ** kwargs ) -> None :
168- self ._progress_bar_context_manager .__exit__ (* args , ** kwargs )
169- self ._run = False
158+ def _progress_bar_update (self , advance : int = 0 ) -> None :
159+ self ._progress_bar .update (
160+ self ._progress_bar_task_id ,
161+ advance = advance ,
162+ description = self ._current_section .label ,
163+ right_side_label = self ._current_right_side_label ,
164+ )
170165
171166 def start (self ) -> None :
172167 if not self ._run :
173- self .__enter__ ()
168+ self ._progress_bar . start ()
174169
175170 def stop (self ) -> None :
176171 if self ._run :
177- self .__exit__ ( None , None , None )
172+ self ._progress_bar . stop ( )
178173
179174 def set_section_length (self , section : 'ProgressBarSection' , length : int = 0 ) -> None :
180175 logger .debug ('Calling set_section_length, %s' , {'section' : str (section ), 'length' : length })
@@ -190,7 +185,7 @@ def _get_section_length(self, section: 'ProgressBarSection') -> int:
190185 return section_info .stop_percent - section_info .start_percent
191186
192187 def _skip_section (self , section : 'ProgressBarSection' ) -> None :
193- self ._progress_bar . update (self ._get_section_length (section ))
188+ self ._progress_bar_update (self ._get_section_length (section ))
194189 self ._maybe_update_current_section ()
195190
196191 def _increment_section_value (self , section : 'ProgressBarSection' , value : int ) -> None :
@@ -205,13 +200,13 @@ def _increment_section_value(self, section: 'ProgressBarSection', value: int) ->
205200
206201 def _rerender_progress_bar (self ) -> None :
207202 """Used to update label right after changing the progress bar section."""
208- self ._progress_bar . update ( 0 )
203+ self ._progress_bar_update ( )
209204
210205 def _increment_progress (self , section : 'ProgressBarSection' ) -> None :
211206 increment_value = self ._get_increment_progress_value (section )
212207
213208 self ._current_section_value += increment_value
214- self ._progress_bar . update (increment_value )
209+ self ._progress_bar_update (increment_value )
215210
216211 def _maybe_update_current_section (self ) -> None :
217212 if not self ._current_section .section .has_next ():
@@ -237,13 +232,7 @@ def _get_increment_progress_value(self, section: 'ProgressBarSection') -> int:
237232
238233 return expected_value - self ._current_section_value
239234
240- def _progress_bar_item_show_func (self , _ : Optional ['ProgressBarValue' ] = None ) -> str :
241- return self ._current_section .label
242-
243235 def update (self , section : 'ProgressBarSection' , value : int = 1 ) -> None :
244- if not self ._progress_bar :
245- raise ValueError ('Progress bar is not initialized. Call start() first or use "with" statement.' )
246-
247236 if section not in self ._section_lengths :
248237 raise ValueError (f'{ section } section is not initialized. Call set_section_length() first.' )
249238 if section is not self ._current_section .section :
@@ -255,12 +244,9 @@ def update(self, section: 'ProgressBarSection', value: int = 1) -> None:
255244 self ._increment_progress (section )
256245 self ._maybe_update_current_section ()
257246
258- def update_label (self , label : Optional [str ] = None ) -> None :
259- if not self ._progress_bar :
260- raise ValueError ('Progress bar is not initialized. Call start() first or use "with" statement.' )
261-
262- self ._progress_bar .label = label or ''
263- self ._progress_bar .render_progress ()
247+ def update_right_side_label (self , label : Optional [str ] = None ) -> None :
248+ self ._current_right_side_label = f'({ label } )' or ''
249+ self ._progress_bar_update ()
264250
265251
266252def get_progress_bar (* , hidden : bool , sections : ProgressBarSections ) -> BaseProgressBar :
@@ -284,9 +270,9 @@ def get_progress_bar(*, hidden: bool, sections: ProgressBarSections) -> BaseProg
284270
285271 for _i in range (section_capacity ):
286272 time .sleep (0.01 )
287- bar .update_label (f'{ bar_section } { _i } /{ section_capacity } ' )
273+ bar .update_right_side_label (f'{ bar_section } { _i } /{ section_capacity } ' )
288274 bar .update (bar_section )
289275
290- bar .update_label ()
276+ bar .update_right_side_label ()
291277
292278 bar .stop ()
0 commit comments