Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 14 additions & 17 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1238,30 +1238,27 @@ def _describe(cols: set[str], side: str) -> str:
return cls(f"Cannot perform union. {'. '.join(parts)}")


def _validate_columns(
left_columns: Iterable[ColumnElement], right_columns: Iterable[ColumnElement]
) -> set[str]:
left_names = {c.name for c in left_columns}
right_names = {c.name for c in right_columns}

if left_names == right_names:
return left_names

raise UnionSchemaMismatchError.from_column_sets(
left_names - right_names,
right_names - left_names,
)


def _order_columns(
left_columns: Iterable[ColumnElement], right_columns: Iterable[ColumnElement]
) -> list[list[ColumnElement]]:
column_order = _validate_columns(left_columns, right_columns)
left_names = [c.name for c in left_columns]
right_names = [c.name for c in right_columns]

# validate
if sorted(left_names) != sorted(right_names):
left_names_set = set(left_names)
right_names_set = set(right_names)
raise UnionSchemaMismatchError.from_column_sets(
left_names_set - right_names_set,
right_names_set - left_names_set,
)

# Order columns to match left_names order
column_dicts = [
{c.name: c for c in columns} for columns in [left_columns, right_columns]
]

return [[d[n] for n in column_order] for d in column_dicts]
return [[d[n] for n in left_names] for d in column_dicts]


def _drop_system_columns(columns: Iterable[ColumnElement]) -> list[ColumnElement]:
Expand Down
50 changes: 50 additions & 0 deletions tests/unit/lib/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -4555,3 +4555,53 @@ class Signal2(DataModel):
assert chain.max("signals.signal.i3") == 15
assert chain.max("signals.signal.f3") == 7.5
assert chain.max("signals.signal.s3") == "eee"


def test_union_does_not_break_schema_order(test_session):
class Meta(BaseModel):
name: str
count: int

def add_file(key) -> File:
return File(path="")

def add_meta(file) -> Meta:
return Meta(name="meta", count=10)
Comment on lines +4560 to +4569
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Test covers the main schema order issue but does not check for edge cases like mismatched schemas or extra/missing columns.

Add tests for mismatched schemas and extra or missing columns to ensure the union operation handles these edge cases correctly.

Suggested implementation:

def test_union_does_not_break_schema_order(test_session):
    class Meta(BaseModel):
        name: str
        count: int

    def add_file(key) -> File:
        return File(path="")

    def add_meta(file) -> Meta:
        return Meta(name="meta", count=10)

def test_union_with_mismatched_schemas(test_session):
    class MetaA(BaseModel):
        name: str
        count: int

    class MetaB(BaseModel):
        name: str
        value: float

    meta_a = MetaA(name="metaA", count=5)
    meta_b = MetaB(name="metaB", value=3.14)

    # Simulate union operation
    try:
        result = [meta_a, meta_b]  # Replace with actual union logic if available
        # Check that mismatched schemas are handled (e.g., raise error or skip)
        assert not (hasattr(result[0], "value") and hasattr(result[1], "count"))
    except Exception as e:
        assert "schema" in str(e).lower()

def test_union_with_extra_columns(test_session):
    class MetaBase(BaseModel):
        name: str

    class MetaExtra(BaseModel):
        name: str
        extra: int

    meta_base = MetaBase(name="base")
    meta_extra = MetaExtra(name="extra", extra=42)

    # Simulate union operation
    result = [meta_base, meta_extra]  # Replace with actual union logic if available
    # Check that extra columns do not break the union
    assert hasattr(result[1], "extra")
    assert not hasattr(result[0], "extra")

def test_union_with_missing_columns(test_session):
    class MetaFull(BaseModel):
        name: str
        count: int

    class MetaMissing(BaseModel):
        name: str

    meta_full = MetaFull(name="full", count=10)
    meta_missing = MetaMissing(name="missing")

    # Simulate union operation
    result = [meta_full, meta_missing]  # Replace with actual union logic if available
    # Check that missing columns are handled gracefully
    assert hasattr(result[0], "count")
    assert not hasattr(result[1], "count")

If your codebase has a specific union operation or function, replace the list concatenation [meta_a, meta_b] etc. with the actual union logic to ensure the tests are meaningful. You may also want to check for specific exceptions or error messages if your union implementation raises them for schema mismatches.


keys = ["a", "b", "c", "d"]
values = [3, 3, 3, 3]

(
dc.read_values(key=keys, val=values, session=test_session)
.map(file=add_file)
.map(meta=add_meta)
.save("ds1")
)
(
dc.read_values(key=keys, val=values, session=test_session)
.map(file=add_file)
.map(meta=add_meta)
.save("ds2")
)

(
dc.read_dataset("ds1", session=test_session)
.union(dc.read_dataset("ds2", session=test_session))
.save("union")
)

dat = test_session.catalog.get_dataset("union")
assert list(dat.versions[0].schema.keys()) == [
"key",
"val",
"file__source",
"file__path",
"file__size",
"file__version",
"file__etag",
"file__is_latest",
"file__last_modified",
"file__location",
"meta__name",
"meta__count",
]