Skip to content

Improve collection during repr and repr_html #1036

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 22, 2025
23 changes: 14 additions & 9 deletions python/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import os
import re
from typing import Any

import pyarrow as pa
Expand Down Expand Up @@ -1245,13 +1246,17 @@ def add_with_parameter(df_internal, value: Any) -> DataFrame:
def test_dataframe_repr_html(df) -> None:
output = df._repr_html_()
Comment on lines 1246 to 1247
Copy link
Contributor

@kosiew kosiew Mar 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be a good idea to test for other df fixtures too?

In addition, maybe add an empty_df fixture for tests too
eg

@pytest.fixture
def empty_df():
    ctx = SessionContext()

    # Create an empty RecordBatch with the same schema as df
    batch = pa.RecordBatch.from_arrays(
        [
            pa.array([], type=pa.int64()),
            pa.array([], type=pa.int64()),
            pa.array([], type=pa.int64()),
        ],
        names=["a", "b", "c"],
    )

    return ctx.from_arrow(batch)


@pytest.mark.parametrize(
    "dataframe_fixture",
    ["empty_df", "df", "nested_df", "struct_df", "partitioned_df", "aggregate_df"],
)
def test_dataframe_repr_html(request, dataframe_fixture) -> None:
    df = request.getfixturevalue(dataframe_fixture)
    output = df._repr_html_()


ref_html = """<table border='1'>
<tr><th>a</td><th>b</td><th>c</td></tr>
<tr><td>1</td><td>4</td><td>8</td></tr>
<tr><td>2</td><td>5</td><td>5</td></tr>
<tr><td>3</td><td>6</td><td>8</td></tr>
</table>
"""
# Since we've added a fair bit of processing to the html output, lets just verify
# the values we are expecting in the table exist. Use regex and ignore everything
# between the <th></th> and <td></td>. We also don't want the closing > on the
# td and th segments because that is where the formatting data is written.

# Ignore whitespace just to make this test look cleaner
assert output.replace(" ", "") == ref_html.replace(" ", "")
headers = ["a", "b", "c"]
headers = [f"<th(.*?)>{v}</th>" for v in headers]
header_pattern = "(.*?)".join(headers)
assert len(re.findall(header_pattern, output, re.DOTALL)) == 1

body_data = [[1, 4, 8], [2, 5, 5], [3, 6, 8]]
body_lines = [f"<td(.*?)>{v}</td>" for inner in body_data for v in inner]
body_pattern = "(.*?)".join(body_lines)
assert len(re.findall(body_pattern, output, re.DOTALL)) == 1
240 changes: 210 additions & 30 deletions src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@ use datafusion::common::UnnestOptions;
use datafusion::config::{CsvOptions, TableParquetOptions};
use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
use datafusion::datasource::TableProvider;
use datafusion::error::DataFusionError;
use datafusion::execution::SendableRecordBatchStream;
use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
use datafusion::prelude::*;
use futures::{StreamExt, TryStreamExt};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::pybacked::PyBackedStr;
Expand Down Expand Up @@ -70,6 +72,9 @@ impl PyTableProvider {
PyTable::new(table_provider)
}
}
const MAX_TABLE_BYTES_TO_DISPLAY: usize = 2 * 1024 * 1024; // 2 MB
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about make this configurable? 2MB still can mean lots of rows as the upper bound is usize::MAX

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about I open an issue to enhance this to be configurable as well as the follow on part about disabling the styling? I'd like to get this in so we fix explain and add some useful functionality now and then we can get these things tightened up in the next iteration.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added to issue #1078

const MIN_TABLE_ROWS_TO_DISPLAY: usize = 20;
const MAX_LENGTH_CELL_WITHOUT_MINIMIZE: usize = 25;

/// A PyDataFrame is a representation of a logical plan and an API to compose statements.
/// Use it to build a plan and `.collect()` to execute the plan and collect the result.
Expand Down Expand Up @@ -111,56 +116,151 @@ impl PyDataFrame {
}

