diff --git a/TODO.md b/TODO.md index 0834fb0..e42d239 100644 --- a/TODO.md +++ b/TODO.md @@ -45,16 +45,16 @@ After completing a milestone, create a pull request with your changes for review ## PR4: Data Visualization Module -- [ ] Set up visualization framework -- [ ] Implement histogram/density plots -- [ ] Create scatter plot functionality -- [ ] Add bar chart and pie chart generators -- [ ] Implement box plots and violin plots -- [ ] Create heatmap functionality -- [ ] Add visualization customization options -- [ ] Implement visualization export capability -- [ ] Write tests for all visualization functions -- [ ] Test visualization rendering with different data inputs +- [x] Set up visualization framework +- [x] Implement histogram/density plots +- [x] Create scatter plot functionality +- [x] Add bar chart and pie chart generators +- [x] Implement box plots and violin plots +- [x] Create heatmap functionality +- [x] Add visualization customization options +- [x] Implement visualization export capability +- [x] Write tests for all visualization functions +- [x] Test visualization rendering with different data inputs ## PR5: Model Training - Classification diff --git a/tests/test_viz.py b/tests/test_viz.py new file mode 100644 index 0000000..266b7dd --- /dev/null +++ b/tests/test_viz.py @@ -0,0 +1,53 @@ +import pandas as pd +from utils import viz + + +def sample_df(): + return pd.DataFrame({ + 'num1': [1, 2, 3, 4, 5], + 'num2': [5, 4, 3, 2, 1], + 'cat': ['a', 'b', 'a', 'b', 'a'], + }) + + +def test_histogram_and_density(): + df = sample_df() + fig = viz.histogram(df, 'num1', bins=2, title='Hist') + assert fig.layout.title.text == 'Hist' + fig = viz.histogram(df, 'num1', density=True) + assert fig.data[0].histnorm == 'probability density' + + +def test_scatter_plot(): + df = sample_df() + fig = viz.scatter_plot(df, 'num1', 'num2', color='cat', title='Scatter') + assert fig.layout.title.text == 'Scatter' + assert fig.data[0].marker.color is not None + + +def test_bar_and_pie_charts(): + df = sample_df() + bar = viz.bar_chart(df, 'cat', 'num1') + pie = viz.pie_chart(df, names='cat', values='num1') + assert bar.data and pie.data + + +def test_box_and_violin(): + df = sample_df() + box = viz.box_plot(df, x='cat', y='num1') + violin = viz.violin_plot(df, x='cat', y='num1') + assert box.data and violin.data + + +def test_heatmap(): + df = sample_df() + fig = viz.heatmap(df, title='Heat') + assert fig.layout.title.text == 'Heat' + + +def test_export_figure(tmp_path): + df = sample_df() + fig = viz.bar_chart(df, 'cat', 'num1') + out = tmp_path / 'chart.html' + viz.export_figure(fig, out) + assert out.exists() and out.stat().st_size > 0 diff --git a/utils/__init__.py b/utils/__init__.py index 9ac963c..aa655b5 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -3,5 +3,6 @@ from . import config from . import data from . import eda +from . import viz -__all__ = ["config", "data", "eda"] +__all__ = ["config", "data", "eda", "viz"] diff --git a/utils/viz.py b/utils/viz.py new file mode 100644 index 0000000..1087f4c --- /dev/null +++ b/utils/viz.py @@ -0,0 +1,103 @@ +"""Data visualization utilities.""" + +from __future__ import annotations + +from pathlib import Path +from typing import Optional + +import pandas as pd +import plotly.express as px +import plotly.graph_objects as go + + +def histogram( + df: pd.DataFrame, + column: str, + *, + bins: int = 20, + density: bool = False, + title: Optional[str] = None, +) -> go.Figure: + """Return a histogram or density plot for a column.""" + histnorm = "probability density" if density else None + fig = px.histogram(df, x=column, nbins=bins, histnorm=histnorm, title=title) + return fig + + +def scatter_plot( + df: pd.DataFrame, + x: str, + y: str, + *, + color: Optional[str] = None, + title: Optional[str] = None, +) -> go.Figure: + """Return a scatter plot.""" + fig = px.scatter(df, x=x, y=y, color=color, title=title) + return fig + + +def bar_chart( + df: pd.DataFrame, + x: str, + y: str, + *, + title: Optional[str] = None, +) -> go.Figure: + """Return a bar chart.""" + fig = px.bar(df, x=x, y=y, title=title) + return fig + + +def pie_chart( + df: pd.DataFrame, + names: str, + values: str, + *, + title: Optional[str] = None, +) -> go.Figure: + """Return a pie chart.""" + fig = px.pie(df, names=names, values=values, title=title) + return fig + + +def box_plot( + df: pd.DataFrame, + x: str, + y: str, + *, + title: Optional[str] = None, +) -> go.Figure: + """Return a box plot.""" + fig = px.box(df, x=x, y=y, title=title) + return fig + + +def violin_plot( + df: pd.DataFrame, + x: str, + y: str, + *, + title: Optional[str] = None, +) -> go.Figure: + """Return a violin plot.""" + fig = px.violin(df, x=x, y=y, box=True, title=title) + return fig + + +def heatmap( + df: pd.DataFrame, + *, + columns: Optional[list[str]] = None, + title: Optional[str] = None, +) -> go.Figure: + """Return a correlation heatmap for the given columns.""" + cols = columns or df.select_dtypes(include="number").columns.tolist() + corr = df[cols].corr() + fig = px.imshow(corr, text_auto=True, title=title) + return fig + + +def export_figure(fig: go.Figure, path: Path) -> None: + """Export a figure to an HTML file.""" + fig.write_html(str(path))