Skip to content
230 changes: 230 additions & 0 deletions core/src/main/java/io/substrait/expression/MaskExpression.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
package io.substrait.expression;

import java.util.List;
import java.util.Optional;
import org.immutables.value.Value;

/**
* A mask expression that selectively removes fields from complex types (struct, list, map).
*
* <p>This corresponds to the {@code Expression.MaskExpression} message in the Substrait protobuf
* specification. It is used in {@code ReadRel} to describe column projection — the subset of a
* relation's schema that should actually be read.
*
* @see <a href="https://substrait.io/expressions/field_references/">Substrait Field References</a>
*/
@Value.Enclosing
public interface MaskExpression {

// ---------------------------------------------------------------------------
// Top-level MaskExpression value
// ---------------------------------------------------------------------------

/** The concrete mask expression value holding the top-level struct selection and options. */
@Value.Immutable
interface MaskExpr {
/** The top-level struct selection describing which fields to include. */
StructSelect getSelect();

/**
* When {@code true}, a struct that has only a single selected field will <em>not</em> be
* unwrapped into its child type.
*/
@Value.Default
default boolean getMaintainSingularStruct() {
return false;
}

static ImmutableMaskExpression.MaskExpr.Builder builder() {
return ImmutableMaskExpression.MaskExpr.builder();
}
}
Comment on lines +41 to +46
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could also merged this into the top-level MaskExpression interface which would be similar to what we are doing in Expression.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I merged MaskExpr into MaskExpression interface.

And sync this changes to what I've changed in this PR.


// ---------------------------------------------------------------------------
// Select – a union of StructSelect | ListSelect | MapSelect
// ---------------------------------------------------------------------------

/** A selection on a complex type – exactly one of struct, list, or map must be set. */
@Value.Immutable
interface Select {
Optional<StructSelect> getStruct();

Optional<ListSelect> getList();

Optional<MapSelect> getMap();

static ImmutableMaskExpression.Select.Builder builder() {
return ImmutableMaskExpression.Select.builder();
}

static Select ofStruct(StructSelect structSelect) {
return builder().struct(structSelect).build();
}

static Select ofList(ListSelect listSelect) {
return builder().list(listSelect).build();
}

static Select ofMap(MapSelect mapSelect) {
return builder().map(mapSelect).build();
}
}
Comment on lines +53 to +56
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be mirroring the generated proto Java classes. What if we would turn this into an empty base interface without the ofStruct, ofList, ofMap methods. Then StructSelect, ListSelect and MapSelect can extend that interface and you can use a visitor pattern to process the different types. This would fit more with the existing code like e.g. Expression with its ExpressionVisitor.

Suggested change
interface Select {
Optional<StructSelect> getStruct();
Optional<ListSelect> getList();
Optional<MapSelect> getMap();
static ImmutableMaskExpression.Select.Builder builder() {
return ImmutableMaskExpression.Select.builder();
}
static Select ofStruct(StructSelect structSelect) {
return builder().struct(structSelect).build();
}
static Select ofList(ListSelect listSelect) {
return builder().list(listSelect).build();
}
static Select ofMap(MapSelect mapSelect) {
return builder().map(mapSelect).build();
}
}
interface Select {
/**
* Accepts a visitor to traverse this mask expression.
*
* @param <R> the return type of the visitor
* @param <C> the context type
* @param <E> the exception type that may be thrown
* @param visitor the visitor to accept
* @param context the visitation context
* @return the result of the visitation
* @throws E if an error occurs during visitation
*/
<R, C extends VisitationContext, E extends Throwable> R accept(
MaskExpressionVisitor<R, C, E> visitor, C context) throws E;
}
interface StructSelect extends Select {
...
    @Override
    public <R, C extends VisitationContext, E extends Throwable> R accept(
        MaskExpressionVisitor<R, C, E> visitor, C context) throws E {
      return visitor.visit(this, context);
    }
}

interface ListSelect extends Select {
...
    @Override
    public <R, C extends VisitationContext, E extends Throwable> R accept(
        MaskExpressionVisitor<R, C, E> visitor, C context) throws E {
      return visitor.visit(this, context);
    }
}

interface MapSelect extends Select {
...
    @Override
    public <R, C extends VisitationContext, E extends Throwable> R accept(
        MaskExpressionVisitor<R, C, E> visitor, C context) throws E {
      return visitor.visit(this, context);
    }
}

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion!

I refactored Select into a base interface with accept() method for visiting. And StructSelect, ListSelect, MapSelect now extend it directly.

Also, Added MaskExpressionVisitor interface following the same project code pattern as you suggested.


// ---------------------------------------------------------------------------
// Struct selection
// ---------------------------------------------------------------------------

/** Selects a subset of fields from a struct type. */
@Value.Immutable
interface StructSelect {
List<StructItem> getStructItems();

static ImmutableMaskExpression.StructSelect.Builder builder() {
return ImmutableMaskExpression.StructSelect.builder();
}
}

/** Selects a single field from a struct, with an optional nested child selection. */
@Value.Immutable
interface StructItem {
/** Zero-based field index within the struct. */
int getField();

/** Optional child selection for nested complex types. */
Optional<Select> getChild();

static ImmutableMaskExpression.StructItem.Builder builder() {
return ImmutableMaskExpression.StructItem.builder();
}

static StructItem of(int field) {
return builder().field(field).build();
}

static StructItem of(int field, Select child) {
return builder().field(field).child(child).build();
}
}

// ---------------------------------------------------------------------------
// List selection
// ---------------------------------------------------------------------------

/** Selects elements from a list type by index or slice. */
@Value.Immutable
interface ListSelect {
List<ListSelectItem> getSelection();

/** Optional child selection applied to each selected element. */
Optional<Select> getChild();

static ImmutableMaskExpression.ListSelect.Builder builder() {
return ImmutableMaskExpression.ListSelect.builder();
}
}

/** A single selection within a list – either an element or a slice. */
@Value.Immutable
interface ListSelectItem {
Optional<ListElement> getItem();

Optional<ListSlice> getSlice();

static ImmutableMaskExpression.ListSelectItem.Builder builder() {
return ImmutableMaskExpression.ListSelectItem.builder();
}

static ListSelectItem ofItem(ListElement element) {
return builder().item(element).build();
}

static ListSelectItem ofSlice(ListSlice slice) {
return builder().slice(slice).build();
}
}

/** Selects a single element from a list by zero-based index. */
@Value.Immutable
interface ListElement {
int getField();

static ImmutableMaskExpression.ListElement.Builder builder() {
return ImmutableMaskExpression.ListElement.builder();
}

static ListElement of(int field) {
return builder().field(field).build();
}
}

/** Selects a contiguous range of elements from a list. */
@Value.Immutable
interface ListSlice {
int getStart();

int getEnd();

static ImmutableMaskExpression.ListSlice.Builder builder() {
return ImmutableMaskExpression.ListSlice.builder();
}

static ListSlice of(int start, int end) {
return builder().start(start).end(end).build();
}
}

// ---------------------------------------------------------------------------
// Map selection
// ---------------------------------------------------------------------------

/** Selects entries from a map type by exact key or key expression. */
@Value.Immutable
interface MapSelect {
Optional<MapKey> getKey();

Optional<MapKeyExpression> getExpression();

/** Optional child selection applied to each selected map value. */
Optional<Select> getChild();

static ImmutableMaskExpression.MapSelect.Builder builder() {
return ImmutableMaskExpression.MapSelect.builder();
}

static MapSelect ofKey(MapKey key) {
return builder().key(key).build();
}

static MapSelect ofExpression(MapKeyExpression expression) {
return builder().expression(expression).build();
}
}

/** Selects a map entry by an exact key match. */
@Value.Immutable
interface MapKey {
String getMapKey();

static ImmutableMaskExpression.MapKey.Builder builder() {
return ImmutableMaskExpression.MapKey.builder();
}

static MapKey of(String mapKey) {
return builder().mapKey(mapKey).build();
}
}

/** Selects map entries by a wildcard key expression. */
@Value.Immutable
interface MapKeyExpression {
String getMapKeyExpression();

static ImmutableMaskExpression.MapKeyExpression.Builder builder() {
return ImmutableMaskExpression.MapKeyExpression.builder();
}

static MapKeyExpression of(String mapKeyExpression) {
return builder().mapKeyExpression(mapKeyExpression).build();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package io.substrait.expression;

import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.util.List;

/**
* Applies a {@link MaskExpression} projection to a {@link Type.Struct}, returning a pruned struct.
*/
public final class MaskExpressionTypeProjector {

private MaskExpressionTypeProjector() {}

public static Type.Struct project(MaskExpression.MaskExpr projection, Type.Struct baseStruct) {
return projectStruct(projection.getSelect(), baseStruct);
}

private static Type.Struct projectStruct(
MaskExpression.StructSelect structSelect, Type.Struct baseStruct) {
List<Type> fields = baseStruct.fields();
List<MaskExpression.StructItem> items = structSelect.getStructItems();

return TypeCreator.of(baseStruct.nullable())
.struct(items.stream().map(item -> projectItem(item, fields.get(item.getField()))));
}

private static Type projectItem(MaskExpression.StructItem item, Type fieldType) {
if (!item.getChild().isPresent()) {
return fieldType;
}

MaskExpression.Select select = item.getChild().get();

if (select.getStruct().isPresent()) {
Type.Struct structField = (Type.Struct) fieldType;
return projectStruct(select.getStruct().get(), structField);
}

if (select.getList().isPresent()) {
Type.ListType listField = (Type.ListType) fieldType;
return projectList(select.getList().get(), listField);
}

if (select.getMap().isPresent()) {
Type.Map mapField = (Type.Map) fieldType;
return projectMap(select.getMap().get(), mapField);
}

return fieldType;
}

private static Type.ListType projectList(
MaskExpression.ListSelect listSelect, Type.ListType listType) {
if (!listSelect.getChild().isPresent()) {
return listType;
}

MaskExpression.Select childSelect = listSelect.getChild().get();
Type elementType = listType.elementType();

if (childSelect.getStruct().isPresent() && elementType instanceof Type.Struct) {
Type.Struct prunedElement =
projectStruct(childSelect.getStruct().get(), (Type.Struct) elementType);
return TypeCreator.of(listType.nullable()).list(prunedElement);
}

return listType;
}

private static Type.Map projectMap(MaskExpression.MapSelect mapSelect, Type.Map mapType) {
if (!mapSelect.getChild().isPresent()) {
return mapType;
}

MaskExpression.Select childSelect = mapSelect.getChild().get();
Type valueType = mapType.value();

if (childSelect.getStruct().isPresent() && valueType instanceof Type.Struct) {
Type.Struct prunedValue =
projectStruct(childSelect.getStruct().get(), (Type.Struct) valueType);
return TypeCreator.of(mapType.nullable()).map(mapType.key(), prunedValue);
}

return mapType;
}
}
Loading
Loading