Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Parquet] Fix parquet decimal type match. #5001

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import org.apache.paimon.data.Timestamp;
import org.apache.paimon.data.columnar.heap.HeapIntVector;
import org.apache.paimon.data.columnar.heap.HeapLongVector;
import org.apache.paimon.data.columnar.writable.WritableBooleanVector;
import org.apache.paimon.data.columnar.writable.WritableByteVector;
import org.apache.paimon.data.columnar.writable.WritableBytesVector;
Expand Down Expand Up @@ -64,7 +65,6 @@

import java.math.BigDecimal;
import java.math.BigInteger;
import java.math.RoundingMode;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.concurrent.TimeUnit;
Expand All @@ -79,13 +79,8 @@
/** Updater Factory to get {@link ParquetVectorUpdater}. */
public class ParquetVectorUpdaterFactory {

private final LogicalTypeAnnotation logicalTypeAnnotation;

ParquetVectorUpdaterFactory(LogicalTypeAnnotation logicalTypeAnnotation) {
this.logicalTypeAnnotation = logicalTypeAnnotation;
}

public ParquetVectorUpdater getUpdater(ColumnDescriptor descriptor, DataType paimonType) {
public static ParquetVectorUpdater getUpdater(
ColumnDescriptor descriptor, DataType paimonType) {
return paimonType.accept(UpdaterFactoryVisitor.INSTANCE).apply(descriptor);
}

Expand Down Expand Up @@ -614,10 +609,10 @@ public void decodeSingleDictionaryId(
private abstract static class DecimalUpdater<T extends WritableColumnVector>
implements ParquetVectorUpdater<T> {

private final DecimalType sparkType;
protected final DecimalType paimonType;

DecimalUpdater(DecimalType sparkType) {
this.sparkType = sparkType;
DecimalUpdater(DecimalType paimonType) {
this.paimonType = paimonType;
}

@Override
Expand All @@ -627,22 +622,6 @@ public void readValues(
readValue(offset + i, values, valuesReader);
}
}

protected void writeDecimal(int offset, WritableColumnVector values, BigDecimal decimal) {
BigDecimal scaledDecimal =
decimal.setScale(sparkType.getScale(), RoundingMode.UNNECESSARY);
int precision = decimal.precision();
if (ParquetSchemaConverter.is32BitDecimal(precision)) {
((WritableIntVector) values)
.setInt(offset, scaledDecimal.unscaledValue().intValue());
} else if (ParquetSchemaConverter.is64BitDecimal(precision)) {
((WritableLongVector) values)
.setLong(offset, scaledDecimal.unscaledValue().longValue());
} else {
byte[] bytes = scaledDecimal.unscaledValue().toByteArray();
((WritableBytesVector) values).putByteArray(offset, bytes, 0, bytes.length);
}
}
}

private static class IntegerToDecimalUpdater extends DecimalUpdater<WritableIntVector> {
Expand Down Expand Up @@ -687,8 +666,8 @@ public void decodeSingleDictionaryId(
private static class LongToDecimalUpdater extends DecimalUpdater<WritableLongVector> {
private final int parquetScale;

LongToDecimalUpdater(ColumnDescriptor descriptor, DecimalType sparkType) {
super(sparkType);
LongToDecimalUpdater(ColumnDescriptor descriptor, DecimalType paimonType) {
super(paimonType);
LogicalTypeAnnotation typeAnnotation =
descriptor.getPrimitiveType().getLogicalTypeAnnotation();
if (typeAnnotation instanceof DecimalLogicalTypeAnnotation) {
Expand Down Expand Up @@ -726,8 +705,8 @@ public void decodeSingleDictionaryId(
private static class BinaryToDecimalUpdater extends DecimalUpdater<WritableBytesVector> {
private final int parquetScale;

BinaryToDecimalUpdater(ColumnDescriptor descriptor, DecimalType sparkType) {
super(sparkType);
BinaryToDecimalUpdater(ColumnDescriptor descriptor, DecimalType paimonType) {
super(paimonType);
LogicalTypeAnnotation typeAnnotation =
descriptor.getPrimitiveType().getLogicalTypeAnnotation();
this.parquetScale = ((DecimalLogicalTypeAnnotation) typeAnnotation).getScale();
Expand Down Expand Up @@ -767,14 +746,16 @@ public void decodeSingleDictionaryId(

private static class FixedLenByteArrayToDecimalUpdater
extends DecimalUpdater<WritableBytesVector> {
private final int parquetScale;
private final int arrayLen;

FixedLenByteArrayToDecimalUpdater(ColumnDescriptor descriptor, DecimalType sparkType) {
super(sparkType);
FixedLenByteArrayToDecimalUpdater(ColumnDescriptor descriptor, DecimalType paimonType) {
super(paimonType);
LogicalTypeAnnotation typeAnnotation =
descriptor.getPrimitiveType().getLogicalTypeAnnotation();
this.parquetScale = ((DecimalLogicalTypeAnnotation) typeAnnotation).getScale();
int parquetScale = ((DecimalLogicalTypeAnnotation) typeAnnotation).getScale();
checkArgument(
parquetScale == paimonType.getScale(),
"Scale should be match between paimon decimal type and parquet decimal type in file");
this.arrayLen = descriptor.getPrimitiveType().getTypeLength();
}

Expand All @@ -786,10 +767,33 @@ public void skipValues(int total, VectorizedValuesReader valuesReader) {
@Override
public void readValue(
int offset, WritableBytesVector values, VectorizedValuesReader valuesReader) {
BigInteger value = new BigInteger(valuesReader.readBinary(arrayLen).getBytesUnsafe());
BigDecimal decimal = new BigDecimal(value, this.parquetScale);
byte[] bytes = decimal.unscaledValue().toByteArray();
values.putByteArray(offset, bytes, 0, bytes.length);
Binary binary = valuesReader.readBinary(arrayLen);

int precision = paimonType.getPrecision();
if (ParquetSchemaConverter.is32BitDecimal(precision)) {
((HeapIntVector) values).setInt(offset, (int) heapBinaryToLong(binary));
} else if (ParquetSchemaConverter.is64BitDecimal(precision)) {
((HeapLongVector) values).setLong(offset, heapBinaryToLong(binary));
} else {
byte[] bytes = binary.getBytesUnsafe();
values.putByteArray(offset, bytes, 0, bytes.length);
}
}

private long heapBinaryToLong(Binary binary) {
ByteBuffer buffer = binary.toByteBuffer();
byte[] bytes = buffer.array();
int start = buffer.arrayOffset() + buffer.position();
int end = buffer.arrayOffset() + buffer.limit();

long unscaled = 0L;

for (int i = start; i < end; i++) {
unscaled = (unscaled << 8) | (bytes[i] & 0xff);
}

int bits = 8 * (end - start);
return (unscaled << (64 - bits)) >> (64 - bits);
}

@Override
Expand All @@ -798,14 +802,16 @@ public void decodeSingleDictionaryId(
WritableBytesVector values,
WritableIntVector dictionaryIds,
Dictionary dictionary) {
BigInteger value =
new BigInteger(
dictionary
.decodeToBinary(dictionaryIds.getInt(offset))
.getBytesUnsafe());
BigDecimal decimal = new BigDecimal(value, this.parquetScale);
byte[] bytes = decimal.unscaledValue().toByteArray();
values.putByteArray(offset, bytes, 0, bytes.length);
Binary binary = dictionary.decodeToBinary(dictionaryIds.getInt(offset));
int precision = paimonType.getPrecision();
if (ParquetSchemaConverter.is32BitDecimal(precision)) {
((HeapIntVector) values).setInt(offset, (int) heapBinaryToLong(binary));
} else if (ParquetSchemaConverter.is64BitDecimal(precision)) {
((HeapLongVector) values).setLong(offset, heapBinaryToLong(binary));
} else {
byte[] bytes = binary.getBytesUnsafe();
values.putByteArray(offset, bytes, 0, bytes.length);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import org.apache.parquet.column.page.PageReader;
import org.apache.parquet.column.values.RequiresPreviousReader;
import org.apache.parquet.column.values.ValuesReader;
import org.apache.parquet.schema.LogicalTypeAnnotation;
import org.apache.parquet.schema.PrimitiveType;

import java.io.IOException;
Expand Down Expand Up @@ -67,9 +66,6 @@ public class VectorizedColumnReader {
/** Vectorized RLE decoder for repetition levels. */
private VectorizedRleValuesReader repColumn;

/** Factory to get type-specific vector updater. */
private final ParquetVectorUpdaterFactory updaterFactory;

/**
* Helper struct to track intermediate states while reading Parquet pages in the column chunk.
*/
Expand All @@ -83,7 +79,6 @@ public class VectorizedColumnReader {

private final PageReader pageReader;
private final ColumnDescriptor descriptor;
private final LogicalTypeAnnotation logicalTypeAnnotation;
private final ParsedVersion writerVersion;

public VectorizedColumnReader(
Expand All @@ -97,8 +92,6 @@ public VectorizedColumnReader(
this.readState =
new ParquetReadState(
descriptor, isRequired, pageReadStore.getRowIndexes().orElse(null));
this.logicalTypeAnnotation = descriptor.getPrimitiveType().getLogicalTypeAnnotation();
this.updaterFactory = new ParquetVectorUpdaterFactory(logicalTypeAnnotation);

DictionaryPage dictionaryPage = pageReader.readDictionaryPage();
if (dictionaryPage != null) {
Expand All @@ -120,7 +113,7 @@ public VectorizedColumnReader(
}

private boolean isLazyDecodingSupported(
PrimitiveType.PrimitiveTypeName typeName, DataType sparkType) {
PrimitiveType.PrimitiveTypeName typeName, DataType paimonType) {
return true;
}

Expand All @@ -133,7 +126,7 @@ void readBatch(
WritableIntVector definitionLevels)
throws IOException {
WritableIntVector dictionaryIds = null;
ParquetVectorUpdater updater = updaterFactory.getUpdater(descriptor, type);
ParquetVectorUpdater updater = ParquetVectorUpdaterFactory.getUpdater(descriptor, type);

if (dictionary != null) {
// SPARK-16334: We only maintain a single dictionary per row batch, so that it can be
Expand Down
Loading