fn __repr__(&self, py: Python) -> PyDataFusionResult<String> {
let df = self.df.as_ref().clone().limit(0, Some(10))?;
let batches = wait_for_future(py, df.collect())?;
let batches_as_string = pretty::pretty_format_batches(&batches);
match batches_as_string {
Ok(batch) => Ok(format!("DataFrame()\n{batch}")),
Err(err) => Ok(format!("Error: {:?}", err.to_string())),
let (batches, has_more) = wait_for_future(
py,
collect_record_batches_to_display(self.df.as_ref().clone(), 10, 10),
)?;
if batches.is_empty() {
// This should not be reached, but do it for safety since we index into the vector below
return Ok("No data to display".to_string());
}
}

fn _repr_html_(&self, py: Python) -> PyDataFusionResult<String> {
let mut html_str = "<table border='1'>\n".to_string();
let batches_as_displ =
pretty::pretty_format_batches(&batches).map_err(py_datafusion_err)?;

let additional_str = match has_more {
true => "\nData truncated.",
false => "",
};

let df = self.df.as_ref().clone().limit(0, Some(10))?;
let batches = wait_for_future(py, df.collect())?;
Ok(format!("DataFrame()\n{batches_as_displ}{additional_str}"))
}

fn _repr_html_(&self, py: Python) -> PyDataFusionResult<String> {
let (batches, has_more) = wait_for_future(
py,
collect_record_batches_to_display(
self.df.as_ref().clone(),
MIN_TABLE_ROWS_TO_DISPLAY,
usize::MAX,
),
)?;
Comment on lines +139 to +147
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extracting some variables into helper functions could make this more readable and easier to maintain eg:

    fn _repr_html_(&self, py: Python) -> PyDataFusionResult<String> {
        let (batches, has_more) = wait_for_future(
            py,
            collect_record_batches_to_display(
                self.df.as_ref().clone(),
                MIN_TABLE_ROWS_TO_DISPLAY,
                usize::MAX,
            ),
        )?;
        if batches.is_empty() {
            // This should not be reached, but do it for safety since we index into the vector below
            return Ok("No data to display".to_string());
        }

        let table_uuid = uuid::Uuid::new_v4().to_string();
        let schema = batches[0].schema();
        
        // Get table formatters for displaying cell values
        let batch_formatters = get_batch_formatters(&batches)?;
        let rows_per_batch = batches.iter().map(|batch| batch.num_rows());

        // Generate HTML components
        let mut html_str = generate_html_table_header(&schema);
        html_str.push_str(&generate_table_rows(
            &batch_formatters, 
            rows_per_batch, 
            &table_uuid
        )?);
        html_str.push_str("</tbody></table></div>\n");
        html_str.push_str(&generate_javascript());

        if has_more {
            html_str.push_str("Data truncated due to size.");
        }

        Ok(html_str)
    }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added to follow on issue #1078

if batches.is_empty() {
html_str.push_str("</table>\n");
return Ok(html_str);
// This should not be reached, but do it for safety since we index into the vector below
return Ok("No data to display".to_string());
}

let table_uuid = uuid::Uuid::new_v4().to_string();

let mut html_str = "
<style>
.expandable-container {
display: inline-block;
max-width: 200px;
}
.expandable {
white-space: nowrap;
overflow: hidden;
text-overflow: ellipsis;
display: block;
}
.full-text {
display: none;
white-space: normal;
}
.expand-btn {
cursor: pointer;
color: blue;
text-decoration: underline;
border: none;
background: none;
font-size: inherit;
display: block;
margin-top: 5px;
}
</style>

<div style=\"width: 100%; max-width: 1000px; max-height: 300px; overflow: auto; border: 1px solid #ccc;\">
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't feel so positive to add hardcoded styles especially absolute width/margin. If Jupyter modify their style/layout or users apply customization based on Jupyter UI, there might be incompatibility. At least, there should be a switch to turn off it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added to issue #1078

<table style=\"border-collapse: collapse; min-width: 100%\">
<thead>\n".to_string();

let schema = batches[0].schema();

let mut header = Vec::new();
for field in schema.fields() {
header.push(format!("<th>{}</td>", field.name()));
header.push(format!("<th style='border: 1px solid black; padding: 8px; text-align: left; background-color: #f2f2f2; white-space: nowrap; min-width: fit-content; max-width: fit-content;'>{}</th>", field.name()));
}
let header_str = header.join("");
html_str.push_str(&format!("<tr>{}</tr>\n", header_str));

for batch in batches {
let formatters = batch
.columns()
.iter()
.map(|c| ArrayFormatter::try_new(c.as_ref(), &FormatOptions::default()))
.map(|c| {
c.map_err(|e| PyValueError::new_err(format!("Error: {:?}", e.to_string())))
})
.collect::<Result<Vec<_>, _>>()?;

for row in 0..batch.num_rows() {
html_str.push_str(&format!("<tr>{}</tr></thead><tbody>\n", header_str));

let batch_formatters = batches
.iter()
.map(|batch| {
batch
.columns()
.iter()
.map(|c| ArrayFormatter::try_new(c.as_ref(), &FormatOptions::default()))
.map(|c| {
c.map_err(|e| PyValueError::new_err(format!("Error: {:?}", e.to_string())))
})
.collect::<Result<Vec<_>, _>>()
})
.collect::<Result<Vec<_>, _>>()?;

let rows_per_batch = batches.iter().map(|batch| batch.num_rows());

// We need to build up row by row for html
let mut table_row = 0;
for (batch_formatter, num_rows_in_batch) in batch_formatters.iter().zip(rows_per_batch) {
for batch_row in 0..num_rows_in_batch {
table_row += 1;
let mut cells = Vec::new();
for formatter in &formatters {
cells.push(format!("<td>{}</td>", formatter.value(row)));
for (col, formatter) in batch_formatter.iter().enumerate() {
let cell_data = formatter.value(batch_row).to_string();
// From testing, primitive data types do not typically get larger than 21 characters
if cell_data.len() > MAX_LENGTH_CELL_WITHOUT_MINIMIZE {
let short_cell_data = &cell_data[0..MAX_LENGTH_CELL_WITHOUT_MINIMIZE];
cells.push(format!("
<td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>
<div class=\"expandable-container\">
<span class=\"expandable\" id=\"{table_uuid}-min-text-{table_row}-{col}\">{short_cell_data}</span>
<span class=\"full-text\" id=\"{table_uuid}-full-text-{table_row}-{col}\">{cell_data}</span>
<button class=\"expand-btn\" onclick=\"toggleDataFrameCellText('{table_uuid}',{table_row},{col})\">...</button>
</div>
</td>"));
} else {
cells.push(format!("<td style='border: 1px solid black; padding: 8px; text-align: left; white-space: nowrap;'>{}</td>", formatter.value(batch_row)));
}
}
let row_str = cells.join("");
html_str.push_str(&format!("<tr>{}</tr>\n", row_str));
}
}
html_str.push_str("</tbody></table></div>\n");

html_str.push_str("
<script>
function toggleDataFrameCellText(table_uuid, row, col) {
var shortText = document.getElementById(table_uuid + \"-min-text-\" + row + \"-\" + col);
var fullText = document.getElementById(table_uuid + \"-full-text-\" + row + \"-\" + col);
var button = event.target;

if (fullText.style.display === \"none\") {
shortText.style.display = \"none\";
fullText.style.display = \"inline\";
button.textContent = \"(less)\";
} else {
shortText.style.display = \"inline\";
fullText.style.display = \"none\";
button.textContent = \"...\";
}
}
</script>
");

html_str.push_str("</table>\n");
if has_more {
html_str.push_str("Data truncated due to size.");
}

Ok(html_str)
}
Expand Down Expand Up @@ -771,3 +871,83 @@ fn record_batch_into_schema(

RecordBatch::try_new(schema, data_arrays)
}

/// This is a helper function to return the first non-empty record batch from executing a DataFrame.
/// It additionally returns a bool, which indicates if there are more record batches available.
/// We do this so we can determine if we should indicate to the user that the data has been
/// truncated. This collects until we have achived both of these two conditions
///
/// - We have collected our minimum number of rows
/// - We have reached our limit, either data size or maximum number of rows
///
/// Otherwise it will return when the stream has exhausted. If you want a specific number of
/// rows, set min_rows == max_rows.
async fn collect_record_batches_to_display(
df: DataFrame,
min_rows: usize,
max_rows: usize,
) -> Result<(Vec<RecordBatch>, bool), DataFusionError> {
let partitioned_stream = df.execute_stream_partitioned().await?;
let mut stream = futures::stream::iter(partitioned_stream).flatten();
let mut size_estimate_so_far = 0;
let mut rows_so_far = 0;
let mut record_batches = Vec::default();
let mut has_more = false;

while (size_estimate_so_far < MAX_TABLE_BYTES_TO_DISPLAY && rows_so_far < max_rows)
|| rows_so_far < min_rows
{
let mut rb = match stream.next().await {
None => {
break;
}
Some(Ok(r)) => r,
Some(Err(e)) => return Err(e),
};

let mut rows_in_rb = rb.num_rows();
if rows_in_rb > 0 {
size_estimate_so_far += rb.get_array_memory_size();

if size_estimate_so_far > MAX_TABLE_BYTES_TO_DISPLAY {
let ratio = MAX_TABLE_BYTES_TO_DISPLAY as f32 / size_estimate_so_far as f32;
let total_rows = rows_in_rb + rows_so_far;

let mut reduced_row_num = (total_rows as f32 * ratio).round() as usize;
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This estimation is not so accurate if some rows skew in size.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. And the data size is an estimate as well. The point is to get a general ball park and not necessarily an exact measure. It should still indicate that that data have been truncated.

if reduced_row_num < min_rows {
reduced_row_num = min_rows.min(total_rows);
}

let limited_rows_this_rb = reduced_row_num - rows_so_far;
if limited_rows_this_rb < rows_in_rb {
rows_in_rb = limited_rows_this_rb;
rb = rb.slice(0, limited_rows_this_rb);
has_more = true;
}
}

if rows_in_rb + rows_so_far > max_rows {
rb = rb.slice(0, max_rows - rows_so_far);
has_more = true;
}

rows_so_far += rb.num_rows();
record_batches.push(rb);
}
}

if record_batches.is_empty() {
return Ok((Vec::default(), false));
}

if !has_more {
// Data was not already truncated, so check to see if more record batches remain
has_more = match stream.try_next().await {
Ok(None) => false, // reached end
Ok(Some(_)) => true,
Err(_) => false, // Stream disconnected
};
}

Ok((record_batches, has_more))
}
2 changes: 1 addition & 1 deletion src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub(crate) fn get_tokio_runtime() -> &'static TokioRuntime {
#[inline]
pub(crate) fn get_global_ctx() -> &'static SessionContext {
static CTX: OnceLock<SessionContext> = OnceLock::new();
CTX.get_or_init(|| SessionContext::new())
CTX.get_or_init(SessionContext::new)
}

/// Utility to collect rust futures with GIL released
Expand Down