Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-47629: Generate metric aggregation plots in analysis_tools #313

Merged
merged 3 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions python/lsst/analysis/tools/actions/plot/histPlot.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def makePlot(

# set up figure
fig = make_figure(dpi=300)
hist_fig, side_fig = fig.subfigures(1, 2, wspace=0, width_ratios=[3, 1])
hist_fig, side_fig = fig.subfigures(1, 2, wspace=0, width_ratios=[2.9, 1.1])
axs, ncols, nrows = self._makeAxes(hist_fig)

# loop over each panel; plot histograms
Expand All @@ -289,7 +289,7 @@ def makePlot(
if nth_panel == 0 and nrows * ncols - len(self.panels) > 0:
nth_col -= 1
# Set font size for legend based on number of panels being plotted.
legend_font_size = max(4, int(8 - len(self.panels[panel].hists) / 2 - nrows // 2)) # type: ignore
legend_font_size = max(4, int(7 - len(self.panels[panel].hists) / 2 - nrows // 2)) # type: ignore
nums, meds, mads, stats_dict = self._makePanel(
data,
panel,
Expand Down Expand Up @@ -608,12 +608,16 @@ def _addStatisticsPanel(
legend_labels = (
([""] * (len(handles) + 1))
+ [stats_dict["statLabels"][0]]
+ [f"{x:.3g}" for x in stats_dict["stat1"]]
+ [f"{x:.3g}" if abs(x) > 0.01 else f"{x:.2e}" for x in stats_dict["stat1"]]
+ [stats_dict["statLabels"][1]]
+ [f"{x:.3g}" for x in stats_dict["stat2"]]
+ [f"{x:.3g}" if abs(x) > 0.01 else f"{x:.2e}" for x in stats_dict["stat2"]]
+ [stats_dict["statLabels"][2]]
+ [f"{x:.3g}" for x in stats_dict["stat3"]]
+ [f"{x:.3g}" if abs(x) > 0.01 else f"{x:.2e}" for x in stats_dict["stat3"]]
)
# Replace "e+0" with "e" and "e-0" with "e-" to save space.
legend_labels = [label.replace("e+0", "e") for label in legend_labels]
jrmullaney marked this conversation as resolved.
Show resolved Hide resolved
legend_labels = [label.replace("e-0", "e-") for label in legend_labels]

# Set the y anchor for the legend such that it roughly lines up with
# the panels.
yAnchor = max(0, yAnchor0 - 0.01) + nth_col * (0.008 + len(nums) * 0.005) * legend_font_size
Expand Down
32 changes: 31 additions & 1 deletion python/lsst/analysis/tools/atools/calexpMetrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,15 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from __future__ import annotations

__all__ = ("CalexpSummaryMetrics",)
__all__ = (
"CalexpSummaryMetrics",
"CalexpMetricHists",
)

from lsst.pex.config import DictField

from ..actions.plot import HistPanel, HistPlot
from ..actions.vector import BandSelector, LoadVector
from ..interfaces import AnalysisTool


Expand Down Expand Up @@ -77,3 +84,26 @@ def setDefaults(self):

self.prep.keysToLoad = list(self._units.keys())
self.produce.metric.units = self._units


class CalexpMetricHists(AnalysisTool):
"""
Class to generate histograms of metrics extracted from a Metrics Table.
One plot per band.
"""

parameterizedBand: bool = False
metrics = DictField[str, str](doc="The metrics to plot and their respective labels.")

def setDefaults(self):
jrmullaney marked this conversation as resolved.
Show resolved Hide resolved
super().setDefaults()

# Band is passed as a kwarg from the calling task.
self.prep.selectors.bandSelector = BandSelector()
self.produce.plot = HistPlot()

def finalize(self):

for metric, label in self.metrics.items():
setattr(self.process.buildActions, metric, LoadVector(vectorKey=metric))
self.produce.plot.panels[metric] = HistPanel(hists={metric: "Number of calexps"}, label=label)
67 changes: 58 additions & 9 deletions python/lsst/analysis/tools/tasks/metricAnalysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,33 +26,82 @@
)


from lsst.pex.config import ListField
from lsst.pipe.base import connectionTypes as ct

from ..interfaces import AnalysisBaseConfig, AnalysisBaseConnections, AnalysisPipelineTask


class MetricAnalysisConnections(
AnalysisBaseConnections,
dimensions=("skymap",),
defaultTemplates={"metricBundleName": "objectTableCore_metrics"},
dimensions=(),
defaultTemplates={"metricTableName": ""},
):

data = ct.Input(
doc="A summary table of all metrics by tract.",
name="{metricBundleName}Table",
doc="A table containing metrics.",
name="{metricTableName}",
storageClass="ArrowAstropy",
dimensions=("skymap",),
deferLoad=True,
dimensions=(),
)

def __init__(self, *, config=None):

self.dimensions.update(frozenset(sorted(config.outputDataDimensions)))
super().__init__(config=config)
self.data = ct.Input(
doc=self.data.doc,
name=self.data.name,
storageClass=self.data.storageClass,
deferLoad=self.data.deferLoad,
dimensions=frozenset(sorted(config.inputDataDimensions)),
)


class MetricAnalysisConfig(AnalysisBaseConfig, pipelineConnections=MetricAnalysisConnections):
pass
class MetricAnalysisConfig(
AnalysisBaseConfig,
pipelineConnections=MetricAnalysisConnections,
):
inputDataDimensions = ListField[str](
doc="Dimensions of the input data table.",
default=(),
optional=False,
)
outputDataDimensions = ListField[str](
doc="Dimensions of the outputs.",
default=(),
optional=False,
)


class MetricAnalysisTask(AnalysisPipelineTask):
"""Turn metric bundles which are per tract into a
summary metric table.
"""Take a metric table and run an analysis tool on the
data it contains. This could include creating a plot
the metrics and/or calculating summary values of those
metrics, such as means, medians, etc. The analysis
is outlined within the analysis tool.
"""

ConfigClass = MetricAnalysisConfig
_DefaultName = "metricAnalysis"

def runQuantum(self, butlerQC, inputRefs, outputRefs):
# Doctstring inherited

inputs = butlerQC.get(inputRefs)
dataId = butlerQC.quantum.dataId
plotInfo = self.parsePlotInfo(inputs, dataId)

data = self.loadData(inputs.pop("data"))

# TODO: "bands" kwarg is a workaround for DM-47941.
outputs = self.run(
data=data,
plotInfo=plotInfo,
bands=dataId["band"],
band=dataId["band"],
**inputs,
)

butlerQC.put(outputs, outputRefs)
Loading