Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ jobs:
run: mvn -f java/pom.xml test

- name: Build JNI library
run: cargo build -p paimon-vindex-jni --release
run: cargo build -p paimon-vindex-jni --release --features panic-boundary-test-hook

- name: Test JNI native behavior
run: |
Expand Down
36 changes: 29 additions & 7 deletions core/src/index_io_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -434,9 +434,10 @@ pub(crate) fn read_i64_le<R: SeekRead + ?Sized>(
Ok(i64::from_le_bytes(buf))
}

pub(crate) fn read_f32_vec<R: SeekRead + ?Sized>(
pub(crate) fn read_f32_vec_checked<R: SeekRead + ?Sized>(
reader: &mut PreadCursor<'_, R>,
count: usize,
section: &str,
) -> io::Result<Vec<f32>> {
let byte_len = count.checked_mul(4).ok_or_else(|| {
io::Error::new(
Expand All @@ -446,20 +447,41 @@ pub(crate) fn read_f32_vec<R: SeekRead + ?Sized>(
})?;
let mut buf = vec![0u8; byte_len];
reader.read_exact(&mut buf)?;
bytes_to_f32_vec(&buf)
bytes_to_f32_vec_checked(&buf, section)
}

pub(crate) fn bytes_to_f32_vec(bytes: &[u8]) -> io::Result<Vec<f32>> {
pub(crate) fn bytes_to_f32_vec_checked(bytes: &[u8], section: &str) -> io::Result<Vec<f32>> {
if !bytes.len().is_multiple_of(4) {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"f32 byte section is not 4-byte aligned",
));
}
Ok(bytes
.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect())
let mut values = Vec::with_capacity(bytes.len() / 4);
for (offset, chunk) in bytes.chunks_exact(4).enumerate() {
let value = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
if !value.is_finite() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"{} contains non-finite value at offset {}: {}",
section, offset, value
),
));
}
values.push(value);
}
Ok(values)
}

pub(crate) fn validate_finite_f32_value(value: f32, section: &str) -> io::Result<()> {
if !value.is_finite() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("{} contains non-finite value: {}", section, value),
));
}
Ok(())
}

pub(crate) fn validate_positive_i32(val: i32, field: &str) -> io::Result<i32> {
Expand Down
103 changes: 98 additions & 5 deletions core/src/io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,19 +207,42 @@ fn checked_list_bytes(count: usize, bytes_per_entry: usize) -> io::Result<usize>
})
}

fn read_f32_vec<R: SeekRead + ?Sized>(
fn read_f32_vec_checked<R: SeekRead + ?Sized>(
reader: &mut PreadCursor<'_, R>,
count: usize,
section: &str,
) -> io::Result<Vec<f32>> {
let mut buf = vec![0u8; count * 4];
let byte_len = count.checked_mul(4).ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
"f32 section byte length overflow",
)
})?;
let mut buf = vec![0u8; byte_len];
reader.read_exact(&mut buf)?;
let floats: Vec<f32> = buf
.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect();
validate_finite_f32_values(&floats, section)?;
Ok(floats)
}

fn validate_finite_f32_values(values: &[f32], section: &str) -> io::Result<()> {
for (offset, &value) in values.iter().enumerate() {
if !value.is_finite() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"{} contains non-finite value at offset {}: {}",
section, offset, value
),
));
}
}
Ok(())
}

