Skip to content

Commit 909b809

Browse files
timsaucerandygrove
andauthored
Set of small features (#839)
* Add repr_html to give nice displays in notebooks when using display(df) * Allow get_item to get index of an array or a field in a struct * add test for getting array elements * Small typo in array * Add DataFrame transform * Update index in unit test * Add dataframe transform unit test * Add unit test for repr_html * Updating documentation * fix typo --------- Co-authored-by: Andy Grove <[email protected]>
1 parent e8ebc4f commit 909b809

File tree

7 files changed

+181
-4
lines changed

7 files changed

+181
-4
lines changed

docs/source/user-guide/common-operations/expressions.rst

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,43 @@ examples for the and, or, and not operations.
6060
heavy_red_units = (col("color") == lit("red")) & (col("weight") > lit(42))
6161
not_red_units = ~(col("color") == lit("red"))
6262
63+
Arrays
64+
------
65+
66+
For columns that contain arrays of values, you can access individual elements of the array by index
67+
using bracket indexing. This is similar to callling the function
68+
:py:func:`datafusion.functions.array_element`, except that array indexing using brackets is 0 based,
69+
similar to Python arrays and ``array_element`` is 1 based indexing to be compatible with other SQL
70+
approaches.
71+
72+
.. ipython:: python
73+
74+
from datafusion import SessionContext, col
75+
76+
ctx = SessionContext()
77+
df = ctx.from_pydict({"a": [[1, 2, 3], [4, 5, 6]]})
78+
df.select(col("a")[0].alias("a0"))
79+
80+
81+
.. warning::
82+
83+
Indexing an element of an array via ``[]`` starts at index 0 whereas
84+
:py:func:`~datafusion.functions.array_element` starts at index 1.
85+
86+
Structs
87+
-------
88+
89+
Columns that contain struct elements can be accessed using the bracket notation as if they were
90+
Python dictionary style objects. This expects a string key as the parameter passed.
91+
92+
.. ipython:: python
93+
94+
ctx = SessionContext()
95+
data = {"a": [{"size": 15, "color": "green"}, {"size": 10, "color": "blue"}]}
96+
df = ctx.from_pydict(data)
97+
df.select(col("a")["size"].alias("a_size"))
98+
99+
63100
Functions
64101
---------
65102

python/datafusion/dataframe.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import pandas as pd
3131
import polars as pl
3232
import pathlib
33+
from typing import Callable
3334

3435
from datafusion._internal import DataFrame as DataFrameInternal
3536
from datafusion.expr import Expr
@@ -72,6 +73,9 @@ def __repr__(self) -> str:
7273
"""
7374
return self.df.__repr__()
7475

76+
def _repr_html_(self) -> str:
77+
return self.df._repr_html_()
78+
7579
def describe(self) -> DataFrame:
7680
"""Return the statistics for this DataFrame.
7781
@@ -539,3 +543,25 @@ def __arrow_c_stream__(self, requested_schema: pa.Schema) -> Any:
539543
Arrow PyCapsule object.
540544
"""
541545
return self.df.__arrow_c_stream__(requested_schema)
546+
547+
def transform(self, func: Callable[..., DataFrame], *args: Any) -> DataFrame:
548+
"""Apply a function to the current DataFrame which returns another DataFrame.
549+
550+
This is useful for chaining together multiple functions. For example::
551+
552+
def add_3(df: DataFrame) -> DataFrame:
553+
return df.with_column("modified", lit(3))
554+
555+
def within_limit(df: DataFrame, limit: int) -> DataFrame:
556+
return df.filter(col("a") < lit(limit)).distinct()
557+
558+
df = df.transform(modify_df).transform(within_limit, 4)
559+
560+
Args:
561+
func: A callable function that takes a DataFrame as it's first argument
562+
args: Zero or more arguments to pass to `func`
563+
564+
Returns:
565+
DataFrame: After applying func to the original dataframe.
566+
"""
567+
return func(self, *args)

python/datafusion/expr.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@
2222

2323
from __future__ import annotations
2424

25-
from ._internal import expr as expr_internal, LogicalPlan
25+
from ._internal import (
26+
expr as expr_internal,
27+
LogicalPlan,
28+
functions as functions_internal,
29+
)
2630
from datafusion.common import NullTreatment, RexType, DataTypeMap
2731
from typing import Any, Optional
2832
import pyarrow as pa
@@ -257,8 +261,17 @@ def __invert__(self) -> Expr:
257261
"""Binary not (~)."""
258262
return Expr(self.expr.__invert__())
259263

260-
def __getitem__(self, key: str) -> Expr:
261-
"""For struct data types, return the field indicated by ``key``."""
264+
def __getitem__(self, key: str | int) -> Expr:
265+
"""Retrieve sub-object.
266+
267+
If ``key`` is a string, returns the subfield of the struct.
268+
If ``key`` is an integer, retrieves the element in the array. Note that the
269+
element index begins at ``0``, unlike `array_element` which begines at ``1``.
270+
"""
271+
if isinstance(key, int):
272+
return Expr(
273+
functions_internal.array_element(self.expr, Expr.literal(key + 1).expr)
274+
)
262275
return Expr(self.expr.__getitem__(key))
263276

264277
def __eq__(self, rhs: Any) -> Expr:

python/datafusion/functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1023,7 +1023,7 @@ def array(*args: Expr) -> Expr:
10231023
10241024
This is an alias for :py:func:`make_array`.
10251025
"""
1026-
return make_array(args)
1026+
return make_array(*args)
10271027

10281028

10291029
def range(start: Expr, stop: Expr, step: Expr) -> Expr:

python/datafusion/tests/test_dataframe.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
import os
18+
from typing import Any
1819

1920
import pyarrow as pa
2021
from pyarrow.csv import write_csv
@@ -970,3 +971,34 @@ def test_dataframe_export(df) -> None:
970971
except Exception:
971972
failed_convert = True
972973
assert failed_convert
974+
975+
976+
def test_dataframe_transform(df):
977+
def add_string_col(df_internal) -> DataFrame:
978+
return df_internal.with_column("string_col", literal("string data"))
979+
980+
def add_with_parameter(df_internal, value: Any) -> DataFrame:
981+
return df_internal.with_column("new_col", literal(value))
982+
983+
df = df.transform(add_string_col).transform(add_with_parameter, 3)
984+
985+
result = df.to_pydict()
986+
987+
assert result["a"] == [1, 2, 3]
988+
assert result["string_col"] == ["string data" for _i in range(0, 3)]
989+
assert result["new_col"] == [3 for _i in range(0, 3)]
990+
991+
992+
def test_dataframe_repr_html(df) -> None:
993+
output = df._repr_html_()
994+
995+
ref_html = """<table border='1'>
996+
<tr><th>a</td><th>b</td><th>c</td></tr>
997+
<tr><td>1</td><td>4</td><td>8</td></tr>
998+
<tr><td>2</td><td>5</td><td>5</td></tr>
999+
<tr><td>3</td><td>6</td><td>8</td></tr>
1000+
</table>
1001+
"""
1002+
1003+
# Ignore whitespace just to make this test look cleaner
1004+
assert output.replace(" ", "") == ref_html.replace(" ", "")

python/datafusion/tests/test_expr.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,26 @@ def traverse_logical_plan(plan):
169169
== '[Expr(Utf8("dfa")), Expr(Utf8("ad")), Expr(Utf8("dfre")), Expr(Utf8("vsa"))]'
170170
)
171171
assert not variant.negated()
172+
173+
174+
def test_expr_getitem() -> None:
175+
ctx = SessionContext()
176+
data = {
177+
"array_values": [[1, 2, 3], [4, 5], [6], []],
178+
"struct_values": [
179+
{"name": "Alice", "age": 15},
180+
{"name": "Bob", "age": 14},
181+
{"name": "Charlie", "age": 13},
182+
{"name": None, "age": 12},
183+
],
184+
}
185+
df = ctx.from_pydict(data, name="table1")
186+
187+
names = df.select(col("struct_values")["name"].alias("name")).collect()
188+
names = [r.as_py() for rs in names for r in rs["name"]]
189+
190+
array_values = df.select(col("array_values")[1].alias("value")).collect()
191+
array_values = [r.as_py() for rs in array_values for r in rs["value"]]
192+
193+
assert names == ["Alice", "Bob", "Charlie", None]
194+
assert array_values == [2, 5, None, None]

