diff --git a/Makefile b/Makefile index f48a785a9..84170bb67 100644 --- a/Makefile +++ b/Makefile @@ -28,15 +28,16 @@ env-%: env-tests: poetry run pip install \ - pytest \ + jsondiff \ nbconvert \ nbformat \ - pytest-subtests \ - pytest-azurepipelines \ - ruff \ + opentelemetry-sdk \ pre-commit \ + pytest \ + pytest-azurepipelines \ pytest-cov \ - jsondiff + pytest-subtests \ + ruff \ env-tests-required: poetry install --only required \ diff --git a/examples/experimental/otel_exporter.ipynb b/examples/experimental/otel_exporter.ipynb index 94b2b0a0a..a3a145b3e 100644 --- a/examples/experimental/otel_exporter.ipynb +++ b/examples/experimental/otel_exporter.ipynb @@ -32,29 +32,18 @@ "metadata": {}, "outputs": [], "source": [ - "from typing import Callable\n", + "import logging\n", "\n", - "from opentelemetry import trace\n", - "from trulens.apps.custom import instrument\n", - "from trulens.experimental.otel_tracing.core.init import TRULENS_SERVICE_NAME\n", - "\n", - "\n", - "def decorator(func: Callable):\n", - " tracer = trace.get_tracer(TRULENS_SERVICE_NAME)\n", - "\n", - " def wrapper(*args, **kwargs):\n", - " print(\"start wrap\")\n", - "\n", - " with tracer.start_as_current_span(\"custom\"):\n", - " result = func(*args, **kwargs)\n", - " span = trace.get_current_span()\n", - " print(\"---span---\")\n", - " print(span.get_span_context())\n", - " span.set_attribute(\"result\", result)\n", - " span.set_status(trace.Status(trace.StatusCode.OK))\n", - " return result\n", - "\n", - " return wrapper" + "root = logging.getLogger()\n", + "root.setLevel(logging.DEBUG)\n", + "handler = logging.StreamHandler(sys.stdout)\n", + "handler.setLevel(logging.DEBUG)\n", + "handler.addFilter(logging.Filter(\"trulens\"))\n", + "formatter = logging.Formatter(\n", + " \"%(asctime)s - %(name)s - %(levelname)s - %(message)s\"\n", + ")\n", + "handler.setFormatter(formatter)\n", + "root.addHandler(handler)" ] }, { @@ -63,22 +52,46 @@ "metadata": {}, "outputs": [], "source": [ - "from examples.dev.dummy_app.dummy import Dummy\n", - "\n", + "from trulens.experimental.otel_tracing.core.instrument import instrument\n", "\n", - "class TestApp(Dummy):\n", - " def __init__(self):\n", - " super().__init__()\n", "\n", - " @decorator\n", - " @instrument\n", + "class TestApp:\n", + " @instrument()\n", " def respond_to_query(self, query: str) -> str:\n", " return f\"answer: {self.nested(query)}\"\n", "\n", - " @decorator\n", - " @instrument\n", + " @instrument(attributes={\"nested_attr1\": \"value1\"})\n", " def nested(self, query: str) -> str:\n", - " return f\"nested: {query}\"" + " return f\"nested: {self.nested2(query)}\"\n", + "\n", + " @instrument(\n", + " attributes=lambda ret, exception, *args, **kwargs: {\n", + " \"nested2_ret\": ret,\n", + " \"nested2_args[0]\": args[0],\n", + " }\n", + " )\n", + " def nested2(self, query: str) -> str:\n", + " nested_result = \"\"\n", + "\n", + " try:\n", + " nested_result = self.nested3(query)\n", + " except Exception:\n", + " pass\n", + "\n", + " return f\"nested2: {nested_result}\"\n", + "\n", + " @instrument(\n", + " attributes=lambda ret, exception, *args, **kwargs: {\n", + " \"nested3_ex\": exception.args if exception else None,\n", + " \"nested3_ret\": ret,\n", + " \"selector_name\": \"special\",\n", + " \"cows\": \"moo\",\n", + " }\n", + " )\n", + " def nested3(self, query: str) -> str:\n", + " if query == \"throw\":\n", + " raise ValueError(\"nested3 exception\")\n", + " return \"nested3\"" ] }, { @@ -87,11 +100,16 @@ "metadata": {}, "outputs": [], "source": [ + "import dotenv\n", "from trulens.core.session import TruSession\n", "from trulens.experimental.otel_tracing.core.init import init\n", "\n", + "dotenv.load_dotenv()\n", + "\n", "session = TruSession()\n", - "init(session)" + "session.experimental_enable_feature(\"otel_tracing\")\n", + "session.reset_database()\n", + "init(session, debug=True)" ] }, { @@ -100,9 +118,16 @@ "metadata": {}, "outputs": [], "source": [ + "from trulens.apps.custom import TruCustomApp\n", + "\n", "test_app = TestApp()\n", + "custom_app = TruCustomApp(test_app)\n", + "\n", + "with custom_app as recording:\n", + " test_app.respond_to_query(\"test\")\n", "\n", - "test_app.respond_to_query(\"test\")" + "with custom_app as recording:\n", + " test_app.respond_to_query(\"throw\")" ] } ], diff --git a/src/core/trulens/core/database/sqlalchemy.py b/src/core/trulens/core/database/sqlalchemy.py index 4f7811f7b..d0923a731 100644 --- a/src/core/trulens/core/database/sqlalchemy.py +++ b/src/core/trulens/core/database/sqlalchemy.py @@ -1044,6 +1044,7 @@ def get_datasets(self) -> pd.DataFrame: def insert_event(self, event: Event) -> types_schema.EventID: """See [DB.insert_event][trulens.core.database.base.DB.insert_event].""" + with self.session.begin() as session: _event = self.orm.Event.parse(event, redact_keys=self.redact_keys) session.add(_event) diff --git a/src/core/trulens/experimental/otel_tracing/core/instrument.py b/src/core/trulens/experimental/otel_tracing/core/instrument.py new file mode 100644 index 000000000..b1a2682ef --- /dev/null +++ b/src/core/trulens/experimental/otel_tracing/core/instrument.py @@ -0,0 +1,139 @@ +from functools import wraps +import logging +from typing import Any, Callable, Dict, Optional, Union + +from opentelemetry import trace +from trulens.experimental.otel_tracing.core.init import TRULENS_SERVICE_NAME +from trulens.otel.semconv.trace import SpanAttributes + +logger = logging.getLogger(__name__) + + +def instrument( + *, + attributes: Optional[ + Union[ + Dict[str, Any], + Callable[ + [Optional[Any], Optional[Exception], Any, Any], Dict[str, Any] + ], + ] + ] = {}, +): + """ + Decorator for marking functions to be instrumented in custom classes that are + wrapped by TruCustomApp, with OpenTelemetry tracing. + """ + + def _validate_selector_name(attributes: Dict[str, Any]) -> Dict[str, Any]: + result = attributes.copy() + + if ( + SpanAttributes.SELECTOR_NAME_KEY in result + and SpanAttributes.SELECTOR_NAME in result + ): + raise ValueError( + f"Both {SpanAttributes.SELECTOR_NAME_KEY} and {SpanAttributes.SELECTOR_NAME} cannot be set." + ) + + if SpanAttributes.SELECTOR_NAME in result: + # Transfer the trulens namespaced to the non-trulens namespaced key. + result[SpanAttributes.SELECTOR_NAME_KEY] = result[ + SpanAttributes.SELECTOR_NAME + ] + del result[SpanAttributes.SELECTOR_NAME] + + if SpanAttributes.SELECTOR_NAME_KEY in result: + selector_name = result[SpanAttributes.SELECTOR_NAME_KEY] + if not isinstance(selector_name, str): + raise ValueError( + f"Selector name must be a string, not {type(selector_name)}" + ) + + return result + + def _validate_attributes(attributes: Dict[str, Any]) -> Dict[str, Any]: + if not isinstance(attributes, dict) or any([ + not isinstance(key, str) for key in attributes.keys() + ]): + raise ValueError( + "Attributes must be a dictionary with string keys." + ) + return _validate_selector_name(attributes) + # TODO: validate OTEL attributes. + # TODO: validate span type attributes. + + def inner_decorator(func: Callable): + @wraps(func) + def wrapper(*args, **kwargs): + with ( + trace.get_tracer_provider() + .get_tracer(TRULENS_SERVICE_NAME) + .start_as_current_span( + name=func.__name__, + ) + ) as span: + ret = None + func_exception: Optional[Exception] = None + attributes_exception: Optional[Exception] = None + + try: + ret = func(*args, **kwargs) + except Exception as e: + # We want to get into the next clause to allow the users to still add attributes. + # It's on the user to deal with None as a return value. + func_exception = e + + try: + attributes_to_add = {} + + # Set the user provider attributes. + if attributes: + if callable(attributes): + attributes_to_add = attributes( + ret, func_exception, *args, **kwargs + ) + else: + attributes_to_add = attributes + + logger.info(f"Attributes to add: {attributes_to_add}") + + final_attributes = _validate_attributes(attributes_to_add) + + prefix = "trulens." + if ( + SpanAttributes.SPAN_TYPE in final_attributes + and final_attributes[SpanAttributes.SPAN_TYPE] + != SpanAttributes.SpanType.UNKNOWN + ): + prefix += ( + final_attributes[SpanAttributes.SPAN_TYPE] + "." + ) + + for key, value in final_attributes.items(): + span.set_attribute(prefix + key, value) + + if ( + key != SpanAttributes.SELECTOR_NAME_KEY + and SpanAttributes.SELECTOR_NAME_KEY + in final_attributes + ): + span.set_attribute( + f"trulens.{final_attributes[SpanAttributes.SELECTOR_NAME_KEY]}.{key}", + value, + ) + + except Exception as e: + attributes_exception = e + logger.error(f"Error setting attributes: {e}") + + if func_exception: + raise func_exception + if attributes_exception: + raise attributes_exception + + return ret + + return wrapper + + return inner_decorator diff --git a/src/otel/semconv/trulens/otel/semconv/trace.py b/src/otel/semconv/trulens/otel/semconv/trace.py index 5b55f0bca..9050f81ae 100644 --- a/src/otel/semconv/trulens/otel/semconv/trace.py +++ b/src/otel/semconv/trulens/otel/semconv/trace.py @@ -34,7 +34,27 @@ class SpanAttributes: In some cases below, we also include span name or span name prefix. """ - SPAN_TYPES = "trulens.span_types" + BASE = "trulens." + """ + Base prefix for the other keys. + """ + + SPAN_TYPE = BASE + "span_type" + """ + Span type attribute. + """ + + SELECTOR_NAME_KEY = "selector_name" + """ + Key for the user-defined selector name for the current span. + Here to help us check both trulens.selector_name and selector_name + to verify the user attributes and make corrections if necessary. + """ + + SELECTOR_NAME = BASE + SELECTOR_NAME_KEY + """ + User-defined selector name for the current span. + """ class SpanType(str, Enum): """Span type attribute values. diff --git a/tests/test.py b/tests/test.py index 729afa099..19081cbd2 100644 --- a/tests/test.py +++ b/tests/test.py @@ -27,6 +27,7 @@ import unittest from unittest import TestCase +import pandas as pd import pydantic from pydantic import BaseModel from trulens.core._utils.pycompat import ReferenceType @@ -225,14 +226,16 @@ class WithJSONTestCase(TestCase): """TestCase mixin class that adds JSON comparisons and golden expectation handling.""" - def load_golden(self, golden_path: Union[str, Path]) -> serial_utils.JSON: + def load_golden( + self, + golden_path: Union[str, Path], + ) -> Union[serial_utils.JSON, pd.DataFrame]: """Load the golden file `path` and return its contents. Args: golden_path: The name of the golden file to load. The file must have an extension of either `.json` or `.yaml`. The extension determines the input format. - """ golden_path = Path(golden_path) @@ -240,6 +243,8 @@ def load_golden(self, golden_path: Union[str, Path]) -> serial_utils.JSON: loader = functools.partial(json.load) elif ".yaml" in golden_path.suffixes or ".yml" in golden_path.suffixes: loader = functools.partial(yaml.load, Loader=yaml.FullLoader) + elif ".csv" in golden_path.suffixes: + loader = functools.partial(pd.read_csv, index_col=0) else: raise ValueError(f"Unknown file extension {golden_path}.") @@ -250,7 +255,9 @@ def load_golden(self, golden_path: Union[str, Path]) -> serial_utils.JSON: return loader(f) def write_golden( - self, golden_path: Union[str, Path], data: serial_utils.JSON + self, + golden_path: Union[str, Path], + data: Union[serial_utils.JSON, pd.DataFrame], ) -> None: """If writing golden file is enabled, write the golden file `path` with `data` and raise exception indicating so. @@ -272,6 +279,10 @@ def write_golden( writer = functools.partial(json.dump, indent=2, sort_keys=True) elif golden_path.suffix == ".yaml": writer = functools.partial(yaml.dump, sort_keys=True) + elif golden_path.suffix == ".csv": + writer = lambda data, f: data.to_csv(f) + elif golden_path.suffix == ".parquet": + writer = lambda data, f: data.to_parquet(f) else: raise ValueError(f"Unknown file extension {golden_path.suffix}.") diff --git a/tests/unit/static/golden/test_otel_instrument__test_instrument_decorator.csv b/tests/unit/static/golden/test_otel_instrument__test_instrument_decorator.csv new file mode 100644 index 000000000..40837d509 --- /dev/null +++ b/tests/unit/static/golden/test_otel_instrument__test_instrument_decorator.csv @@ -0,0 +1,9 @@ +,record,event_id,record_attributes,record_type,resource_attributes,start_timestamp,timestamp,trace +0,"{'name': 'respond_to_query', 'kind': 'SPAN_KIND_TRULENS', 'parent_span_id': '', 'status': 'STATUS_CODE_UNSET'}",7870250274962447839,{},EventRecordType.SPAN,"{'telemetry.sdk.language': 'python', 'telemetry.sdk.name': 'opentelemetry', 'telemetry.sdk.version': '1.28.2', 'service.name': 'trulens'}",2024-12-22 11:20:26.387607,2024-12-22 11:20:26.392174,"{'trace_id': '113376089399064103615948241236196474059', 'parent_id': '', 'span_id': '7870250274962447839'}" +1,"{'name': 'nested', 'kind': 'SPAN_KIND_TRULENS', 'parent_span_id': '7870250274962447839', 'status': 'STATUS_CODE_UNSET'}",8819384298151247754,{'trulens.nested_attr1': 'value1'},EventRecordType.SPAN,"{'telemetry.sdk.language': 'python', 'telemetry.sdk.name': 'opentelemetry', 'telemetry.sdk.version': '1.28.2', 'service.name': 'trulens'}",2024-12-22 11:20:26.387652,2024-12-22 11:20:26.391272,"{'trace_id': '113376089399064103615948241236196474059', 'parent_id': '7870250274962447839', 'span_id': '8819384298151247754'}" +2,"{'name': 'nested2', 'kind': 'SPAN_KIND_TRULENS', 'parent_span_id': '8819384298151247754', 'status': 'STATUS_CODE_UNSET'}",2622992513876904334,"{'trulens.nested2_ret': 'nested2: nested3', 'trulens.nested2_args[1]': 'test'}",EventRecordType.SPAN,"{'telemetry.sdk.language': 'python', 'telemetry.sdk.name': 'opentelemetry', 'telemetry.sdk.version': '1.28.2', 'service.name': 'trulens'}",2024-12-22 11:20:26.387679,2024-12-22 11:20:26.389939,"{'trace_id': '113376089399064103615948241236196474059', 'parent_id': '8819384298151247754', 'span_id': '2622992513876904334'}" +3,"{'name': 'nested3', 'kind': 'SPAN_KIND_TRULENS', 'parent_span_id': '2622992513876904334', 'status': 'STATUS_CODE_UNSET'}",11864485227397090485,"{'trulens.nested3_ret': 'nested3', 'trulens.special.nested3_ret': 'nested3', 'trulens.selector_name': 'special', 'trulens.cows': 'moo', 'trulens.special.cows': 'moo'}",EventRecordType.SPAN,"{'telemetry.sdk.language': 'python', 'telemetry.sdk.name': 'opentelemetry', 'telemetry.sdk.version': '1.28.2', 'service.name': 'trulens'}",2024-12-22 11:20:26.387705,2024-12-22 11:20:26.387762,"{'trace_id': '113376089399064103615948241236196474059', 'parent_id': '2622992513876904334', 'span_id': '11864485227397090485'}" +4,"{'name': 'respond_to_query', 'kind': 'SPAN_KIND_TRULENS', 'parent_span_id': '', 'status': 'STATUS_CODE_UNSET'}",10786111609955477438,{},EventRecordType.SPAN,"{'telemetry.sdk.language': 'python', 'telemetry.sdk.name': 'opentelemetry', 'telemetry.sdk.version': '1.28.2', 'service.name': 'trulens'}",2024-12-22 11:20:26.393563,2024-12-22 11:20:26.397446,"{'trace_id': '214293944471171141309178747794638512671', 'parent_id': '', 'span_id': '10786111609955477438'}" +5,"{'name': 'nested', 'kind': 'SPAN_KIND_TRULENS', 'parent_span_id': '10786111609955477438', 'status': 'STATUS_CODE_UNSET'}",7881765616183808794,{'trulens.nested_attr1': 'value1'},EventRecordType.SPAN,"{'telemetry.sdk.language': 'python', 'telemetry.sdk.name': 'opentelemetry', 'telemetry.sdk.version': '1.28.2', 'service.name': 'trulens'}",2024-12-22 11:20:26.393586,2024-12-22 11:20:26.396613,"{'trace_id': '214293944471171141309178747794638512671', 'parent_id': '10786111609955477438', 'span_id': '7881765616183808794'}" +6,"{'name': 'nested2', 'kind': 'SPAN_KIND_TRULENS', 'parent_span_id': '7881765616183808794', 'status': 'STATUS_CODE_UNSET'}",4318803655649897130,"{'trulens.nested2_ret': 'nested2: ', 'trulens.nested2_args[1]': 'throw'}",EventRecordType.SPAN,"{'telemetry.sdk.language': 'python', 'telemetry.sdk.name': 'opentelemetry', 'telemetry.sdk.version': '1.28.2', 'service.name': 'trulens'}",2024-12-22 11:20:26.393603,2024-12-22 11:20:26.395227,"{'trace_id': '214293944471171141309178747794638512671', 'parent_id': '7881765616183808794', 'span_id': '4318803655649897130'}" +7,"{'name': 'nested3', 'kind': 'SPAN_KIND_TRULENS', 'parent_span_id': '4318803655649897130', 'status': 'STATUS_CODE_ERROR'}",11457830288984624191,"{'trulens.nested3_ex': ['nested3 exception'], 'trulens.special.nested3_ex': ['nested3 exception'], 'trulens.selector_name': 'special', 'trulens.cows': 'moo', 'trulens.special.cows': 'moo'}",EventRecordType.SPAN,"{'telemetry.sdk.language': 'python', 'telemetry.sdk.name': 'opentelemetry', 'telemetry.sdk.version': '1.28.2', 'service.name': 'trulens'}",2024-12-22 11:20:26.393630,2024-12-22 11:20:26.394348,"{'trace_id': '214293944471171141309178747794638512671', 'parent_id': '4318803655649897130', 'span_id': '11457830288984624191'}" diff --git a/tests/unit/test_otel_instrument.py b/tests/unit/test_otel_instrument.py new file mode 100644 index 000000000..cc6b37b21 --- /dev/null +++ b/tests/unit/test_otel_instrument.py @@ -0,0 +1,143 @@ +""" +Tests for OTEL instrument decorator. +""" + +from unittest import main + +import pandas as pd +import sqlalchemy as sa +from trulens.apps.custom import TruCustomApp +from trulens.core.schema.event import EventRecordType +from trulens.core.session import TruSession +from trulens.experimental.otel_tracing.core.init import init +from trulens.experimental.otel_tracing.core.instrument import instrument + +from tests.test import TruTestCase +from tests.util.df_comparison import ( + compare_dfs_accounting_for_ids_and_timestamps, +) + + +class _TestApp: + @instrument() + def respond_to_query(self, query: str) -> str: + return f"answer: {self.nested(query)}" + + @instrument(attributes={"nested_attr1": "value1"}) + def nested(self, query: str) -> str: + return f"nested: {self.nested2(query)}" + + @instrument( + attributes=lambda ret, exception, *args, **kwargs: { + "nested2_ret": ret, + "nested2_args[1]": args[1], + } + ) + def nested2(self, query: str) -> str: + nested_result = "" + + try: + nested_result = self.nested3(query) + except Exception: + pass + + return f"nested2: {nested_result}" + + @instrument( + attributes=lambda ret, exception, *args, **kwargs: { + "nested3_ex": exception.args if exception else None, + "nested3_ret": ret, + "selector_name": "special", + "cows": "moo", + } + ) + def nested3(self, query: str) -> str: + if query == "throw": + raise ValueError("nested3 exception") + return "nested3" + + +class TestOtelInstrument(TruTestCase): + @classmethod + def clear_TruSession_singleton(cls) -> None: + # [HACK!] Clean up any instances of `TruSession` so tests don't + # interfere with each other. + for key in [ + curr + for curr in TruSession._singleton_instances + if curr[0] == "trulens.core.session.TruSession" + ]: + del TruSession._singleton_instances[key] + + @classmethod + def setUpClass(cls) -> None: + cls.clear_TruSession_singleton() + tru_session = TruSession() + tru_session.experimental_enable_feature("otel_tracing") + return super().setUpClass() + + @classmethod + def tearDownClass(cls) -> None: + cls.clear_TruSession_singleton() + return super().tearDownClass() + + @staticmethod + def _get_events() -> pd.DataFrame: + tru_session = TruSession() + db = tru_session.connector.db + with db.session.begin() as db_session: + q = sa.select(db.orm.Event).order_by(db.orm.Event.start_timestamp) + return pd.read_sql(q, db_session.bind) + + @staticmethod + def _convert_column_types(df: pd.DataFrame) -> None: + # Writing to CSV and the reading back causes some type issues so we + # hackily convert things here. + df["event_id"] = df["event_id"].apply(str) + df["record_type"] = df["record_type"].apply( + lambda x: EventRecordType(x[len("EventRecordType.") :]) + if x.startswith("EventRecordType.") + else EventRecordType(x) + ) + df["start_timestamp"] = df["start_timestamp"].apply(pd.Timestamp) + df["timestamp"] = df["timestamp"].apply(pd.Timestamp) + for json_column in [ + "record", + "record_attributes", + "resource_attributes", + "trace", + ]: + df[json_column] = df[json_column].apply(lambda x: eval(x)) + + def test_instrument_decorator(self) -> None: + # Set up. + tru_session = TruSession() + tru_session.reset_database() + init(tru_session, debug=True) + # Create and run app. + test_app = _TestApp() + custom_app = TruCustomApp(test_app) + with custom_app: + test_app.respond_to_query("test") + with custom_app: + test_app.respond_to_query("throw") + # Compare results to expected. + GOLDEN_FILENAME = "tests/unit/static/golden/test_otel_instrument__test_instrument_decorator.csv" + actual = self._get_events() + self.assertEqual(len(actual), 8) + self.write_golden(GOLDEN_FILENAME, actual) + expected = self.load_golden(GOLDEN_FILENAME) + self._convert_column_types(expected) + compare_dfs_accounting_for_ids_and_timestamps( + self, + expected, + actual, + ignore_locators=[ + f"df.iloc[{i}][resource_attributes][telemetry.sdk.version]" + for i in range(8) + ], + ) + + +if __name__ == "__main__": + main() diff --git a/tests/util/df_comparison.py b/tests/util/df_comparison.py new file mode 100644 index 000000000..321edf586 --- /dev/null +++ b/tests/util/df_comparison.py @@ -0,0 +1,114 @@ +from typing import Any, Dict, Optional, Sequence +from unittest import TestCase + +import pandas as pd + + +def compare_dfs_accounting_for_ids_and_timestamps( + test_case: TestCase, + expected: pd.DataFrame, + actual: pd.DataFrame, + ignore_locators: Optional[Sequence[str]], +) -> None: + """ + Compare two Dataframes are equal, accounting for ids and timestamps. That + is: + 1. The ids between the two Dataframes may be different, but they have to be + consistent. That is, if one Dataframe reuses an id in two places, then + the other must as well. + 2. The timestamps between the two Dataframes may be different, but they + have to be in the same order. + + Args: + test_case: unittest.TestCase instance to use for assertions + expected: expected results + actual: actual results + ignore_locators: locators to ignore when comparing the Dataframes + """ + id_mapping: Dict[str, str] = {} + timestamp_mapping: Dict[pd.Timestamp, pd.Timestamp] = {} + test_case.assertEqual(len(expected), len(actual)) + test_case.assertListEqual(list(expected.columns), list(actual.columns)) + for i in range(len(expected)): + for col in expected.columns: + _compare_entity( + test_case, + expected.iloc[i][col], + actual.iloc[i][col], + id_mapping, + timestamp_mapping, + is_id=col.endswith("_id"), + locator=f"df.iloc[{i}][{col}]", + ignore_locators=ignore_locators, + ) + # Ensure that the id mapping is a bijection. + test_case.assertEqual( + len(set(id_mapping.values())), + len(id_mapping), + "Ids are not a bijection!", + ) + # Ensure that the timestamp mapping is strictly increasing. + prev_value = None + for curr in sorted(timestamp_mapping.keys()): + if prev_value is not None: + test_case.assertLess( + prev_value, + timestamp_mapping[curr], + "Timestamps are not in the same order!", + ) + prev_value = timestamp_mapping[curr] + + +def _compare_entity( + test_case: TestCase, + expected: Any, + actual: Any, + id_mapping: Dict[str, str], + timestamp_mapping: Dict[pd.Timestamp, pd.Timestamp], + is_id: bool, + locator: str, + ignore_locators: Optional[Sequence[str]], +) -> None: + if ignore_locators and locator in ignore_locators: + return + test_case.assertEqual( + type(expected), type(actual), f"Types of {locator} do not match!" + ) + if is_id: + test_case.assertEqual( + type(expected), str, f"Type of id {locator} is not a string!" + ) + if expected not in id_mapping: + id_mapping[expected] = actual + test_case.assertEqual( + id_mapping[expected], + actual, + f"Ids of {locator} are not consistent!", + ) + elif isinstance(expected, dict): + test_case.assertEqual( + expected.keys(), + actual.keys(), + f"Keys of {locator} do not match!", + ) + for k in expected.keys(): + _compare_entity( + test_case, + expected[k], + actual[k], + id_mapping, + timestamp_mapping, + is_id=k.endswith("_id"), + locator=f"{locator}[{k}]", + ignore_locators=ignore_locators, + ) + elif isinstance(expected, pd.Timestamp): + if expected not in timestamp_mapping: + timestamp_mapping[expected] = actual + test_case.assertEqual( + timestamp_mapping[expected], + actual, + f"Timestamps of {locator} are not consistent!", + ) + else: + test_case.assertEqual(expected, actual, f"{locator} does not match!")