Skip to content

Adopt the array api #885

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 107 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
107 commits
Select commit Hold shift + click to select a range
6dd9166
Initial attempt to adopt array API
mwcraig Feb 7, 2025
d3b6203
A couple more changes
mwcraig Feb 9, 2025
f0bc6b8
Add function for check if something is an array
mwcraig Feb 10, 2025
424ebdc
Remove numpy from core
mwcraig Feb 12, 2025
7f5050d
WIP rewrite of a boolean index
mwcraig Mar 4, 2025
6a6bfa2
WIP testing updates
mwcraig Mar 4, 2025
1d2d171
Change more almost_equal to allclose
mwcraig Mar 5, 2025
22bc04d
Use the random number generator from numpy
mwcraig Mar 5, 2025
245d641
Remove more in-place array operations
mwcraig Mar 5, 2025
55b90be
Set a reasonable tolerance value in float comparisons
mwcraig Mar 5, 2025
23de261
Continue to use numpy arrays in a few places
mwcraig Mar 6, 2025
dffadce
Refactor warning handling for numerical warnings
mwcraig Mar 6, 2025
12ff8b0
Avoid explicit use of numpy masked arrays
mwcraig Mar 6, 2025
63f81ba
Rewrite test to not modify array in-place
mwcraig Mar 6, 2025
4ae3fc6
Use assert_allclose instead of older alternatives
mwcraig Mar 7, 2025
7b2ed18
Initial attempt to adopt array API in combiner
mwcraig Mar 7, 2025
1cbb7bc
Write a mask-aware sum function that is array API compatible
mwcraig Mar 17, 2025
da90350
Eliminate most use of numpy in tests
mwcraig Mar 19, 2025
bddb2de
Fix a couple calls to numpy masked array
mwcraig Mar 19, 2025
8ac481a
One more workaround for immutable arrays
mwcraig Mar 19, 2025
220a96e
Fix up dependencies and test environment setup
mwcraig Mar 19, 2025
7378297
Update minimum python to 3.10
mwcraig Mar 19, 2025
4485839
Fix mask access
mwcraig Mar 20, 2025
30e9e55
Ignore warnings about negative values in square root
mwcraig Mar 20, 2025
a45d2cd
Update several minimum dependencies
mwcraig Mar 20, 2025
e4afebc
Fix linting errors
mwcraig Mar 20, 2025
3724f0a
Drop unnecessary import
mwcraig Mar 20, 2025
c0928c8
Drop unneeded test
mwcraig Mar 20, 2025
c04f842
Skip memory tests if jax is installed
mwcraig Mar 20, 2025
9fcc655
Explain why numpy is still used in image_collection
mwcraig Mar 21, 2025
1fe1180
Drop numpy import in combiner
mwcraig Mar 24, 2025
282eaee
Use array_api_extra to handle immutable arrays
mwcraig Mar 27, 2025
5f27cd4
Use a consistent namespace for arrays
mwcraig Mar 27, 2025
5f7f3e0
Clean up a couple more cases to use array_api_extra
mwcraig Mar 27, 2025
b98f8f5
Change where bottleneck is test on GitHub Actions
mwcraig Mar 27, 2025
6ce2309
Skip coverage of one function
mwcraig Mar 27, 2025
b66abbc
Remove unused argument and logic
mwcraig Mar 27, 2025
e67ac8b
Add a test
mwcraig Mar 27, 2025
cdb9c45
Use tox environment to handle testing of different array libraries
mwcraig Jun 18, 2025
ef30230
Convert combiner tests to use Array API
mwcraig Jun 18, 2025
fda5ad4
Remove unnecessary copy argument
mwcraig Jun 18, 2025
03d413a
Use the array_api_compt numpy namespace instead of numpy
mwcraig Jun 18, 2025
dfc98cf
Add test against dask and fix bugs uncovered by tests
mwcraig Jun 20, 2025
191da4e
Add dask test to CI
mwcraig Jun 20, 2025
8fc85a5
Fix some errors introduced when changing the tests for dask
mwcraig Jun 20, 2025
73cc0b0
Allow cupy for testing
mwcraig Jun 24, 2025
53fb79a
Suppress square root warning generated in some array libraries
mwcraig Jun 26, 2025
4bef390
Apply suggestions from code review
mwcraig Jun 26, 2025
efeac76
Undo suggested edit
mwcraig Jul 1, 2025
69a79d5
Shorten up a loop with a comprehension
mwcraig Jul 1, 2025
f7ea0a9
Add minimum pin for dependency
mwcraig Jul 1, 2025
d3a732b
Update black target versions
mwcraig Jul 1, 2025
309bbf4
Store array namespace when Combiner is created
mwcraig Jul 2, 2025
bd23ac4
Add optional namespace argument to several functions
mwcraig Jul 2, 2025
fcac48f
Use array API in all cosmic ray tests
mwcraig Jul 3, 2025
d0bcdad
cast number to float to avoid multiple namespaces
mwcraig Jul 3, 2025
baae84f
Change internal data and mask to private properties
mwcraig Jul 5, 2025
655e844
Add properties for accessing the data and mask to be used in combination
mwcraig Jul 5, 2025
17c0dca
Apply suggestions from code review
mwcraig Jul 5, 2025
a4e6d4a
Choose performance over style
mwcraig Jul 6, 2025
5843955
Changes to testing for cupy
mwcraig Jul 12, 2025
bbc1249
Make sure to use CCDData.data instead of CCDData in comparisons
mwcraig Jul 12, 2025
6af1e1e
Add more robust handling of open file test
mwcraig Jul 12, 2025
5ec5759
Add several workarounds for non-compliance of CCData with Array API
mwcraig Jul 12, 2025
41e17e6
Minor changes
mwcraig Jul 16, 2025
7664bc0
Add and use wrapper classes to ensure array API use
mwcraig Jul 16, 2025
fb6e7d6
Fix for immutable array types
mwcraig Jul 16, 2025
c4c1277
Ensure array namespace is used throughout core tests
mwcraig Jul 16, 2025
048f079
Include more uncertainty types in tests
mwcraig Jul 16, 2025
359126a
Use array API correctly in combiner tests
mwcraig Jul 17, 2025
bb42d5b
Add array namespace input to combine function
mwcraig Jul 17, 2025
ba83726
Add an array namespace conversion in one more place
mwcraig Jul 17, 2025
698c6fa
Add argument for desired array namespace
mwcraig Jul 17, 2025
7c7ce31
Specify namespace in a couple of tests where data is read from disk
mwcraig Jul 17, 2025
9a5ccd8
Fix broken links in docstrings
mwcraig Jul 18, 2025
4503ddf
Switch more tests away from np.testing or explain why not switching
mwcraig Jul 18, 2025
9afb5a6
Minimal array API docs
mwcraig Jul 21, 2025
5289ddc
Improve test coverage
mwcraig Jul 29, 2025
802d9a4
Replace .array with .asarray
mwcraig Jul 29, 2025
126d174
Apply suggestions from code review
mwcraig Jul 29, 2025
ad9b9b5
Remove reference to list that is not used
mwcraig Jul 29, 2025
fd9f03e
Point to documentation for list of supported libraries
mwcraig Jul 29, 2025
749761f
Avoid more numpy arrays in tests
mwcraig Jul 29, 2025
f5bea45
Fix formatting issues
mwcraig Jul 30, 2025
bb8b81b
Apply suggestions from code review
mwcraig Jul 30, 2025
53dc6bb
Apply suggestions from code review
mwcraig Jul 31, 2025
12f5ab5
Do masking operation in-place if possible
mwcraig Jul 31, 2025
f657dfa
Avoid converting data to numpy array
mwcraig Aug 5, 2025
6d8edb7
Change mask setting in tests to avoid a numpy conversion
mwcraig Aug 5, 2025
b17566b
Use CCDData compatibility wrapper in trim_image
mwcraig Aug 5, 2025
fdec811
Cast number to array namespace
mwcraig Aug 5, 2025
7531eee
Remove another instance of conversion to a numpy array
mwcraig Aug 5, 2025
c196b93
More fixes for flat_correct
mwcraig Aug 5, 2025
f9ad753
One more flat fix
mwcraig Aug 5, 2025
a0f738a
Do not use pytest.approx because it forces a numpy conversion
mwcraig Aug 5, 2025
65ef951
Add cast to array namespace
mwcraig Aug 5, 2025
52d6d45
Add cast to array namespace in gain_correct
mwcraig Aug 5, 2025
b1f02f7
Fix handling of gain value/unit
mwcraig Aug 5, 2025
9471298
Avoid mask setter in more places
mwcraig Aug 5, 2025
94c04ab
Fix typo
mwcraig Aug 5, 2025
fc95389
Wrap the CCDData object inside transform_image
mwcraig Aug 5, 2025
056fcb2
Wrap before cop because copy sets mask
mwcraig Aug 5, 2025
0b2e798
Fix typo
mwcraig Aug 5, 2025
77ba319
Ensure uncertainty in test uses array namespace
mwcraig Aug 5, 2025
0e16465
Set mask properly in test
mwcraig Aug 5, 2025
59690f7
Make error from data, not ccd object
mwcraig Aug 5, 2025
5c81f63
Fix several array API issues in lacosmic
mwcraig Aug 5, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 16 additions & 23 deletions .github/workflows/ci_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,43 +27,36 @@ jobs:
strategy:
matrix:
include:
- name: 'ubuntu-py38-oldestdeps'
- name: 'ubuntu-py311-oldestdeps'
os: ubuntu-latest
python: '3.8'
python: '3.11'
# Test the oldest supported dependencies on the oldest supported Python
tox_env: 'py38-test-oldestdeps'

