diff --git a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableMerger.java b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableMerger.java index af80391f3d..fc8506584a 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableMerger.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/kudo/KudoTableMerger.java @@ -16,50 +16,36 @@ package com.nvidia.spark.rapids.jni.kudo; +import static com.nvidia.spark.rapids.jni.Preconditions.ensure; +import static com.nvidia.spark.rapids.jni.kudo.ColumnOffsetInfo.INVALID_OFFSET; +import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.getValidityLengthInBytes; +import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padFor64byteAlignment; +import static java.lang.Math.min; +import static java.lang.Math.toIntExact; +import static java.util.Objects.requireNonNull; + import ai.rapids.cudf.HostMemoryBuffer; import ai.rapids.cudf.Schema; import com.nvidia.spark.rapids.jni.Arms; import com.nvidia.spark.rapids.jni.schema.Visitors; - +import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.IntBuffer; import java.util.ArrayList; import java.util.List; import java.util.OptionalInt; -import static com.nvidia.spark.rapids.jni.Preconditions.ensure; -import static com.nvidia.spark.rapids.jni.kudo.ColumnOffsetInfo.INVALID_OFFSET; -import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.getValidityLengthInBytes; -import static com.nvidia.spark.rapids.jni.kudo.KudoSerializer.padFor64byteAlignment; -import static java.lang.Math.min; -import static java.lang.Math.toIntExact; -import static java.util.Objects.requireNonNull; - /** * This class is used to merge multiple KudoTables into a single contiguous buffer, e.g. {@link KudoHostMergeResult}, * which could be easily converted to a {@link ai.rapids.cudf.ContiguousTable}. */ class KudoTableMerger extends MultiKudoTableVisitor { - // Number of 1s in a byte - private static final int[] NUMBER_OF_ONES = new int[256]; - - static { - for (int i = 0; i < NUMBER_OF_ONES.length; i += 1) { - int count = 0; - for (int j = 0; j < 8; j += 1) { - if ((i & (1 << j)) != 0) { - count += 1; - } - } - NUMBER_OF_ONES[i] = count; - } - } - private final List columnOffsets; private final HostMemoryBuffer buffer; private final List colViewInfoList; - public KudoTableMerger(List tables, HostMemoryBuffer buffer, List columnOffsets) { + public KudoTableMerger(List tables, HostMemoryBuffer buffer, + List columnOffsets) { super(tables); requireNonNull(buffer, "buffer can't be null!"); ensure(columnOffsets != null, "column offsets cannot be null"); @@ -155,80 +141,64 @@ private static int copyValidityBuffer(HostMemoryBuffer dest, int startBit, HostMemoryBuffer src, int srcOffset, SliceInfo sliceInfo) { int nullCount = 0; - int totalRowCount = sliceInfo.getRowCount(); - int curIdx = 0; - int curSrcByteIdx = srcOffset; - int curSrcBitIdx = sliceInfo.getValidityBufferInfo().getBeginBit(); - int curDestByteIdx = startBit / 8; - int curDestBitIdx = startBit % 8; - - while (curIdx < totalRowCount) { - int leftRowCount = totalRowCount - curIdx; - int appendCount; - if (curDestBitIdx == 0) { - appendCount = min(8, leftRowCount); - } else { - appendCount = min(8 - curDestBitIdx, leftRowCount); - } - - int leftBitsInCurSrcByte = 8 - curSrcBitIdx; - byte srcByte = src.getByte(curSrcByteIdx); - if (leftBitsInCurSrcByte >= appendCount) { - // Extract appendCount bits from srcByte, starting from curSrcBitIdx - byte mask = (byte) (((1 << appendCount) - 1) & 0xFF); - srcByte = (byte) ((srcByte >>> curSrcBitIdx) & mask); + int totalRowCount = toIntExact(sliceInfo.getRowCount() + sliceInfo.getValidityBufferInfo().getBeginBit()); + int curSrcIdx = sliceInfo.getValidityBufferInfo().getBeginBit(); + int curDestIdx = startBit; - nullCount += (appendCount - NUMBER_OF_ONES[srcByte & 0xFF]); - // Sets the bits in destination buffer starting from curDestBitIdx to 0 - byte destByte = dest.getByte(curDestByteIdx); - destByte = (byte) (destByte & ((1 << curDestBitIdx) - 1) & 0xFF); + while (curSrcIdx < totalRowCount) { + int leftRowCount = totalRowCount - curSrcIdx; - // Update destination byte with the bits from source byte - destByte = (byte) ((destByte | (srcByte << curDestBitIdx)) & 0xFF); - dest.setByte(curDestByteIdx, destByte); + int curDestOffset = (curDestIdx / 32) * Integer.BYTES; + int curDestBitIdx = curDestIdx % 32; - curSrcBitIdx += appendCount; - if (curSrcBitIdx == 8) { - curSrcBitIdx = 0; - curSrcByteIdx += 1; - } - } else { - // Extract appendCount bits from srcByte, starting from curSrcBitIdx - byte mask = (byte) (((1 << leftBitsInCurSrcByte) - 1) & 0xFF); - srcByte = (byte) ((srcByte >>> curSrcBitIdx) & mask); - - byte nextSrcByte = src.getByte(curSrcByteIdx + 1); - byte nextSrcByteMask = (byte) ((1 << (appendCount - leftBitsInCurSrcByte)) - 1); - nextSrcByte = (byte) (nextSrcByte & nextSrcByteMask); - nextSrcByte = (byte) (nextSrcByte << leftBitsInCurSrcByte); - srcByte = (byte) (srcByte | nextSrcByte); + int curSrcOffset = srcOffset + (curSrcIdx / 32) * Integer.BYTES; + int curSrcBitIdx = curSrcIdx % 32; - nullCount += (appendCount - NUMBER_OF_ONES[srcByte & 0xFF]); + // This is safe since we always have validity buffer 4 bytes padded + int srcInt = src.getInt(curSrcOffset); + srcInt = srcInt >>> curSrcBitIdx; - // Sets the bits in destination buffer starting from curDestBitIdx to 0 - byte destByte = dest.getByte(curDestByteIdx); - destByte = (byte) (destByte & ((1 << curDestBitIdx) - 1)); + if (dest.getLength() >= (curDestOffset + Integer.BYTES)) { + // We have enough room to get an int + int destInt = dest.getInt(curDestOffset); + destInt &= (1 << curDestBitIdx) - 1; + destInt |= srcInt << curDestBitIdx; + dest.setInt(curDestOffset, destInt); - // Update destination byte with the bits from source byte - destByte = (byte) (destByte | (srcByte << curDestBitIdx)); - dest.setByte(curDestByteIdx, destByte); - - // Update the source byte index and bit index - curSrcByteIdx += 1; - curSrcBitIdx = appendCount - leftBitsInCurSrcByte; - } + int appendCount = min(leftRowCount, 32 - Math.max(curSrcBitIdx, curDestBitIdx)); - curIdx += appendCount; - - // Update the destination byte index and bit index - curDestBitIdx += appendCount; - if (curDestBitIdx == 8) { - curDestBitIdx = 0; - curDestByteIdx += 1; + curDestIdx += appendCount; + curSrcIdx += appendCount; + if (appendCount == 32) { + nullCount += 32 - Integer.bitCount(srcInt); + } else { + int mask = (1 << appendCount) - 1; + nullCount += (appendCount - Integer.bitCount(srcInt & mask)); + } + } else { + int destBufRemBytes = toIntExact(dest.getLength() - curDestOffset); + byte[] destBytes = new byte[4]; + dest.getBytes(destBytes, 0, curDestOffset, destBufRemBytes); + int destInt = ByteBuffer.wrap(destBytes).order(ByteOrder.LITTLE_ENDIAN).getInt(); + destInt &= (1 << curDestBitIdx) - 1; + destInt |= srcInt << curDestBitIdx; + + ByteBuffer.wrap(destBytes).order(ByteOrder.LITTLE_ENDIAN).putInt(destInt); + dest.setBytes(curDestOffset, destBytes, 0, destBufRemBytes); + + int appendCount = min(leftRowCount, destBufRemBytes * 8 - Math.max(curSrcBitIdx, curDestBitIdx)); + + curDestIdx += appendCount; + curSrcIdx += appendCount; + int mask = (1 << appendCount) - 1; + nullCount += (appendCount - Integer.bitCount(srcInt & mask)); } } + int srcIdx = curSrcIdx; + ensure(curSrcIdx == totalRowCount, () -> "Did not copy all of the validity buffer, total row count: " + totalRowCount + + " current src idx: " + srcIdx); return nullCount; } @@ -325,7 +295,8 @@ static KudoHostMergeResult merge(Schema schema, MergedInfoCalc mergedInfo) { List serializedTables = mergedInfo.getTables(); return Arms.closeIfException(HostMemoryBuffer.allocate(mergedInfo.getTotalDataLen()), buffer -> { - KudoTableMerger merger = new KudoTableMerger(serializedTables, buffer, mergedInfo.getColumnOffsets()); + KudoTableMerger merger = + new KudoTableMerger(serializedTables, buffer, mergedInfo.getColumnOffsets()); return Visitors.visitSchema(schema, merger); }); } diff --git a/src/test/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializerTest.java b/src/test/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializerTest.java index 210777accf..1953e221ba 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializerTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/kudo/KudoSerializerTest.java @@ -39,7 +39,8 @@ public class KudoSerializerTest { public void testSerializeAndDeserializeTable() { try(Table expected = buildTestTable()) { int rowCount = toIntExact(expected.getRowCount()); - for (int sliceSize = 1; sliceSize <= rowCount; sliceSize++) { + IntStream sliceSizes = IntStream.range(1, rowCount + 1); + for (int sliceSize: sliceSizes.toArray()) { List tableSlices = new ArrayList<>(); for (int startRow = 0; startRow < rowCount; startRow += sliceSize) { tableSlices.add(new TableSlice(startRow, Math.min(sliceSize, rowCount - startRow), expected));