Skip to content

Commit

Permalink
fixes to compile SQ_Soil_Temperature
Browse files Browse the repository at this point in the history
- prevent inclusion of duplicate functions
- prevent name collision of user defined variables had the same name as the generated structures
  • Loading branch information
bergm committed Apr 19, 2024
1 parent 10eb725 commit 578bdff
Showing 1 changed file with 62 additions and 35 deletions.
97 changes: 62 additions & 35 deletions src/pycropml/transpiler/generators/cppGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def __init__(self, tree, model=None, name=None):
self.z = middleware(self.tree)
self.z.transform(self.tree)
self.name = name
if self.model:
self.cpp_unique_functions = set()
self.cpp_struct_names = {"s": "s", "s1": "s1", "r": "r", "a": "a", "ex": "ex"}
if self.model:
self.doc = DocGenerator(model, '//')
self.generator = CppTrans([model])
self.generator.model2Node()
Expand Down Expand Up @@ -446,12 +448,18 @@ def visit_function_definition(self, node):
#print(self.funcname)
z = self.add_features(node)
if not node.name.startswith("model_") and not node.name.startswith("init_"):
# self.templateArr(node.params)
self.visit_decl(node.return_type) if node.return_type else self.write("void")
if self.model:
self.write(f" {self.model.name}::{node.name}(")
func_name = f"{self.model.name}::{node.name}" if self.model else f"{node.name}"
func_signature = (func_name,
tuple(map(lambda t: tuple(t) if isinstance(t, list) else t, node.return_type) if isinstance(
node.return_type, list) else node.return_type),
tuple(map(lambda x: (
tuple(x.pseudo_type) if isinstance(x.pseudo_type, list) else x.pseudo_type, x.name), node.params)))
if func_signature in self.cpp_unique_functions:
return
else:
self.write(f" {node.name}(")
self.cpp_unique_functions.add(func_signature)
self.visit_decl(node.return_type) if node.return_type else self.write("void")
self.write(f" {func_name}(")
for i, pa in enumerate(node.params):
#print(pa.name, pa.feat)
# if pa.name in self.array_parameter(node.params)[0].values(): continue
Expand All @@ -468,15 +476,25 @@ def visit_function_definition(self, node):
self.write(f"{self.model.name}::{self.model.name}() {{}}")
self.newline(node)
if self.node_param and not node.name.startswith("init_"):
self.getter(self.model.name,self.node_param)
self.getter(self.model.name, self.node_param)
self.newline(1)
self.setter(self.model.name,self.node_param)
self.setter(self.model.name, self.node_param)
self.newline(1)
param_names = list(map(lambda x: x.name, self.params))
unique_code_struct_names = False
while not unique_code_struct_names:
unique_code_struct_names = True
for sn, code_sn in self.cpp_struct_names.items():
if code_sn in param_names:
self.cpp_struct_names[sn] = f"{code_sn}_"
unique_code_struct_names = False
if node.name.startswith("init_"):
self.write(f"void {self.model.name}::Init(")
else:
self.write(f"void {self.model.name}::Calculate_Model(")
self.write(f'{self.name}State &s, {self.name}State &s1, {self.name}Rate &r, {self.name}Auxiliary &a, {self.name}Exogenous &ex)')
self.write(f"{self.name}State &{self.cpp_struct_names['s']}, {self.name}State &{self.cpp_struct_names['s1']}, "
f"{self.name}Rate &{self.cpp_struct_names['r']}, {self.name}Auxiliary &{self.cpp_struct_names['a']}, "
f"{self.name}Exogenous &{self.cpp_struct_names['ex']})")
self.newline(node)
self.write('{')
self.newline(node)
Expand All @@ -501,7 +519,7 @@ def visit_function_definition(self, node):
self.write(f" {arg.name}")
if node.name.startswith("init_"):
if arg.name in self.exogenous:
self.write(f" = ex.get{arg.name}()")
self.write(f" = {self.cpp_struct_names['ex']}.get{arg.name}()")
elif arg.pseudo_type[0] == "list":
self.write(f" = std::vector<{self.types[arg.pseudo_type[1]]}>()")
elif arg.pseudo_type[0] == "array":
Expand All @@ -513,15 +531,15 @@ def visit_function_definition(self, node):
else:
# make left hand side a reference to the result in case of lists and arrays
if arg.name in self.states and not arg.name.endswith("_t1"):
self.write(f" = s.get{arg.name}()")
self.write(f" = {self.cpp_struct_names['s']}.get{arg.name}()")
elif arg.name.endswith("_t1") and arg.name in self.states:
self.write(f" = s1.get{arg.name[:-3]}()")
self.write(f" = {self.cpp_struct_names['s1']}.get{arg.name[:-3]}()")
elif arg.name in self.rates:
self.write(f" = r.get{arg.name}()")
self.write(f" = {self.cpp_struct_names['r']}.get{arg.name}()")
elif arg.name in self.auxiliary:
self.write(f" = a.get{arg.name}()")
self.write(f" = {self.cpp_struct_names['a']}.get{arg.name}()")
elif arg.name in self.exogenous:
self.write(f" = ex.get{arg.name}()")
self.write(f" = {self.cpp_struct_names['ex']}.get{arg.name}()")
self.write(";")
self.indentation -= 1
self.body(node.block)
Expand Down Expand Up @@ -586,13 +604,13 @@ def visit_return(self, node):
if arg.feat in ("OUT", "INOUT"):
self.newline(node)
if arg.name in self.states:
self.write(f"s.set{arg.name}({arg.name});")
self.write(f"{self.cpp_struct_names['s']}.set{arg.name}({arg.name});")
if arg.name in self.rates:
self.write(f"r.set{arg.name}({arg.name});")
self.write(f"{self.cpp_struct_names['r']}.set{arg.name}({arg.name});")
if arg.name in self.auxiliary:
self.write(f"a.set{arg.name}({arg.name});")
self.write(f"{self.cpp_struct_names['a']}.set{arg.name}({arg.name});")
if arg.name in self.exogenous:
self.write(f"ex.set{arg.name}({arg.name});")
self.write(f"{self.cpp_struct_names['ex']}.set{arg.name}({arg.name});")
else:
self.newline(node)
self.indentation += 1
Expand All @@ -609,7 +627,7 @@ def visit_tuple(self,node):
self.write(")")

