51
51
ABCMultiIndex ,
52
52
ABCPeriodIndex ,
53
53
ABCSeries ,
54
+ ABCTimedeltaIndex ,
54
55
)
55
56
from pandas .core .dtypes .missing import isna
56
57
99
100
Series ,
100
101
)
101
102
102
- import itertools
103
+ from itertools import islice
103
104
104
105
from matplotlib .collections import LineCollection
105
106
@@ -1538,6 +1539,7 @@ def _make_legend(self) -> None:
1538
1539
1539
1540
class LinePlot (MPLPlot ):
1540
1541
_default_rot = 0
1542
+ _wide_line_threshold : int = 200
1541
1543
1542
1544
@property
1543
1545
def orientation (self ) -> PlottingOrientation :
@@ -1547,97 +1549,101 @@ def orientation(self) -> PlottingOrientation:
1547
1549
def _kind (self ) -> Literal ["line" , "area" , "hist" , "kde" , "box" ]:
1548
1550
return "line"
1549
1551
1550
- def __init__ (self , data , ** kwargs ) -> None :
1551
- MPLPlot .__init__ (self , data , ** kwargs )
1552
+ def __init__ (self , data , ** kwargs ):
1553
+ super () .__init__ (data , ** kwargs )
1552
1554
if self .stacked :
1553
- self .data = self .data .fillna (value = 0 )
1555
+ self .data = self .data .fillna (0 )
1554
1556
1555
- def _make_plot (self , fig : Figure ) -> None :
1556
- threshold = 200 # switch when DataFrame has more than this many columns
1557
- can_use_lc = (
1558
- not self ._is_ts_plot () # not a TS plot
1559
- and not self .stacked # stacking not requested
1560
- and not com .any_not_none (* self .errors .values ()) # no error bars
1561
- and len (self .data .columns ) > threshold
1562
- )
1563
- if can_use_lc :
1564
- ax = self ._get_ax (0 )
1565
- x = self ._get_xticks ()
1566
- segments = [
1567
- np .column_stack ((x , self .data [col ].values )) for col in self .data .columns
1568
- ]
1569
- base_colors = mpl .rcParams ["axes.prop_cycle" ].by_key ()["color" ]
1570
- colors = list (itertools .islice (itertools .cycle (base_colors ), len (segments )))
1571
- lc = LineCollection (
1572
- segments ,
1573
- colors = colors ,
1574
- linewidths = self .kwds .get ("linewidth" , mpl .rcParams ["lines.linewidth" ]),
1575
- )
1576
- ax .add_collection (lc )
1577
- ax .margins (0.05 )
1578
- return # skip the per-column Line2D loop
1557
+ def _make_plot (self , fig ):
1558
+ is_ts = self ._is_ts_plot ()
1579
1559
1580
- if self ._is_ts_plot ():
1560
+ # establish X, iterator, plot function
1561
+ if is_ts :
1581
1562
data = maybe_convert_index (self ._get_ax (0 ), self .data )
1582
-
1583
- x = data .index # dummy, not used
1584
- plotf = self ._ts_plot
1585
- it = data .items ()
1563
+ x_vals , iterator , plotf = data .index , data .items (), self ._ts_plot
1586
1564
else :
1587
- x = self ._get_xticks ()
1588
- # error: Incompatible types in assignment (expression has type
1589
- # "Callable[[Any, Any, Any, Any, Any, Any, KwArg(Any)], Any]", variable has
1590
- # type "Callable[[Any, Any, Any, Any, KwArg(Any)], Any]")
1591
- plotf = self ._plot # type: ignore[assignment]
1592
- # error: Incompatible types in assignment (expression has type
1593
- # "Iterator[tuple[Hashable, ndarray[Any, Any]]]", variable has
1594
- # type "Iterable[tuple[Hashable, Series]]")
1595
- it = self ._iter_data (data = self .data ) # type: ignore[assignment]
1565
+ x_vals = self ._get_xticks ()
1566
+ iterator , plotf = self ._iter_data (self .data ), self ._plot
1596
1567
1568
+ n_cols = len (self .data .columns )
1597
1569
stacking_id = self ._get_stacking_id ()
1598
1570
is_errorbar = com .any_not_none (* self .errors .values ())
1571
+ default_lw = self .kwds .get ("linewidth" , mpl .rcParams ["lines.linewidth" ])
1572
+ colours_seq = self ._get_colors ()
1573
+ colours = list (islice (colours_seq , n_cols ))
1574
+
1575
+ # fast aggregate-LC path (numeric, wide, simple)
1576
+ idx = self .data .index
1577
+ numeric_idx = is_any_real_numeric_dtype (idx .dtype )
1578
+
1579
+ complex_idx = isinstance (
1580
+ idx , (ABCTimedeltaIndex , ABCPeriodIndex , ABCDatetimeIndex )
1581
+ ) or getattr (idx , "_is_all_dates" , False )
1582
+
1583
+ lc_fast_path = (
1584
+ n_cols > self ._wide_line_threshold
1585
+ and numeric_idx
1586
+ and not complex_idx
1587
+ and not is_ts
1588
+ and not self .stacked
1589
+ and not is_errorbar
1590
+ and self .style is None
1591
+ and "marker" not in self .kwds
1592
+ and "linestyle" not in self .kwds
1593
+ )
1599
1594
1600
- colors = self ._get_colors ()
1601
- for i , (label , y ) in enumerate (it ):
1602
- ax = self ._get_ax (i )
1595
+ if lc_fast_path :
1596
+ ax = self ._get_ax (0 )
1597
+
1598
+ segs = np .empty ((n_cols , len (x_vals ), 2 ), dtype = float )
1599
+ segs [:, :, 0 ] = np .asarray (x_vals , dtype = float )
1600
+ segs [:, :, 1 ] = self .data .values .T
1601
+
1602
+ lc = LineCollection (segs , colors = colours , linewidths = default_lw )
1603
+ ax .add_collection (lc )
1604
+ ax .relim ()
1605
+ ax .autoscale ()
1606
+
1607
+ if self .legend :
1608
+ for i , col in enumerate (self .data .columns ):
1609
+ h = mpl .lines .Line2D (
1610
+ [], [], color = colours [i ], label = pprint_thing (col )
1611
+ )
1612
+ self ._append_legend_handles_labels (h , h .get_label ())
1613
+
1614
+ self ._post_plot_logic (ax , self .data )
1615
+ return
1616
+
1617
+ # unified per-column loop (complex or narrow frames)
1618
+ for i , (label , y ) in enumerate (iterator ):
1619
+ ax : Axes = self ._get_ax (i )
1603
1620
kwds = self .kwds .copy ()
1604
1621
if self .color is not None :
1605
1622
kwds ["color" ] = self .color
1606
- style , kwds = self ._apply_style_colors (
1607
- colors ,
1608
- kwds ,
1609
- i ,
1610
- # error: Argument 4 to "_apply_style_colors" of "MPLPlot" has
1611
- # incompatible type "Hashable"; expected "str"
1612
- label , # type: ignore[arg-type]
1613
- )
1614
1623
1615
- errors = self ._get_errorbars (label = label , index = i )
1616
- kwds = dict (kwds , ** errors )
1617
-
1618
- label = pprint_thing (label )
1619
- label = self ._mark_right_label (label , index = i )
1620
- kwds ["label" ] = label
1624
+ style , kwds = self ._apply_style_colors (colours_seq , kwds , i , label )
1625
+ kwds .update (self ._get_errorbars (label , i ))
1626
+ kwds ["label" ] = self ._mark_right_label (pprint_thing (label ), index = i )
1621
1627
1622
1628
newlines = plotf (
1623
1629
ax ,
1624
- x ,
1630
+ x_vals ,
1625
1631
y ,
1626
1632
style = style ,
1627
1633
column_num = i ,
1628
1634
stacking_id = stacking_id ,
1629
1635
is_errorbar = is_errorbar ,
1630
1636
** kwds ,
1631
1637
)
1632
- self ._append_legend_handles_labels (newlines [0 ], label )
1638
+ self ._append_legend_handles_labels (newlines [0 ], kwds [ " label" ] )
1633
1639
1634
- if self ._is_ts_plot ():
1635
- # reset of xlim should be used for ts data
1636
- # TODO: GH28021, should find a way to change view limit on xaxis
1637
- lines = get_all_lines (ax )
1638
- left , right = get_xlim (lines )
1640
+ if is_ts : # keep TS rescale
1641
+ left , right = get_xlim (get_all_lines (ax ))
1639
1642
ax .set_xlim (left , right )
1640
1643
1644
+ # shared clean-up
1645
+ self ._post_plot_logic (self ._get_ax (0 ), self .data )
1646
+
1641
1647
# error: Signature of "_plot" incompatible with supertype "MPLPlot"
1642
1648
@classmethod
1643
1649
def _plot ( # type: ignore[override]
0 commit comments