Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions docs/source/usage/python.rst
Original file line number Diff line number Diff line change
Expand Up @@ -645,16 +645,17 @@ This module provides elements and methods for the accelerator lattice.
:param pals_line: PALS Python Line with beamline elements
:param nslice: number of slices used for the application of collective effects

.. py:method:: select(kind=None, name=None)
.. py:method:: select(kind=None, name=None, s=None)

Filter elements by type and/or name.
If both are provided, OR-based logic is applied.
Filter elements by type, name, and/or integrated position.
If multiple criteria are provided, OR-based logic is applied.

Returns references to original elements, allowing modification and chaining.
Chained ``.select(...).select(...)`` selections are AND-filtered.

:param kind: Element type(s) to filter by. Can be a string (e.g., ``"Drift"``), regex pattern (e.g., ``r".*Quad"``), element type (e.g., ``elements.Drift``), or list/tuple of these.
:param name: Element name(s) to filter by. Can be a string, regex pattern, or ``list``/``tuple`` of these.
:param s: Position range to filter by. Tuple/list with (lower, upper) bounds where None represents open-ended. Elements are selected if ANY part overlaps with the range. Examples: ``(1.0, 5.0)`` for range 1.0 <= s <= 5.0, ``(1.0, None)`` for s >= 1.0, ``(None, 5.0)`` for s <= 5.0.

**Examples:**

Expand All @@ -673,6 +674,12 @@ This module provides elements and methods for the accelerator lattice.
# Chain filters (AND logic)
drift_named_d1 = lattice.select(kind="Drift").select(name="drift1")

# Position filtering (overlap logic)
early_elements = lattice.select(s=(None, 2.0)) # Elements overlapping s <= 2.0

# Chaining: s always calculated from original lattice
drift_then_early = lattice.select(kind="Drift").select(s=(1.0, 3.0)) # Drift AND overlapping s=[1.0,3.0]

# Modify original elements through references
drift_elements[0].ds = 2.0 # modifies original lattice

Expand Down
97 changes: 97 additions & 0 deletions src/python/impactx/Kahan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
def _kahan_babushka_core(values, return_cumulative=False):
"""Core implementation of the second-order iterative Kahan-Babuska algorithm.
This is the unified core that implements Klein (2006) algorithm for both
regular summation and cumulative summation.
- https://en.wikipedia.org/wiki/Kahan_summation_algorithm#Further_enhancements
- Klein (2006). "A generalized Kahan–Babuška-Summation-Algorithm". in
Computing. 76 (3–4). Springer-Verlag: 279–293. doi:10.1007/s00607-005-0139-x
Args:
values: Iterable of numeric values to sum
return_cumulative: If True, returns list of cumulative sums; if False, returns final sum
Returns:
float or list: Final sum if return_cumulative=False, list of cumulative sums if True
"""
sum_val = 0.0
cs = 0.0 # first-order compensation for lost low-order bits
ccs = 0.0 # second-order compensation for further lost bits
c = 0.0 # temporary variable for first-order compensation

Check warning

Code scanning / CodeQL

Variable defined multiple times Warning

This assignment to 'c' is unnecessary as it is
redefined
before this value is used.
This assignment to 'c' is unnecessary as it is
redefined
before this value is used.
cc = 0.0 # temporary variable for second-order compensation

Check warning

Code scanning / CodeQL

Variable defined multiple times Warning

This assignment to 'cc' is unnecessary as it is
redefined
before this value is used.
This assignment to 'cc' is unnecessary as it is
redefined
before this value is used.

if return_cumulative:
cumulative_sums = [0.0] # Start with 0.0

for val in values:
# First-order Kahan-Babuška step
t = sum_val + val
if abs(sum_val) >= abs(val):
c = (sum_val - t) + val
else:
c = (val - t) + sum_val
sum_val = t

# Second-order compensation step
t = cs + c
if abs(cs) >= abs(c):
cc = (cs - t) + c
else:
cc = (c - t) + cs
cs = t
ccs += cc

if return_cumulative:
# Store the accurate cumulative sum
cumulative_sums.append(sum_val + cs + ccs)

if return_cumulative:
return cumulative_sums
else:
return sum_val + cs + ccs


def kahan_babushka_sum(values):
"""Calculate an accurate sum using the second-order iterative Kahan-Babuška algorithm.
This implementation follows Klein (2006) to provide high-precision summation
that avoids floating-point precision errors when summing many small values.
- https://en.wikipedia.org/wiki/Kahan_summation_algorithm#Further_enhancements
- Klein (2006). "A generalized Kahan–Babuška-Summation-Algorithm". in
Computing. 76 (3–4). Springer-Verlag: 279–293. doi:10.1007/s00607-005-0139-x
The algorithm uses second-order compensation for lost low-order bits during
floating-point addition, providing significantly better accuracy than naive
summation when dealing with large numbers of small values (e.g., many ds
values in a long lattice).
Args:
values: Iterable of numeric values to sum
Returns:
float: Accurate sum of all values
"""
return _kahan_babushka_core(values, return_cumulative=False)


