Skip to content

Commit f4f499e

Browse files
author
evgenii
committed
ENH: speed up very-wide DataFrame line plots via single-pass LineCollection
1 parent 0febdd9 commit f4f499e

File tree

1 file changed

+72
-66
lines changed

1 file changed

+72
-66
lines changed

pandas/plotting/_matplotlib/core.py

Lines changed: 72 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
ABCMultiIndex,
5252
ABCPeriodIndex,
5353
ABCSeries,
54+
ABCTimedeltaIndex,
5455
)
5556
from pandas.core.dtypes.missing import isna
5657

@@ -99,7 +100,7 @@
99100
Series,
100101
)
101102

102-
import itertools
103+
from itertools import islice
103104

104105
from matplotlib.collections import LineCollection
105106

@@ -1538,6 +1539,7 @@ def _make_legend(self) -> None:
15381539

15391540
class LinePlot(MPLPlot):
15401541
_default_rot = 0
1542+
_wide_line_threshold: int = 200
15411543

15421544
@property
15431545
def orientation(self) -> PlottingOrientation:
@@ -1547,97 +1549,101 @@ def orientation(self) -> PlottingOrientation:
15471549
def _kind(self) -> Literal["line", "area", "hist", "kde", "box"]:
15481550
return "line"
15491551

1550-
def __init__(self, data, **kwargs) -> None:
1551-
MPLPlot.__init__(self, data, **kwargs)
1552+
def __init__(self, data, **kwargs):
1553+
super().__init__(data, **kwargs)
15521554
if self.stacked:
1553-
self.data = self.data.fillna(value=0)
1555+
self.data = self.data.fillna(0)
15541556

1555-
def _make_plot(self, fig: Figure) -> None:
1556-
threshold = 200 # switch when DataFrame has more than this many columns
1557-
can_use_lc = (
1558-
not self._is_ts_plot() # not a TS plot
1559-
and not self.stacked # stacking not requested
1560-
and not com.any_not_none(*self.errors.values()) # no error bars
1561-
and len(self.data.columns) > threshold
1562-
)
1563-
if can_use_lc:
1564-
ax = self._get_ax(0)
1565-
x = self._get_xticks()
1566-
segments = [
1567-
np.column_stack((x, self.data[col].values)) for col in self.data.columns
1568-
]
1569-
base_colors = mpl.rcParams["axes.prop_cycle"].by_key()["color"]
1570-
colors = list(itertools.islice(itertools.cycle(base_colors), len(segments)))
1571-
lc = LineCollection(
1572-
segments,
1573-
colors=colors,
1574-
linewidths=self.kwds.get("linewidth", mpl.rcParams["lines.linewidth"]),
1575-
)
1576-
ax.add_collection(lc)
1577-
ax.margins(0.05)
1578-
return # skip the per-column Line2D loop
1557+
def _make_plot(self, fig):
1558+
is_ts = self._is_ts_plot()
15791559

1580-
if self._is_ts_plot():
1560+
# establish X, iterator, plot function
1561+
if is_ts:
15811562
data = maybe_convert_index(self._get_ax(0), self.data)
1582-
1583-
x = data.index # dummy, not used
1584-
plotf = self._ts_plot
1585-
it = data.items()
1563+
x_vals, iterator, plotf = data.index, data.items(), self._ts_plot
15861564
else:
1587-
x = self._get_xticks()
1588-
# error: Incompatible types in assignment (expression has type
1589-
# "Callable[[Any, Any, Any, Any, Any, Any, KwArg(Any)], Any]", variable has
1590-
# type "Callable[[Any, Any, Any, Any, KwArg(Any)], Any]")
1591-
plotf = self._plot # type: ignore[assignment]
1592-
# error: Incompatible types in assignment (expression has type
1593-
# "Iterator[tuple[Hashable, ndarray[Any, Any]]]", variable has
1594-
# type "Iterable[tuple[Hashable, Series]]")
1595-
it = self._iter_data(data=self.data) # type: ignore[assignment]
1565+
x_vals = self._get_xticks()
1566+
iterator, plotf = self._iter_data(self.data), self._plot
15961567

