diff --git a/src/datachain/query/dataset.py b/src/datachain/query/dataset.py index 8991e73f7..f942b1695 100644 --- a/src/datachain/query/dataset.py +++ b/src/datachain/query/dataset.py @@ -1238,30 +1238,27 @@ def _describe(cols: set[str], side: str) -> str: return cls(f"Cannot perform union. {'. '.join(parts)}") -def _validate_columns( - left_columns: Iterable[ColumnElement], right_columns: Iterable[ColumnElement] -) -> set[str]: - left_names = {c.name for c in left_columns} - right_names = {c.name for c in right_columns} - - if left_names == right_names: - return left_names - - raise UnionSchemaMismatchError.from_column_sets( - left_names - right_names, - right_names - left_names, - ) - - def _order_columns( left_columns: Iterable[ColumnElement], right_columns: Iterable[ColumnElement] ) -> list[list[ColumnElement]]: - column_order = _validate_columns(left_columns, right_columns) + left_names = [c.name for c in left_columns] + right_names = [c.name for c in right_columns] + + # validate + if sorted(left_names) != sorted(right_names): + left_names_set = set(left_names) + right_names_set = set(right_names) + raise UnionSchemaMismatchError.from_column_sets( + left_names_set - right_names_set, + right_names_set - left_names_set, + ) + + # Order columns to match left_names order column_dicts = [ {c.name: c for c in columns} for columns in [left_columns, right_columns] ] - return [[d[n] for n in column_order] for d in column_dicts] + return [[d[n] for n in left_names] for d in column_dicts] def _drop_system_columns(columns: Iterable[ColumnElement]) -> list[ColumnElement]: diff --git a/tests/unit/lib/test_datachain.py b/tests/unit/lib/test_datachain.py index 2998a2898..1527a57ec 100644 --- a/tests/unit/lib/test_datachain.py +++ b/tests/unit/lib/test_datachain.py @@ -4555,3 +4555,53 @@ class Signal2(DataModel): assert chain.max("signals.signal.i3") == 15 assert chain.max("signals.signal.f3") == 7.5 assert chain.max("signals.signal.s3") == "eee" + + +def test_union_does_not_break_schema_order(test_session): + class Meta(BaseModel): + name: str + count: int + + def add_file(key) -> File: + return File(path="") + + def add_meta(file) -> Meta: + return Meta(name="meta", count=10) + + keys = ["a", "b", "c", "d"] + values = [3, 3, 3, 3] + + ( + dc.read_values(key=keys, val=values, session=test_session) + .map(file=add_file) + .map(meta=add_meta) + .save("ds1") + ) + ( + dc.read_values(key=keys, val=values, session=test_session) + .map(file=add_file) + .map(meta=add_meta) + .save("ds2") + ) + + ( + dc.read_dataset("ds1", session=test_session) + .union(dc.read_dataset("ds2", session=test_session)) + .save("union") + ) + + dat = test_session.catalog.get_dataset("union") + assert list(dat.versions[0].schema.keys()) == [ + "key", + "val", + "file__source", + "file__path", + "file__size", + "file__version", + "file__etag", + "file__is_latest", + "file__last_modified", + "file__location", + "meta__name", + "meta__count", + ]