src/dataframe.rs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use arrow::compute::can_cast_types;
2323
use arrow::error::ArrowError;
2424
use arrow::ffi::FFI_ArrowSchema;
2525
use arrow::ffi_stream::FFI_ArrowArrayStream;
26+
use arrow::util::display::{ArrayFormatter, FormatOptions};
2627
use datafusion::arrow::datatypes::Schema;
2728
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
2829
use datafusion::arrow::util::pretty;
@@ -95,6 +96,51 @@ impl PyDataFrame {
9596
}
9697
}
9798

99+
fn _repr_html_(&self, py: Python) -> PyResult<String> {
100+
let mut html_str = "<table border='1'>\n".to_string();
101+
102+
let df = self.df.as_ref().clone().limit(0, Some(10))?;
103+
let batches = wait_for_future(py, df.collect())?;
104+
105+
if batches.is_empty() {
106+
html_str.push_str("</table>\n");
107+
return Ok(html_str);
108+
}
109+
110+
let schema = batches[0].schema();
111+
112+
let mut header = Vec::new();
113+
for field in schema.fields() {
114+
header.push(format!("<th>{}</td>", field.name()));
115+
}
116+
let header_str = header.join("");
117+
html_str.push_str(&format!("<tr>{}</tr>\n", header_str));
118+
119+
for batch in batches {
120+
let formatters = batch
121+
.columns()
122+
.iter()
123+
.map(|c| ArrayFormatter::try_new(c.as_ref(), &FormatOptions::default()))
124+
.map(|c| {
125+
c.map_err(|e| PyValueError::new_err(format!("Error: {:?}", e.to_string())))
126+
})
127+
.collect::<Result<Vec<_>, _>>()?;
128+
129+
for row in 0..batch.num_rows() {
130+
let mut cells = Vec::new();
131+
for formatter in &formatters {
132+
cells.push(format!("<td>{}</td>", formatter.value(row)));
133+
}
134+
let row_str = cells.join("");
135+
html_str.push_str(&format!("<tr>{}</tr>\n", row_str));
136+
}
137+
}
138+
139+
html_str.push_str("</table>\n");
140+
141+
Ok(html_str)
142+
}
143+
98144
/// Calculate summary statistics for a DataFrame
99145
fn describe(&self, py: Python) -> PyResult<Self> {
100146
let df = self.df.as_ref().clone();

0 commit comments

Comments
 (0)