diff --git a/paimon-common/src/main/java/org/apache/paimon/data/columnar/heap/CastedVectorColumnVector.java b/paimon-common/src/main/java/org/apache/paimon/data/columnar/heap/CastedVectorColumnVector.java new file mode 100644 index 000000000000..a7cd1b001bd0 --- /dev/null +++ b/paimon-common/src/main/java/org/apache/paimon/data/columnar/heap/CastedVectorColumnVector.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.data.columnar.heap; + +import org.apache.paimon.data.InternalVector; +import org.apache.paimon.data.columnar.ColumnVector; +import org.apache.paimon.data.columnar.ColumnarVec; +import org.apache.paimon.data.columnar.VecColumnVector; + +/** + * Cast internal Vector to paimon readable vector(cast for Timestamp type and Decimal type) for + * vector type. + */ +public class CastedVectorColumnVector implements VecColumnVector { + + private final HeapArrayVector heapArrayVector; + private final ColumnVector[] children; + private final int vectorSize; + + public CastedVectorColumnVector( + HeapArrayVector heapArrayVector, ColumnVector child, int vectorSize) { + this.heapArrayVector = heapArrayVector; + this.children = new ColumnVector[] {child}; + this.vectorSize = vectorSize; + } + + @Override + public InternalVector getVector(int i) { + long offset = heapArrayVector.offsets[i]; + long length = heapArrayVector.lengths[i]; + if (length != vectorSize) { + throw new IllegalArgumentException( + "Vector length mismatch: expected " + vectorSize + " but got " + length); + } + return ColumnarVec.DEFAULT_FACTORY.create(children[0], (int) offset, (int) length); + } + + @Override + public ColumnVector getColumnVector() { + return children[0]; + } + + @Override + public int getVectorSize() { + return vectorSize; + } + + @Override + public boolean isNullAt(int i) { + return heapArrayVector.isNullAt(i); + } + + @Override + public int getCapacity() { + return heapArrayVector.getCapacity(); + } + + @Override + public ColumnVector[] getChildren() { + return children; + } +} diff --git a/paimon-format/src/main/java/org/apache/paimon/format/parquet/ParquetReaderFactory.java b/paimon-format/src/main/java/org/apache/paimon/format/parquet/ParquetReaderFactory.java index e7941b2e1cd8..fad603e74c4b 100644 --- a/paimon-format/src/main/java/org/apache/paimon/format/parquet/ParquetReaderFactory.java +++ b/paimon-format/src/main/java/org/apache/paimon/format/parquet/ParquetReaderFactory.java @@ -35,6 +35,7 @@ import org.apache.paimon.types.DataType; import org.apache.paimon.types.MapType; import org.apache.paimon.types.RowType; +import org.apache.paimon.types.VectorType; import org.apache.paimon.utils.Pair; import org.apache.paimon.utils.Preconditions; @@ -225,7 +226,11 @@ private Type clipParquetType(DataType readType, Type parquetType) { clipParquetType(mapType.getKeyType(), keyValueType.getLeft()), clipParquetType(mapType.getValueType(), keyValueType.getRight())); case ARRAY: - ArrayType arrayType = (ArrayType) readType; + case VECTOR: + DataType elementReadType = + readType instanceof ArrayType + ? ((ArrayType) readType).getElementType() + : ((VectorType) readType).getElementType(); GroupType arrayGroup = (GroupType) parquetType; int listSubFields = arrayGroup.getFieldCount(); Preconditions.checkArgument( @@ -236,8 +241,7 @@ private Type clipParquetType(DataType readType, Type parquetType) { // https://impala.apache.org/docs/build/html/topics/impala_parquet_array_resolution.html. int level = arrayGroup.getType(0) instanceof GroupType ? 3 : 2; Type elementType = - clipParquetType( - arrayType.getElementType(), parquetListElementType(arrayGroup)); + clipParquetType(elementReadType, parquetListElementType(arrayGroup)); if (level == 3) { // In case that the name in middle level is not "list". diff --git a/paimon-format/src/main/java/org/apache/paimon/format/parquet/ParquetSchemaConverter.java b/paimon-format/src/main/java/org/apache/paimon/format/parquet/ParquetSchemaConverter.java index 640081cd5002..3ce514cb759e 100644 --- a/paimon-format/src/main/java/org/apache/paimon/format/parquet/ParquetSchemaConverter.java +++ b/paimon-format/src/main/java/org/apache/paimon/format/parquet/ParquetSchemaConverter.java @@ -31,6 +31,7 @@ import org.apache.paimon.types.MultisetType; import org.apache.paimon.types.RowType; import org.apache.paimon.types.TimestampType; +import org.apache.paimon.types.VectorType; import org.apache.paimon.utils.Pair; import org.apache.parquet.schema.ConversionPatterns; @@ -159,13 +160,13 @@ public static Type convertToParquetType(String name, DataType type, int fieldId, name, localZonedTimestampType.getPrecision(), repetition, true) .withId(fieldId); case ARRAY: - ArrayType arrayType = (ArrayType) type; + case VECTOR: + DataType listElementType = + type instanceof ArrayType + ? ((ArrayType) type).getElementType() + : ((VectorType) type).getElementType(); Type elementParquetType = - convertToParquetType( - LIST_ELEMENT_NAME, - arrayType.getElementType(), - fieldId, - depth + 1) + convertToParquetType(LIST_ELEMENT_NAME, listElementType, fieldId, depth + 1) .withId(SpecialFields.getArrayElementFieldId(fieldId, depth + 1)); return ConversionPatterns.listOfElements(repetition, name, elementParquetType) .withId(fieldId); diff --git a/paimon-format/src/main/java/org/apache/paimon/format/parquet/reader/ParquetColumnVector.java b/paimon-format/src/main/java/org/apache/paimon/format/parquet/reader/ParquetColumnVector.java index 37fc4272c6d3..f469e772b045 100644 --- a/paimon-format/src/main/java/org/apache/paimon/format/parquet/reader/ParquetColumnVector.java +++ b/paimon-format/src/main/java/org/apache/paimon/format/parquet/reader/ParquetColumnVector.java @@ -191,6 +191,7 @@ void assemble() { DataTypeRoot type = column.getType().getTypeRoot(); if (type == DataTypeRoot.ARRAY + || type == DataTypeRoot.VECTOR || type == DataTypeRoot.MAP || type == DataTypeRoot.MULTISET) { for (ParquetColumnVector child : children) { diff --git a/paimon-format/src/main/java/org/apache/paimon/format/parquet/reader/ParquetReaderUtil.java b/paimon-format/src/main/java/org/apache/paimon/format/parquet/reader/ParquetReaderUtil.java index a2741f869ab6..316cc2e4fede 100644 --- a/paimon-format/src/main/java/org/apache/paimon/format/parquet/reader/ParquetReaderUtil.java +++ b/paimon-format/src/main/java/org/apache/paimon/format/parquet/reader/ParquetReaderUtil.java @@ -22,6 +22,7 @@ import org.apache.paimon.data.columnar.heap.CastedArrayColumnVector; import org.apache.paimon.data.columnar.heap.CastedMapColumnVector; import org.apache.paimon.data.columnar.heap.CastedRowColumnVector; +import org.apache.paimon.data.columnar.heap.CastedVectorColumnVector; import org.apache.paimon.data.columnar.heap.HeapArrayVector; import org.apache.paimon.data.columnar.heap.HeapBooleanVector; import org.apache.paimon.data.columnar.heap.HeapByteVector; @@ -51,6 +52,7 @@ import org.apache.paimon.types.MultisetType; import org.apache.paimon.types.RowType; import org.apache.paimon.types.VariantType; +import org.apache.paimon.types.VectorType; import org.apache.paimon.utils.Pair; import org.apache.paimon.utils.StringUtils; @@ -126,6 +128,11 @@ public static WritableColumnVector createWritableColumnVector( return new HeapArrayVector( batchSize, createWritableColumnVector(batchSize, arrayType.getElementType())); + case VECTOR: + VectorType vectorType = (VectorType) fieldType; + return new HeapArrayVector( + batchSize, + createWritableColumnVector(batchSize, vectorType.getElementType())); case MAP: MapType mapType = (MapType) fieldType; return new HeapMapVector( @@ -188,6 +195,16 @@ public static ColumnVector createReadableColumnVector( Arrays.stream(writableVector.getChildren()) .map(WritableColumnVector.class::cast) .toArray(WritableColumnVector[]::new))); + case VECTOR: + VectorType vectorType = (VectorType) type; + return new CastedVectorColumnVector( + (HeapArrayVector) writableVector, + createReadableColumnVectors( + Collections.singletonList(vectorType.getElementType()), + Arrays.stream(writableVector.getChildren()) + .map(WritableColumnVector.class::cast) + .toArray(WritableColumnVector[]::new))[0], + vectorType.getLength()); case MAP: MapType mapType = (MapType) type; return new CastedMapColumnVector( @@ -322,8 +339,11 @@ private static ParquetField constructField( groupColumnIO.getFieldPath()); } - if (type instanceof ArrayType) { - ArrayType arrayType = (ArrayType) type; + if (type instanceof ArrayType || type instanceof VectorType) { + DataType elementType = + type instanceof ArrayType + ? ((ArrayType) type).getElementType() + : ((VectorType) type).getElementType(); ColumnIO elementTypeColumnIO; if (columnIO instanceof GroupColumnIO) { GroupColumnIO groupColumnIO = (GroupColumnIO) columnIO; @@ -333,7 +353,7 @@ private static ParquetField constructField( } elementTypeColumnIO = groupColumnIO; } else { - if (arrayType.getElementType() instanceof RowType) { + if (elementType instanceof RowType) { elementTypeColumnIO = groupColumnIO; } else { elementTypeColumnIO = groupColumnIO.getChild(0); @@ -347,7 +367,7 @@ private static ParquetField constructField( ParquetField field = constructField( - new DataField(0, "", arrayType.getElementType()), + new DataField(0, "", elementType), getArrayElementColumn(elementTypeColumnIO), parquetListElementType(parquetType.asGroupType())); if (repetitionLevel == field.getRepetitionLevel()) { diff --git a/paimon-format/src/main/java/org/apache/paimon/format/parquet/writer/ParquetRowDataWriter.java b/paimon-format/src/main/java/org/apache/paimon/format/parquet/writer/ParquetRowDataWriter.java index 80b788733342..6a77b6cf20f1 100644 --- a/paimon-format/src/main/java/org/apache/paimon/format/parquet/writer/ParquetRowDataWriter.java +++ b/paimon-format/src/main/java/org/apache/paimon/format/parquet/writer/ParquetRowDataWriter.java @@ -40,6 +40,7 @@ import org.apache.paimon.types.RowType; import org.apache.paimon.types.TimestampType; import org.apache.paimon.types.VariantType; +import org.apache.paimon.types.VectorType; import org.apache.hadoop.conf.Configuration; import org.apache.parquet.io.api.Binary; @@ -138,9 +139,15 @@ private FieldWriter createWriter(DataType t, Type type) { GroupType groupType = type.asGroupType(); LogicalTypeAnnotation annotation = type.getLogicalTypeAnnotation(); - if (t instanceof ArrayType + if ((t instanceof ArrayType || t instanceof VectorType) && annotation instanceof LogicalTypeAnnotation.ListLogicalTypeAnnotation) { - return new ArrayWriter(((ArrayType) t).getElementType(), groupType); + DataType elementType = + t instanceof ArrayType + ? ((ArrayType) t).getElementType() + : ((VectorType) t).getElementType(); + Integer expectedVectorLength = + t instanceof VectorType ? ((VectorType) t).getLength() : null; + return new ArrayWriter(elementType, groupType, expectedVectorLength); } else if (t instanceof MapType && annotation instanceof LogicalTypeAnnotation.MapLogicalTypeAnnotation) { return new MapWriter( @@ -511,8 +518,10 @@ private class ArrayWriter implements FieldWriter { private final String elementName; private final FieldWriter elementWriter; private final String repeatedGroupName; + @Nullable private final Integer expectedVectorLength; - private ArrayWriter(DataType t, GroupType groupType) { + private ArrayWriter( + DataType t, GroupType groupType, @Nullable Integer expectedVectorLength) { // Get the internal array structure GroupType repeatedType = groupType.getType(0).asGroupType(); this.repeatedGroupName = repeatedType.getName(); @@ -521,22 +530,35 @@ private ArrayWriter(DataType t, GroupType groupType) { this.elementName = elementType.getName(); this.elementWriter = createWriter(t, elementType); + this.expectedVectorLength = expectedVectorLength; } @Override public void write(InternalRow row, int ordinal) { - writeArrayData(row.getArray(ordinal)); + writeArrayData( + expectedVectorLength != null ? row.getVector(ordinal) : row.getArray(ordinal)); } @Override public void write(InternalArray arrayData, int ordinal) { - writeArrayData(arrayData.getArray(ordinal)); + writeArrayData( + expectedVectorLength != null + ? arrayData.getVector(ordinal) + : arrayData.getArray(ordinal)); } private void writeArrayData(InternalArray arrayData) { recordConsumer.startGroup(); int listLength = arrayData.size(); + if (expectedVectorLength != null && listLength != expectedVectorLength) { + throw new IllegalArgumentException( + "Vector length mismatch: expected " + + expectedVectorLength + + " but got " + + listLength); + } + if (listLength > 0) { recordConsumer.startField(repeatedGroupName, 0); for (int i = 0; i < listLength; i++) { diff --git a/paimon-format/src/test/java/org/apache/paimon/format/parquet/ParquetReadWriteTest.java b/paimon-format/src/test/java/org/apache/paimon/format/parquet/ParquetReadWriteTest.java index 5851ef7db5f4..ef18dd0c29fb 100644 --- a/paimon-format/src/test/java/org/apache/paimon/format/parquet/ParquetReadWriteTest.java +++ b/paimon-format/src/test/java/org/apache/paimon/format/parquet/ParquetReadWriteTest.java @@ -19,12 +19,14 @@ package org.apache.paimon.format.parquet; import org.apache.paimon.data.BinaryString; +import org.apache.paimon.data.BinaryVector; import org.apache.paimon.data.Decimal; import org.apache.paimon.data.GenericArray; import org.apache.paimon.data.GenericMap; import org.apache.paimon.data.GenericRow; import org.apache.paimon.data.InternalMap; import org.apache.paimon.data.InternalRow; +import org.apache.paimon.data.InternalVector; import org.apache.paimon.data.Timestamp; import org.apache.paimon.data.serializer.InternalRowSerializer; import org.apache.paimon.format.FormatReaderContext; @@ -741,6 +743,52 @@ public void testReadTimestampNanosWrittenByParquet() throws Exception { assertThat(count.get()).isEqualTo(nanosValues.length); } + @Test + public void testReadWriteVector() throws Exception { + RowType rowType = + RowType.builder() + .fields(DataTypes.INT(), DataTypes.VECTOR(3, DataTypes.FLOAT())) + .build(); + List rows = + Arrays.asList( + GenericRow.of(1, BinaryVector.fromPrimitiveArray(new float[] {1, 2, 3})), + GenericRow.of(2, BinaryVector.fromPrimitiveArray(new float[] {4, 5, 6}))); + + Path path = createTempParquetFileByPaimon(folder, rows, 1024, rowType); + ParquetReaderFactory format = + new ParquetReaderFactory(new Options(), rowType, 500, FilterCompat.NOOP); + + RecordReader reader = + format.createReader( + new FormatReaderContext( + new LocalFileIO(), path, new LocalFileIO().getFileSize(path))); + List results = new ArrayList<>(); + InternalRowSerializer serializer = new InternalRowSerializer(rowType); + reader.forEachRemaining(row -> results.add(serializer.copy(row))); + + assertThat(results).hasSize(2); + assertThat(results.get(0).getInt(0)).isEqualTo(1); + assertVector(results.get(0).getVector(1), new float[] {1, 2, 3}); + assertThat(results.get(1).getInt(0)).isEqualTo(2); + assertVector(results.get(1).getVector(1), new float[] {4, 5, 6}); + } + + @Test + public void testWriteVectorLengthMismatch() { + RowType rowType = + RowType.builder() + .fields(DataTypes.INT(), DataTypes.VECTOR(3, DataTypes.FLOAT())) + .build(); + List rows = + Collections.singletonList( + GenericRow.of( + 1, BinaryVector.fromPrimitiveArray(new float[] {1, 2, 3, 4}))); + + assertThatThrownBy(() -> createTempParquetFileByPaimon(folder, rows, 1024, rowType)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("Vector length mismatch: expected 3 but got 4"); + } + private void innerTestTypes(File folder, List records, int rowGroupSize) throws IOException { List rows = records.stream().map(this::newRow).collect(Collectors.toList()); @@ -749,6 +797,10 @@ private void innerTestTypes(File folder, List records, int rowGroupSize assertThat(len).isEqualTo(records.size()); } + private static void assertVector(InternalVector vector, float[] expected) { + Assertions.assertArrayEquals(expected, vector.toFloatArray()); + } + private Path createTempParquetFileByPaimon( File folder, List rows, int rowGroupSize, RowType rowType) throws IOException {