Skip to content
229 changes: 229 additions & 0 deletions core/src/main/java/io/substrait/expression/MaskExpression.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
package io.substrait.expression;

import io.substrait.util.VisitationContext;
import java.util.List;
import java.util.Optional;
import org.immutables.value.Value;

/**
* A mask expression that selectively removes fields from complex types (struct, list, map).
*
* <p>This corresponds to the {@code Expression.MaskExpression} message in the Substrait protobuf
* specification. It is used in {@code ReadRel} to describe column projection — the subset of a
* relation's schema that should actually be read.
*
* @see <a href="https://substrait.io/expressions/field_references/">Substrait Field References</a>
*/
@Value.Enclosing
public interface MaskExpression {

// ---------------------------------------------------------------------------
// Top-level MaskExpression value
// ---------------------------------------------------------------------------

/** The concrete mask expression value holding the top-level struct selection and options. */
@Value.Immutable
interface MaskExpr {
/** The top-level struct selection describing which fields to include. */
StructSelect getSelect();

/**
* When {@code true}, a struct that has only a single selected field will <em>not</em> be
* unwrapped into its child type.
*/
@Value.Default
default boolean getMaintainSingularStruct() {
return false;
}

static ImmutableMaskExpression.MaskExpr.Builder builder() {
return ImmutableMaskExpression.MaskExpr.builder();
}
}
Comment on lines +41 to +46
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could also merged this into the top-level MaskExpression interface which would be similar to what we are doing in Expression.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I merged MaskExpr into MaskExpression interface.

And sync this changes to what I've changed in this PR.


// ---------------------------------------------------------------------------
// Select – a union of StructSelect | ListSelect | MapSelect
// ---------------------------------------------------------------------------

/** A selection on a complex type – one of StructSelect, ListSelect, or MapSelect. */
interface Select {
<R, C extends VisitationContext, E extends Throwable> R accept(
MaskExpressionVisitor<R, C, E> visitor, C context) throws E;
}
Comment on lines +53 to +56
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be mirroring the generated proto Java classes. What if we would turn this into an empty base interface without the ofStruct, ofList, ofMap methods. Then StructSelect, ListSelect and MapSelect can extend that interface and you can use a visitor pattern to process the different types. This would fit more with the existing code like e.g. Expression with its ExpressionVisitor.

Suggested change
interface Select {
Optional<StructSelect> getStruct();
Optional<ListSelect> getList();
Optional<MapSelect> getMap();
static ImmutableMaskExpression.Select.Builder builder() {
return ImmutableMaskExpression.Select.builder();
}
static Select ofStruct(StructSelect structSelect) {
return builder().struct(structSelect).build();
}
static Select ofList(ListSelect listSelect) {
return builder().list(listSelect).build();
}
static Select ofMap(MapSelect mapSelect) {
return builder().map(mapSelect).build();
}
}
interface Select {
/**
* Accepts a visitor to traverse this mask expression.
*
* @param <R> the return type of the visitor
* @param <C> the context type
* @param <E> the exception type that may be thrown
* @param visitor the visitor to accept
* @param context the visitation context
* @return the result of the visitation
* @throws E if an error occurs during visitation
*/
<R, C extends VisitationContext, E extends Throwable> R accept(
MaskExpressionVisitor<R, C, E> visitor, C context) throws E;
}
interface StructSelect extends Select {
...
    @Override
    public <R, C extends VisitationContext, E extends Throwable> R accept(
        MaskExpressionVisitor<R, C, E> visitor, C context) throws E {
      return visitor.visit(this, context);
    }
}

interface ListSelect extends Select {
...
    @Override
    public <R, C extends VisitationContext, E extends Throwable> R accept(
        MaskExpressionVisitor<R, C, E> visitor, C context) throws E {
      return visitor.visit(this, context);
    }
}

interface MapSelect extends Select {
...
    @Override
    public <R, C extends VisitationContext, E extends Throwable> R accept(
        MaskExpressionVisitor<R, C, E> visitor, C context) throws E {
      return visitor.visit(this, context);
    }
}

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion!

I refactored Select into a base interface with accept() method for visiting. And StructSelect, ListSelect, MapSelect now extend it directly.

Also, Added MaskExpressionVisitor interface following the same project code pattern as you suggested.


// ---------------------------------------------------------------------------
// Struct selection
// ---------------------------------------------------------------------------

/** Selects a subset of fields from a struct type. */
@Value.Immutable
interface StructSelect extends Select {
List<StructItem> getStructItems();

static ImmutableMaskExpression.StructSelect.Builder builder() {
return ImmutableMaskExpression.StructSelect.builder();
}

@Override
default <R, C extends VisitationContext, E extends Throwable> R accept(
MaskExpressionVisitor<R, C, E> visitor, C context) throws E {
return visitor.visit(this, context);
}
}

/** Selects a single field from a struct, with an optional nested child selection. */
@Value.Immutable
interface StructItem {
/** Zero-based field index within the struct. */
int getField();

/** Optional child selection for nested complex types. */
Optional<Select> getChild();

static ImmutableMaskExpression.StructItem.Builder builder() {
return ImmutableMaskExpression.StructItem.builder();
}

static StructItem of(int field) {
return builder().field(field).build();
}

static StructItem of(int field, Select child) {
return builder().field(field).child(child).build();
}
}

// ---------------------------------------------------------------------------
// List selection
// ---------------------------------------------------------------------------

/** Selects elements from a list type by index or slice. */
@Value.Immutable
interface ListSelect extends Select {
List<ListSelectItem> getSelection();

/** Optional child selection applied to each selected element. */
Optional<Select> getChild();

static ImmutableMaskExpression.ListSelect.Builder builder() {
return ImmutableMaskExpression.ListSelect.builder();
}

@Override
default <R, C extends VisitationContext, E extends Throwable> R accept(
MaskExpressionVisitor<R, C, E> visitor, C context) throws E {
return visitor.visit(this, context);
}
}

/** A single selection within a list – either an element or a slice. */
@Value.Immutable
interface ListSelectItem {
Optional<ListElement> getItem();

Optional<ListSlice> getSlice();

static ImmutableMaskExpression.ListSelectItem.Builder builder() {
return ImmutableMaskExpression.ListSelectItem.builder();
}

static ListSelectItem ofItem(ListElement element) {
return builder().item(element).build();
}

static ListSelectItem ofSlice(ListSlice slice) {
return builder().slice(slice).build();
}
}

/** Selects a single element from a list by zero-based index. */
@Value.Immutable
interface ListElement {
int getField();

static ImmutableMaskExpression.ListElement.Builder builder() {
return ImmutableMaskExpression.ListElement.builder();
}

static ListElement of(int field) {
return builder().field(field).build();
}
}

/** Selects a contiguous range of elements from a list. */
@Value.Immutable
interface ListSlice {
int getStart();

int getEnd();

static ImmutableMaskExpression.ListSlice.Builder builder() {
return ImmutableMaskExpression.ListSlice.builder();
}

static ListSlice of(int start, int end) {
return builder().start(start).end(end).build();
}
}

// ---------------------------------------------------------------------------
// Map selection
// ---------------------------------------------------------------------------

/** Selects entries from a map type by exact key or key expression. */
@Value.Immutable
interface MapSelect extends Select {
Optional<MapKey> getKey();

Optional<MapKeyExpression> getExpression();

/** Optional child selection applied to each selected map value. */
Optional<Select> getChild();

static ImmutableMaskExpression.MapSelect.Builder builder() {
return ImmutableMaskExpression.MapSelect.builder();
}

static MapSelect ofKey(MapKey key) {
return builder().key(key).build();
}

static MapSelect ofExpression(MapKeyExpression expression) {
return builder().expression(expression).build();
}

@Override
default <R, C extends VisitationContext, E extends Throwable> R accept(
MaskExpressionVisitor<R, C, E> visitor, C context) throws E {
return visitor.visit(this, context);
}
}

/** Selects a map entry by an exact key match. */
@Value.Immutable
interface MapKey {
String getMapKey();

static ImmutableMaskExpression.MapKey.Builder builder() {
return ImmutableMaskExpression.MapKey.builder();
}

static MapKey of(String mapKey) {
return builder().mapKey(mapKey).build();
}
}

/** Selects map entries by a wildcard key expression. */
@Value.Immutable
interface MapKeyExpression {
String getMapKeyExpression();

static ImmutableMaskExpression.MapKeyExpression.Builder builder() {
return ImmutableMaskExpression.MapKeyExpression.builder();
}

static MapKeyExpression of(String mapKeyExpression) {
return builder().mapKeyExpression(mapKeyExpression).build();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package io.substrait.expression;

import io.substrait.type.Type;
import io.substrait.type.TypeCreator;
import java.util.List;

/**
* Applies a {@link MaskExpression} projection to a {@link Type.Struct}, returning a pruned struct.
*/
public final class MaskExpressionTypeProjector {

private MaskExpressionTypeProjector() {}

public static Type.Struct project(MaskExpression.MaskExpr projection, Type.Struct baseStruct) {
return projectStruct(projection.getSelect(), baseStruct);
}

private static Type.Struct projectStruct(
MaskExpression.StructSelect structSelect, Type.Struct baseStruct) {
List<Type> fields = baseStruct.fields();
List<MaskExpression.StructItem> items = structSelect.getStructItems();

return TypeCreator.of(baseStruct.nullable())
.struct(items.stream().map(item -> projectItem(item, fields.get(item.getField()))));
}

private static Type projectItem(MaskExpression.StructItem item, Type fieldType) {
if (!item.getChild().isPresent()) {
return fieldType;
}

MaskExpression.Select select = item.getChild().get();

if (select instanceof MaskExpression.StructSelect) {
return projectStruct((MaskExpression.StructSelect) select, (Type.Struct) fieldType);
}

if (select instanceof MaskExpression.ListSelect) {
return projectList((MaskExpression.ListSelect) select, (Type.ListType) fieldType);
}

if (select instanceof MaskExpression.MapSelect) {
return projectMap((MaskExpression.MapSelect) select, (Type.Map) fieldType);
}

return fieldType;
}

private static Type.ListType projectList(
MaskExpression.ListSelect listSelect, Type.ListType listType) {
if (!listSelect.getChild().isPresent()) {
return listType;
}

MaskExpression.Select childSelect = listSelect.getChild().get();
Type elementType = listType.elementType();

if (childSelect instanceof MaskExpression.StructSelect && elementType instanceof Type.Struct) {
Type.Struct prunedElement =
projectStruct((MaskExpression.StructSelect) childSelect, (Type.Struct) elementType);
return TypeCreator.of(listType.nullable()).list(prunedElement);
}

return listType;
}

private static Type.Map projectMap(MaskExpression.MapSelect mapSelect, Type.Map mapType) {
if (!mapSelect.getChild().isPresent()) {
return mapType;
}

MaskExpression.Select childSelect = mapSelect.getChild().get();
Type valueType = mapType.value();

if (childSelect instanceof MaskExpression.StructSelect && valueType instanceof Type.Struct) {
Type.Struct prunedValue =
projectStruct((MaskExpression.StructSelect) childSelect, (Type.Struct) valueType);
return TypeCreator.of(mapType.nullable()).map(mapType.key(), prunedValue);
}

return mapType;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package io.substrait.expression;

import io.substrait.util.VisitationContext;

/**
* Visitor for {@link MaskExpression} select nodes.
*
* @param <R> result type returned by each visit
* @param <C> visitation context type
* @param <E> throwable type that visit methods may throw
*/
public interface MaskExpressionVisitor<R, C extends VisitationContext, E extends Throwable> {

R visit(MaskExpression.StructSelect structSelect, C context) throws E;

R visit(MaskExpression.ListSelect listSelect, C context) throws E;

R visit(MaskExpression.MapSelect mapSelect, C context) throws E;
}
Loading