- name: 'macos-py310-astroscrappy11'
# Keep this test until astroscrappy 1.1.0 is the oldest supported
# version.
os: macos-latest
python: '3.10'
tox_env: 'py310-test-astroscrappy11'
tox_env: 'py311-test-oldestdeps'

- name: 'ubuntu-py312-bottleneck'
# Do not include bottleneck in this coverage test. By not including
# it we get a better measure of how we are covered when using the
# array API, which bottleneck short-circuits.
- name: 'ubuntu-py312-coverage'
os: ubuntu-latest
python: '3.12'
tox_env: 'py312-test-alldeps-bottleneck-cov'

- name: 'ubuntu-py310'
os: ubuntu-latest
python: '3.10'
tox_env: 'py310-test-alldeps-numpy124'
tox_env: 'py312-test-alldeps-cov'

- name: 'ubuntu-py311'
# Test non-numpy array libraries
- name: 'ubuntu-py313-jax'
os: ubuntu-latest
python: '3.11'
tox_env: 'py311-test-alldeps-numpy124'
python: '3.13'
tox_env: 'py313-jax'

- name: 'ubuntu-py312'
# Move bottleneck test a test without coverage
- name: 'ubuntu-py312-bottleneck'
os: ubuntu-latest
python: '3.12'
tox_env: 'py312-test-alldeps-numpy126'