def visit_datetime(self, node):
self.write("'%s/%s/%s'"%(node.value[0].value,node.value[1].value,node.value[2].value))
self.write(f"'{node.value[0].value}/{node.value[1].value}/{node.value[0].value}'")

def visit_str(self, node):
self.safe_double(node)
Expand Down Expand Up @@ -1022,17 +1040,24 @@ def public_hpp(self, node, typ, mc=None, h=None, init=False, iscompo=False):
self.write(f"void Init({mc}State &s, {mc}State &s1, {mc}Rate &r, {mc}Auxiliary &a, {mc}Exogenous &ex);")

if h: # function externs
for i in h:
self.newline(1)
self.visit_decl(list(i.values())[0][0])
self.write(f" {list(i.keys())[0]}(")
x = list(i.values())[0][1]
for pa in x:
self.visit_decl(pa.pseudo_type)
self.write(f" {pa.name}")
if pa != x[len(x)-1]:
self.write(", ")
self.write(");")
unique_functions = set()
for fs in h:
for func_name, (func_return_type, func_params) in fs.items():
key = (func_name,
tuple(map(lambda t: tuple(t) if isinstance(t, list) else t, func_return_type) if isinstance(
func_return_type, list) else func_return_type),
tuple(map(lambda x: (
tuple(x.pseudo_type) if isinstance(x.pseudo_type, list) else x.pseudo_type, x.name),
func_params)))
if key not in unique_functions:
unique_functions.add(key)
self.newline(1)
self.visit_decl(func_return_type)
self.write(f" {func_name}(")
for i, pa in enumerate(func_params):
self.visit_decl(pa.pseudo_type)
self.write(f" {pa.name}{', ' if i != len(func_params)-1 else ''}")
self.write(");")
for arg in node:
self.newline(node)
if (iscompo and arg.name in self.getRealInputs()) or not iscompo:
Expand Down Expand Up @@ -1266,11 +1291,13 @@ def header_mu_cpp(models, rep, name):
file_func = mf.filename
path_func = Path(os.path.join(m.path, "crop2ml", file_func))
func_tree=parser(Path(path_func))
newtree = AstTransformer(func_tree,path_func)
newtree = AstTransformer(func_tree, path_func)
#print(newtree)
dictAst = newtree.transformer()
nodeAst= transform_to_syntax_tree(dictAst)
z ={nodeAst.body[0].name: [nodeAst.body[0].return_type,nodeAst.body[0].params]}
dict_ast = newtree.transformer()
node_ast= transform_to_syntax_tree(dict_ast)
z = {}
for f in filter(lambda x: x.type == "function_definition", node_ast.body):
z[f.name] = [f.return_type, f.params]
h.append(z)
if m.initialization:
init = True
Expand Down

0 comments on commit 578bdff

Please sign in to comment.