Skip to content

Commit

Permalink
Implement first set of review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
MImmesberger committed Dec 13, 2024
1 parent 549be6d commit cae23a3
Showing 1 changed file with 20 additions and 12 deletions.
32 changes: 20 additions & 12 deletions src/_gettsim/functions/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,13 @@
from typing import Any, Literal, TypeAlias

from _gettsim.config import PATHS_TO_INTERNAL_FUNCTIONS, RESOURCE_DIR
from _gettsim.functions.policy_function import PolicyFunction
from _gettsim.gettsim_typing import NestedFunctionDict
from _gettsim.shared import (
get_path_from_qualified_name,
tree_update,
)

from .policy_function import PolicyFunction


def load_functions_tree_for_date(date: datetime.date) -> NestedFunctionDict:
"""
Expand Down Expand Up @@ -219,6 +218,8 @@ def _create_policy_function_from_decorated_callable(
"""

# Only needed until the directory structure is cleaned up
# TODO(@MImmesberger): Remove the removeprefix calls once the directory
# structure is cleaned up
clean_module_name = (
module_name.removeprefix("_gettsim__")
.removeprefix("taxes__")
Expand Down Expand Up @@ -269,20 +270,25 @@ def _load_aggregation_dict(
.removeprefix("transfers__")
)
tree_keys = get_path_from_qualified_name(clean_module_name)
dicts_in_module = _load_dicts_in_module(path, package_root, f"{variant}_")
_fail_if_more_than_one_dict_loaded(dicts_in_module)
tree = tree_update(tree, tree_keys, *dicts_in_module)
derived_function_specs = load_functions_to_derive(
path, package_root, f"{variant}_"
)
tree = tree_update(tree, tree_keys, *derived_function_specs)

return tree


def _load_dicts_in_module(
def load_functions_to_derive(
path: Path,
package_root: Path,
prefix_filter: str,
) -> list[dict]:
"""
Load dictionaries defined in a module.
Load the dictionary that specifies which functions to derive from the module.
Returns one aggregation dictionary where keys are the names of the functions to
derive and values are dictionaries with the aggregation specifications ('aggr',
'source_col', 'p_id_to_aggregate_by').
Parameters
----------
Expand All @@ -297,17 +303,19 @@ def _load_dicts_in_module(
Loaded dictionaries.
"""
module = _load_module(path, package_root)

return [
module_name = _convert_path_to_module_name(path, package_root)
dicts_in_module = [
member
for name, member in inspect.getmembers(module)
if isinstance(member, dict) and name.startswith(prefix_filter)
]

_fail_if_more_than_one_dict_loaded(dicts_in_module, module_name)
return dicts_in_module


def _fail_if_more_than_one_dict_loaded(dicts: list[dict]) -> None:
def _fail_if_more_than_one_dict_loaded(dicts: list[dict], module_name: str) -> None:
if len(dicts) > 1:
raise ValueError(
"More than one dictionary found in the module. "
"Only one dictionary is allowed."
"More than one dictionary found in the module:\n\n" f"{module_name}\n\n"
)

0 comments on commit cae23a3

Please sign in to comment.