/// Write a complete IVF-PQ index with delta-varint ID encoding.
pub fn write_index(index: &IVFPQIndex, out: &mut dyn SeekWrite) -> io::Result<()> {
let d = index.d;
Expand Down Expand Up @@ -594,7 +617,7 @@ impl<R: SeekRead> IVFPQIndexReader<R> {
let mut cursor = PreadCursor::new(&mut self.reader, self.centroids_offset);
if self.has_opq {
cursor.seek(HEADER_SIZE as u64);
let rotation = read_f32_vec(&mut cursor, rotation_count)?;
let rotation = read_f32_vec_checked(&mut cursor, rotation_count, "IVFPQ OPQ matrix")?;
self.opq = Some(OPQMatrix {
d,
m,
Expand All @@ -607,9 +630,11 @@ impl<R: SeekRead> IVFPQIndexReader<R> {
});
}

self.quantizer_centroids = read_f32_vec(&mut cursor, centroids_count)?;
self.quantizer_centroids =
read_f32_vec_checked(&mut cursor, centroids_count, "IVFPQ coarse centroids")?;

let pq_centroids = read_f32_vec(&mut cursor, pq_centroids_count)?;
let pq_centroids =
read_f32_vec_checked(&mut cursor, pq_centroids_count, "IVFPQ PQ codebooks")?;
self.pq = ProductQuantizer {
d,
m,
Expand Down Expand Up @@ -1313,6 +1338,48 @@ mod tests {
);
}

#[test]
fn test_reader_rejects_non_finite_opq_matrix() {
let mut buf = build_small_ivfpq_index_bytes(true);
write_f32_at(&mut buf, HEADER_SIZE, f32::NAN);

let mut cursor = Cursor::new(&buf);
let mut reader = IVFPQIndexReader::open(&mut cursor).unwrap();
let err = reader.ensure_loaded().unwrap_err();

assert_eq!(err.kind(), io::ErrorKind::InvalidData);
assert!(err.to_string().contains("IVFPQ OPQ matrix"));
}

#[test]
fn test_reader_rejects_non_finite_coarse_centroid() {
let mut buf = build_small_ivfpq_index_bytes(false);
write_f32_at(&mut buf, HEADER_SIZE, f32::INFINITY);

let mut cursor = Cursor::new(&buf);
let mut reader = IVFPQIndexReader::open(&mut cursor).unwrap();
let err = reader.ensure_loaded().unwrap_err();

assert_eq!(err.kind(), io::ErrorKind::InvalidData);
assert!(err.to_string().contains("IVFPQ coarse centroids"));
}

#[test]
fn test_reader_rejects_non_finite_pq_codebook() {
let mut buf = build_small_ivfpq_index_bytes(false);
let d = read_i32_at(&buf, 8) as usize;
let nlist = read_i32_at(&buf, 12) as usize;
let pq_codebook_offset = HEADER_SIZE + nlist * d * 4;
write_f32_at(&mut buf, pq_codebook_offset, f32::NAN);

let mut cursor = Cursor::new(&buf);
let mut reader = IVFPQIndexReader::open(&mut cursor).unwrap();
let err = reader.ensure_loaded().unwrap_err();

assert_eq!(err.kind(), io::ErrorKind::InvalidData);
assert!(err.to_string().contains("IVFPQ PQ codebooks"));
}

#[test]
fn test_negative_header_d_returns_error() {
let mut buf = Vec::new();
Expand Down Expand Up @@ -1498,4 +1565,30 @@ mod tests {
let result = IVFPQIndexReader::open(&mut cursor);
assert!(result.is_err(), "d != m*dsub should return error");
}

fn build_small_ivfpq_index_bytes(use_opq: bool) -> Vec<u8> {
let d = 4;
let nlist = 1;
let m = 1;
let n = 300;
let mut rng = rand::rngs::StdRng::seed_from_u64(7);
let data: Vec<f32> = (0..n * d).map(|_| rng.gen::<f32>()).collect();
let ids: Vec<i64> = (0..n as i64).collect();

let mut index = IVFPQIndex::new(d, nlist, m, MetricType::L2, use_opq);
index.train(&data, n);
index.add(&data, &ids, n);

let mut buf = Vec::new();
write_index(&index, &mut PosWriter::new(&mut buf)).unwrap();
buf
}

fn read_i32_at(buf: &[u8], offset: usize) -> i32 {
i32::from_le_bytes(buf[offset..offset + 4].try_into().unwrap())
}

fn write_f32_at(buf: &mut [u8], offset: usize, value: f32) {
buf[offset..offset + 4].copy_from_slice(&value.to_le_bytes());
}
}
110 changes: 97 additions & 13 deletions core/src/ivfflat_io.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,11 @@ impl<R: SeekRead> IVFFlatIndexReader<R> {
}

let mut cursor = PreadCursor::new(&mut self.reader, IVFFLAT_HEADER_SIZE as u64);
self.quantizer_centroids =
read_f32_vec(&mut cursor, checked_section_size(self.nlist, self.d)?)?;
self.quantizer_centroids = read_f32_vec_checked(
&mut cursor,
checked_section_size(self.nlist, self.d)?,
"IVF-FLAT coarse centroids",
)?;
self.list_offsets = vec![0; self.nlist];
self.list_counts = vec![0; self.nlist];
self.list_id_bytes_lens = vec![0; self.nlist];
Expand Down Expand Up @@ -294,7 +297,10 @@ impl<R: SeekRead> IVFFlatIndexReader<R> {
));
}
let ids = decode_delta_varint_ids(base_id, &payload[12..12 + id_bytes_len], count)?;
let vectors = bytes_to_f32_vec(&payload[12 + id_bytes_len..])?;
let vectors = bytes_to_f32_vec_checked(
&payload[12 + id_bytes_len..],
&format!("IVF-FLAT list {} vectors", list_id),
)?;
Ok((ids, vectors))
} else {
Err(io::Error::new(
Expand Down Expand Up @@ -542,7 +548,7 @@ impl ReaderTopKHeap {
.data
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.0.partial_cmp(&b.0).unwrap())
.max_by(|(_, a), (_, b)| a.0.total_cmp(&b.0))
{
if dist < self.data[worst_idx].0 {
self.data[worst_idx] = (dist, id);
Expand All @@ -551,7 +557,7 @@ impl ReaderTopKHeap {
}

fn into_sorted(mut self) -> Vec<(f32, i64)> {
self.data.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
self.data.sort_by(|a, b| a.0.total_cmp(&b.0));
self.data
}
}
Expand Down Expand Up @@ -727,26 +733,44 @@ fn checked_list_bytes(count: usize, bytes_per_entry: usize) -> io::Result<usize>
})
}

fn read_f32_vec<R: SeekRead + ?Sized>(
fn read_f32_vec_checked<R: SeekRead + ?Sized>(
reader: &mut PreadCursor<'_, R>,
count: usize,
section: &str,
) -> io::Result<Vec<f32>> {
let mut buf = vec![0u8; count * 4];
let byte_len = count.checked_mul(4).ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidData,
"f32 section byte length overflow",
)
})?;
let mut buf = vec![0u8; byte_len];
reader.read_exact(&mut buf)?;
bytes_to_f32_vec(&buf)
bytes_to_f32_vec_checked(&buf, section)
}

