Skip to content

Commit c86a19a

Browse files
BSnellingdweindl
andauthored
Update PEtab SciML support to PEtab v2 (#3165)
* bank changes - values ok - gradients broken * sciml tests passing (one skipped) * fix petab test suite * refactor sweep - for readability mostly * delete unused code * fix updating parameters in notebook for unscaled * update test_sciml tests * Apply suggestions from code review Co-authored-by: Daniel Weindl <dweindl@users.noreply.github.com> * implement feedback from code review * revert out_val mapping * try sympify petab again --------- Co-authored-by: Daniel Weindl <dweindl@users.noreply.github.com>
1 parent 9fc9a41 commit c86a19a

File tree

13 files changed

+879
-366
lines changed

13 files changed

+879
-366
lines changed

.github/workflows/test_petab_sciml.yml.FIXME renamed to .github/workflows/test_petab_sciml.yml

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
name: PEtab SciML
2-
# on:
3-
# push:
4-
# branches:
5-
# - main
6-
# - 'release*'
7-
# pull_request:
8-
# branches:
9-
# - main
10-
# merge_group:
11-
# workflow_dispatch:
2+
on:
3+
push:
4+
branches:
5+
- main
6+
- 'release*'
7+
pull_request:
8+
branches:
9+
- main
10+
merge_group:
11+
workflow_dispatch:
1212

1313
jobs:
1414
build:
@@ -33,11 +33,10 @@ jobs:
3333
with:
3434
fetch-depth: 20
3535

36-
# todo, update after https://github.com/sebapersson/petab_sciml_testsuite/issues/14 is merged
3736
- name: Download PEtab SciML test suite
3837
run: |
3938
git clone --depth 1 --branch main \
40-
https://github.com/FFroehlich/petab_sciml_testsuite \
39+
https://github.com/PEtab-dev/petab_sciml_testsuite.git \
4140
tests/sciml/testsuite
4241
4342
- name: Install apt dependencies
@@ -74,7 +73,7 @@ jobs:
7473
run: |
7574
source ./venv/bin/activate \
7675
&& python3 -m pip uninstall -y petab \
77-
&& python3 -m pip install git+https://github.com/petab-dev/libpetab-python.git@sciml
76+
&& python3 -m pip install git+https://github.com/petab-dev/libpetab-python.git@main
7877
7978
- name: Run PEtab SciML testsuite
8079
run: |

doc/examples/example_jax_petab/ExampleJaxPEtab.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@
228228
"import jax\n",
229229
"\n",
230230
"# Generate random noise to update the parameters\n",
231-
"noise = (\n",
231+
"noise = jax.numpy.exp(\n",
232232
" jax.random.normal(\n",
233233
" key=jax.random.PRNGKey(0), shape=jax_problem.parameters.shape\n",
234234
" )\n",
@@ -260,7 +260,7 @@
260260
"outputs": [],
261261
"source": [
262262
"# Update the parameters and create a new JAXProblem instance\n",
263-
"jax_problem = jax_problem.update_parameters(jax_problem.parameters + noise)\n",
263+
"jax_problem = jax_problem.update_parameters(jax_problem.parameters * noise)\n",
264264
"\n",
265265
"# Run simulations with the updated parameters\n",
266266
"llh, results = run_simulations(jax_problem)\n",

python/sdist/amici/_symbolic/de_model.py

Lines changed: 74 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
ObservableTransformation,
2828
_default_simplify,
2929
amici_time_symbol,
30+
symbol_with_assumptions,
3031
toposort_symbols,
3132
unique_preserve_order,
3233
)
@@ -2499,26 +2500,67 @@ def _process_hybridization(self, hybridization: dict) -> None:
24992500
added_expressions = False
25002501
orig_obs = tuple([s.get_sym() for s in self._observables])
25012502
for net_id, net in hybridization.items():
2502-
if net["static"]:
2503+
if net["pre_initialization"]:
25032504
# do not integrate into ODEs, handle in amici.sim.jax.petab
25042505
continue
2505-
inputs = [
2506-
comp
2507-
for comp in self._components
2508-
if str(comp.get_sym()) in net["input_vars"]
2509-
]
2510-
# sort inputs by order in input_vars
2511-
inputs = sorted(
2512-
inputs,
2513-
key=lambda comp: net["input_vars"].index(str(comp.get_sym())),
2514-
)
2506+
comp_by_sym = {comp.get_id(): comp for comp in self._components}
2507+
sym_locals = {s: comp.get_sym() for s, comp in comp_by_sym.items()}
2508+
2509+
inputs = []
2510+
unresolved_vars = []
2511+
for input_var in net["input_vars"]:
2512+
if input_var in comp_by_sym:
2513+
inputs.append(comp_by_sym[input_var])
2514+
else:
2515+
try:
2516+
expr = sp.sympify(input_var, locals=sym_locals)
2517+
except (sp.SympifyError, Exception):
2518+
unresolved_vars.append(input_var)
2519+
continue
2520+
2521+
if {str(s) for s in expr.free_symbols} - set(comp_by_sym):
2522+
unresolved_vars.append(input_var)
2523+
continue
2524+
2525+
expr_sym = symbol_with_assumptions(
2526+
f"_nn_{net_id}_input_{len(inputs)}"
2527+
)
2528+
new_expr_comp = Expression(
2529+
symbol=expr_sym,
2530+
name=f"{net_id}_input_{len(inputs)}",
2531+
value=expr,
2532+
)
2533+
self.add_component(new_expr_comp)
2534+
added_expressions = True
2535+
inputs.append(new_expr_comp)
2536+
25152537
if len(inputs) != len(net["input_vars"]):
2516-
found_vars = {str(comp.get_sym()) for comp in inputs}
2517-
missing_vars = set(net["input_vars"]) - found_vars
2518-
raise ValueError(
2519-
f"Could not find all input variables for neural network {net_id}. "
2520-
f"Missing variables: {sorted(missing_vars)}"
2521-
)
2538+
missing_vars = set(unresolved_vars)
2539+
if missing_vars == {"array"}:
2540+
array_inputs = net.get("array_inputs", {})
2541+
petab_ids = list(array_inputs.keys())
2542+
for i, input_var in enumerate(net["input_vars"]):
2543+
if input_var == "array":
2544+
if not petab_ids:
2545+
raise ValueError(
2546+
f"Array input specified for {net_id} but no "
2547+
f"array_inputs info provided in hybridization."
2548+
)
2549+
petab_id = petab_ids.pop(0)
2550+
array_sym = symbol_with_assumptions(
2551+
f"_nn_array_{petab_id}"
2552+
)
2553+
array_comp = Expression(
2554+
symbol=array_sym,
2555+
name=f"{net_id}_array_{petab_id}",
2556+
value=sp.Integer(0),
2557+
)
2558+
inputs.insert(i, array_comp)
2559+
else:
2560+
raise ValueError(
2561+
f"Could not find all input variables for neural network {net_id}. "
2562+
f"Missing variables: {sorted(missing_vars)}"
2563+
)
25222564
for inp in inputs:
25232565
if isinstance(
25242566
inp,
@@ -2547,12 +2589,13 @@ def _process_hybridization(self, hybridization: dict) -> None:
25472589
f"Could not find all output variables for neural network {net_id}. "
25482590
f"Missing variables: {sorted(missing_vars)}"
25492591
)
2550-
25512592
for out_var, parts in outputs.items():
25522593
comp = parts["comp"]
25532594
# remove output from model components
25542595
if isinstance(comp, FreeParameter):
25552596
self._free_parameters.remove(comp)
2597+
elif isinstance(comp, FixedParameter):
2598+
self._fixed_parameters.remove(comp)
25562599
elif isinstance(comp, Expression):
25572600
self._expressions.remove(comp)
25582601
elif isinstance(comp, DifferentialState):
@@ -2586,7 +2629,7 @@ def _process_hybridization(self, hybridization: dict) -> None:
25862629
added_expressions = True
25872630

25882631
observables = {
2589-
ob_var: {"comp": comp, "ind": net["observable_vars"][ob_var]}
2632+
ob_var: {"comp": comp, **net["observable_vars"][ob_var]}
25902633
for comp in self._components
25912634
if (ob_var := str(comp.get_sym())) in net["observable_vars"]
25922635
# # TODO: SYNTAX NEEDS to CHANGE
@@ -2609,9 +2652,19 @@ def _process_hybridization(self, hybridization: dict) -> None:
26092652
raise ValueError(
26102653
f"{comp.get_name()} ({type(comp)}) is not an observable."
26112654
)
2612-
out_val = sp.Function(net_id)(
2613-
*[input.get_sym() for input in inputs], parts["ind"]
2655+
nn_call = sp.Function(net_id)(
2656+
*[input.get_sym() for input in inputs], parts["index"]
26142657
)
2658+
formula = parts["formula"]
2659+
petab_id = parts["petab_id"]
2660+
if formula == petab_id:
2661+
out_val = nn_call
2662+
else:
2663+
from petab.math.sympify import sympify_petab
2664+
2665+
out_val = sympify_petab(formula).subs(
2666+
symbol_with_assumptions(petab_id), nn_call
2667+
)
26152668
# add to the model
26162669
self.add_component(
26172670
Observable(

python/sdist/amici/exporters/jax/jax.template.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ def __init__(self):
2121
self.jax_py_file = Path(__file__).resolve()
2222
self.nns = {TPL_NETS}
2323
self.parameters = TPL_ALL_P_VALUES
24+
TPL_ARRAY_INPUTS_INIT
25+
self._array_input_index = jnp.int32(0)
2426
super().__init__()
2527

2628
def _xdot(self, t, x, args):

python/sdist/amici/exporters/jax/jaxcodeprinter.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,44 @@ def _print_Mul(self, expr: sp.Expr) -> str:
4444
return f"safe_div({self.doprint(numer)}, {self.doprint(denom)})"
4545

4646
def _print_Function(self, expr: sp.Expr) -> str:
47-
if isinstance(expr.func, UndefinedFunction):
48-
return f"self.nns['{expr.func.__name__}'].forward(jnp.array([{', '.join(self.doprint(a) for a in expr.args[:-1])}]))[{expr.args[-1]}]"
49-
else:
47+
if not isinstance(expr.func, UndefinedFunction):
5048
return super()._print_Function(expr)
5149

50+
nn_name = expr.func.__name__
51+
input_args = expr.args[:-1]
52+
output_idx = expr.args[-1]
53+
forward_arg = self._build_nn_forward_arg(input_args)
54+
return f"self.nns['{nn_name}'].forward({forward_arg})[{output_idx}]"
55+
56+
def _build_nn_forward_arg(self, input_args) -> str:
57+
array_prefix = "_nn_array_"
58+
if not any(
59+
hasattr(a, "name") and a.name.startswith(array_prefix)
60+
for a in input_args
61+
):
62+
return f"jnp.array([{', '.join(self.doprint(a) for a in input_args)}])"
63+
64+
groups = self._group_nn_inputs(input_args, array_prefix)
65+
return groups[0] if len(groups) == 1 else f"[{', '.join(groups)}]"
66+
67+
def _group_nn_inputs(self, input_args, array_prefix: str) -> list[str]:
68+
groups = []
69+
scalar_group = []
70+
for a in input_args:
71+
if hasattr(a, "name") and a.name.startswith(array_prefix):
72+
if scalar_group:
73+
groups.append(f"jnp.array([{', '.join(scalar_group)}])")
74+
scalar_group = []
75+
petab_id = a.name[len(array_prefix) :]
76+
groups.append(
77+
f"self._array_inputs['{petab_id}'][self._array_input_index]"
78+
)
79+
else:
80+
scalar_group.append(self.doprint(a))
81+
if scalar_group:
82+
groups.append(f"jnp.array([{', '.join(scalar_group)}])")
83+
return groups
84+
5285
def _print_Max(self, expr: sp.Expr) -> str:
5386
"""
5487
Print the max function, replacing it with jnp.max.

python/sdist/amici/exporters/jax/nn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def generate_equinox(
136136
filename.parent.mkdir(parents=True, exist_ok=True)
137137

138138
apply_template(
139-
Path(amiciModulePath) / "jax" / "nn.template.py",
139+
Path(amiciModulePath) / "exporters" / "jax" / "nn.template.py",
140140
filename,
141141
tpl_data,
142142
)
@@ -184,7 +184,7 @@ def _generate_layer(layer: "Layer", indent: int, ilayer: int) -> str: # noqa: F
184184
layer_map = {
185185
"Dropout1d": "eqx.nn.Dropout",
186186
"Dropout2d": "eqx.nn.Dropout",
187-
"Flatten": "amici.export.jax.Flatten",
187+
"Flatten": "amici.exporters.jax.Flatten",
188188
}
189189

190190
# mapping of keyword argument names in sciml yaml format to equinox/custom amici implementations
@@ -320,9 +320,9 @@ def _process_activation_call(node: "Node") -> str: # noqa: F821
320320
"hardtanh": "jax.nn.hard_tanh",
321321
"hardsigmoid": "jax.nn.hard_sigmoid",
322322
"hardswish": "jax.nn.hard_swish",
323-
"tanhshrink": "amici.export.jax.tanhshrink",
323+
"tanhshrink": "amici.exporters.jax.tanhshrink",
324324
"softsign": "jax.nn.soft_sign",
325-
"cat": "amici.export.jax.cat",
325+
"cat": "amici.exporters.jax.cat",
326326
}
327327

328328
# Validate hardtanh parameters

python/sdist/amici/exporters/jax/nn.template.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
# ruff: noqa: F401, F821, F841
22
import equinox as eqx
3+
import jax
34
import jax.random as jr
45

6+
import amici
7+
58

69
class TPL_MODEL_ID(eqx.Module):
710
layers: dict

python/sdist/amici/exporters/jax/ode_export.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import os
1717
from pathlib import Path
1818

19+
import numpy as np
1920
import sympy as sp
2021

2122
from amici._symbolic.de_model import DEModel
@@ -247,6 +248,8 @@ def _generate_jax_code(self) -> None:
247248
)
248249
subs = subs_heaviside | subs_observables
249250

251+
array_inputs_init = self._generate_array_inputs_init()
252+
250253
tpl_data = {
251254
# assign named variable using corresponding algebraic formula (function body)
252255
**_jax_variable_equations(
@@ -298,6 +301,7 @@ def _generate_jax_code(self) -> None:
298301
f'"{net}": {net}.net(jr.PRNGKey(0))'
299302
for net in self.hybridization.keys()
300303
),
304+
"ARRAY_INPUTS_INIT": array_inputs_init,
301305
}
302306

303307
apply_template(
@@ -306,6 +310,42 @@ def _generate_jax_code(self) -> None:
306310
tpl_data,
307311
)
308312

313+
def _generate_array_inputs_init(self) -> str:
314+
"""Generate code to initialize array inputs from HDF5 data.
315+
316+
Reads array data from HDF5 files at code-generation time and
317+
embeds them as jnp.array constants in the generated model code.
318+
319+
:return:
320+
Python code string for initializing self._array_inputs
321+
"""
322+
array_entries = []
323+
for _, net in self.hybridization.items():
324+
if net.get("pre_initialization", False):
325+
continue
326+
for petab_id, hdf5_path in net.get("array_inputs", {}).items():
327+
import h5py
328+
329+
with h5py.File(hdf5_path, "r") as f:
330+
group = f["inputs"][petab_id]
331+
keys = sorted(group.keys())
332+
# Stack with experiment dimension first
333+
data = np.stack([group[k][:] for k in keys])
334+
rows = []
335+
for row in data:
336+
rows.append(
337+
"[" + ", ".join(str(float(v)) for v in row) + "]"
338+
)
339+
data_str = ", ".join(rows)
340+
341+
array_entries.append(f"'{petab_id}': jnp.array([{data_str}])")
342+
343+
if not array_entries:
344+
return "self._array_inputs = {}"
345+
346+
entries_str = ", ".join(array_entries)
347+
return f"self._array_inputs = {{{entries_str}}}"
348+
309349
def _get_all_p_syms(self) -> list[sp.Symbol]:
310350
return list(self.model.sym("p")) + list(self.model.sym("k"))
311351

0 commit comments

Comments
 (0)