Skip to content

Commit

Permalink
New function that partions trees.
Browse files Browse the repository at this point in the history
  • Loading branch information
MImmesberger committed Feb 1, 2025
1 parent acfa726 commit d441709
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 67 deletions.
13 changes: 6 additions & 7 deletions src/_gettsim/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@
)
from _gettsim.shared import (
KeyErrorMessage,
_filter_tree_by_name_list,
create_dict_from_list,
format_errors_and_warnings,
format_list_linewise,
get_by_path,
get_names_of_arguments_without_defaults,
get_path_from_qualified_name,
merge_nested_dicts,
partition_tree_by_reference_tree,
tree_to_dict_with_qualified_name,
tree_update,
)
Expand Down Expand Up @@ -96,15 +96,14 @@ def compute_taxes_and_transfers( # noqa: PLR0913
# Process data and load dictionaries with functions.
data = _process_and_check_data(data=data)

names_of_cols_in_data = list(tree_to_dict_with_qualified_name(data).keys())
all_functions = add_derived_functions_to_functions_tree(
environment=environment,
targets=targets,
data=data,
)
functions_not_overridden, functions_overridden = _filter_tree_by_name_list(
tree=all_functions,
qualified_names_list=names_of_cols_in_data,
functions_not_overridden, functions_overridden = partition_tree_by_reference_tree(
tree_to_split=all_functions,
other_tree=data,
)
data = _convert_data_to_correct_types(data, functions_overridden)

Expand Down Expand Up @@ -135,7 +134,7 @@ def compute_taxes_and_transfers( # noqa: PLR0913

# Round and partial parameters into functions that are nodes in the DAG.
processed_functions = _round_and_partial_parameters_to_functions(
_filter_tree_by_name_list(functions_not_overridden, nodes)[1],
partition_tree_by_reference_tree(functions_not_overridden, nodes)[1],
environment.params,
rounding,
)
Expand Down Expand Up @@ -557,7 +556,7 @@ def _create_input_data( # noqa: PLR0913
)

# Check that only necessary data is passed
unnecessary_data, input_data = _filter_tree_by_name_list(
unnecessary_data, input_data = partition_tree_by_reference_tree(
tree=data,
qualified_names_list=root_nodes,
)
Expand Down
73 changes: 38 additions & 35 deletions src/_gettsim/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any, TypeVar

import numpy
import optree
from dags.signature import rename_arguments
from optree import tree_flatten_with_path

Expand Down Expand Up @@ -98,52 +99,54 @@ def merge_nested_dicts(base_dict: dict, update_dict: dict) -> dict:
return result


def _filter_tree_by_name_list(
tree: NestedFunctionDict | NestedDataDict,
qualified_names_list: list[str],
def partition_tree_by_reference_tree(
target_tree: NestedFunctionDict | NestedDataDict,
reference_tree: NestedDataDict,
) -> tuple[NestedFunctionDict, NestedFunctionDict]:
"""Filter a tree by name.
Splits the functions tree in two parts: functions whose qualified name is in the
qualified_names_list and functions whose qualified name is not in
qualified_names_list.
"""
Partition a tree into two separate trees based on the presence of its leaves in a
reference tree.
Parameters
----------
tree : NestedFunctionDict | NestedDataDict
Dictionary containing functions to build the DAG.
qualified_names_list : list[str]
List of qualified names.
target_tree : NestedFunctionDict | NestedDataDict
The tree to be partitioned.
reference_tree : NestedDataDict
The reference tree used to determine the partitioning.
Returns
-------
not_in_names_list : NestedFunctionDict
All functions except the ones that are overridden by an input column.
in_names_list : NestedFunctionDict
Functions that are overridden by an input column.
tuple[NestedFunctionDict, NestedFunctionDict]
A tuple containing:
- The first tree with leaves present in the reference tree.
- The second tree with leaves absent in the reference tree.
"""
not_in_names_list = {}
in_names_list = {}

paths, leafs, _ = tree_flatten_with_path(tree)

for name, leaf in zip(paths, leafs):
qualified_name = "__".join(name)
if qualified_name in qualified_names_list:
in_names_list = tree_update(
in_names_list,
name,
leaf,
# Obtain accessors and tree specifications for the target and reference trees
tree_accessors = optree.tree_accessors(target_tree)

# New trees
tree_with_present_leaves = {}
tree_with_absent_leaves = {}

# Iterate over each accessor and its corresponding tree specification accessor
for current_accessor in tree_accessors:
try:
# Attempt to access the leaf in the reference tree
tree_with_present_leaves = tree_update(
tree_with_present_leaves,
current_accessor.path,
current_accessor(reference_tree),
)
else:
not_in_names_list = tree_update(
not_in_names_list,
name,
leaf,
except KeyError:
# If the leaf is not present in the reference tree, access it from the
# target tree
tree_with_absent_leaves = tree_update(
tree_with_absent_leaves,
current_accessor.path,
current_accessor(target_tree),
)

return not_in_names_list, in_names_list
return tree_with_absent_leaves, tree_with_present_leaves


def format_errors_and_warnings(text, width=79):
Expand Down
6 changes: 3 additions & 3 deletions src/_gettsim/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
add_derived_functions_to_functions_tree,
)
from _gettsim.shared import (
_filter_tree_by_name_list,
format_list_linewise,
get_names_of_arguments_without_defaults,
partition_tree_by_reference_tree,
tree_to_dict_with_qualified_name,
)

Expand Down Expand Up @@ -92,7 +92,7 @@ def plot_dag(
targets=targets,
data=names_of_columns_overriding_functions,
)
functions_not_overridden = _filter_tree_by_name_list(
functions_not_overridden = partition_tree_by_reference_tree(
tree=all_functions,
qualified_names_list=names_of_columns_overriding_functions,
)[0]
Expand All @@ -113,7 +113,7 @@ def plot_dag(
)

processed_functions = _round_and_partial_parameters_to_functions(
_filter_tree_by_name_list(functions_not_overridden, dag.nodes)[1],
partition_tree_by_reference_tree(functions_not_overridden, dag.nodes)[1],
environment.params,
rounding=False,
)
Expand Down
45 changes: 23 additions & 22 deletions src/_gettsim_tests/test_shared.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import pytest

from _gettsim.shared import (
_filter_tree_by_name_list,
create_dict_from_list,
merge_nested_dicts,
tree_flatten_with_qualified_name,
partition_tree_by_reference_tree,
tree_to_dict_with_qualified_name,
tree_update,
)
Expand Down Expand Up @@ -63,43 +62,45 @@ def test_tree_flatten_with_qualified_name(tree, expected):


@pytest.mark.parametrize(
"tree, names, expected_names",
"target_tree, reference_tree, expected",
[
(
{
"a": {
"b": lambda: 1,
"c": lambda: 1,
"b": 1,
"c": 1,
},
"b": lambda: 1,
"b": 1,
},
{
"a": {
"b": 1,
},
"b": 1,
},
["a__b", "b"],
(
["a__c"],
["a__b", "b"],
{"a": {"c": 1}},
{"a": {"b": 1}, "b": 1},
),
),
(
{
"a": {
"c": lambda: 1,
"c": 1,
},
},
[],
{},
(
["a__c"],
[],
{"a": {"c": 1}},
{},
),
),
],
)
def test_filter_tree_by_name_list(tree, names, expected_names):
result_not_in_names, result_in_names = _filter_tree_by_name_list(tree, names)
flattened_result_not_in_names = tree_flatten_with_qualified_name(
result_not_in_names
)[0]
flattened_result_in_names = tree_flatten_with_qualified_name(result_in_names)[0]
expected_not_in_names, expected_in_names = expected_names
def test_partition_tree_by_reference_tree(target_tree, reference_tree, expected):
not_in_reference_tree, in_reference_tree = partition_tree_by_reference_tree(
target_tree, reference_tree
)

assert flattened_result_not_in_names == expected_not_in_names
assert flattened_result_in_names == expected_in_names
assert not_in_reference_tree == expected[0]
assert in_reference_tree == expected[1]

0 comments on commit d441709

Please sign in to comment.