Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions src/datachain/query/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,15 +1170,18 @@ def apply_sql_clause(self, query) -> Select:

def _validate_columns(
left_columns: Iterable[ColumnElement], right_columns: Iterable[ColumnElement]
) -> set[str]:
left_names = {c.name for c in left_columns}
right_names = {c.name for c in right_columns}
) -> list[str]:
left_names = [c.name for c in left_columns]
right_names = [c.name for c in right_columns]

if left_names == right_names:
if sorted(left_names) == sorted(right_names):
Copy link
Contributor

@dmpetrov dmpetrov Oct 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It suppose to be much more strict - you cannot union if any mismatch.

If we try to be smart here, it opens a can of worms when someone will be always not happy with the results.

More details:

  1. Number of columns. Must have.
  2. Types of columns (some exception: it's ok to make it nullable or convert int to float).
  3. Names of columns.

In SQL, they require only (1) and (2) while ignoring (3).

We might have issues with (2) - sqlalchemy is not good at types. So, for us it's better to use (3) in addition to (1) and ignore types (2). It should be == without any sorting.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So you are saying that we should not sort in order to not mix columns that are named the same but actually have different types? If it's only a matter of removing sorting then I will add it here, but if it's more complex I would add another PR / issue for this since it's not actually related to this PR (before my change we were comparing sets of column names which is pretty much the same)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of tests failing after removing sorting ... let's do this in a separate issue as it's not trivial it seems ,and as mention above, it's not really related to this PR anyway.

return left_names

missing_right = left_names - right_names
missing_left = right_names - left_names
left_names_set = set(left_names)
right_names_set = set(right_names)

missing_right = left_names_set - right_names_set
missing_left = right_names_set - left_names_set

def _prepare_msg_part(missing_columns: set[str], side: str) -> str:
return f"{', '.join(sorted(missing_columns))} only present in {side}"
Expand Down
52 changes: 52 additions & 0 deletions tests/unit/lib/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -4452,3 +4452,55 @@ class Signal2(DataModel):
assert chain.max("signals.signal.i3") == 15
assert chain.max("signals.signal.f3") == 7.5
assert chain.max("signals.signal.s3") == "eee"


def test_union_does_not_break_schema_order(test_session):
class Meta(BaseModel):
name: str
count: int

def add_file(key) -> File:
return File(path="")

def add_meta(file) -> Meta:
return Meta(name="meta", count=10)
Comment on lines +4560 to +4569
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Test covers the main schema order issue but does not check for edge cases like mismatched schemas or extra/missing columns.

Add tests for mismatched schemas and extra or missing columns to ensure the union operation handles these edge cases correctly.

Suggested implementation:

def test_union_does_not_break_schema_order(test_session):
    class Meta(BaseModel):
        name: str
        count: int

    def add_file(key) -> File:
        return File(path="")

    def add_meta(file) -> Meta:
        return Meta(name="meta", count=10)

def test_union_with_mismatched_schemas(test_session):
    class MetaA(BaseModel):
        name: str
        count: int

    class MetaB(BaseModel):
        name: str
        value: float

    meta_a = MetaA(name="metaA", count=5)
    meta_b = MetaB(name="metaB", value=3.14)

    # Simulate union operation
    try:
        result = [meta_a, meta_b]  # Replace with actual union logic if available
        # Check that mismatched schemas are handled (e.g., raise error or skip)
        assert not (hasattr(result[0], "value") and hasattr(result[1], "count"))
    except Exception as e:
        assert "schema" in str(e).lower()

def test_union_with_extra_columns(test_session):
    class MetaBase(BaseModel):
        name: str

    class MetaExtra(BaseModel):
        name: str
        extra: int

    meta_base = MetaBase(name="base")
    meta_extra = MetaExtra(name="extra", extra=42)

    # Simulate union operation
    result = [meta_base, meta_extra]  # Replace with actual union logic if available
    # Check that extra columns do not break the union
    assert hasattr(result[1], "extra")
    assert not hasattr(result[0], "extra")

def test_union_with_missing_columns(test_session):
    class MetaFull(BaseModel):
        name: str
        count: int

    class MetaMissing(BaseModel):
        name: str

    meta_full = MetaFull(name="full", count=10)
    meta_missing = MetaMissing(name="missing")

    # Simulate union operation
    result = [meta_full, meta_missing]  # Replace with actual union logic if available
    # Check that missing columns are handled gracefully
    assert hasattr(result[0], "count")
    assert not hasattr(result[1], "count")

If your codebase has a specific union operation or function, replace the list concatenation [meta_a, meta_b] etc. with the actual union logic to ensure the tests are meaningful. You may also want to check for specific exceptions or error messages if your union implementation raises them for schema mismatches.


keys = ["a", "b", "c", "d"]
values = [3, 3, 3, 3]

(
dc.read_values(key=keys, val=values, session=test_session)
.map(file=add_file)
.map(meta=add_meta)
.save("ds1")
)
(
dc.read_values(key=keys, val=values, session=test_session)
.map(file=add_file)
.map(meta=add_meta)
.save("ds2")
)

(
dc.read_dataset("ds1", session=test_session)
.union(dc.read_dataset("ds2", session=test_session))
.save("union")
)

dat = test_session.catalog.get_dataset("union")
assert list(dat.versions[0].schema.keys()) == [
"key",
"val",
"file__source",
"file__path",
"file__size",
"file__version",
"file__etag",
"file__is_latest",
"file__last_modified",
"file__location",
"meta__name",
"meta__count",
"sys__id",
"sys__rand",
]