Skip to content

Commit

Permalink
Create Aggregation class and simplify code. Add test for tree of deri…
Browse files Browse the repository at this point in the history
…ved functions.
  • Loading branch information
MImmesberger committed Jan 29, 2025
1 parent 1b09dd4 commit 8eee022
Show file tree
Hide file tree
Showing 4 changed files with 252 additions and 94 deletions.
27 changes: 27 additions & 0 deletions src/_gettsim/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,33 @@
from _gettsim.config import USE_JAX


class AggregationSpec:
"""
A container for aggregation specifications.
Parameters
----------
aggregation_specs:
A dictionary with aggregation specifications.
target_name:
The name of the target column.
"""

def __init__(self, aggregation_specs: dict[str, str | dict], target_name: str):
self._aggregation_specs = aggregation_specs
self._target_name = target_name

@property
def aggregation_specs(self) -> dict[str, str | dict]:
"""The aggregation specifications."""
return self._aggregation_specs

@property
def target_name(self) -> str:
"""The name of the target column."""
return self._target_name


def grouped_count(group_id):
if USE_JAX:
return grouped_count_jax(group_id)
Expand Down
15 changes: 11 additions & 4 deletions src/_gettsim/functions/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,15 @@
from types import ModuleType
from typing import Any, Literal, TypeAlias

from _gettsim.aggregation import AggregationSpec
from _gettsim.config import (
PATHS_TO_INTERNAL_FUNCTIONS,
QUALIFIED_NAME_SEPARATOR,
RESOURCE_DIR,
)
from _gettsim.functions.policy_function import PolicyFunction
from _gettsim.gettsim_typing import NestedFunctionDict
from _gettsim.shared import (
tree_update,
)
from _gettsim.shared import tree_update


def load_functions_tree_for_date(date: datetime.date) -> NestedFunctionDict:
Expand Down Expand Up @@ -389,7 +388,15 @@ def _load_functions_to_derive(
]

_fail_if_more_than_one_dict_loaded(dicts_in_module, module_name)
return dicts_in_module[0] if dicts_in_module else {}

return (
{
name: AggregationSpec(aggregation_specs=spec, target_name=name)
for name, spec in dicts_in_module[0].items()
}
if dicts_in_module
else {}
)


def _fail_if_more_than_one_dict_loaded(dicts: list[dict], module_name: str) -> None:
Expand Down
131 changes: 41 additions & 90 deletions src/_gettsim/policy_environment_postprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

import numpy
import optree
from optree import tree_flatten_with_path

