diff --git a/explorer/charts.py b/explorer/charts.py index 6d3c9a3c..a10f82c4 100644 --- a/explorer/charts.py +++ b/explorer/charts.py @@ -5,7 +5,7 @@ BAR_WIDTH = 0.2 -def get_chart(result: QueryResult, chart_type: str) -> Optional[str]: +def get_chart(result: QueryResult, chart_type: str, num_rows: int) -> Optional[str]: import matplotlib.pyplot as plt """ Return a line or bar chart in SVG format if the result table adheres to the expected format. @@ -21,26 +21,27 @@ def get_chart(result: QueryResult, chart_type: str) -> Optional[str]: return if len(result.data) < 1: return None + data = result.data[:num_rows] numeric_columns = [ - c for c in range(1, len(result.data[0])) - if all([isinstance(col[c], (int, float)) or col[c] is None for col in result.data]) + c for c in range(1, len(data[0])) + if all([isinstance(col[c], (int, float)) or col[c] is None for col in data]) ] # Don't create charts for > 10 series. This is a lightweight visualization. if len(numeric_columns) < 1 or len(numeric_columns) > 10: return None - labels = [row[0] for row in result.data] + labels = [row[0] for row in data] fig, ax = plt.subplots(figsize=(10, 3.8)) bars = [] bar_positions = [] for idx, col_num in enumerate(numeric_columns): if chart_type == "bar": - values = [row[col_num] for row in result.data] + values = [row[col_num] for row in data] bar_container = ax.bar([x + idx * BAR_WIDTH for x in range(len(labels))], values, BAR_WIDTH, label=result.headers[col_num]) bars.append(bar_container) bar_positions.append([(rect.get_x(), rect.get_height()) for rect in bar_container]) if chart_type == "line": - ax.plot(labels, [row[col_num] for row in result.data], label=result.headers[col_num]) + ax.plot(labels, [row[col_num] for row in data], label=result.headers[col_num]) ax.set_xlabel(result.headers[0]) diff --git a/explorer/templates/explorer/schema.html b/explorer/templates/explorer/schema.html index 904fe841..c3f16f9f 100644 --- a/explorer/templates/explorer/schema.html +++ b/explorer/templates/explorer/schema.html @@ -6,7 +6,6 @@