- name: 'macos-py312'
- name: 'macos-py312-dask'
os: macos-latest
python: '3.12'
tox_env: 'py312-test-alldeps'
tox_env: 'py312-alldeps-dask'

- name: 'windows-py312'
os: windows-latest
Expand Down
257 changes: 257 additions & 0 deletions ccdproc/_ccddata_wrapper_for_array_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
# This file is a rough draft of the changes that will be needed
# in astropy.nddata to adopt the array API. This does not cover all
# of the changes that will be needed, but it is a start.

import array_api_compat
import numpy as np
from astropy import units as u
from astropy.nddata import (
CCDData,
StdDevUncertainty,
)
from astropy.nddata.compat import NDDataArray
from astropy.units import UnitsError


class _NDDataArray(NDDataArray):
@NDDataArray.mask.setter
def mask(self, value):
xp = array_api_compat.array_namespace(self.data)
# Check that value is not either type of null mask.
if (value is not None) and (value is not np.ma.nomask):
mask = xp.asarray(value, dtype=bool)
if mask.shape != self.data.shape:
raise ValueError(
f"dimensions of mask {mask.shape} and data "
f"{self.data.shape} do not match"
)
else:
self._mask = mask
else:
# internal representation should be one numpy understands
self._mask = np.ma.nomask


