Skip to content

Commit

Permalink
Review comment about unpacking tuple returns of functions,
Browse files Browse the repository at this point in the history
  • Loading branch information
MImmesberger committed Jan 20, 2025
1 parent cae23a3 commit ff4e1b2
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 14 deletions.
11 changes: 6 additions & 5 deletions src/_gettsim/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,12 @@ def compute_taxes_and_transfers( # noqa: PLR0913
input_structure=input_structure,
check_minimal_specification=check_minimal_specification,
).nodes
# Select functions that are nodes of the DAG.
_, necessary_functions = _filter_tree_by_name_list(functions_not_overridden, nodes)
# Round and partial parameters into functions.

# Round and partial parameters into functions that are nodes in the DAG.
processed_functions = _round_and_partial_parameters_to_functions(
necessary_functions, environment.params, rounding
_filter_tree_by_name_list(functions_not_overridden, nodes)[1],
environment.params,
rounding,
)

# Input structure for final DAG.
Expand Down Expand Up @@ -189,7 +190,7 @@ def build_targets_tree(targets: NestedTargetDict | list[str] | str) -> NestedTar
if isinstance(targets, str):
targets = [targets]

flattened_targets, _ = tree_flatten(targets)
flattened_targets = tree_flatten(targets)[0]
all_leafs_none = all(el is None for el in flattened_targets)
all_leafs_str_or_list = all(isinstance(el, str | list) for el in flattened_targets)

Expand Down
9 changes: 3 additions & 6 deletions src/_gettsim/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ def plot_dag(
targets=targets,
names_of_columns_in_data=names_of_columns_overriding_functions,
)
functions_not_overridden, _ = _filter_tree_by_name_list(
functions_not_overridden = _filter_tree_by_name_list(
tree=all_functions,
qualified_names_list=names_of_columns_overriding_functions,
)
)[0]

# Create parameter input structure.
input_structure = dags.dag_tree.create_input_structure_tree(
Expand All @@ -112,11 +112,8 @@ def plot_dag(
check_minimal_specification=check_minimal_specification,
)

_, necessary_functions = _filter_tree_by_name_list(
functions_not_overridden, dag.nodes
)
processed_functions = _round_and_partial_parameters_to_functions(
necessary_functions,
_filter_tree_by_name_list(functions_not_overridden, dag.nodes)[1],
environment.params,
rounding=False,
)
Expand Down
2 changes: 1 addition & 1 deletion src/_gettsim_tests/test_policy_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def test_convert_function_to_correct_type(tree, expected_module_name):
_convert_function_to_correct_type,
tree,
)
func_list, _ = tree_flatten(funcs_with_correct_type)
func_list = tree_flatten(funcs_with_correct_type)[0]
for func in func_list:
assert func.module_name == expected_module_name
assert isinstance(func, PolicyFunction)
4 changes: 2 additions & 2 deletions src/_gettsim_tests/test_rounding.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,9 @@ def test_decorator_for_all_functions_with_rounding_spec():
# addressed.
time_dependent_functions = {}
for year in range(1990, 2023):
year_functions, _ = tree_flatten(
year_functions = tree_flatten(
load_functions_tree_for_date(datetime.date(year=year, month=1, day=1))
)
)[0]
function_name_to_name_in_dag_dict = {
func.function.__name__: func.name_in_dag for func in year_functions
}
Expand Down

0 comments on commit ff4e1b2

Please sign in to comment.