diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 23b0801..b012931 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -108,7 +108,7 @@ jobs: run: mvn -f java/pom.xml test - name: Build JNI library - run: cargo build -p paimon-vindex-jni --release + run: cargo build -p paimon-vindex-jni --release --features panic-boundary-test-hook - name: Test JNI native behavior run: | diff --git a/core/src/index_io_util.rs b/core/src/index_io_util.rs index e5f2796..964eb28 100644 --- a/core/src/index_io_util.rs +++ b/core/src/index_io_util.rs @@ -434,9 +434,10 @@ pub(crate) fn read_i64_le( Ok(i64::from_le_bytes(buf)) } -pub(crate) fn read_f32_vec( +pub(crate) fn read_f32_vec_checked( reader: &mut PreadCursor<'_, R>, count: usize, + section: &str, ) -> io::Result> { let byte_len = count.checked_mul(4).ok_or_else(|| { io::Error::new( @@ -446,20 +447,41 @@ pub(crate) fn read_f32_vec( })?; let mut buf = vec![0u8; byte_len]; reader.read_exact(&mut buf)?; - bytes_to_f32_vec(&buf) + bytes_to_f32_vec_checked(&buf, section) } -pub(crate) fn bytes_to_f32_vec(bytes: &[u8]) -> io::Result> { +pub(crate) fn bytes_to_f32_vec_checked(bytes: &[u8], section: &str) -> io::Result> { if !bytes.len().is_multiple_of(4) { return Err(io::Error::new( io::ErrorKind::InvalidData, "f32 byte section is not 4-byte aligned", )); } - Ok(bytes - .chunks_exact(4) - .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])) - .collect()) + let mut values = Vec::with_capacity(bytes.len() / 4); + for (offset, chunk) in bytes.chunks_exact(4).enumerate() { + let value = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]); + if !value.is_finite() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "{} contains non-finite value at offset {}: {}", + section, offset, value + ), + )); + } + values.push(value); + } + Ok(values) +} + +pub(crate) fn validate_finite_f32_value(value: f32, section: &str) -> io::Result<()> { + if !value.is_finite() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("{} contains non-finite value: {}", section, value), + )); + } + Ok(()) } pub(crate) fn validate_positive_i32(val: i32, field: &str) -> io::Result { diff --git a/core/src/io.rs b/core/src/io.rs index 7570c82..c9931bb 100644 --- a/core/src/io.rs +++ b/core/src/io.rs @@ -207,19 +207,42 @@ fn checked_list_bytes(count: usize, bytes_per_entry: usize) -> io::Result }) } -fn read_f32_vec( +fn read_f32_vec_checked( reader: &mut PreadCursor<'_, R>, count: usize, + section: &str, ) -> io::Result> { - let mut buf = vec![0u8; count * 4]; + let byte_len = count.checked_mul(4).ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidData, + "f32 section byte length overflow", + ) + })?; + let mut buf = vec![0u8; byte_len]; reader.read_exact(&mut buf)?; let floats: Vec = buf .chunks_exact(4) .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])) .collect(); + validate_finite_f32_values(&floats, section)?; Ok(floats) } +fn validate_finite_f32_values(values: &[f32], section: &str) -> io::Result<()> { + for (offset, &value) in values.iter().enumerate() { + if !value.is_finite() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "{} contains non-finite value at offset {}: {}", + section, offset, value + ), + )); + } + } + Ok(()) +} + /// Write a complete IVF-PQ index with delta-varint ID encoding. pub fn write_index(index: &IVFPQIndex, out: &mut dyn SeekWrite) -> io::Result<()> { let d = index.d; @@ -594,7 +617,7 @@ impl IVFPQIndexReader { let mut cursor = PreadCursor::new(&mut self.reader, self.centroids_offset); if self.has_opq { cursor.seek(HEADER_SIZE as u64); - let rotation = read_f32_vec(&mut cursor, rotation_count)?; + let rotation = read_f32_vec_checked(&mut cursor, rotation_count, "IVFPQ OPQ matrix")?; self.opq = Some(OPQMatrix { d, m, @@ -607,9 +630,11 @@ impl IVFPQIndexReader { }); } - self.quantizer_centroids = read_f32_vec(&mut cursor, centroids_count)?; + self.quantizer_centroids = + read_f32_vec_checked(&mut cursor, centroids_count, "IVFPQ coarse centroids")?; - let pq_centroids = read_f32_vec(&mut cursor, pq_centroids_count)?; + let pq_centroids = + read_f32_vec_checked(&mut cursor, pq_centroids_count, "IVFPQ PQ codebooks")?; self.pq = ProductQuantizer { d, m, @@ -1313,6 +1338,48 @@ mod tests { ); } + #[test] + fn test_reader_rejects_non_finite_opq_matrix() { + let mut buf = build_small_ivfpq_index_bytes(true); + write_f32_at(&mut buf, HEADER_SIZE, f32::NAN); + + let mut cursor = Cursor::new(&buf); + let mut reader = IVFPQIndexReader::open(&mut cursor).unwrap(); + let err = reader.ensure_loaded().unwrap_err(); + + assert_eq!(err.kind(), io::ErrorKind::InvalidData); + assert!(err.to_string().contains("IVFPQ OPQ matrix")); + } + + #[test] + fn test_reader_rejects_non_finite_coarse_centroid() { + let mut buf = build_small_ivfpq_index_bytes(false); + write_f32_at(&mut buf, HEADER_SIZE, f32::INFINITY); + + let mut cursor = Cursor::new(&buf); + let mut reader = IVFPQIndexReader::open(&mut cursor).unwrap(); + let err = reader.ensure_loaded().unwrap_err(); + + assert_eq!(err.kind(), io::ErrorKind::InvalidData); + assert!(err.to_string().contains("IVFPQ coarse centroids")); + } + + #[test] + fn test_reader_rejects_non_finite_pq_codebook() { + let mut buf = build_small_ivfpq_index_bytes(false); + let d = read_i32_at(&buf, 8) as usize; + let nlist = read_i32_at(&buf, 12) as usize; + let pq_codebook_offset = HEADER_SIZE + nlist * d * 4; + write_f32_at(&mut buf, pq_codebook_offset, f32::NAN); + + let mut cursor = Cursor::new(&buf); + let mut reader = IVFPQIndexReader::open(&mut cursor).unwrap(); + let err = reader.ensure_loaded().unwrap_err(); + + assert_eq!(err.kind(), io::ErrorKind::InvalidData); + assert!(err.to_string().contains("IVFPQ PQ codebooks")); + } + #[test] fn test_negative_header_d_returns_error() { let mut buf = Vec::new(); @@ -1498,4 +1565,30 @@ mod tests { let result = IVFPQIndexReader::open(&mut cursor); assert!(result.is_err(), "d != m*dsub should return error"); } + + fn build_small_ivfpq_index_bytes(use_opq: bool) -> Vec { + let d = 4; + let nlist = 1; + let m = 1; + let n = 300; + let mut rng = rand::rngs::StdRng::seed_from_u64(7); + let data: Vec = (0..n * d).map(|_| rng.gen::()).collect(); + let ids: Vec = (0..n as i64).collect(); + + let mut index = IVFPQIndex::new(d, nlist, m, MetricType::L2, use_opq); + index.train(&data, n); + index.add(&data, &ids, n); + + let mut buf = Vec::new(); + write_index(&index, &mut PosWriter::new(&mut buf)).unwrap(); + buf + } + + fn read_i32_at(buf: &[u8], offset: usize) -> i32 { + i32::from_le_bytes(buf[offset..offset + 4].try_into().unwrap()) + } + + fn write_f32_at(buf: &mut [u8], offset: usize, value: f32) { + buf[offset..offset + 4].copy_from_slice(&value.to_le_bytes()); + } } diff --git a/core/src/ivfflat_io.rs b/core/src/ivfflat_io.rs index 56d860f..dc356b8 100644 --- a/core/src/ivfflat_io.rs +++ b/core/src/ivfflat_io.rs @@ -230,8 +230,11 @@ impl IVFFlatIndexReader { } let mut cursor = PreadCursor::new(&mut self.reader, IVFFLAT_HEADER_SIZE as u64); - self.quantizer_centroids = - read_f32_vec(&mut cursor, checked_section_size(self.nlist, self.d)?)?; + self.quantizer_centroids = read_f32_vec_checked( + &mut cursor, + checked_section_size(self.nlist, self.d)?, + "IVF-FLAT coarse centroids", + )?; self.list_offsets = vec![0; self.nlist]; self.list_counts = vec![0; self.nlist]; self.list_id_bytes_lens = vec![0; self.nlist]; @@ -294,7 +297,10 @@ impl IVFFlatIndexReader { )); } let ids = decode_delta_varint_ids(base_id, &payload[12..12 + id_bytes_len], count)?; - let vectors = bytes_to_f32_vec(&payload[12 + id_bytes_len..])?; + let vectors = bytes_to_f32_vec_checked( + &payload[12 + id_bytes_len..], + &format!("IVF-FLAT list {} vectors", list_id), + )?; Ok((ids, vectors)) } else { Err(io::Error::new( @@ -542,7 +548,7 @@ impl ReaderTopKHeap { .data .iter() .enumerate() - .max_by(|(_, a), (_, b)| a.0.partial_cmp(&b.0).unwrap()) + .max_by(|(_, a), (_, b)| a.0.total_cmp(&b.0)) { if dist < self.data[worst_idx].0 { self.data[worst_idx] = (dist, id); @@ -551,7 +557,7 @@ impl ReaderTopKHeap { } fn into_sorted(mut self) -> Vec<(f32, i64)> { - self.data.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); + self.data.sort_by(|a, b| a.0.total_cmp(&b.0)); self.data } } @@ -727,26 +733,44 @@ fn checked_list_bytes(count: usize, bytes_per_entry: usize) -> io::Result }) } -fn read_f32_vec( +fn read_f32_vec_checked( reader: &mut PreadCursor<'_, R>, count: usize, + section: &str, ) -> io::Result> { - let mut buf = vec![0u8; count * 4]; + let byte_len = count.checked_mul(4).ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidData, + "f32 section byte length overflow", + ) + })?; + let mut buf = vec![0u8; byte_len]; reader.read_exact(&mut buf)?; - bytes_to_f32_vec(&buf) + bytes_to_f32_vec_checked(&buf, section) } -fn bytes_to_f32_vec(bytes: &[u8]) -> io::Result> { +fn bytes_to_f32_vec_checked(bytes: &[u8], section: &str) -> io::Result> { if !bytes.len().is_multiple_of(4) { return Err(io::Error::new( io::ErrorKind::InvalidData, "f32 byte section is not 4-byte aligned", )); } - Ok(bytes - .chunks_exact(4) - .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]])) - .collect()) + let mut values = Vec::with_capacity(bytes.len() / 4); + for (offset, chunk) in bytes.chunks_exact(4).enumerate() { + let value = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]); + if !value.is_finite() { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!( + "{} contains non-finite value at offset {}: {}", + section, offset, value + ), + )); + } + values.push(value); + } + Ok(values) } fn encode_varint(mut val: u64, buf: &mut Vec) { @@ -1039,6 +1063,31 @@ mod tests { assert!(reader.search(&[0.0, 0.0], 1, 0).is_err()); } + #[test] + fn test_ivfflat_reader_rejects_non_finite_centroid() { + let mut buf = build_small_ivfflat_index_bytes(); + write_f32_at(&mut buf, IVFFLAT_HEADER_SIZE, f32::NAN); + + let mut reader = IVFFlatIndexReader::open(Cursor::new(buf)).unwrap(); + let err = reader.search(&[0.0, 0.0], 1, 1).unwrap_err(); + + assert_eq!(err.kind(), io::ErrorKind::InvalidData); + assert!(err.to_string().contains("IVF-FLAT coarse centroids")); + } + + #[test] + fn test_ivfflat_reader_rejects_non_finite_list_vector() { + let mut buf = build_small_ivfflat_index_bytes(); + let vector_offset = first_ivfflat_vector_offset(&buf); + write_f32_at(&mut buf, vector_offset, f32::NAN); + + let mut reader = IVFFlatIndexReader::open(Cursor::new(buf)).unwrap(); + let err = reader.search(&[0.0, 0.0], 1, 1).unwrap_err(); + + assert_eq!(err.kind(), io::ErrorKind::InvalidData); + assert!(err.to_string().contains("IVF-FLAT list 0 vectors")); + } + #[test] fn test_ivfflat_writer_validates_shape_before_writing() { let mut index = IVFFlatIndex::new(2, 1, MetricType::L2); @@ -1120,4 +1169,39 @@ mod tests { }; assert!(err.to_string().contains("reserved bytes must be zero")); } + + fn build_small_ivfflat_index_bytes() -> Vec { + let d = 2; + let data = vec![0.0, 0.0, 1.0, 0.0]; + let ids = vec![10, 11]; + + let mut index = IVFFlatIndex::new(d, 1, MetricType::L2); + index.train(&data, 2); + index.add(&data, &ids, 2); + + let mut buf = Vec::new(); + write_ivfflat_index(&index, &mut PosWriter::new(&mut buf)).unwrap(); + buf + } + + fn first_ivfflat_vector_offset(buf: &[u8]) -> usize { + let d = read_i32_at(buf, 8) as usize; + let nlist = read_i32_at(buf, 12) as usize; + let offset_table = IVFFLAT_HEADER_SIZE + d * nlist * 4; + let list_offset = read_i64_at(buf, offset_table) as usize; + let id_bytes_len = read_i32_at(buf, list_offset + 8) as usize; + list_offset + 12 + id_bytes_len + } + + fn read_i32_at(buf: &[u8], offset: usize) -> i32 { + i32::from_le_bytes(buf[offset..offset + 4].try_into().unwrap()) + } + + fn read_i64_at(buf: &[u8], offset: usize) -> i64 { + i64::from_le_bytes(buf[offset..offset + 8].try_into().unwrap()) + } + + fn write_f32_at(buf: &mut [u8], offset: usize, value: f32) { + buf[offset..offset + 4].copy_from_slice(&value.to_le_bytes()); + } } diff --git a/core/src/ivfhnswflat_io.rs b/core/src/ivfhnswflat_io.rs index f87e413..a4e049b 100644 --- a/core/src/ivfhnswflat_io.rs +++ b/core/src/ivfhnswflat_io.rs @@ -19,11 +19,11 @@ use crate::distance::{fvec_distance, preprocess_vectors, MetricType}; use crate::hnsw::{HnswBuildParams, HnswGraph}; use crate::hnsw_search::{search_hnsw_lists, HnswSearchList}; use crate::index_io_util::{ - bytes_to_f32_vec, checked_list_bytes, checked_list_offset, checked_section_size, + bytes_to_f32_vec_checked, checked_list_bytes, checked_list_offset, checked_section_size, decode_delta_varint_ids, decode_graph, decode_roaring_filter, encode_delta_varint_ids, - encode_graph, read_f32_vec, read_i32_le, read_i64_le, read_u32_le, u64_to_i64, usize_to_i32, - usize_to_i64, validate_positive_i32, validate_reserved_zero, validate_search_inputs, - write_f32_slice, write_i32_le, write_i64_le, write_u32_le, + encode_graph, read_f32_vec_checked, read_i32_le, read_i64_le, read_u32_le, u64_to_i64, + usize_to_i32, usize_to_i64, validate_positive_i32, validate_reserved_zero, + validate_search_inputs, write_f32_slice, write_i32_le, write_i64_le, write_u32_le, }; use crate::io::{PreadCursor, ReadRequest, SeekRead, SeekWrite}; use crate::ivfhnswflat::IVFHNSWFlatIndex; @@ -237,8 +237,11 @@ impl IVFHNSWFlatIndexReader { } let mut cursor = PreadCursor::new(&mut self.reader, IVF_HNSW_FLAT_HEADER_SIZE as u64); - self.quantizer_centroids = - read_f32_vec(&mut cursor, checked_section_size(self.nlist, self.d)?)?; + self.quantizer_centroids = read_f32_vec_checked( + &mut cursor, + checked_section_size(self.nlist, self.d)?, + "IVF-HNSW-FLAT coarse centroids", + )?; self.list_offsets = vec![0; self.nlist]; self.list_counts = vec![0; self.nlist]; self.list_graph_bytes_lens = vec![0; self.nlist]; @@ -532,7 +535,10 @@ impl IVFHNSWFlatIndexReader { )); } let ids = decode_delta_varint_ids(base_id, &payload[base_header_len..ids_end], meta.count)?; - let vectors = bytes_to_f32_vec(&payload[ids_end..vectors_end])?; + let vectors = bytes_to_f32_vec_checked( + &payload[ids_end..vectors_end], + &format!("IVF-HNSW-FLAT list {} vectors", meta.list_id), + )?; let graph = decode_graph( &payload[vectors_end..], vectors.clone(), @@ -1552,6 +1558,31 @@ mod tests { assert!(err.to_string().contains("missing HNSW graph")); } + #[test] + fn test_ivfhnswflat_reader_rejects_non_finite_centroid() { + let mut buf = build_small_ivfhnswflat_index_bytes(); + write_f32_at(&mut buf, IVF_HNSW_FLAT_HEADER_SIZE, f32::NAN); + + let mut reader = IVFHNSWFlatIndexReader::open(Cursor::new(buf)).unwrap(); + let err = reader.search(&[0.0, 0.0], 1, 1, 4).unwrap_err(); + + assert_eq!(err.kind(), io::ErrorKind::InvalidData); + assert!(err.to_string().contains("IVF-HNSW-FLAT coarse centroids")); + } + + #[test] + fn test_ivfhnswflat_reader_rejects_non_finite_list_vector() { + let mut buf = build_small_ivfhnswflat_index_bytes(); + let vector_offset = first_ivfhnswflat_vector_offset(&buf); + write_f32_at(&mut buf, vector_offset, f32::INFINITY); + + let mut reader = IVFHNSWFlatIndexReader::open(Cursor::new(buf)).unwrap(); + let err = reader.search(&[0.0, 0.0], 1, 1, 4).unwrap_err(); + + assert_eq!(err.kind(), io::ErrorKind::InvalidData); + assert!(err.to_string().contains("IVF-HNSW-FLAT list 0 vectors")); + } + #[test] fn test_ivfhnswflat_graph_delta_varint_reduces_graph_bytes() { let d = 2; @@ -1743,4 +1774,41 @@ mod tests { buf.extend_from_slice(&1u32.to_le_bytes()); buf.extend_from_slice(&1u32.to_le_bytes()); } + + fn build_small_ivfhnswflat_index_bytes() -> Vec { + let d = 2; + let nlist = 1; + let data = vec![0.0, 0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0]; + let ids = vec![10, 11, 12, 13]; + + let mut index = IVFHNSWFlatIndex::new(d, nlist, MetricType::L2, HnswBuildParams::default()); + index.train(&data, 4); + index.add(&data, &ids, 4); + index.build_graphs().unwrap(); + + let mut buf = Vec::new(); + write_ivfhnswflat_index(&index, &mut PosWriter::new(&mut buf)).unwrap(); + buf + } + + fn first_ivfhnswflat_vector_offset(buf: &[u8]) -> usize { + let d = read_i32_at(buf, 8) as usize; + let nlist = read_i32_at(buf, 12) as usize; + let offset_table = IVF_HNSW_FLAT_HEADER_SIZE + d * nlist * 4; + let list_offset = read_i64_at(buf, offset_table) as usize; + let id_bytes_len = read_i32_at(buf, list_offset + 8) as usize; + list_offset + 12 + id_bytes_len + } + + fn read_i32_at(buf: &[u8], offset: usize) -> i32 { + i32::from_le_bytes(buf[offset..offset + 4].try_into().unwrap()) + } + + fn read_i64_at(buf: &[u8], offset: usize) -> i64 { + i64::from_le_bytes(buf[offset..offset + 8].try_into().unwrap()) + } + + fn write_f32_at(buf: &mut [u8], offset: usize, value: f32) { + buf[offset..offset + 4].copy_from_slice(&value.to_le_bytes()); + } } diff --git a/core/src/ivfhnswsq_io.rs b/core/src/ivfhnswsq_io.rs index d5ea2dc..ffa93c6 100644 --- a/core/src/ivfhnswsq_io.rs +++ b/core/src/ivfhnswsq_io.rs @@ -20,10 +20,10 @@ use crate::hnsw::{HnswBuildParams, HnswGraph}; use crate::hnsw_search::{search_hnsw_lists, HnswSearchList}; use crate::index_io_util::{ checked_list_bytes, checked_list_offset, checked_section_size, decode_delta_varint_ids, - decode_graph, decode_roaring_filter, encode_delta_varint_ids, encode_graph, read_f32_vec, - read_i32_le, read_i64_le, read_u32_le, u64_to_i64, usize_to_i32, usize_to_i64, - validate_positive_i32, validate_reserved_zero, validate_search_inputs, write_f32_slice, - write_i32_le, write_i64_le, write_u32_le, + decode_graph, decode_roaring_filter, encode_delta_varint_ids, encode_graph, + read_f32_vec_checked, read_i32_le, read_i64_le, read_u32_le, u64_to_i64, usize_to_i32, + usize_to_i64, validate_finite_f32_value, validate_positive_i32, validate_reserved_zero, + validate_search_inputs, write_f32_slice, write_i32_le, write_i64_le, write_u32_le, }; use crate::io::{PreadCursor, ReadRequest, SeekRead, SeekWrite}; use crate::ivfhnswsq::IVFHNSWSQIndex; @@ -206,6 +206,8 @@ impl IVFHNSWSQIndexReader { .sanitized(); let sq_min_summary = read_f32_le(&mut cursor)?; let sq_max_summary = read_f32_le(&mut cursor)?; + validate_finite_f32_value(sq_min_summary, "IVF-HNSW-SQ min bound summary")?; + validate_finite_f32_value(sq_max_summary, "IVF-HNSW-SQ max bound summary")?; let flags = read_u32_le(&mut cursor)?; let mut reserved = [0u8; 12]; cursor.read_exact(&mut reserved)?; @@ -224,9 +226,9 @@ impl IVFHNSWSQIndexReader { )); } - let mins = read_f32_vec(&mut cursor, d)?; - let maxs = read_f32_vec(&mut cursor, d)?; - validate_sq_bounds(d, &mins, &maxs)?; + let mins = read_f32_vec_checked(&mut cursor, d, "IVF-HNSW-SQ global min bounds")?; + let maxs = read_f32_vec_checked(&mut cursor, d, "IVF-HNSW-SQ global max bounds")?; + validate_sq_bounds(d, &mins, &maxs, io::ErrorKind::InvalidData)?; let (sq_min, sq_max) = sq_global_bounds(&mins, &maxs); if sq_min.to_bits() != sq_min_summary.to_bits() || sq_max.to_bits() != sq_max_summary.to_bits() @@ -238,10 +240,18 @@ impl IVFHNSWSQIndexReader { } let sq = ScalarQuantizer::with_dimension_bounds(d, mins, maxs); let mut list_sqs = Vec::with_capacity(nlist); - for _ in 0..nlist { - let mins = read_f32_vec(&mut cursor, d)?; - let maxs = read_f32_vec(&mut cursor, d)?; - validate_sq_bounds(d, &mins, &maxs)?; + for list_id in 0..nlist { + let mins = read_f32_vec_checked( + &mut cursor, + d, + &format!("IVF-HNSW-SQ list {} min bounds", list_id), + )?; + let maxs = read_f32_vec_checked( + &mut cursor, + d, + &format!("IVF-HNSW-SQ list {} max bounds", list_id), + )?; + validate_sq_bounds(d, &mins, &maxs, io::ErrorKind::InvalidData)?; list_sqs.push(ScalarQuantizer::with_dimension_bounds(d, mins, maxs)); } @@ -271,8 +281,11 @@ impl IVFHNSWSQIndexReader { let quantizer_centroids_offset = IVF_HNSW_SQ_HEADER_SIZE as u64 + (self.d as u64) * 8 * (self.nlist as u64 + 1); let mut cursor = PreadCursor::new(&mut self.reader, quantizer_centroids_offset); - self.quantizer_centroids = - read_f32_vec(&mut cursor, checked_section_size(self.nlist, self.d)?)?; + self.quantizer_centroids = read_f32_vec_checked( + &mut cursor, + checked_section_size(self.nlist, self.d)?, + "IVF-HNSW-SQ coarse centroids", + )?; self.list_offsets = vec![0; self.nlist]; self.list_counts = vec![0; self.nlist]; self.list_graph_bytes_lens = vec![0; self.nlist]; @@ -920,7 +933,12 @@ fn validate_index_shape(index: &IVFHNSWSQIndex) -> io::Result<()> { "SQ dimension does not match index dimension", )); } - validate_sq_bounds(index.d, &index.sq.mins, &index.sq.maxs)?; + validate_sq_bounds( + index.d, + &index.sq.mins, + &index.sq.maxs, + io::ErrorKind::InvalidInput, + )?; if index.list_sqs.len() != index.nlist { return Err(io::Error::new( io::ErrorKind::InvalidInput, @@ -937,7 +955,7 @@ fn validate_index_shape(index: &IVFHNSWSQIndex) -> io::Result<()> { ), )); } - validate_sq_bounds(index.d, &sq.mins, &sq.maxs)?; + validate_sq_bounds(index.d, &sq.mins, &sq.maxs, io::ErrorKind::InvalidInput)?; } let centroid_len = checked_section_size(index.nlist, index.d)?; if index.quantizer_centroids.len() != centroid_len { @@ -1004,10 +1022,15 @@ fn validate_index_shape(index: &IVFHNSWSQIndex) -> io::Result<()> { Ok(()) } -fn validate_sq_bounds(d: usize, mins: &[f32], maxs: &[f32]) -> io::Result<()> { +fn validate_sq_bounds( + d: usize, + mins: &[f32], + maxs: &[f32], + error_kind: io::ErrorKind, +) -> io::Result<()> { if mins.len() != d || maxs.len() != d { return Err(io::Error::new( - io::ErrorKind::InvalidInput, + error_kind, format!( "SQ bounds length mismatch: d={}, mins={}, maxs={}", d, @@ -1019,7 +1042,7 @@ fn validate_sq_bounds(d: usize, mins: &[f32], maxs: &[f32]) -> io::Result<()> { for (dim, (&min, &max)) in mins.iter().zip(maxs.iter()).enumerate() { if !min.is_finite() || !max.is_finite() { return Err(io::Error::new( - io::ErrorKind::InvalidInput, + error_kind, format!("SQ bounds at dimension {} must be finite", dim), )); } @@ -1357,6 +1380,65 @@ mod tests { assert!(err.to_string().contains("SQ bounds summary")); } + #[test] + fn test_ivfhnswsq_reader_rejects_non_finite_sq_bounds_summary() { + let mut buf = build_small_ivfhnswsq_index_bytes(); + write_f32_at(&mut buf, 40, f32::NAN); + + let err = match IVFHNSWSQIndexReader::open(Cursor::new(buf)) { + Ok(_) => panic!("non-finite SQ bounds summary should be rejected"), + Err(err) => err, + }; + + assert_eq!(err.kind(), io::ErrorKind::InvalidData); + assert!(err.to_string().contains("bound summary")); + } + + #[test] + fn test_ivfhnswsq_reader_rejects_non_finite_global_sq_bound() { + let mut buf = build_small_ivfhnswsq_index_bytes(); + write_f32_at(&mut buf, IVF_HNSW_SQ_HEADER_SIZE, f32::INFINITY); + + let err = match IVFHNSWSQIndexReader::open(Cursor::new(buf)) { + Ok(_) => panic!("non-finite global SQ bound should be rejected"), + Err(err) => err, + }; + + assert_eq!(err.kind(), io::ErrorKind::InvalidData); + assert!(err.to_string().contains("global min bounds")); + } + + #[test] + fn test_ivfhnswsq_reader_rejects_non_finite_list_sq_bound() { + let mut buf = build_small_ivfhnswsq_index_bytes(); + let d = read_i32_at(&buf, 8) as usize; + let list_min_bounds_offset = IVF_HNSW_SQ_HEADER_SIZE + d * 8; + write_f32_at(&mut buf, list_min_bounds_offset, f32::NAN); + + let err = match IVFHNSWSQIndexReader::open(Cursor::new(buf)) { + Ok(_) => panic!("non-finite list SQ bound should be rejected"), + Err(err) => err, + }; + + assert_eq!(err.kind(), io::ErrorKind::InvalidData); + assert!(err.to_string().contains("list 0 min bounds")); + } + + #[test] + fn test_ivfhnswsq_reader_rejects_non_finite_centroid() { + let mut buf = build_small_ivfhnswsq_index_bytes(); + let d = read_i32_at(&buf, 8) as usize; + let nlist = read_i32_at(&buf, 12) as usize; + let centroid_offset = IVF_HNSW_SQ_HEADER_SIZE + d * 8 * (nlist + 1); + write_f32_at(&mut buf, centroid_offset, f32::NAN); + + let mut reader = IVFHNSWSQIndexReader::open(Cursor::new(buf)).unwrap(); + let err = reader.search(&[0.0, 0.0], 1, 1, 4).unwrap_err(); + + assert_eq!(err.kind(), io::ErrorKind::InvalidData); + assert!(err.to_string().contains("IVF-HNSW-SQ coarse centroids")); + } + #[test] fn test_ivfhnswsq_reader_search_with_roaring_filter() { let d = 2; @@ -1616,4 +1698,28 @@ mod tests { Ok(()) } } + + fn build_small_ivfhnswsq_index_bytes() -> Vec { + let d = 2; + let nlist = 1; + let data = vec![0.0, 0.0, 1.0, 0.0, 2.0, 0.0, 3.0, 0.0]; + let ids = vec![10, 11, 12, 13]; + + let mut index = IVFHNSWSQIndex::new(d, nlist, MetricType::L2, HnswBuildParams::default()); + index.train(&data, 4); + index.add(&data, &ids, 4); + index.build_graphs().unwrap(); + + let mut buf = Vec::new(); + write_ivfhnswsq_index(&index, &mut PosWriter::new(&mut buf)).unwrap(); + buf + } + + fn read_i32_at(buf: &[u8], offset: usize) -> i32 { + i32::from_le_bytes(buf[offset..offset + 4].try_into().unwrap()) + } + + fn write_f32_at(buf: &mut [u8], offset: usize, value: f32) { + buf[offset..offset + 4].copy_from_slice(&value.to_le_bytes()); + } } diff --git a/core/src/ivfpq.rs b/core/src/ivfpq.rs index 9971998..ade51dd 100644 --- a/core/src/ivfpq.rs +++ b/core/src/ivfpq.rs @@ -1411,7 +1411,7 @@ impl TopKHeap { } fn into_sorted(mut self) -> Vec<(f32, i64)> { - self.data.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); + self.data.sort_by(|a, b| a.0.total_cmp(&b.0)); self.data } } diff --git a/java/src/test/java/org/apache/paimon/index/vector/VectorIndexNativePanicBoundaryTest.java b/java/src/test/java/org/apache/paimon/index/vector/VectorIndexNativePanicBoundaryTest.java index 739defe..c898e73 100644 --- a/java/src/test/java/org/apache/paimon/index/vector/VectorIndexNativePanicBoundaryTest.java +++ b/java/src/test/java/org/apache/paimon/index/vector/VectorIndexNativePanicBoundaryTest.java @@ -32,6 +32,7 @@ public static void main(String[] args) { testVoidEntrypointPanicBecomesRuntimeException(); testObjectEntrypointPanicBecomesRuntimeException(); + testMalformedIvfFlatPayloadReturnsRuntimeException(); VectorIndexWriter survivor = new VectorIndexWriter(ivfFlatOptions()); survivor.close(); @@ -40,7 +41,7 @@ public static void main(String[] args) { private static void testVoidEntrypointPanicBecomesRuntimeException() { final VectorIndexWriter writer = new VectorIndexWriter(ivfFlatOptions()); try { - assertThrows(RuntimeException.class, new ThrowingRunnable() { + assertThrowsPanic(RuntimeException.class, new ThrowingRunnable() { @Override public void run() { writer.addVectors(new long[] {1L}, new float[] {1.0f}, 1); @@ -52,6 +53,15 @@ public void run() { } private static void testObjectEntrypointPanicBecomesRuntimeException() { + assertThrowsPanic(RuntimeException.class, new ThrowingRunnable() { + @Override + public void run() { + objectEntrypointPanicForTesting(); + } + }); + } + + private static void testMalformedIvfFlatPayloadReturnsRuntimeException() { ByteArrayPositionOutputStream output = new ByteArrayPositionOutputStream(); VectorIndexWriter writer = new VectorIndexWriter(ivfFlatOptions()); try { @@ -68,12 +78,14 @@ private static void testObjectEntrypointPanicBecomesRuntimeException() { new VectorIndexReader(new ByteArraySeekableInputStream(indexBytes)); try { assertEquals(1, reader.dimension()); - assertThrows(RuntimeException.class, new ThrowingRunnable() { - @Override - public void run() { - reader.search(new float[] {0.0f}, 2, 1); - } - }); + assertThrowsMalformedIndex( + "IVF-FLAT list 0 vectors contains non-finite value at offset 0: NaN", + new ThrowingRunnable() { + @Override + public void run() { + reader.search(new float[] {0.0f}, 2, 1); + } + }); assertEquals(2L, reader.totalVectors()); } finally { reader.close(); @@ -101,7 +113,8 @@ private static void assertEquals(long expected, long actual) { } } - private static void assertThrows(Class expected, ThrowingRunnable runnable) { + private static void assertThrowsPanic( + Class expected, ThrowingRunnable runnable) { try { runnable.run(); } catch (Throwable t) { @@ -112,11 +125,30 @@ private static void assertThrows(Class expected, ThrowingRu } return; } - throw new AssertionError("expected " + expected.getName() + " but got " + t.getClass().getName(), t); + throw new AssertionError( + "expected " + expected.getName() + " but got " + t.getClass().getName(), t); } throw new AssertionError("expected " + expected.getName()); } + private static void assertThrowsMalformedIndex(String expectedMessage, ThrowingRunnable runnable) { + try { + runnable.run(); + } catch (RuntimeException e) { + String message = e.getMessage(); + if (message == null || !message.contains(expectedMessage)) { + throw new AssertionError("unexpected exception message: " + message, e); + } + if (message.contains("Rust panic in JNI call")) { + throw new AssertionError("malformed index should not cross the panic boundary", e); + } + return; + } catch (Throwable t) { + throw new AssertionError("expected RuntimeException but got " + t.getClass().getName(), t); + } + throw new AssertionError("expected RuntimeException"); + } + private static void corruptFirstIvfFlatVector(byte[] indexBytes, float value) { int dimension = readIntLe(indexBytes, 8); int nlist = readIntLe(indexBytes, 12); @@ -154,6 +186,8 @@ private interface ThrowingRunnable { void run() throws Throwable; } + private static native VectorSearchResult objectEntrypointPanicForTesting(); + public static final class ByteArrayPositionOutputStream { private final ByteArrayOutputStream out = new ByteArrayOutputStream(); diff --git a/jni/Cargo.toml b/jni/Cargo.toml index e785f52..0a4f46c 100644 --- a/jni/Cargo.toml +++ b/jni/Cargo.toml @@ -25,6 +25,9 @@ license = "Apache-2.0" name = "paimon_vindex_jni" crate-type = ["cdylib"] +[features] +panic-boundary-test-hook = [] + [dependencies] paimon-vindex-core = { path = "../core" } jni = "0.21" diff --git a/jni/src/lib.rs b/jni/src/lib.rs index dd4aa8a..179abcd 100644 --- a/jni/src/lib.rs +++ b/jni/src/lib.rs @@ -690,3 +690,15 @@ pub extern "system" fn Java_org_apache_paimon_index_vector_VectorIndexNative_fre } }) } + +#[cfg(feature = "panic-boundary-test-hook")] +#[no_mangle] +pub extern "system" fn Java_org_apache_paimon_index_vector_VectorIndexNativePanicBoundaryTest_objectEntrypointPanicForTesting( + env: JNIEnv, + _class: JClass, +) -> jobject { + // Feature-gated test hook used by the Java panic-boundary test for object returns. + jni_call(env, |_env| -> jobject { + panic!("intentional object-return panic for JNI boundary test") + }) +}