@@ -106,11 +106,15 @@ impl PyDataFrame {
106106 }
107107
108108 fn _repr_html_ ( & self , py : Python ) -> PyDataFusionResult < String > {
109- let ( batch , mut has_more) =
110- wait_for_future ( py, get_first_record_batch ( self . df . as_ref ( ) . clone ( ) ) ) ?;
111- let Some ( batch ) = batch else {
109+ let ( batches , mut has_more) =
110+ wait_for_future ( py, get_first_few_record_batches ( self . df . as_ref ( ) . clone ( ) ) ) ?;
111+ let Some ( batches ) = batches else {
112112 return Ok ( "No data to display" . to_string ( ) ) ;
113113 } ;
114+ if batches. is_empty ( ) {
115+ // This should not be reached, but do it for safety since we index into the vector below
116+ return Ok ( "No data to display" . to_string ( ) ) ;
117+ }
114118
115119 let table_uuid = uuid:: Uuid :: new_v4 ( ) . to_string ( ) ;
116120
@@ -142,12 +146,11 @@ impl PyDataFrame {
142146 }
143147 </style>
144148
145-
146149 <div style=\" width: 100%; max-width: 1000px; max-height: 300px; overflow: auto; border: 1px solid #ccc;\" >
147150 <table style=\" border-collapse: collapse; min-width: 100%\" >
148151 <thead>\n " . to_string ( ) ;
149152
150- let schema = batch . schema ( ) ;
153+ let schema = batches [ 0 ] . schema ( ) ;
151154
152155 let mut header = Vec :: new ( ) ;
153156 for field in schema. fields ( ) {
@@ -156,52 +159,75 @@ impl PyDataFrame {
156159 let header_str = header. join ( "" ) ;
157160 html_str. push_str ( & format ! ( "<tr>{}</tr></thead><tbody>\n " , header_str) ) ;
158161
159- let formatters = batch
160- . columns ( )
162+ let batch_formatters = batches
161163 . iter ( )
162- . map ( |c| ArrayFormatter :: try_new ( c. as_ref ( ) , & FormatOptions :: default ( ) ) )
163- . map ( |c| c. map_err ( |e| PyValueError :: new_err ( format ! ( "Error: {:?}" , e. to_string( ) ) ) ) )
164+ . map ( |batch| {
165+ batch
166+ . columns ( )
167+ . iter ( )
168+ . map ( |c| ArrayFormatter :: try_new ( c. as_ref ( ) , & FormatOptions :: default ( ) ) )
169+ . map ( |c| {
170+ c. map_err ( |e| PyValueError :: new_err ( format ! ( "Error: {:?}" , e. to_string( ) ) ) )
171+ } )
172+ . collect :: < Result < Vec < _ > , _ > > ( )
173+ } )
164174 . collect :: < Result < Vec < _ > , _ > > ( ) ?;
165175
166- let batch_size = batch. get_array_memory_size ( ) ;
167- let num_rows_to_display = match batch_size > MAX_TABLE_BYTES_TO_DISPLAY {
176+ let total_memory: usize = batches
177+ . iter ( )
178+ . map ( |batch| batch. get_array_memory_size ( ) )
179+ . sum ( ) ;
180+ let rows_per_batch = batches. iter ( ) . map ( |batch| batch. num_rows ( ) ) ;
181+ let total_rows = rows_per_batch. clone ( ) . sum ( ) ;
182+
183+ // let (total_memory, total_rows) = batches.iter().fold((0, 0), |acc, batch| {
184+ // (acc.0 + batch.get_array_memory_size(), acc.1 + batch.num_rows())
185+ // });
186+
187+ let num_rows_to_display = match total_memory > MAX_TABLE_BYTES_TO_DISPLAY {
168188 true => {
169- let num_batch_rows = batch. num_rows ( ) ;
170- let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / batch_size as f32 ;
171- let mut reduced_row_num = ( num_batch_rows as f32 * ratio) . round ( ) as usize ;
189+ let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / total_memory as f32 ;
190+ let mut reduced_row_num = ( total_rows as f32 * ratio) . round ( ) as usize ;
172191 if reduced_row_num < MIN_TABLE_ROWS_TO_DISPLAY {
173- reduced_row_num = MIN_TABLE_ROWS_TO_DISPLAY . min ( num_batch_rows ) ;
192+ reduced_row_num = MIN_TABLE_ROWS_TO_DISPLAY . min ( total_rows ) ;
174193 }
175194
176- has_more = has_more || reduced_row_num < num_batch_rows ;
195+ has_more = has_more || reduced_row_num < total_rows ;
177196 reduced_row_num
178197 }
179- false => batch . num_rows ( ) ,
198+ false => total_rows ,
180199 } ;
181200
182- for row in 0 ..num_rows_to_display {
183- let mut cells = Vec :: new ( ) ;
184- for ( col, formatter) in formatters. iter ( ) . enumerate ( ) {
185- let cell_data = formatter. value ( row) . to_string ( ) ;
186- // From testing, primitive data types do not typically get larger than 21 characters
187- if cell_data. len ( ) > MAX_LENGTH_CELL_WITHOUT_MINIMIZE {
188- let short_cell_data = & cell_data[ 0 ..MAX_LENGTH_CELL_WITHOUT_MINIMIZE ] ;
189- cells. push ( format ! ( "
190- <td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>
191- <div class=\" expandable-container\" >
192- <span class=\" expandable\" id=\" {table_uuid}-min-text-{row}-{col}\" >{short_cell_data}</span>
193- <span class=\" full-text\" id=\" {table_uuid}-full-text-{row}-{col}\" >{cell_data}</span>
194- <button class=\" expand-btn\" onclick=\" toggleDataFrameCellText('{table_uuid}',{row},{col})\" >...</button>
195- </div>
196- </td>" ) ) ;
197- } else {
198- cells. push ( format ! ( "<td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>{}</td>" , formatter. value( row) ) ) ;
201+ // We need to build up row by row for html
202+ let mut table_row = 0 ;
203+ for ( batch_formatter, num_rows_in_batch) in batch_formatters. iter ( ) . zip ( rows_per_batch) {
204+ for batch_row in 0 ..num_rows_in_batch {
205+ table_row += 1 ;
206+ if table_row > num_rows_to_display {
207+ break ;
208+ }
209+ let mut cells = Vec :: new ( ) ;
210+ for ( col, formatter) in batch_formatter. iter ( ) . enumerate ( ) {
211+ let cell_data = formatter. value ( batch_row) . to_string ( ) ;
212+ // From testing, primitive data types do not typically get larger than 21 characters
213+ if cell_data. len ( ) > MAX_LENGTH_CELL_WITHOUT_MINIMIZE {
214+ let short_cell_data = & cell_data[ 0 ..MAX_LENGTH_CELL_WITHOUT_MINIMIZE ] ;
215+ cells. push ( format ! ( "
216+ <td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>
217+ <div class=\" expandable-container\" >
218+ <span class=\" expandable\" id=\" {table_uuid}-min-text-{table_row}-{col}\" >{short_cell_data}</span>
219+ <span class=\" full-text\" id=\" {table_uuid}-full-text-{table_row}-{col}\" >{cell_data}</span>
220+ <button class=\" expand-btn\" onclick=\" toggleDataFrameCellText('{table_uuid}',{table_row},{col})\" >...</button>
221+ </div>
222+ </td>" ) ) ;
223+ } else {
224+ cells. push ( format ! ( "<td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>{}</td>" , formatter. value( batch_row) ) ) ;
225+ }
199226 }
227+ let row_str = cells. join ( "" ) ;
228+ html_str. push_str ( & format ! ( "<tr>{}</tr>\n " , row_str) ) ;
200229 }
201- let row_str = cells. join ( "" ) ;
202- html_str. push_str ( & format ! ( "<tr>{}</tr>\n " , row_str) ) ;
203230 }
204-
205231 html_str. push_str ( "</tbody></table></div>\n " ) ;
206232
207233 html_str. push_str ( "
@@ -824,24 +850,36 @@ fn record_batch_into_schema(
824850/// It additionally returns a bool, which indicates if there are more record batches available.
825851/// We do this so we can determine if we should indicate to the user that the data has been
826852/// truncated.
827- async fn get_first_record_batch (
853+ async fn get_first_few_record_batches (
828854 df : DataFrame ,
829- ) -> Result < ( Option < RecordBatch > , bool ) , DataFusionError > {
855+ ) -> Result < ( Option < Vec < RecordBatch > > , bool ) , DataFusionError > {
830856 let mut stream = df. execute_stream ( ) . await ?;
831- loop {
857+ let mut size_estimate_so_far = 0 ;
858+ let mut record_batches = Vec :: default ( ) ;
859+ while size_estimate_so_far < MAX_TABLE_BYTES_TO_DISPLAY {
832860 let rb = match stream. next ( ) . await {
833- None => return Ok ( ( None , false ) ) ,
861+ None => {
862+ break ;
863+ }
834864 Some ( Ok ( r) ) => r,
835865 Some ( Err ( e) ) => return Err ( e) ,
836866 } ;
837867
838868 if rb. num_rows ( ) > 0 {
839- let has_more = match stream. try_next ( ) . await {
840- Ok ( None ) => false , // reached end
841- Ok ( Some ( _) ) => true ,
842- Err ( _) => false , // Stream disconnected
843- } ;
844- return Ok ( ( Some ( rb) , has_more) ) ;
869+ size_estimate_so_far += rb. get_array_memory_size ( ) ;
870+ record_batches. push ( rb) ;
845871 }
846872 }
873+
874+ if record_batches. is_empty ( ) {
875+ return Ok ( ( None , false ) ) ;
876+ }
877+
878+ let has_more = match stream. try_next ( ) . await {
879+ Ok ( None ) => false , // reached end
880+ Ok ( Some ( _) ) => true ,
881+ Err ( _) => false , // Stream disconnected
882+ } ;
883+
884+ Ok ( ( Some ( record_batches) , has_more) )
847885}
0 commit comments