Skip to content

Commit 155d8f4

Browse files
author
FEO\preus
committed
- expand gm to dirac interface, allow any covariance matrix
- expand tests - minor fixes - expand testing input parameters
1 parent 0111b70 commit 155d8f4

File tree

8 files changed

+345
-49
lines changed

8 files changed

+345
-49
lines changed

src/deterministic_gaussian_sampling/approximation/base_approximation.py

Lines changed: 67 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,27 @@
11
import ctypes
2+
from dataclasses import dataclass
23
import numpy
34
from typing import Optional
45
from deterministic_gaussian_sampling.dll_handling import load_dll
56

7+
@dataclass
8+
class CovarianceData:
9+
cov: numpy.ndarray
10+
eigvals: numpy.ndarray
11+
Q: numpy.ndarray
12+
sqrt_eigvals: numpy.ndarray
13+
14+
def __init__(self, cov: numpy.ndarray):
15+
self.cov = cov
16+
self.eigvals, self.Q = numpy.linalg.eigh(cov)
17+
self.eigvals = numpy.maximum(self.eigvals, 1e-14)
18+
self.sqrt_eigvals = numpy.sqrt(self.eigvals)
19+
20+
self.cov = numpy.ascontiguousarray(self.cov, dtype=numpy.float64)
21+
self.Q = numpy.ascontiguousarray(self.Q, dtype=numpy.float64)
22+
self.eigvals = numpy.ascontiguousarray(self.eigvals, dtype=numpy.float64)
23+
self.sqrt_eigvals = numpy.ascontiguousarray(self.sqrt_eigvals, dtype=numpy.float64)
24+
625
class BaseApproximation:
726
cdll: Optional[ctypes.CDLL] = None
827

@@ -22,23 +41,61 @@ def _map_ctypes_numpy(self, ctype) -> type:
2241
raise TypeError("Unsupported ctype for mapping to numpy dtype")
2342

2443
def _check_numpy_ndarray(
25-
self, arr: Optional[numpy.ndarray], L: int, N: int
26-
) -> ctypes.Array:
44+
self, arr: Optional[numpy.ndarray], L: int, N: Optional[int] = None
45+
) -> Optional[numpy.ndarray]:
2746
if arr is None:
2847
return None
2948
if not isinstance(arr, numpy.ndarray):
3049
raise TypeError("Input must be a numpy array")
31-
if not numpy.issubdtype(arr.dtype, numpy.floating) and arr.dtype != float:
50+
if not numpy.issubdtype(arr.dtype, numpy.floating):
3251
raise TypeError(
3352
f"Input array must be of a floating type, but got {arr.dtype}."
3453
)
35-
if arr.shape != (L, N):
36-
row, cols = arr.shape
37-
raise ValueError(
38-
f"Input array must have size [{L}x{N}] but got [{row}x{cols}]"
39-
)
40-
return arr.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
54+
if L <= 0 or (N is not None and N <= 0):
55+
raise ValueError("L and N must be positive integers.")
56+
arrShape = arr.shape
57+
if arr.ndim == 1:
58+
if arrShape[0] != L:
59+
raise ValueError(
60+
f"Input array must have size [{L}] but got [{arrShape[0]}]"
61+
)
62+
if N is not None and N != 1:
63+
raise ValueError(
64+
f"Input array must have size [{L}x{N}] but got [{L}]"
65+
)
66+
elif arr.ndim == 2:
67+
if arrShape != (L, N):
68+
row, cols = arrShape
69+
raise ValueError(
70+
f"Input array must have size [{L}x{N}] but got [{row}x{cols}]"
71+
)
72+
else:
73+
raise ValueError("Input array must be 1D or 2D")
74+
75+
# ensure contiguous float64 memory
76+
return numpy.ascontiguousarray(arr, dtype=numpy.float64)
77+
78+
def _check_covariance_matrix(self, cov: numpy.ndarray, N: int, tol = 1e-6) -> CovarianceData:
79+
if cov.shape != (N, N):
80+
raise ValueError(f"Covariance matrix must be of shape [{N}x{N}] but got {cov.shape}")
81+
if not numpy.allclose(cov, cov.T, atol=tol):
82+
raise ValueError("Covariance matrix must be symmetric")
83+
if numpy.any(numpy.linalg.eigvalsh(cov) <= -tol):
84+
raise ValueError("Covariance matrix must be positive definite")
85+
86+
covChecked = self._check_numpy_ndarray(cov, N, N)
87+
return CovarianceData(covChecked)
88+
89+
def _check_weights(self, weight: Optional[numpy.ndarray], size: int, tol=1e-6) -> Optional[numpy.ndarray]:
90+
if weight is None:
91+
return None
92+
wChecked = self._check_numpy_ndarray(weight, size, None)
93+
if numpy.any(wChecked < 0):
94+
raise ValueError("Weights cannot be negative")
95+
if numpy.abs(1 - numpy.sum(wChecked)) > tol:
96+
raise ValueError(f"Sum of weights must be 1 within tolerance +/-{tol}, but got {numpy.sum(wChecked)}")
97+
return wChecked
4198