def kahan_babushka_cumsum(values):
"""Calculate an accurate cumulative sum using the second-order iterative Kahan-Babuska algorithm.
This implementation follows Klein (2006) to provide high-precision summation
that avoids floating-point precision errors when summing many small values.
- https://en.wikipedia.org/wiki/Kahan_summation_algorithm#Further_enhancements
- Klein (2006). "A generalized Kahan–Babuška-Summation-Algorithm". in
Computing. 76 (3–4). Springer-Verlag: 279–293. doi:10.1007/s00607-005-0139-x
The algorithm uses second-order compensation for lost low-order bits during
floating-point addition, providing significantly better accuracy than naive
summation when dealing with large numbers of small values (e.g., many ds
values in a long lattice).
Args:
values: Iterable of numeric values to cumulatively sum
Returns:
list: List of cumulative sums with initial 0.0 prepended
"""
return _kahan_babushka_core(values, return_cumulative=True)
1 change: 1 addition & 0 deletions src/python/impactx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# import core bindings to C++
from . import impactx_pybind
from .impactx_pybind import * # noqa
from .Kahan import kahan_babushka_cumsum, kahan_babushka_sum # noqa
from .madx_to_impactx import read_beam # noqa

__version__ = impactx_pybind.__version__
Expand Down
131 changes: 113 additions & 18 deletions src/python/impactx/extensions/KnownElementsList.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import os
import re

from impactx import elements
from impactx import elements, kahan_babushka_sum


def load_file(self, filename, nslice=1):
Expand Down Expand Up @@ -124,6 +124,7 @@ def select(
*,
kind=None,
name=None,
s=None,
):
"""Apply filtering to this filtered list.

Expand All @@ -146,10 +147,14 @@ def select(
Examples: "quad1", r"quad\d+", ["quad1", "quad2"], [r"quad\d+", "bend1"]
:type name: str or list[str] or tuple[str, ...] or None, optional

:param s: Position range to filter by. Tuple/list with (lower, upper) bounds where None represents open-ended.
Elements are selected if ANY part overlaps with the range. Examples: (1.0, 5.0) for range 1.0 <= s <= 5.0, (1.0, None) for s >= 1.0, (None, 5.0) for s <= 5.0
:type s: tuple[float | None, float | None] or list[float | None] or None, optional

:return: FilteredElementsList containing references to original elements
:rtype: FilteredElementsList

:raises TypeError: If kind/name parameters have wrong types
:raises TypeError: If kind/name/s parameters have wrong types

**Examples:**

Expand All @@ -169,18 +174,26 @@ def select(
strong_quads = quad_elements.select(
name=r"quad\d+"
) # Filter quads by regex pattern

# Position-based filtering (always calculated from original lattice)
early_elements = lattice.select(s=(None, 2.0)) # Elements with s <= 2.0
drift_then_early = lattice.select(kind="Drift").select(
s=(1.0, None)
) # Drift elements with s >= 1.0
"""
# Apply filtering directly to the indices we already have
if kind is not None or name is not None:
if kind is not None or name is not None or s is not None:
# Validate parameters
_validate_select_parameters(kind, name)
_validate_select_parameters(kind, name, s)

matching_indices = []

for i in self._indices:
element = self._original_list[i]
if _check_element_match(element, kind, name):
matching_indices.append(i)
for original_idx in self._indices:
element = self._original_list[original_idx]
if _check_element_match(
element, kind, name, s, original_idx, self._original_list
):
matching_indices.append(original_idx)

return FilteredElementsList(self._original_list, matching_indices)

Expand Down Expand Up @@ -249,12 +262,13 @@ def _matches_string(text: str, string_pattern: str) -> bool:
return text == string_pattern


def _validate_select_parameters(kind, name):
def _validate_select_parameters(kind, name, s):
"""Validate parameters for select methods.

Args:
kind: Element type(s) to filter by
name: Element name(s) to filter by
s: Position range to filter by

Raises:
TypeError: If parameters have wrong types
Expand All @@ -279,6 +293,20 @@ def _validate_select_parameters(kind, name):
"'name' parameter must be a string or list/tuple of strings"
)

if s is not None:
if isinstance(s, (list, tuple)):
if len(s) != 2:
raise TypeError(
"'s' parameter must have exactly 2 elements (lower, upper)"
)
for bound in s:
if bound is not None and not isinstance(bound, (int, float)):
raise TypeError("'s' parameter bounds must be numbers or None")
else:
raise TypeError(
"'s' parameter must be a tuple/list with 2 elements (lower, upper)"
)


def _matches_kind_pattern(element, kind_pattern):
"""Check if an element matches a kind pattern.
Expand Down Expand Up @@ -316,13 +344,57 @@ def _matches_name_pattern(element, name_pattern):
)


def _check_element_match(element, kind, name):
"""Check if an element matches the given kind and name criteria.
def _matches_s_position(element, s_range, element_s):
"""Check if an element's integrated position matches the s range criteria.