class _CCDDataWrapperForArrayAPI(CCDData):
"""
Thin wrapper around CCDData to allow arithmetic operations with
arbitray array API backends.
"""

def _arithmetic_wrapper(self, operation, operand, result_unit, **kwargs):
""" "
Use NDDataArray for arthmetic because that does not force conversion
to Quantity (and hence numpy array). If there are units on the operands
then NDArithmeticMixin will convert to Quantity.
"""
# Take the units off to make sure the arithmetic operation
# does not try to convert to Quantity.
if hasattr(self, "unit"):
self_unit = self.unit
self._unit = None
else:
self_unit = None

if hasattr(operand, "unit"):
operand_unit = operand.unit
operand._unit = None
else:
operand_unit = None

# Also take the units off of the uncertainty
if self_unit is not None and hasattr(self.uncertainty, "unit"):
self.uncertainty._unit = None

if (
operand_unit is not None
and hasattr(operand, "uncertainty")
and hasattr(operand.uncertainty, "unit")
):
operand.uncertainty._unit = None

_result = _NDDataArray._prepare_then_do_arithmetic(
operation, self, operand, **kwargs
)
if self_unit:
self._unit = self_unit
if operand_unit:
operand._unit = operand_unit
# Also take the units off of the uncertainty
if hasattr(self, "uncertainty") and self.uncertainty is not None:
self.uncertainty._unit = self_unit

if hasattr(operand, "uncertainty") and operand.uncertainty is not None:
operand.uncertainty._unit = operand_unit

# We need to handle the mask separately if we want to return a
# genuine CCDDatta object and CCDData does not understand the
# array API.
result_mask = None
if _result.mask is not None:
result_mask = _result._mask
_result._mask = None
result = CCDData(_result, unit=result_unit)
result._mask = result_mask
return result

def subtract(self, operand, xp=None, **kwargs):
"""
Determine the right operation to use and figure out
the units of the result.
"""
xp = xp or array_api_compat.array_namespace(self.data)
if not self.unit.is_equivalent(operand.unit):
raise UnitsError("Units must be equivalent for subtraction.")
result_unit = self.unit
handle_mask = kwargs.pop("handle_mask", xp.logical_or)
return self._arithmetic_wrapper(
xp.subtract, operand, result_unit, handle_mask=handle_mask, **kwargs
)

def add(self, operand, xp=None, **kwargs):
"""
Determine the right operation to use and figure out
the units of the result.
"""
xp = xp or array_api_compat.array_namespace(self.data)
if not self.unit.is_equivalent(operand.unit):
raise UnitsError("Units must be equivalent for addition.")
result_unit = self.unit
handle_mask = kwargs.pop("handle_mask", xp.logical_or)
return self._arithmetic_wrapper(
xp.add, operand, result_unit, handle_mask=handle_mask, **kwargs
)