1568+
n_cols = len(self.data.columns)
15971569
stacking_id = self._get_stacking_id()
15981570
is_errorbar = com.any_not_none(*self.errors.values())
1571+
default_lw = self.kwds.get("linewidth", mpl.rcParams["lines.linewidth"])
1572+
colours_seq = self._get_colors()
1573+
colours = list(islice(colours_seq, n_cols))
1574+
1575+
# fast aggregate-LC path (numeric, wide, simple)
1576+
idx = self.data.index
1577+
numeric_idx = is_any_real_numeric_dtype(idx.dtype)
1578+
1579+
complex_idx = isinstance(
1580+
idx, (ABCTimedeltaIndex, ABCPeriodIndex, ABCDatetimeIndex)
1581+
) or getattr(idx, "_is_all_dates", False)
1582+
1583+
lc_fast_path = (
1584+
n_cols > self._wide_line_threshold
1585+
and numeric_idx
1586+
and not complex_idx
1587+
and not is_ts
1588+
and not self.stacked
1589+
and not is_errorbar
1590+
and self.style is None
1591+
and "marker" not in self.kwds
1592+
and "linestyle" not in self.kwds
1593+
)
15991594

1600-
colors = self._get_colors()
1601-
for i, (label, y) in enumerate(it):
1602-
ax = self._get_ax(i)
1595+
if lc_fast_path:
1596+
ax = self._get_ax(0)
1597+
1598+
segs = np.empty((n_cols, len(x_vals), 2), dtype=float)
1599+
segs[:, :, 0] = np.asarray(x_vals, dtype=float)
1600+
segs[:, :, 1] = self.data.values.T
1601+
1602+
lc = LineCollection(segs, colors=colours, linewidths=default_lw)
1603+
ax.add_collection(lc)
1604+
ax.relim()
1605+
ax.autoscale()
1606+
1607+
if self.legend:
1608+
for i, col in enumerate(self.data.columns):
1609+
h = mpl.lines.Line2D(
1610+
[], [], color=colours[i], label=pprint_thing(col)
1611+
)
1612+
self._append_legend_handles_labels(h, h.get_label())
1613+
1614+
self._post_plot_logic(ax, self.data)
1615+
return
1616+
1617+
# unified per-column loop (complex or narrow frames)
1618+
for i, (label, y) in enumerate(iterator):
1619+
ax: Axes = self._get_ax(i)
16031620
kwds = self.kwds.copy()
16041621
if self.color is not None:
16051622
kwds["color"] = self.color
1606-
style, kwds = self._apply_style_colors(
1607-
colors,
1608-
kwds,
1609-
i,
1610-
# error: Argument 4 to "_apply_style_colors" of "MPLPlot" has
1611-
# incompatible type "Hashable"; expected "str"
1612-
label, # type: ignore[arg-type]
1613-
)
16141623

1615-
errors = self._get_errorbars(label=label, index=i)
1616-
kwds = dict(kwds, **errors)
1617-
1618-
label = pprint_thing(label)
1619-
label = self._mark_right_label(label, index=i)
1620-
kwds["label"] = label
1624+
style, kwds = self._apply_style_colors(colours_seq, kwds, i, label)
1625+
kwds.update(self._get_errorbars(label, i))
1626+
kwds["label"] = self._mark_right_label(pprint_thing(label), index=i)
16211627

16221628
newlines = plotf(
16231629
ax,
1624-
x,
1630+
x_vals,
16251631
y,
16261632
style=style,
16271633
column_num=i,
16281634
stacking_id=stacking_id,
16291635
is_errorbar=is_errorbar,
16301636
**kwds,
16311637
)
1632-
self._append_legend_handles_labels(newlines[0], label)
1638+
self._append_legend_handles_labels(newlines[0], kwds["label"])
16331639

1634-
if self._is_ts_plot():
1635-
# reset of xlim should be used for ts data
1636-
# TODO: GH28021, should find a way to change view limit on xaxis
1637-
lines = get_all_lines(ax)
1638-
left, right = get_xlim(lines)
1640+
if is_ts: # keep TS rescale
1641+
left, right = get_xlim(get_all_lines(ax))
16391642
ax.set_xlim(left, right)
16401643

1644+
# shared clean-up
1645+
self._post_plot_logic(self._get_ax(0), self.data)
1646+
16411647
# error: Signature of "_plot" incompatible with supertype "MPLPlot"
16421648
@classmethod
16431649
def _plot( # type: ignore[override]

0 commit comments

Comments
 (0)