Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -116,42 +116,38 @@ public FieldNestedUpdateAgg(

@Override
public Object agg(Object accumulator, Object inputField) {
if (accumulator == null || inputField == null) {
return accumulator == null ? inputField : accumulator;
}

InternalArray acc = (InternalArray) accumulator;
InternalArray input = (InternalArray) inputField;

if (acc.size() >= countLimit) {
if (inputField == null) {
return accumulator;
}

int remainCount = countLimit - acc.size();
InternalArray input = (InternalArray) inputField;

List<InternalRow> rows = new ArrayList<>(acc.size() + input.size());
addNonNullRows(acc, rows);
addNonNullRows(input, rows, remainCount);
if (keyProjection == null) {
if (accumulator == null) {
List<InternalRow> rows = new ArrayList<>(input.size());
addNonNullRows(input, rows, countLimit);
return new GenericArray(rows.toArray());
}

if (keyProjection != null) {
Map<BinaryRow, InternalRow> map = new HashMap<>();
for (InternalRow row : rows) {
BinaryRow key = keyProjection.apply(row).copy();
if (hasSequenceField) {
// When sequence field is configured, only update if the new sequence is greater
InternalRow existing = map.get(key);
if (existing == null || compareSequence(row, existing) >= 0) {
map.put(key, row);
}
} else {
map.put(key, row);
}
InternalArray acc = (InternalArray) accumulator;
if (acc.size() >= countLimit) {
return accumulator;
}

rows = new ArrayList<>(map.values());
int remainCount = countLimit - acc.size();

List<InternalRow> rows = new ArrayList<>(acc.size() + input.size());
addNonNullRows(acc, rows);
addNonNullRows(input, rows, remainCount);
return new GenericArray(rows.toArray());
}

return new GenericArray(rows.toArray());
Map<BinaryRow, InternalRow> map = new HashMap<>();
if (accumulator != null) {
addNestedRows((InternalArray) accumulator, map, false);
}
addNestedRows(input, map, true);
return new GenericArray(new ArrayList<>(map.values()).toArray());
}

@Override
Expand Down Expand Up @@ -235,4 +231,26 @@ private void addNonNullRows(InternalArray array, List<InternalRow> rows, int rem
count++;
}
}

private void addNestedRows(
InternalArray array, Map<BinaryRow, InternalRow> rows, boolean limitNewKeys) {
checkNotNull(keyProjection);

for (int i = 0; i < array.size(); i++) {
if (array.isNullAt(i)) {
continue;
}

InternalRow row = array.getRow(i, nestedFields);
BinaryRow key = keyProjection.apply(row).copy();
InternalRow existing = rows.get(key);
if (existing != null) {
if (!hasSequenceField || compareSequence(row, existing) >= 0) {
rows.put(key, row);
}
} else if (!limitNewKeys || rows.size() < countLimit) {
rows.put(key, row);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,95 @@ public void testFieldNestedAppendAggWithCountLimit() {
.containsExactlyInAnyOrderElementsOf(Arrays.asList(row(0, 1, "B"), row(0, 1, "b")));
}

@Test
public void testFieldNestedAppendAggWithCountLimitOnFirstInputArray() {
DataType elementRowType =
DataTypes.ROW(
DataTypes.FIELD(0, "k0", DataTypes.INT()),
DataTypes.FIELD(1, "k1", DataTypes.INT()),
DataTypes.FIELD(2, "v", DataTypes.STRING()));
FieldNestedUpdateAgg agg =
new FieldNestedUpdateAgg(
FieldNestedUpdateAggFactory.NAME,
DataTypes.ARRAY(elementRowType),
Collections.emptyList(),
2);

InternalArray.ElementGetter elementGetter =
InternalArray.createElementGetter(elementRowType);
InternalArray accumulator =
(InternalArray)
agg.agg(null, array(row(0, 1, "B"), null, row(0, 1, "b"), row(0, 1, "C")));

assertThat(unnest(accumulator, elementGetter))
.containsExactlyInAnyOrderElementsOf(Arrays.asList(row(0, 1, "B"), row(0, 1, "b")));
}

@Test
public void testFieldNestedUpdateAggWithCountLimitUpdatesExistingKeyAtLimitWithoutSequence() {
DataType elementRowType =
DataTypes.ROW(
DataTypes.FIELD(0, "k0", DataTypes.INT()),
DataTypes.FIELD(1, "k1", DataTypes.INT()),
DataTypes.FIELD(2, "v", DataTypes.STRING()));

FieldNestedUpdateAgg agg =
new FieldNestedUpdateAgg(
FieldNestedUpdateAggFactory.NAME,
DataTypes.ARRAY(elementRowType),
Arrays.asList("k0", "k1"),
2);

InternalArray accumulator = null;
InternalArray.ElementGetter elementGetter =
InternalArray.createElementGetter(elementRowType);

accumulator = (InternalArray) agg.agg(accumulator, singletonArray(row(0, 1, "B")));
accumulator = (InternalArray) agg.agg(accumulator, singletonArray(row(1, 2, "C")));

accumulator = (InternalArray) agg.agg(accumulator, singletonArray(row(0, 1, "B_updated")));
assertThat(unnest(accumulator, elementGetter))
.containsExactlyInAnyOrderElementsOf(
Arrays.asList(row(0, 1, "B_updated"), row(1, 2, "C")));

accumulator = (InternalArray) agg.agg(accumulator, singletonArray(row(2, 3, "D")));
assertThat(unnest(accumulator, elementGetter))
.containsExactlyInAnyOrderElementsOf(
Arrays.asList(row(0, 1, "B_updated"), row(1, 2, "C")));
}

@Test
public void testFieldNestedUpdateAggWithCountLimitOnFirstInputArrayWithoutSequence() {
DataType elementRowType =
DataTypes.ROW(
DataTypes.FIELD(0, "k0", DataTypes.INT()),
DataTypes.FIELD(1, "k1", DataTypes.INT()),
DataTypes.FIELD(2, "v", DataTypes.STRING()));

FieldNestedUpdateAgg agg =
new FieldNestedUpdateAgg(
FieldNestedUpdateAggFactory.NAME,
DataTypes.ARRAY(elementRowType),
Arrays.asList("k0", "k1"),
2);

InternalArray.ElementGetter elementGetter =
InternalArray.createElementGetter(elementRowType);
InternalArray accumulator =
(InternalArray)
agg.agg(
null,
array(
row(0, 1, "B"),
row(1, 2, "C"),
row(2, 3, "D"),
row(0, 1, "B_updated")));

assertThat(unnest(accumulator, elementGetter))
.containsExactlyInAnyOrderElementsOf(
Arrays.asList(row(0, 1, "B_updated"), row(1, 2, "C")));
}

@Test
public void testFieldNestedUpdateAggWithSequenceField() {
DataType elementRowType =
Expand Down Expand Up @@ -1076,12 +1165,86 @@ public void testFieldNestedUpdateAggWithCountLimitWithSequenceField() {
Arrays.asList(row(0, 1, "B_updated", 2), row(1, 2, "C", 3)));
}

@Test
public void testFieldNestedUpdateAggWithCountLimitUpdatesExistingKeyAtLimit() {
DataType elementRowType =
DataTypes.ROW(
DataTypes.FIELD(0, "k0", DataTypes.INT()),
DataTypes.FIELD(1, "k1", DataTypes.INT()),
DataTypes.FIELD(2, "v", DataTypes.STRING()),
DataTypes.FIELD(3, "seq", DataTypes.INT()));

FieldNestedUpdateAgg agg =
new FieldNestedUpdateAgg(
FieldNestedUpdateAggFactory.NAME,
DataTypes.ARRAY(elementRowType),
Arrays.asList("k0", "k1"),
Collections.singletonList("seq"),
2);

InternalArray accumulator = null;
InternalArray.ElementGetter elementGetter =
InternalArray.createElementGetter(elementRowType);

accumulator = (InternalArray) agg.agg(accumulator, singletonArray(row(0, 1, "B", 1)));
accumulator = (InternalArray) agg.agg(accumulator, singletonArray(row(1, 2, "C", 3)));

accumulator =
(InternalArray) agg.agg(accumulator, singletonArray(row(0, 1, "B_updated", 4)));
assertThat(unnest(accumulator, elementGetter))
.containsExactlyInAnyOrderElementsOf(
Arrays.asList(row(0, 1, "B_updated", 4), row(1, 2, "C", 3)));

accumulator = (InternalArray) agg.agg(accumulator, singletonArray(row(2, 3, "D", 5)));
assertThat(unnest(accumulator, elementGetter))
.containsExactlyInAnyOrderElementsOf(
Arrays.asList(row(0, 1, "B_updated", 4), row(1, 2, "C", 3)));
}

@Test
public void testFieldNestedUpdateAggWithCountLimitOnFirstInputArrayWithSequence() {
DataType elementRowType =
DataTypes.ROW(
DataTypes.FIELD(0, "k0", DataTypes.INT()),
DataTypes.FIELD(1, "k1", DataTypes.INT()),
DataTypes.FIELD(2, "v", DataTypes.STRING()),
DataTypes.FIELD(3, "seq", DataTypes.INT()));

FieldNestedUpdateAgg agg =
new FieldNestedUpdateAgg(
FieldNestedUpdateAggFactory.NAME,
DataTypes.ARRAY(elementRowType),
Arrays.asList("k0", "k1"),
Collections.singletonList("seq"),
2);

InternalArray.ElementGetter elementGetter =
InternalArray.createElementGetter(elementRowType);
InternalArray accumulator =
(InternalArray)
agg.agg(
null,
array(
row(0, 1, "B", 1),
row(1, 2, "C", 3),
row(2, 3, "D", 5),
row(0, 1, "B_updated", 4)));

assertThat(unnest(accumulator, elementGetter))
.containsExactlyInAnyOrderElementsOf(
Arrays.asList(row(0, 1, "B_updated", 4), row(1, 2, "C", 3)));
}

private List<Object> unnest(InternalArray array, InternalArray.ElementGetter elementGetter) {
return IntStream.range(0, array.size())
.mapToObj(i -> elementGetter.getElementOrNull(array, i))
.collect(Collectors.toList());
}

private GenericArray array(InternalRow... rows) {
return new GenericArray(rows);
}

private GenericArray singletonArray(InternalRow row) {
return new GenericArray(new InternalRow[] {row});
}
Expand Down