def multiply(self, operand, xp=None, **kwargs):
"""
Determine the right operation to use and figure out
the units of the result.
"""
xp = xp or array_api_compat.array_namespace(self.data)
# The "1 *" below is because quantities do arithmetic properly
# but units do not necessarily.
if not hasattr(operand, "unit"):
operand_unit = 1 * u.dimensionless_unscaled
else:
operand_unit = operand.unit
result_unit = (1 * self.unit) * (1 * operand_unit)
handle_mask = kwargs.pop("handle_mask", xp.logical_or)
return self._arithmetic_wrapper(
xp.multiply, operand, result_unit, handle_mask=handle_mask, **kwargs
)

def divide(self, operand, xp=None, **kwargs):
"""
Determine the right operation to use and figure out
the units of the result.
"""
xp = xp or array_api_compat.array_namespace(self.data)
if not hasattr(operand, "unit"):
operand_unit = 1 * u.dimensionless_unscaled
else:
operand_unit = operand.unit
result_unit = (1 * self.unit) / (1 * operand_unit)
handle_mask = kwargs.pop("handle_mask", xp.logical_or)
return self._arithmetic_wrapper(
xp.divide, operand, result_unit, handle_mask=handle_mask, **kwargs
)

@NDDataArray.mask.setter
def mask(self, value):
xp = array_api_compat.array_namespace(self.data)
# Check that value is not either type of null mask.
if (value is not None) and (value is not np.ma.nomask):
mask = xp.asarray(value, dtype=bool)
if mask.shape != self.data.shape:
raise ValueError(
f"dimensions of mask {mask.shape} and data "
f"{self.data.shape} do not match"
)
else:
self._mask = mask
else:
# internal representation should be one numpy understands
self._mask = np.ma.nomask


class _StdDevUncertaintyWrapper(StdDevUncertainty):
"""
Override propagate methods to make sure they use the array API.
"""

def _propagate_add(self, other_uncert, result_data, correlation):
xp = array_api_compat.array_namespace(self.array, other_uncert.array)
return super()._propagate_add_sub(
other_uncert,
result_data,
correlation,
subtract=False,
to_variance=xp.square,
from_variance=xp.sqrt,
)

def _propagate_subtract(self, other_uncert, result_data, correlation):
xp = array_api_compat.array_namespace(self.array, other_uncert.array)
return super()._propagate_add_sub(
other_uncert,
result_data,
correlation,
subtract=True,
to_variance=xp.square,
from_variance=xp.sqrt,
)

def _propagate_multiply(self, other_uncert, result_data, correlation):
xp = array_api_compat.array_namespace(self.array, other_uncert.array)
return super()._propagate_multiply_divide(
other_uncert,
result_data,
correlation,
divide=False,
to_variance=xp.square,
from_variance=xp.sqrt,
)

def _propagate_divide(self, other_uncert, result_data, correlation):
xp = array_api_compat.array_namespace(self.array, other_uncert.array)
return super()._propagate_multiply_divide(
other_uncert,
result_data,
correlation,
divide=True,
to_variance=xp.square,
from_variance=xp.sqrt,
)


def _wrap_ccddata_for_array_api(ccd):
"""
Wrap a CCDData object for use with array API backends.
"""
if isinstance(ccd, _CCDDataWrapperForArrayAPI):
return ccd

_ccd = _CCDDataWrapperForArrayAPI(ccd)
if isinstance(_ccd.uncertainty, StdDevUncertainty):
_ccd.uncertainty = _StdDevUncertaintyWrapper(_ccd.uncertainty)
return _ccd


def _unwrap_ccddata_for_array_api(ccd):
"""
Unwrap a CCDData object from array API backends to the original CCDData.
"""

if isinstance(ccd.uncertainty, _StdDevUncertaintyWrapper):
ccd.uncertainty = StdDevUncertainty(ccd.uncertainty.array)

if isinstance(ccd, CCDData):
return ccd

if not isinstance(ccd, _CCDDataWrapperForArrayAPI):
raise TypeError(
"Input must be a CCDData or _CCDDataWrapperForArrayAPI instance."
)

# Convert back to CCDData
return CCDData(ccd)
Loading
Loading