An element matches if ANY part of it overlaps with the specified range.
The element spans from element_s to element_s + element.ds.

Args:
element: The element to check
s_range: Tuple/list of (lower, upper) bounds. None represents open-ended.
element_s: The cumulative position of the element (calculated externally)

Returns:
bool: True if element's position overlaps with the range
"""
if s_range is None:
return True

# Convert to tuple if it's a list
if isinstance(s_range, list):
s_range = tuple(s_range)

if not isinstance(s_range, tuple) or len(s_range) != 2:
raise TypeError(
"'s' parameter must be a tuple/list with 2 elements (lower, upper)"
)

lower, upper = s_range

# Element spans from element_s to element_s + element.ds
element_start = element_s
element_end = element_s + element.ds

# Check if any part of the element overlaps with the range
if lower is not None and element_end < lower:
return False
if upper is not None and element_start > upper:
return False

return True


def _check_element_match(element, kind, name, s, element_index, lattice):
"""Check if an element matches the given kind, name, and s criteria.

Args:
element: The element to check
kind: Kind criteria (str, type, list, tuple, or None)
name: Name criteria (str, list, tuple, or None)
s: Position criteria (tuple/list with 2 elements, or None)
element_index: Index of the element in the lattice
lattice: The full lattice to calculate cumulative positions

Returns:
bool: True if element matches any criteria (OR logic)
Expand Down Expand Up @@ -355,6 +427,15 @@ def _check_element_match(element, kind, name):
match = True
break

# Check for 's' parameter (only if neither kind nor name matched - OR logic)
if s is not None and not match:
# Calculate cumulative position up to this element using accurate summation
ds_values = [lattice[i].ds for i in range(element_index)]
cumulative_s = kahan_babushka_sum(ds_values)

if _matches_s_position(element, s, cumulative_s):
match = True

return match


Expand All @@ -363,16 +444,17 @@ def select(
*,
kind=None,
name=None,
s=None,
) -> FilteredElementsList:
"""Filter elements by type and name with OR-based logic.
"""Filter elements by type, name, and position with OR-based logic.

This method supports filtering elements by their type and/or name using keyword arguments.
This method supports filtering elements by their type, name, and/or integrated position using keyword arguments.
Returns references to original elements, allowing modification and chaining.

**Filtering Logic:**

- **Within a single filter**: OR logic (e.g., ``kind=["Drift", "Quad"]`` matches Drift OR Quad)
- **Between different filters**: OR logic (e.g., ``kind="Quad", name="quad1"`` matches Quad OR named "quad1")
- **Between different filters**: OR logic (e.g., ``kind="Quad", name="quad1", s=(1.0, 5.0)`` matches Quad OR named "quad1" OR in position range)
- **Chaining filters**: AND logic (e.g., ``lattice.select(kind="Drift").select(name="drift1")`` matches Drift AND named "drift1")

:param kind: Element type(s) to filter by. Can be a single string/type or a list/tuple
Expand All @@ -385,10 +467,14 @@ def select(
Examples: "quad1", r"quad\d+", ["quad1", "quad2"], [r"quad\d+", "bend1"]
:type name: str or list[str] or tuple[str, ...] or None, optional

:param s: Position range to filter by. Tuple/list with (lower, upper) bounds where None represents open-ended.
Elements are selected if ANY part overlaps with the range. Examples: (1.0, 5.0) for range 1.0 <= s <= 5.0, (1.0, None) for s >= 1.0, (None, 5.0) for s <= 5.0
:type s: tuple[float | None, float | None] or list[float | None] or None, optional

:return: FilteredElementsList containing references to original elements
:rtype: FilteredElementsList

:raises TypeError: If kind/name parameters have wrong types
:raises TypeError: If kind/name/s parameters have wrong types

**Examples:**

Expand Down Expand Up @@ -423,6 +509,15 @@ def select(
lattice.select(name=r"quad\d+") # Get elements matching pattern
lattice.select(name=[r"quad\d+", "bend1"]) # Mix regex and strings

Position-based filtering:

.. code-block:: python

lattice.select(s=(1.0, 5.0)) # Elements that overlap with range 1.0 <= s <= 5.0
lattice.select(s=(1.0, None)) # Elements that overlap with s >= 1.0
lattice.select(s=(None, 5.0)) # Elements that overlap with s <= 5.0
lattice.select(kind="Drift", s=(0.0, 2.0)) # Drift elements OR overlapping range 0.0 <= s <= 2.0

Chaining filters (AND logic between chained calls):

.. code-block:: python
Expand Down Expand Up @@ -455,14 +550,14 @@ def select(
"""

# Handle keyword arguments for filtering
if kind is not None or name is not None:
if kind is not None or name is not None or s is not None:
# Validate parameters
_validate_select_parameters(kind, name)
_validate_select_parameters(kind, name, s)

matching_indices = []

for i, element in enumerate(self):
if _check_element_match(element, kind, name):
if _check_element_match(element, kind, name, s, i, self):
matching_indices.append(i)

return FilteredElementsList(self, matching_indices)
Expand Down
Loading
Loading