2727 ObservableTransformation ,
2828 _default_simplify ,
2929 amici_time_symbol ,
30+ symbol_with_assumptions ,
3031 toposort_symbols ,
3132 unique_preserve_order ,
3233)
@@ -2499,26 +2500,67 @@ def _process_hybridization(self, hybridization: dict) -> None:
24992500 added_expressions = False
25002501 orig_obs = tuple ([s .get_sym () for s in self ._observables ])
25012502 for net_id , net in hybridization .items ():
2502- if net ["static " ]:
2503+ if net ["pre_initialization " ]:
25032504 # do not integrate into ODEs, handle in amici.sim.jax.petab
25042505 continue
2505- inputs = [
2506- comp
2507- for comp in self ._components
2508- if str (comp .get_sym ()) in net ["input_vars" ]
2509- ]
2510- # sort inputs by order in input_vars
2511- inputs = sorted (
2512- inputs ,
2513- key = lambda comp : net ["input_vars" ].index (str (comp .get_sym ())),
2514- )
2506+ comp_by_sym = {comp .get_id (): comp for comp in self ._components }
2507+ sym_locals = {s : comp .get_sym () for s , comp in comp_by_sym .items ()}
2508+
2509+ inputs = []
2510+ unresolved_vars = []
2511+ for input_var in net ["input_vars" ]:
2512+ if input_var in comp_by_sym :
2513+ inputs .append (comp_by_sym [input_var ])
2514+ else :
2515+ try :
2516+ expr = sp .sympify (input_var , locals = sym_locals )
2517+ except (sp .SympifyError , Exception ):
2518+ unresolved_vars .append (input_var )
2519+ continue
2520+
2521+ if {str (s ) for s in expr .free_symbols } - set (comp_by_sym ):
2522+ unresolved_vars .append (input_var )
2523+ continue
2524+
2525+ expr_sym = symbol_with_assumptions (
2526+ f"_nn_{ net_id } _input_{ len (inputs )} "
2527+ )
2528+ new_expr_comp = Expression (
2529+ symbol = expr_sym ,
2530+ name = f"{ net_id } _input_{ len (inputs )} " ,
2531+ value = expr ,
2532+ )
2533+ self .add_component (new_expr_comp )
2534+ added_expressions = True
2535+ inputs .append (new_expr_comp )
2536+
25152537 if len (inputs ) != len (net ["input_vars" ]):
2516- found_vars = {str (comp .get_sym ()) for comp in inputs }
2517- missing_vars = set (net ["input_vars" ]) - found_vars
2518- raise ValueError (
2519- f"Could not find all input variables for neural network { net_id } . "
2520- f"Missing variables: { sorted (missing_vars )} "
2521- )
2538+ missing_vars = set (unresolved_vars )
2539+ if missing_vars == {"array" }:
2540+ array_inputs = net .get ("array_inputs" , {})
2541+ petab_ids = list (array_inputs .keys ())
2542+ for i , input_var in enumerate (net ["input_vars" ]):
2543+ if input_var == "array" :
2544+ if not petab_ids :
2545+ raise ValueError (
2546+ f"Array input specified for { net_id } but no "
2547+ f"array_inputs info provided in hybridization."
2548+ )
2549+ petab_id = petab_ids .pop (0 )
2550+ array_sym = symbol_with_assumptions (
2551+ f"_nn_array_{ petab_id } "
2552+ )
2553+ array_comp = Expression (
2554+ symbol = array_sym ,
2555+ name = f"{ net_id } _array_{ petab_id } " ,
2556+ value = sp .Integer (0 ),
2557+ )
2558+ inputs .insert (i , array_comp )
2559+ else :
2560+ raise ValueError (
2561+ f"Could not find all input variables for neural network { net_id } . "
2562+ f"Missing variables: { sorted (missing_vars )} "
2563+ )
25222564 for inp in inputs :
25232565 if isinstance (
25242566 inp ,
@@ -2547,12 +2589,13 @@ def _process_hybridization(self, hybridization: dict) -> None:
25472589 f"Could not find all output variables for neural network { net_id } . "
25482590 f"Missing variables: { sorted (missing_vars )} "
25492591 )
2550-
25512592 for out_var , parts in outputs .items ():
25522593 comp = parts ["comp" ]
25532594 # remove output from model components
25542595 if isinstance (comp , FreeParameter ):
25552596 self ._free_parameters .remove (comp )
2597+ elif isinstance (comp , FixedParameter ):
2598+ self ._fixed_parameters .remove (comp )
25562599 elif isinstance (comp , Expression ):
25572600 self ._expressions .remove (comp )
25582601 elif isinstance (comp , DifferentialState ):
@@ -2586,7 +2629,7 @@ def _process_hybridization(self, hybridization: dict) -> None:
25862629 added_expressions = True
25872630
25882631 observables = {
2589- ob_var : {"comp" : comp , "ind" : net ["observable_vars" ][ob_var ]}
2632+ ob_var : {"comp" : comp , ** net ["observable_vars" ][ob_var ]}
25902633 for comp in self ._components
25912634 if (ob_var := str (comp .get_sym ())) in net ["observable_vars" ]
25922635 # # TODO: SYNTAX NEEDS to CHANGE
@@ -2609,9 +2652,19 @@ def _process_hybridization(self, hybridization: dict) -> None:
26092652 raise ValueError (
26102653 f"{ comp .get_name ()} ({ type (comp )} ) is not an observable."
26112654 )
2612- out_val = sp .Function (net_id )(
2613- * [input .get_sym () for input in inputs ], parts ["ind " ]
2655+ nn_call = sp .Function (net_id )(
2656+ * [input .get_sym () for input in inputs ], parts ["index " ]
26142657 )
2658+ formula = parts ["formula" ]
2659+ petab_id = parts ["petab_id" ]
2660+ if formula == petab_id :
2661+ out_val = nn_call
2662+ else :
2663+ from petab .math .sympify import sympify_petab
2664+
2665+ out_val = sympify_petab (formula ).subs (
2666+ symbol_with_assumptions (petab_id ), nn_call
2667+ )
26152668 # add to the model
26162669 self .add_component (
26172670 Observable (
0 commit comments