from _gettsim.aggregation import (
AggregationSpec,
all_by_p_id,
any_by_p_id,
count_by_p_id,
Expand All @@ -25,7 +25,6 @@
sum_by_p_id,
)
from _gettsim.config import (
QUALIFIED_NAME_SEPARATOR,
SUPPORTED_GROUPINGS,
TYPES_INPUT_VARIABLES,
)
Expand All @@ -35,7 +34,6 @@
from _gettsim.shared import (
format_list_linewise,
get_names_of_arguments_without_defaults,
get_path_from_qualified_name,
merge_nested_dicts,
remove_group_suffix,
rename_arguments_and_add_annotations,
Expand Down Expand Up @@ -166,11 +164,9 @@ def _create_aggregate_by_group_functions(
functions_tree: NestedFunctionDict,
targets: NestedTargetDict,
data: NestedDataDict,
aggregate_by_group_specs: dict[str, Any],
aggregation_dicts_provided_by_env: dict[str, Any],
) -> dict[str, DerivedFunction]:
"""Create aggregation functions."""

aggregation_dicts_provided_by_env = _get_aggregation_dicts(aggregate_by_group_specs)
automatically_created_aggregation_dicts = (
_create_derived_aggregation_specifications(
functions_tree=functions_tree,
Expand All @@ -182,31 +178,32 @@ def _create_aggregate_by_group_functions(
# Add automated aggregation specs.
# Note: For duplicate keys, explicitly set specs are treated with higher priority
# than automated specs.
full_aggregate_by_group_spec = merge_nested_dicts(
all_aggregate_by_group_specs = merge_nested_dicts(
automatically_created_aggregation_dicts,
aggregation_dicts_provided_by_env,
)

derived_functions = {}
for module_name, agg_dicts_of_module in full_aggregate_by_group_spec.items():
for func_name, agg_spec in agg_dicts_of_module.items():
_check_agg_specs_validity(
agg_specs=agg_spec, agg_col=func_name, module=module_name
)
derived_func = _create_one_aggregate_by_group_func(
new_function_name=func_name,
agg_specs=agg_spec,
functions_tree=functions_tree,
)
module_path = get_path_from_qualified_name(module_name)
function_path = [*module_path, func_name]
# TODO(@MImmesberger): Let derived functions inherit namespace from source
# function or source column.
qualified_name = QUALIFIED_NAME_SEPARATOR.join(function_path)
derived_func.set_qualified_name(qualified_name)
derived_functions = tree_update(
derived_functions, function_path, derived_func
)
_all_paths, _all_aggregation_specs, _ = optree.tree_flatten_with_path(
all_aggregate_by_group_specs
)
for path, aggregation_spec in zip(_all_paths, _all_aggregation_specs):
module_name = ".".join(path)
_check_agg_specs_validity(
agg_specs=aggregation_spec,
agg_col=aggregation_spec.target_name,
module=module_name,
)
derived_func = _create_one_aggregate_by_group_func(
new_function_name=aggregation_spec.target_name,
agg_specs=aggregation_spec,
functions_tree=functions_tree,
)
derived_functions = tree_update(
derived_functions,
path,
derived_func,
)

return derived_functions

Expand Down Expand Up @@ -262,10 +259,13 @@ def _create_derived_aggregation_specifications(
# targets that already exist in the source tree.
continue
else:
agg_specs_single_function = {
"aggr": "sum",
"source_col": remove_group_suffix(leaf_name),
}
agg_specs_single_function = AggregationSpec(
{
"aggr": "sum",
"source_col": remove_group_suffix(leaf_name),
},
target_name="leaf_name",
)

all_agg_specs = tree_update(
tree=all_agg_specs,
Expand Down Expand Up @@ -474,74 +474,25 @@ def aggregate_by_group_func(source_col, group_id):

def _create_aggregate_by_p_id_functions(
functions_tree: NestedFunctionDict,
aggregate_by_p_id_specs: dict[str, Any],
aggregation_dicts_provided_by_env: dict[str, Any],
) -> NestedFunctionDict:
"""Create function dict with functions that link variables across persons."""

aggregation_dicts = _get_aggregation_dicts(aggregate_by_p_id_specs)

derived_functions = {}

for module_name, module_aggregation_dicts in aggregation_dicts.items():
for func_name, aggregation_dict in module_aggregation_dicts.items():
derived_func = _create_one_aggregate_by_p_id_func(
new_function_name=func_name,
agg_specs=aggregation_dict,
functions_tree=functions_tree,
)
module_path = get_path_from_qualified_name(module_name)
function_path = [*module_path, func_name]
# TODO(@MImmesberger): Let derived functions inherit namespace from source
# function or source column.
qualified_name = QUALIFIED_NAME_SEPARATOR.join(function_path)
derived_func.set_qualified_name(qualified_name)
derived_functions = tree_update(
derived_functions, function_path, derived_func
)
_all_paths, _all_aggregation_specs, _ = optree.tree_flatten_with_path(
aggregation_dicts_provided_by_env
)
for path, aggregation_spec in zip(_all_paths, _all_aggregation_specs):
derived_func = _create_one_aggregate_by_p_id_func(
new_function_name=aggregation_spec.target_name,
agg_specs=aggregation_spec,
functions_tree=functions_tree,
)
derived_functions = tree_update(derived_functions, path, derived_func)

return derived_functions


def _get_aggregation_dicts(aggregate_by_p_id_specs: dict[str, Any]) -> dict[str, Any]:
"""Get aggregation dictionaries from the specs.
Reduces the tree to a dict with qualified module names as keys and the aggregation
dict as values.
Example:
{"module1": {"module2": {"func": {"source_col": "col", "p_id_to_aggregate_by":
"groupings__xx_id"}}},
Result: {"module1__module2": {"func": {
"source_col": "module1__module2__col", "p_id_to_aggregate_by": "groupings__xx_id"}}}
"""

out = {}
paths, leafs, _ = tree_flatten_with_path(aggregate_by_p_id_specs)
for path, leaf in zip(paths, leafs):
# Qualified name of module
module_name = "__".join(path[:-2])

# Simple name of aggregation target
aggregation_func_name = path[-2]

# Key word of the aggregation dict (e.g. "source_col", "aggr",
# "p_id_to_aggregate_by", ...)
aggregation_spec_key = path[-1]

# Transform source cols and IDs to qualified names if not already done
if aggregation_spec_key == "p_id_to_aggregate_by":
value = f"groupings__{leaf}" if "__" not in leaf else leaf
elif aggregation_spec_key == "source_col":
value = f"{module_name}__{leaf}" if "__" not in leaf else leaf
else:
value = leaf

keys = [module_name, aggregation_func_name, aggregation_spec_key]
out = tree_update(out, keys, value)

return out


def _create_one_aggregate_by_p_id_func(
new_function_name: str,
agg_specs: dict[str, str],
Expand Down
Loading

0 comments on commit 8eee022

Please sign in to comment.