From ed032cab48006e6104ae7446553c3ddc6779c81c Mon Sep 17 00:00:00 2001 From: Lauren MacArthur Date: Sat, 3 Aug 2024 15:08:41 -0700 Subject: [PATCH] Add unittest for CullFromMaskedRegion selection --- tests/test_sourceSelector.py | 48 +++++++++++++++++++++++++++++++++--- 1 file changed, 45 insertions(+), 3 deletions(-) diff --git a/tests/test_sourceSelector.py b/tests/test_sourceSelector.py index d15923577..a008f3d96 100644 --- a/tests/test_sourceSelector.py +++ b/tests/test_sourceSelector.py @@ -26,6 +26,7 @@ import astropy.units as u import warnings +import lsst.afw.image import lsst.afw.table import lsst.meas.algorithms import lsst.meas.base.tests @@ -56,28 +57,35 @@ def setUp(self): schema.addField("nChild", np.int32, "Number of children") schema.addField("detect_isPrimary", "Flag", "Is primary detection?") schema.addField("sky_source", "Flag", "Empty sky region.") + + self.xCol = "centroid_x" + self.yCol = "centroid_y" + schema.addField(self.xCol, float, "Centroid x value.") + schema.addField(self.yCol, float, "Centroid y value.") + self.catalog = lsst.afw.table.SourceCatalog(schema) self.catalog.reserve(10) self.config = self.Task.ConfigClass() + self.exposure = None def tearDown(self): del self.catalog def check(self, expected): task = self.Task(config=self.config) - results = task.run(self.catalog) + results = task.run(self.catalog, exposure=self.exposure) self.assertListEqual(results.selected.tolist(), expected) self.assertListEqual([src.getId() for src in results.sourceCat], [src.getId() for src, ok in zip(self.catalog, expected) if ok]) # Check with pandas.DataFrame version of catalog - results = task.run(self.catalog.asAstropy().to_pandas()) + results = task.run(self.catalog.asAstropy().to_pandas(), exposure=self.exposure) self.assertListEqual(results.selected.tolist(), expected) self.assertListEqual(list(results.sourceCat['id']), [src.getId() for src, ok in zip(self.catalog, expected) if ok]) # Check with astropy.table.Table version of catalog - results = task.run(self.catalog.asAstropy()) + results = task.run(self.catalog.asAstropy(), exposure=self.exposure) self.assertListEqual(results.selected.tolist(), expected) self.assertListEqual(list(results.sourceCat['id']), [src.getId() for src, ok in zip(self.catalog, expected) if ok]) @@ -369,6 +377,40 @@ def testFiniteRaDec(self): self.check([False, False, True, True, True]) + def testCullFromMaskedRegion(self): + # Test that objects whose centroids land on specified mask(s) are + # culled. + maskNames = ["NO_DATA", "BLAH"] + noDataPoints = [[0, 0], [3, 2]] + self.exposure = lsst.afw.image.ExposureF(5, 5) + mask = self.exposure.mask + for maskName in maskNames: + if maskName not in mask.getMaskPlaneDict(): + mask.addMaskPlane(maskName) + num = 5 + for _ in range(num): + self.catalog.addNew() + self.catalog[self.xCol][:] = 5.0 + self.catalog[self.yCol][:] = 5.0 + + # Set first two entries in catalog to land maskNames region. + for i, noDataPoint in enumerate(noDataPoints): + # Flip x & y for numpy array convention. + mask.array[noDataPoint[1]][noDataPoint[0]] = mask.getPlaneBitMask( + maskNames[min(i, len(maskNames) - 1)] + ) + self.catalog[self.xCol][i] = noDataPoint[0] + self.catalog[self.yCol][i] = noDataPoint[1] + + self.config.doCullFromMaskedRegion = True + self.config.cullFromMaskedRegion.xColName = self.xCol + self.config.cullFromMaskedRegion.yColName = self.yCol + self.config.cullFromMaskedRegion.badMaskNames = maskNames + self.check([False, False, True, True, True]) + # Reset config back to False and None for other tests. + self.config.doCullFromMaskedRegion = False + self.exposure = None + class TestBaseSourceSelector(lsst.utils.tests.TestCase): """Test the API of the Abstract Base Class with a trivial example."""