fn bytes_to_f32_vec(bytes: &[u8]) -> io::Result<Vec<f32>> {
fn bytes_to_f32_vec_checked(bytes: &[u8], section: &str) -> io::Result<Vec<f32>> {
if !bytes.len().is_multiple_of(4) {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"f32 byte section is not 4-byte aligned",
));
}
Ok(bytes
.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect())
let mut values = Vec::with_capacity(bytes.len() / 4);
for (offset, chunk) in bytes.chunks_exact(4).enumerate() {
let value = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
if !value.is_finite() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"{} contains non-finite value at offset {}: {}",
section, offset, value
),
));
}
values.push(value);
}
Ok(values)
}

fn encode_varint(mut val: u64, buf: &mut Vec<u8>) {
Expand Down Expand Up @@ -1039,6 +1063,31 @@ mod tests {
assert!(reader.search(&[0.0, 0.0], 1, 0).is_err());
}

#[test]
fn test_ivfflat_reader_rejects_non_finite_centroid() {
let mut buf = build_small_ivfflat_index_bytes();
write_f32_at(&mut buf, IVFFLAT_HEADER_SIZE, f32::NAN);

let mut reader = IVFFlatIndexReader::open(Cursor::new(buf)).unwrap();
let err = reader.search(&[0.0, 0.0], 1, 1).unwrap_err();

assert_eq!(err.kind(), io::ErrorKind::InvalidData);
assert!(err.to_string().contains("IVF-FLAT coarse centroids"));
}

#[test]
fn test_ivfflat_reader_rejects_non_finite_list_vector() {
let mut buf = build_small_ivfflat_index_bytes();
let vector_offset = first_ivfflat_vector_offset(&buf);
write_f32_at(&mut buf, vector_offset, f32::NAN);

let mut reader = IVFFlatIndexReader::open(Cursor::new(buf)).unwrap();
let err = reader.search(&[0.0, 0.0], 1, 1).unwrap_err();

assert_eq!(err.kind(), io::ErrorKind::InvalidData);
assert!(err.to_string().contains("IVF-FLAT list 0 vectors"));
}

#[test]
fn test_ivfflat_writer_validates_shape_before_writing() {
let mut index = IVFFlatIndex::new(2, 1, MetricType::L2);
Expand Down Expand Up @@ -1120,4 +1169,39 @@ mod tests {
};
assert!(err.to_string().contains("reserved bytes must be zero"));
}

fn build_small_ivfflat_index_bytes() -> Vec<u8> {
let d = 2;
let data = vec![0.0, 0.0, 1.0, 0.0];
let ids = vec![10, 11];

let mut index = IVFFlatIndex::new(d, 1, MetricType::L2);
index.train(&data, 2);
index.add(&data, &ids, 2);

let mut buf = Vec::new();
write_ivfflat_index(&index, &mut PosWriter::new(&mut buf)).unwrap();
buf
}

fn first_ivfflat_vector_offset(buf: &[u8]) -> usize {
let d = read_i32_at(buf, 8) as usize;
let nlist = read_i32_at(buf, 12) as usize;
let offset_table = IVFFLAT_HEADER_SIZE + d * nlist * 4;
let list_offset = read_i64_at(buf, offset_table) as usize;
let id_bytes_len = read_i32_at(buf, list_offset + 8) as usize;
list_offset + 12 + id_bytes_len
}

fn read_i32_at(buf: &[u8], offset: usize) -> i32 {
i32::from_le_bytes(buf[offset..offset + 4].try_into().unwrap())
}

fn read_i64_at(buf: &[u8], offset: usize) -> i64 {
i64::from_le_bytes(buf[offset..offset + 8].try_into().unwrap())
}

fn write_f32_at(buf: &mut [u8], offset: usize, value: f32) {
buf[offset..offset + 4].copy_from_slice(&value.to_le_bytes());
}
}
Loading
Loading