4299
@staticmethod
43100
def _register_cdll(cdll: ctypes.CDLL):
44-
BaseApproximation.cdll = cdll
101+
BaseApproximation.cdll = cdll

src/deterministic_gaussian_sampling/approximation/dirac_to_dirac.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,21 @@ def approximate_double(
3838
cdll = self.__class__.cdll
3939
if cdll is None:
4040
raise OSError("C++-Library was not loaded. Unable to continue!!!")
41+
yChecked = self._check_numpy_ndarray(y, M, N)
42+
xChecked = self._check_numpy_ndarray(x, L, N)
43+
wXChecked = self._check_weights(wX, L)
44+
wYChecked = self._check_weights(wY, M)
4145
minimizer_result = ctypes_wrapper.GslMinimizerResultCTypes()
4246
success: ctypes.c_bool = cdll.dirac_to_dirac_approx_short_double_approximate(
4347
self.d2d_short_double,
44-
self._check_numpy_ndarray(y, M, N),
48+
yChecked.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
4549
ctypes.c_size_t(M),
4650
ctypes.c_size_t(L),
4751
ctypes.c_size_t(N),
4852
ctypes.c_size_t(100),
49-
self._check_numpy_ndarray(x, L, N),
50-
self._check_numpy_ndarray(wX, L, 1),
51-
self._check_numpy_ndarray(wY, M, 1),
53+
xChecked.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
54+
None if wXChecked is None else wXChecked.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
55+
None if wYChecked is None else wYChecked.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
5256
ctypes.byref(minimizer_result),
5357
(
5458
None
@@ -69,24 +73,28 @@ def approximate_function_double(
6973
L: int,
7074
N: int,
7175
x: numpy.ndarray,
72-
wX: Optional[numpy.ndarray] = None,
73-
wY: Optional[numpy.ndarray] = None,
76+
wX: python_variant.wXCallbackPythonType,
77+
wXD: python_variant.wXDCallbackPythonType,
7478
options: Optional[python_variant.ApproximateOptionsPy] = None,
7579
) -> python_variant.ApproximationResultPy:
7680
cdll = self.__class__.cdll
7781
if cdll is None:
7882
raise OSError("C++-Library was not loaded. Unable to continue!!!")
83+
yChecked = self._check_numpy_ndarray(y, M, N)
84+
xChecked = self._check_numpy_ndarray(x, L, N)
7985
minimizer_result = ctypes_wrapper.GslMinimizerResultCTypes()
86+
wX = python_variant.wx_callback_python_wrapper(wX)
87+
wXD = python_variant.wxd_callback_python_wrapper(wXD)
8088
success = cdll.dirac_to_dirac_approx_short_function_double_approximate(
8189
self.d2d_func_double,
82-
self._check_numpy_ndarray(y, M, N),
90+
yChecked.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
8391
ctypes.c_size_t(M),
8492
ctypes.c_size_t(L),
8593
ctypes.c_size_t(N),
8694
ctypes.c_size_t(100),
87-
self._check_numpy_ndarray(x, L, N),
88-
wX, # TODO: replace with function wrapper
89-
wY, # TODO: replace with function wrapper
95+
xChecked.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
96+
wX,
97+
wXD,
9098
ctypes.byref(minimizer_result),
9199
(
92100
None
@@ -114,17 +122,21 @@ def approximate_thread_double(
114122
cdll = self.__class__.cdll
115123
if cdll is None:
116124
raise OSError("C++-Library was not loaded. Unable to continue!!!")
125+
yChecked = self._check_numpy_ndarray(y, M, N)
126+
xChecked = self._check_numpy_ndarray(x, L, N)
127+
wXChecked = self._check_weights(wX, L)
128+
wYChecked = self._check_weights(wY, M)
117129
minimizer_result = ctypes_wrapper.GslMinimizerResultCTypes()
118130
success = cdll.dirac_to_dirac_approx_short_thread_double_approximate(
119131
self.d2d_thread_double,
120-
self._check_numpy_ndarray(y, M, N),
132+
yChecked.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
121133
ctypes.c_size_t(M),
122134
ctypes.c_size_t(L),
123135
ctypes.c_size_t(N),
124136
ctypes.c_size_t(100),
125-
self._check_numpy_ndarray(x, L, N),
126-
self._check_numpy_ndarray(wX, L, 1),
127-
self._check_numpy_ndarray(wY, M, 1),
137+
xChecked.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
138+
None if wXChecked is None else wXChecked.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
139+
None if wYChecked is None else wYChecked.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
128140
ctypes.byref(minimizer_result),
129141
(
130142
None

src/deterministic_gaussian_sampling/approximation/gaussian_to_dirac.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __del__(self):
2828

2929
def approximate_double(
3030
self,
31-
covDiag: numpy.ndarray,
31+
cov: numpy.ndarray,
3232
L: int,
3333
N: int,
3434
x: numpy.ndarray,
@@ -38,25 +38,30 @@ def approximate_double(
3838
cdll = self.__class__.cdll
3939
if cdll is None:
4040
raise OSError("C++-Library was not loaded. Unable to continue!!!")
41+
covData = self._check_covariance_matrix(cov, N)
42+
xChecked = self._check_numpy_ndarray(x, L, N)
43+
wXChecked = self._check_weights(wX, L)
4144
minimizer_result = ctypes_wrapper.GslMinimizerResultCTypes()
4245
success = cdll.gm_to_dirac_short_double_approximate(
4346
self.gm_to_dirac_double,
44-
self._check_numpy_ndarray(covDiag, covDiag.shape[0], covDiag.shape[0]),
47+
covData.sqrt_eigvals.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
4548
ctypes.c_size_t(L),
4649
ctypes.c_size_t(N),
4750
ctypes.c_size_t(100),
48-
self._check_numpy_ndarray(x, L, N),
49-
self._check_numpy_ndarray(wX, L, 1),
51+
xChecked.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
52+
None if wXChecked is None else wXChecked.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
5053
ctypes.byref(minimizer_result),
5154
(
5255
None
5356
if options is None
5457
else ctypes_wrapper.ApproximateOptionsCTypes.from_py_type(options)
5558
),
5659
)
57-
return python_variant.ApproximationResultPy.from_ctypes(
60+
result = python_variant.ApproximationResultPy.from_ctypes(
5861
success, minimizer_result, x, L, N
5962
)
63+
result.x = result.x @ covData.Q.T
64+
return result
6065

6166
def approximate_snd_double(
6267
self,
@@ -69,14 +74,16 @@ def approximate_snd_double(
6974
cdll = self.__class__.cdll
7075
if cdll is None:
7176
raise OSError("C++-Library was not loaded. Unable to continue!!!")
77+
xChecked = self._check_numpy_ndarray(x, L, N)
78+
wXChecked = self._check_weights(wX, L)
7279
minimizer_result = ctypes_wrapper.GslMinimizerResultCTypes()
7380
success = cdll.gm_to_dirac_short_standard_normal_deviation_double_approximate(
7481
self.gm_to_dirac_snd_double,
7582
ctypes.c_size_t(L),
7683
ctypes.c_size_t(N),
7784
ctypes.c_size_t(100),
78-
self._check_numpy_ndarray(x, L, N),
79-
self._check_numpy_ndarray(wX, L, 1),
85+
xChecked.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
86+
None if wXChecked is None else wXChecked.ctypes.data_as(ctypes.POINTER(ctypes.c_double)),
8087
ctypes.byref(minimizer_result),
8188
(
8289
None
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from .ctypes_wrapper import GslMinimizerResultCTypes, ApproximateOptionsCTypes, wXCallbackCTypes
2-
from .python_variant import ApproximationResultPy, GslMinimizerResultPy, ApproximateOptionsPy
2+
from .python_variant import ApproximationResultPy, GslMinimizerResultPy, ApproximateOptionsPy, wx_callback_python_wrapper, wxd_callback_python_wrapper, wXCallbackPythonType, wXDCallbackPythonType
33

4-
__all__ = ["GslMinimizerResultCTypes", "ApproximateOptionsCTypes", "wXCallbackCTypes", "ApproximationResultPy", "GslMinimizerResultPy", "ApproximateOptionsPy"]
4+
__all__ = ["GslMinimizerResultCTypes", "ApproximateOptionsCTypes", "wXCallbackCTypes", "ApproximationResultPy", "GslMinimizerResultPy", "ApproximateOptionsPy", "wx_callback_python_wrapper", "wxd_callback_python_wrapper", "wXCallbackPythonType", "wXDCallbackPythonType"]

src/deterministic_gaussian_sampling/type_wrapper/ctypes_wrapper.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ class GslMinimizerResultCTypes(ctypes.Structure):
2626

2727
@staticmethod
2828
def from_py_type(pyT: "deterministic_gaussian_sampling.type_wrapper.python_variant.GslMinimizerResultPy") -> "GslMinimizerResultCTypes":
29-
# local import to avoid circular import at module import time
30-
import deterministic_gaussian_sampling.type_wrapper.python_variant as python_variant # noqa: F401
29+
import deterministic_gaussian_sampling.type_wrapper.python_variant as python_variant
3130
return GslMinimizerResultCTypes(
3231
ctypes.c_double(pyT.initalStepSize),
3332
ctypes.c_double(pyT.stepTolerance),
@@ -47,7 +46,6 @@ def from_py_type(pyT: "deterministic_gaussian_sampling.type_wrapper.python_varia
4746
)
4847

4948
def to_py_type(self):
50-
# local import to avoid circular import at module import time
5149
import deterministic_gaussian_sampling.type_wrapper.python_variant as python_variant
5250
return python_variant.GslMinimizerResultPy(
5351
float(self.initalStepSize),
@@ -82,8 +80,7 @@ class ApproximateOptionsCTypes(ctypes.Structure):
8280

8381
@staticmethod
8482
def from_py_type(pyT: "deterministic_gaussian_sampling.type_wrapper.python_variant.ApproximateOptionsPy") -> "ApproximateOptionsCTypes":
85-
# local import
86-
import deterministic_gaussian_sampling.type_wrapper.python_variant as python_variant # noqa: F401
83+
import deterministic_gaussian_sampling.type_wrapper.python_variant as python_variant
8784
return ApproximateOptionsCTypes(
8885
ctypes.c_double(pyT.xtolAbs),
8986
ctypes.c_double(pyT.xtolRel),

src/deterministic_gaussian_sampling/type_wrapper/python_variant.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
# python_variant.py
22
from __future__ import annotations
3+
from typing import Callable, Sequence, Union
34
from dataclasses import dataclass
45
import ctypes
56
import numpy as np
67

8+
from deterministic_gaussian_sampling.type_wrapper import ctypes_wrapper
9+
710
@dataclass
811
class ApproximationResultPy:
912
success: bool
@@ -126,3 +129,50 @@ def __str__(self):
126129
return "\n".join(lines)
127130

128131
__repr__ = __str__
132+
133+
ArrayLike = Union[Sequence[float], np.ndarray]
134+
# Function type: (x, L, N) -> res
135+
wXCallbackPythonType = Callable[[ArrayLike, int, int], ArrayLike]
136+
wXDCallbackPythonType = wXCallbackPythonType
137+
138+
def wx_callback_python_wrapper(func: wXCallbackPythonType) -> "ctypes_wrapper.wXCallbackCTypes":
139+
"""
140+
Wrap a Python function of signature (x: ndarray[L,N], L, N) -> res
141+
to a ctypes callback for C.
142+
"""
143+
def c_callback(x_ptr, res_ptr, L, N):
144+
L = int(L)
145+
N = int(N)
146+
size = L * N
147+
148+
# Convert C pointer to 1D array and reshape to (N,L) for row-major -> column-major
149+
x_raw = np.ctypeslib.as_array(x_ptr, shape=(size,))
150+
x = x_raw.reshape((L, N))
151+
152+
# Call user Python function
153+
res_val = func(x, L, N)
154+
res_val = np.asarray(res_val, dtype=np.float64)
155+
156+
# Prepare the result to write back to C in row-major
157+
if res_val.ndim == 1:
158+
# vector case (L,)
159+
if res_val.size != L:
160+
raise ValueError(f"Returned array has size {res_val.size}, expected {L}")
161+
res_arr = np.ctypeslib.as_array(res_ptr, shape=(L,))
162+
res_arr[:] = res_val
163+
elif res_val.ndim == 2:
164+
if res_val.shape != (L, N):
165+
raise ValueError(f"Returned array has shape {res_val.shape}, expected ({L},{N})")
166+
# Convert to row-major (flattened)
167+
res_row_major = res_val.flatten()
168+
res_arr = np.ctypeslib.as_array(res_ptr, shape=(size,))
169+
res_arr[:] = res_row_major
170+
else:
171+
raise ValueError(f"Returned array has invalid ndim {res_val.ndim}")
172+
173+
cb = ctypes_wrapper.wXCallbackCTypes(c_callback)
174+
cb._keepalive = func # keep reference alive
175+
return cb
176+
177+
# Alias for your derivative callback if needed
178+
wxd_callback_python_wrapper = wx_callback_python_wrapper

0 commit comments

Comments
 (0)