diff --git a/native-engine/datafusion-ext-plans/src/flink/serde/json_deserializer.rs b/native-engine/datafusion-ext-plans/src/flink/serde/json_deserializer.rs index cb9a88adf..81fdd49b2 100644 --- a/native-engine/datafusion-ext-plans/src/flink/serde/json_deserializer.rs +++ b/native-engine/datafusion-ext-plans/src/flink/serde/json_deserializer.rs @@ -1034,4 +1034,58 @@ mod tests { // msg2 missing "name" field, should get default empty string from ensure_size assert_eq!(name_col.value(1), ""); } + + /// Pin the omitted-vs-null distinction for a Boolean field that this PR + /// introduced by switching the shared `ensure_output_array_builders_size` + /// boolean default from `append_null()` to `append_value(false)`. + /// + /// After the change, the PB path correctly emits `false` for a proto3 field + /// absent from a message, but the shared default also reaches this JSON + /// path: an *omitted* boolean now yields a non-null `false`, while an + /// *explicit* JSON `null` still yields a null (the JSON handler appends + /// null for explicit nulls). This test locks that behavior so a future + /// change to either side is a conscious decision, not an accident. + #[test] + fn test_parse_json_boolean_omitted_vs_explicit_null() { + let schema = Arc::new(Schema::new(vec![ + Field::new("serialized_kafka_records_partition", DataType::Int32, false), + Field::new("serialized_kafka_records_offset", DataType::Int64, false), + Field::new("serialized_kafka_records_timestamp", DataType::Int64, false), + Field::new("active", DataType::Boolean, true), + ])); + + let nested_mapping = HashMap::new(); + let mut deserializer = JsonDeserializer::new(schema.clone(), &nested_mapping) + .expect("Failed to create JsonDeserializer"); + + // row0: explicit true; row1: explicit null; row2: field omitted entirely. + let msg0 = br#"{"active": true}"#; + let msg1 = br#"{"active": null}"#; + let msg2 = br#"{}"#; + + let messages = create_binary_array(vec![msg0.as_ref(), msg1.as_ref(), msg2.as_ref()]); + let partitions = create_partition_array(vec![0, 0, 0]); + let offsets = create_offset_array(vec![100, 101, 102]); + let timestamps = create_timestamp_array(vec![1000, 1001, 1002]); + + let batch = deserializer + .parse_messages_with_kafka_meta(&messages, &partitions, &offsets, ×tamps) + .expect("Failed to parse messages"); + + assert_eq!(batch.num_rows(), 3); + let active_col = batch + .column(3) + .as_any() + .downcast_ref::() + .expect("active column"); + + // row0: explicit true. + assert!(active_col.value(0)); + assert!(!active_col.is_null(0)); + // row1: explicit null → null (JSON handler's own behavior). + assert!(active_col.is_null(1)); + // row2: omitted → non-null false (shared ensure_size default). + assert!(!active_col.is_null(2)); + assert!(!active_col.value(2)); + } } diff --git a/native-engine/datafusion-ext-plans/src/flink/serde/pb_deserializer.rs b/native-engine/datafusion-ext-plans/src/flink/serde/pb_deserializer.rs index aaf82f47b..4b98c6a1a 100644 --- a/native-engine/datafusion-ext-plans/src/flink/serde/pb_deserializer.rs +++ b/native-engine/datafusion-ext-plans/src/flink/serde/pb_deserializer.rs @@ -15,7 +15,6 @@ use std::{ any::Any, - cell::UnsafeCell, collections::{HashMap, HashSet}, io::Cursor, sync::Arc, @@ -44,14 +43,76 @@ use crate::flink::serde::{ type ValueHandler = Box, u32, WireType) -> Result<()> + Send>; type ValueHandlerMap = hashbrown::HashMap; +/// Adaptive dispatch table for protobuf field handlers keyed by tag. +/// +/// O2 optimization: when the tag space is dense (max_tag is small relative to +/// the number of fields), use a `Vec>` for O(1) array indexing, +/// avoiding the HashMap hashing/probing overhead on the hot path. When tags +/// are sparse (e.g. extensions or large field numbers), fall back to a +/// `HashMap` to avoid wasting memory. +/// +/// The threshold `max_tag <= 64 && max_tag <= 4 * field_count` keeps the Vec +/// path activated for the overwhelmingly common case where producers use +/// small contiguous tags (typically 1..N). +enum ValueHandlers { + Vec(Vec>), + Map(ValueHandlerMap), +} + +impl ValueHandlers { + fn from_map(map: ValueHandlerMap) -> Self { + let max_tag = map.keys().copied().max().unwrap_or(0); + let field_count = map.len(); + // Heuristic: dense enough and within 64-tag bitmap range. We cap at + // 64 so it composes nicely with O3's seen_tags bitmap, but the cap + // is independent — the fallback HashMap remains correct. + if field_count > 0 && max_tag <= 64 && (max_tag as usize) <= field_count.saturating_mul(4) { + let mut vec: Vec> = (0..=max_tag).map(|_| None).collect(); + for (tag, handler) in map.into_iter() { + vec[tag as usize] = Some(handler); + } + ValueHandlers::Vec(vec) + } else { + ValueHandlers::Map(map) + } + } + + #[inline(always)] + fn get(&self, tag: u32) -> Option<&ValueHandler> { + match self { + ValueHandlers::Vec(v) => v.get(tag as usize).and_then(|h| h.as_ref()), + ValueHandlers::Map(m) => m.get(&tag), + } + } + + fn len(&self) -> usize { + match self { + ValueHandlers::Vec(v) => v.iter().filter(|h| h.is_some()).count(), + ValueHandlers::Map(m) => m.len(), + } + } +} + pub struct PbDeserializer { output_schema: SchemaRef, output_schema_without_meta: SchemaRef, pb_schema: SchemaRef, output_array_builders: Vec, ensure_size: Box, - value_handlers: ValueHandlerMap, + value_handlers: ValueHandlers, + /// O(n)/O(1)-read cache of `value_handlers.len()`, computed once in + /// `try_new`. The `Vec` variant of `len()` is an O(max_tag) scan, so + /// recomputing it per batch (as `total_handlers`) is wasteful — it is + /// constant for the deserializer's lifetime. + handler_count: u32, msg_mapping: Vec>, + /// C1 fix: whether any top-level pb_schema column is a List or Map. The O3 + /// ensure_size skip is only sound for scalar/struct columns, which finalize + /// their own per-row slot. List/Map builders rely on ensure_size to append + /// their per-row offset/null entries (the per-value handlers only push to + /// the child values builder, never the parent). When this is true, + /// ensure_size must run every row regardless of how many tags were seen. + top_level_has_list_or_map: bool, } impl FlinkDeserializer for PbDeserializer { @@ -62,48 +123,83 @@ impl FlinkDeserializer for PbDeserializer { kafka_offset: &Int64Array, kafka_timestamp: &Int64Array, ) -> datafusion::common::Result { - let mut msg_cursors = messages - .iter() - .map(|v| { - let s = v.expect("message bytes must not be null"); - Cursor::new(s) - }) - .collect::>(); - for (row_idx, msg_cursor) in msg_cursors.iter_mut().enumerate() { + // O5: inline cursor creation (avoid Vec> preallocation) + // O7/C3 fix: replace `expect("message bytes must not be null")` with `?` + // so that JNI callers don't crash the JVM via process abort. + // O3: track which tags appear via a u64 bitmap (tag 0..63). When all + // schema tags were observed in a row, scalar/struct builders are + // already aligned and ensure_size can be skipped for that row. + // C1 fix: the O3 skip is UNSOUND for top-level List/Map columns. Their + // per-row offset/null slot is finalized only inside ensure_size — + // the per-value handlers append to the child values builder, never + // to the parent SharedListArrayBuilder/SharedMapArrayBuilder. So + // when the schema has any top-level List/Map, ensure_size must run + // every row (see `ensure_size_every_row` below). + // NOTE on builder row-alignment invariant: every row, all builders must + // be padded to length `row_idx + 1`. We therefore must NOT defer + // ensure_size to after the loop — that would let later rows write + // values into the wrong positions. + // NOTE: we cannot use a simple counter because protobuf repeated + // fields (non-packed) emit multiple tag-value pairs for the same tag, + // which would over-count. The bitmap correctly records unique tags. + let total_handlers = self.handler_count; + let ensure_size_every_row = self.top_level_has_list_or_map; + for (row_idx, opt_bytes) in messages.iter().enumerate() { + let bytes = opt_bytes.ok_or_else(|| { + DataFusionError::Execution("message bytes must not be null".to_string()) + })?; + let mut msg_cursor = Cursor::new(bytes); + let mut seen_tags: u64 = 0; while msg_cursor.has_remaining() { - let (tag, wired_type) = prost::encoding::decode_key(msg_cursor).map_err(|e| { - DataFusionError::Execution(format!("Failed to parse protobuf key: {e}")) - })?; - if let Some(value_handler) = self.value_handlers.get_mut(&tag) { - value_handler(msg_cursor, tag, wired_type)?; + let (tag, wired_type) = + prost::encoding::decode_key(&mut msg_cursor).map_err(|e| { + DataFusionError::Execution(format!("Failed to parse protobuf key: {e}")) + })?; + if let Some(value_handler) = self.value_handlers.get(tag) { + value_handler(&mut msg_cursor, tag, wired_type)?; + // Tags >= 64 fall through to ensure_size (always safe). + if tag < 64 { + seen_tags |= 1u64 << tag; + } + } else { + // O1/C1 fix: skip unknown tags so the cursor stays in sync. + skip_pb_value(&mut msg_cursor, tag, wired_type)?; } } - let ensure_size = &mut self.ensure_size; - ensure_size(row_idx + 1); + if ensure_size_every_row || seen_tags.count_ones() < total_handlers { + (self.ensure_size)(row_idx + 1); + } } - let root_struct = StructArray::from({ - RecordBatch::try_new_with_options( - self.pb_schema.clone(), - self.output_array_builders - .iter() - .map(|builder| builder.get_dyn_mut().finish()) - .collect(), - &RecordBatchOptions::new().with_row_count(Some(messages.len())), - )? - }); + // O4 optimization: avoid building an intermediate `RecordBatch` and + // converting it to `StructArray`. We finish builders directly into a + // `Vec` and walk the per-output `msg_mapping` path to + // extract the target column from any nested StructArray. + let pb_top_arrays: Vec = self + .output_array_builders + .iter() + .map(|builder| builder.get_dyn_mut().finish()) + .collect(); let mut output_arrays: Vec = Vec::new(); output_arrays.push(Arc::new(kafka_partition.clone())); output_arrays.push(Arc::new(kafka_offset.clone())); output_arrays.push(Arc::new(kafka_timestamp.clone())); for (field_idx, field) in self.output_schema_without_meta.fields().iter().enumerate() { - let array_ref: ArrayRef = get_output_array(&root_struct, &self.msg_mapping[field_idx])?; + let mapping = &self.msg_mapping[field_idx]; + let array_ref: ArrayRef = get_output_array_from_top(&pb_top_arrays, mapping)?; if array_ref.null_count() == array_ref.len() { output_arrays.push(new_null_array(field.data_type(), array_ref.len())); } else { + // O7/C3 fix: replace `.expect("Failed to cast array")` with + // error propagation so JNI callers don't get a process abort. output_arrays.push( datafusion_ext_commons::arrow::cast::cast(&array_ref, field.data_type()) - .expect("Failed to cast array"), + .map_err(|e| { + DataFusionError::Execution(format!( + "Failed to cast array for field {}: {e}", + field.name() + )) + })?, ); } } @@ -164,13 +260,14 @@ impl PbDeserializer { .collect::(), )); // Schema inferred from the PB descriptor. + // O9: pass nested_msg_mapping by reference to avoid a HashMap clone + // on every initialization (and on every recursive nested call). let pb_schema = transfer_output_schema_to_pb_schema( message_descriptor.clone(), &output_schema_without_meta, - nested_msg_mapping.clone(), + nested_msg_mapping, &skip_fields, - ) - .expect("Failed to transfer output schema to pb schema"); + )?; let tag_to_output_mapping = create_tag_to_output_mapping(message_descriptor.clone(), &pb_schema); @@ -179,7 +276,7 @@ impl PbDeserializer { create_output_array_builders(&pb_schema, message_descriptor.clone())?; let ensure_size = ensure_output_array_builders_size(&output_array_builders)?; - let value_handlers = message_descriptor + let value_handlers_map = message_descriptor .fields() .map(|field| { Ok(( @@ -194,6 +291,11 @@ impl PbDeserializer { )) }) .collect::>>()?; + // O2 optimization: switch to Vec> when tags are dense. + let value_handlers = ValueHandlers::from_map(value_handlers_map); + // Precompute the handler count once (the Vec variant's `len()` is an + // O(max_tag) scan); the per-batch hot path reads `handler_count` instead. + let handler_count = value_handlers.len() as u32; // precompute message mappings let msg_mapping = output_schema_without_meta @@ -232,6 +334,13 @@ impl PbDeserializer { }) .collect::>>()?; + // C1 fix: detect top-level List/Map columns that require ensure_size + // every row (their per-row slots are finalized only inside ensure_size). + let top_level_has_list_or_map = pb_schema + .fields() + .iter() + .any(|f| matches!(f.data_type(), DataType::List(_) | DataType::Map(_, _))); + Ok(Self { output_schema, output_schema_without_meta, @@ -239,7 +348,9 @@ impl PbDeserializer { output_array_builders, ensure_size, value_handlers, + handler_count, msg_mapping, + top_level_has_list_or_map, }) } } @@ -247,7 +358,7 @@ impl PbDeserializer { fn transfer_output_schema_to_pb_schema( message_descriptor: MessageDescriptor, output_schema: &SchemaRef, - nested_msg_mapping: HashMap, + nested_msg_mapping: &HashMap, skip_fields: &[String], ) -> Result { let mut pb_schema_fields: Vec = vec![]; @@ -298,10 +409,12 @@ fn transfer_output_schema_to_pb_schema( let sub_pb_schema = transfer_output_schema_to_pb_schema( sub_message_desc.clone(), &Arc::new(Schema::new(sub_fields)), - sub_pb_nested_msg_mapping.clone(), + // O9 optimization: pass by reference instead of + // cloning the entire HashMap on every recursive + // call. + &sub_pb_nested_msg_mapping, skip_fields, - ) - .expect("transfer_output_schema_to_pb_schema failed"); + )?; pb_schema_fields.push(Field::new( msg_field_name, DataType::Struct(sub_pb_schema.fields.clone()), @@ -845,7 +958,7 @@ pub(crate) fn ensure_output_array_builders_size( .map(|(builder_type, builders)| { Ok(match builder_type { BuilderType::Boolean => { - impl_for_builders!(BooleanBuilder, builders, |b| b.append_null()) + impl_for_builders!(BooleanBuilder, builders, |b| b.append_value(false)) } BuilderType::Int32 => { impl_for_builders!(Int32Builder, builders, |b| b.append_value(0)) @@ -902,6 +1015,24 @@ fn get_output_array(struct_array: &StructArray, nested_field_name: &[usize]) -> Ok(column.clone()) } +/// O4 optimization helper: extract a (possibly nested) column from the list +/// of top-level finished arrays without first building a wrapping +/// `StructArray` for the root level. The first index selects from the top +/// `Vec`; remaining indices descend into nested `StructArray`s. +fn get_output_array_from_top( + top_arrays: &[ArrayRef], + nested_field_indices: &[usize], +) -> Result { + let column = top_arrays[nested_field_indices[0]].clone(); + if nested_field_indices.len() > 1 { + return get_output_array( + downcast_any!(&column, StructArray)?, + &nested_field_indices[1..], + ); + } + Ok(column) +} + fn create_value_handler( message_descriptor: &MessageDescriptor, tag_id: u32, @@ -949,7 +1080,12 @@ fn create_value_handler( return df_execution_err!("buffer underflow"); } let value = &cursor.get_mut()[cursor.position() as usize..][..len as usize]; - $handle_fn(value); + // O7/C3 fix: propagate handle_fn errors instead of + // discarding them, so an invalid UTF-8 string (or any other + // handle_fn failure) surfaces to the caller rather than + // aborting the JVM via JNI. + let res: Result<()> = $handle_fn(value); + res?; cursor.advance(len as usize); Ok(()) }) @@ -958,26 +1094,29 @@ fn create_value_handler( macro_rules! impl_for_repeated_builder { ($encoding_tyname:ident, $handle_fn:expr) => {{ + // O6 optimization: hoist the buffer out of the per-call body so + // its capacity is reused across calls instead of alloc/dealloc + // per repeated field decode. We use `RefCell` because the outer + // ValueHandler is `Box` (immutable closure); the buffer + // is borrowed mut for the duration of decoding/handle_fn, and + // each handler is single-threaded. + let value_buf: std::cell::RefCell> = + std::cell::RefCell::new(Default::default()); Box::new(move |cursor, tag, wire_type| { let merge_method = prost::encoding::$encoding_tyname::merge_repeated; - let value = UnsafeCell::new(Default::default()); - merge_method( - wire_type, - unsafe { &mut *value.get() }, - cursor, - DecodeContext::default(), - ) - .map_err(|e| { - DataFusionError::Execution(format!( - "Failed to decode repeated {:?} [{}] and {} field: {}", - wire_type, - tag, - stringify!($encoding_tyname), - e - )) - })?; - $handle_fn(unsafe { &*value.get() }); - unsafe { &mut *value.get() }.clear(); + let mut value = value_buf.borrow_mut(); + value.clear(); + merge_method(wire_type, &mut *value, cursor, DecodeContext::default()) + .map_err(|e| { + DataFusionError::Execution(format!( + "Failed to decode repeated {:?} [{}] and {} field: {}", + wire_type, + tag, + stringify!($encoding_tyname), + e + )) + })?; + $handle_fn(&*value); Ok(()) }) }}; @@ -994,7 +1133,12 @@ fn create_value_handler( return df_execution_err!("buffer underflow"); } - $handle_fn(&cursor.get_mut()[cursor.position() as usize..][..len as usize]); + // O7/C3 fix: handle_fn is now expected to return Result<()> so + // sub-handler errors propagate up through `?` instead of using + // .expect()` which would abort the JVM via JNI. + let res: Result<()> = + $handle_fn(&cursor.get_mut()[cursor.position() as usize..][..len as usize]); + res?; cursor.advance(len as usize); Ok(()) }) @@ -1089,14 +1233,30 @@ fn create_value_handler( .values() .get_mut::()?; return Ok(impl_for_bytes_builder!(string, |value: &[u8]| { - let s = unsafe { str::from_utf8_unchecked(value) }; + // SAFETY: validate on the release path. protobuf 3 says + // `string` fields are UTF-8, but Kafka payloads may come + // from non-conformant producers; an unchecked decode + // would construct an invalid `&str` and violate Arrow's + // UTF-8 invariant (UB). Surface the error instead. + let s = std::str::from_utf8(value).map_err(|e| { + DataFusionError::Execution(format!( + "protobuf string field contains invalid UTF-8: {e}" + )) + })?; array_builder.get_mut().append_value(s); + Ok(()) })); } else { let array_builder = output_array_builder.get_mut::()?; return Ok(impl_for_bytes_builder!(string, |value: &[u8]| { - let s = unsafe { str::from_utf8_unchecked(value) }; + // SAFETY: see above — validate UTF-8 on the release path. + let s = std::str::from_utf8(value).map_err(|e| { + DataFusionError::Execution(format!( + "protobuf string field contains invalid UTF-8: {e}" + )) + })?; array_builder.get_mut().append_value(s); + Ok(()) })); } } @@ -1205,7 +1365,11 @@ fn create_value_handler( } } Kind::Enum(enum_descriptor) => { - let mut enum_string_mapping = HashMap::new(); + // Build the enum number→name map once for this field and move it + // into the value-handler closure. It is per-field (not shared + // across handlers), so a plain owned HashMap is enough — no Arc + // refcount overhead. + let mut enum_string_mapping: HashMap = HashMap::new(); for enum_value_descriptor in enum_descriptor.values() { enum_string_mapping.insert( enum_value_descriptor.number(), @@ -1283,30 +1447,23 @@ fn create_value_handler( let struct_builder = output_array_builder .get_mut::() .expect("SharedStructArrayBuilder is null"); + let sub_ensure_size = std::cell::RefCell::new( + ensure_output_array_builders_size(&sub_output_array_builders)?, + ); - return Ok(impl_for_message_builder!(|buf: &[u8]| { + return Ok(impl_for_message_builder!(|buf: &[u8]| -> Result<()> { if buf.is_empty() { + // C2 fix: pad the struct's child builders before + // advancing the struct null buffer, so children + // length stays aligned with the struct length. + (sub_ensure_size.borrow_mut())(struct_builder.get_mut().len() + 1); struct_builder.get_mut().append(false); } else { - let mut sub_cursor = Cursor::new(buf); - while sub_cursor.has_remaining() { - if let Ok((sub_tag, sub_wire_type)) = - prost::encoding::decode_key(&mut sub_cursor) - { - if let Some(sub_value_handler) = - sub_value_handlers.get(&sub_tag) - { - (*sub_value_handler)( - &mut sub_cursor, - sub_tag, - sub_wire_type, - ) - .expect("Failed to process sub field"); - } - } - } + decode_sub_message(buf, &sub_value_handlers)?; + (sub_ensure_size.borrow_mut())(struct_builder.get_mut().len() + 1); struct_builder.get_mut().append(true); } + Ok(()) })); } else if let DataType::List(struct_fields) = output_field.data_type() { if let DataType::Struct(sub_fields) = struct_fields.data_type() { @@ -1343,7 +1500,10 @@ fn create_value_handler( ); } } - return Ok(impl_for_message_builder!(|buf: &[u8]| { + let sub_ensure_size = std::cell::RefCell::new( + ensure_output_array_builders_size(&sub_output_array_builders)?, + ); + return Ok(impl_for_message_builder!(|buf: &[u8]| -> Result<()> { let struct_builder = output_array_builder .get_mut::() .expect("SharedListArrayBuilder is null") @@ -1352,28 +1512,19 @@ fn create_value_handler( .get_mut::() .expect("SharedStructArrayBuilder is null"); if buf.is_empty() { + // C2 fix: pad child builders before append(false) + // to keep struct children aligned with the + // struct length (symmetric with the non-empty + // branch below). + (sub_ensure_size.borrow_mut())(struct_builder.get_mut().len() + 1); struct_builder.get_mut().append(false); } else { // 解析嵌套的 message - let mut sub_cursor = Cursor::new(buf); - while sub_cursor.has_remaining() { - if let Ok((sub_tag, sub_wire_type)) = - prost::encoding::decode_key(&mut sub_cursor) - { - if let Some(sub_value_handler) = - sub_value_handlers.get(&sub_tag) - { - (*sub_value_handler)( - &mut sub_cursor, - sub_tag, - sub_wire_type, - ) - .expect("Failed to process sub field"); - } - } - } + decode_sub_message(buf, &sub_value_handlers)?; + (sub_ensure_size.borrow_mut())(struct_builder.get_mut().len() + 1); struct_builder.get_mut().append(true); } + Ok(()) })); } else { return Err(DataFusionError::Execution(format!( @@ -1419,28 +1570,13 @@ fn create_value_handler( .get_mut::() .expect("SharedMapArrayBuilder is null"); - return Ok(impl_for_message_builder!(|buf: &[u8]| { + return Ok(impl_for_message_builder!(|buf: &[u8]| -> Result<()> { if buf.is_empty() { map_builder.get_mut().append(true); } else { - let mut sub_cursor = Cursor::new(buf); - while sub_cursor.has_remaining() { - if let Ok((sub_tag, sub_wire_type)) = - prost::encoding::decode_key(&mut sub_cursor) - { - if let Some(sub_value_handler) = - sub_value_handlers.get(&sub_tag) - { - (*sub_value_handler)( - &mut sub_cursor, - sub_tag, - sub_wire_type, - ) - .expect("Failed to process sub field"); - } - } - } + decode_sub_message(buf, &sub_value_handlers)?; } + Ok(()) })); } else { return Err(DataFusionError::Execution(format!( @@ -1486,46 +1622,7 @@ fn create_value_handler( } Ok(Box::new(|cursor, tag, wire_type| { - let mut skip_value = move || { - match wire_type { - WireType::Varint => { - prost::encoding::decode_varint(cursor) - .map_err(|e| DataFusionError::Execution(e.to_string()))?; - } - WireType::ThirtyTwoBit => { - if cursor.remaining() < 4 { - return df_execution_err!("buffer underflow"); - } - cursor.advance(4); - } - WireType::SixtyFourBit => { - if cursor.remaining() < 8 { - return df_execution_err!("buffer underflow"); - } - cursor.advance(8); - } - WireType::LengthDelimited => { - let len = prost::encoding::decode_varint(cursor) - .map_err(|e| DataFusionError::Execution(e.to_string()))? - as usize; - if cursor.remaining() < len { - return df_execution_err!("buffer underflow"); - } - cursor.advance(len); - } - _ => { - UnknownField::decode_value(tag, wire_type, cursor, DecodeContext::default()) - .map_err(|e| { - DataFusionError::Execution(format!( - "Failed to decode unknown value: {e}" - )) - })?; - } - } - Ok(()) - }; - - skip_value() + skip_pb_value(cursor, tag, wire_type) .map_err(|e| DataFusionError::Execution(format!("Failed to decode unknown value: {e}"))) })) } @@ -1537,6 +1634,67 @@ fn get_content_after_last_dot(s: &str) -> &str { } } +/// Skip an unknown protobuf field's value, advancing the cursor past it so the +/// outer parsing loop stays in sync. Used by both the top-level main loop and +/// the fallback handler returned by `create_value_handler` when the field has +/// no associated builder. Without this, an unknown tag (e.g., a new field +/// added by an upstream producer) would leave the cursor positioned at the +/// value bytes and the next `decode_key` would interpret garbage. +fn skip_pb_value(cursor: &mut Cursor<&[u8]>, tag: u32, wire_type: WireType) -> Result<()> { + match wire_type { + WireType::Varint => { + prost::encoding::decode_varint(cursor) + .map_err(|e| DataFusionError::Execution(e.to_string()))?; + } + WireType::ThirtyTwoBit => { + if cursor.remaining() < 4 { + return df_execution_err!("buffer underflow"); + } + cursor.advance(4); + } + WireType::SixtyFourBit => { + if cursor.remaining() < 8 { + return df_execution_err!("buffer underflow"); + } + cursor.advance(8); + } + WireType::LengthDelimited => { + let len = prost::encoding::decode_varint(cursor) + .map_err(|e| DataFusionError::Execution(e.to_string()))? + as usize; + if cursor.remaining() < len { + return df_execution_err!("buffer underflow"); + } + cursor.advance(len); + } + _ => { + UnknownField::decode_value(tag, wire_type, cursor, DecodeContext::default()).map_err( + |e| DataFusionError::Execution(format!("Failed to decode unknown value: {e}")), + )?; + } + } + Ok(()) +} + +/// Decode a length-delimited sub-message body, dispatching each known tag to +/// its value handler and skipping unknown tags (C1) with error propagation +/// (O7). Shared by the Struct / List-of-Struct / Map sub-message handlers, +/// which were previously three near-verbatim copies of this loop. +fn decode_sub_message(buf: &[u8], handlers: &ValueHandlerMap) -> Result<()> { + let mut sub_cursor = Cursor::new(buf); + while sub_cursor.has_remaining() { + let (sub_tag, sub_wire_type) = prost::encoding::decode_key(&mut sub_cursor) + .map_err(|e| DataFusionError::Execution(format!("Failed to decode sub key: {e}")))?; + if let Some(sub_value_handler) = handlers.get(&sub_tag) { + (*sub_value_handler)(&mut sub_cursor, sub_tag, sub_wire_type)?; + } else { + // C1 fix: skip unknown sub-tags so the cursor stays in sync. + skip_pb_value(&mut sub_cursor, sub_tag, sub_wire_type)?; + } + } + Ok(()) +} + pub(crate) fn adaptive_append_children( builder: &SharedArrayBuilder, ) -> Option> { @@ -1569,7 +1727,7 @@ mod tests { use prost::Message as ProstMessage; use prost_reflect::prost_types::{DescriptorProto, FileDescriptorProto, FileDescriptorSet}; use prost_types::{ - FieldDescriptorProto, + FieldDescriptorProto, MessageOptions, field_descriptor_proto::{Label, Type}, }; @@ -1789,6 +1947,207 @@ mod tests { buf } + fn create_repeated_test_descriptor() -> Vec { + let field_descriptors = vec![ + FieldDescriptorProto { + name: Some("id".to_string()), + number: Some(1), + label: Some(Label::Optional as i32), + r#type: Some(Type::Int32 as i32), + type_name: None, + extendee: None, + default_value: None, + oneof_index: None, + json_name: Some("id".to_string()), + options: None, + proto3_optional: None, + }, + FieldDescriptorProto { + name: Some("scores".to_string()), + number: Some(2), + label: Some(Label::Repeated as i32), + r#type: Some(Type::Int32 as i32), + type_name: None, + extendee: None, + default_value: None, + oneof_index: None, + json_name: Some("scores".to_string()), + options: None, + proto3_optional: None, + }, + ]; + + let message_descriptor = DescriptorProto { + name: Some("RepeatedMessage".to_string()), + field: field_descriptors, + extension: vec![], + nested_type: vec![], + enum_type: vec![], + extension_range: vec![], + oneof_decl: vec![], + options: None, + reserved_range: vec![], + reserved_name: vec![], + }; + + let file_descriptor = FileDescriptorProto { + name: Some("repeated_test.proto".to_string()), + package: Some("test".to_string()), + dependency: vec![], + public_dependency: vec![], + weak_dependency: vec![], + message_type: vec![message_descriptor], + enum_type: vec![], + service: vec![], + extension: vec![], + options: None, + source_code_info: None, + syntax: Some("proto3".to_string()), + }; + + let descriptor_set = FileDescriptorSet { + file: vec![file_descriptor], + }; + + let mut buf = Vec::new(); + descriptor_set + .encode(&mut buf) + .expect("Failed to encode FileDescriptorSet"); + buf + } + + /// Descriptor for a message with a single top-level `map` + /// field `kv` (number 1). The map entry is a nested message `KvEntry` + /// with the `[map_entry = true]` option, which is what `prost_reflect` + /// keys on to report `field.is_map()` == true (and thus produce an + /// Arrow `DataType::Map` rather than a `DataType::List` of struct). + fn create_map_test_descriptor() -> Vec { + let kv_entry_fields = vec![ + // string key = 1; + FieldDescriptorProto { + name: Some("key".to_string()), + number: Some(1), + label: Some(Label::Optional as i32), + r#type: Some(Type::String as i32), + type_name: None, + extendee: None, + default_value: None, + oneof_index: None, + json_name: Some("key".to_string()), + options: None, + proto3_optional: None, + }, + // int32 value = 2; + FieldDescriptorProto { + name: Some("value".to_string()), + number: Some(2), + label: Some(Label::Optional as i32), + r#type: Some(Type::Int32 as i32), + type_name: None, + extendee: None, + default_value: None, + oneof_index: None, + json_name: Some("value".to_string()), + options: None, + proto3_optional: None, + }, + ]; + + let kv_entry_descriptor = DescriptorProto { + name: Some("KvEntry".to_string()), + field: kv_entry_fields, + extension: vec![], + nested_type: vec![], + enum_type: vec![], + extension_range: vec![], + oneof_decl: vec![], + // The marker that makes prost_reflect treat `kv` as a map. + options: Some(MessageOptions { + map_entry: Some(true), + ..Default::default() + }), + reserved_range: vec![], + reserved_name: vec![], + }; + + // map kv = 1; — desugars to `repeated KvEntry kv = 1`. + let kv_field = FieldDescriptorProto { + name: Some("kv".to_string()), + number: Some(1), + label: Some(Label::Repeated as i32), + r#type: Some(Type::Message as i32), + type_name: Some(".test.MapMessage.KvEntry".to_string()), + extendee: None, + default_value: None, + oneof_index: None, + json_name: Some("kv".to_string()), + options: None, + proto3_optional: None, + }; + + let map_message_descriptor = DescriptorProto { + name: Some("MapMessage".to_string()), + field: vec![kv_field], + extension: vec![], + nested_type: vec![kv_entry_descriptor], + enum_type: vec![], + extension_range: vec![], + oneof_decl: vec![], + options: None, + reserved_range: vec![], + reserved_name: vec![], + }; + + let file_descriptor = FileDescriptorProto { + name: Some("map_test.proto".to_string()), + package: Some("test".to_string()), + dependency: vec![], + public_dependency: vec![], + weak_dependency: vec![], + message_type: vec![map_message_descriptor], + enum_type: vec![], + service: vec![], + extension: vec![], + options: None, + source_code_info: None, + syntax: Some("proto3".to_string()), + }; + + let descriptor_set = FileDescriptorSet { + file: vec![file_descriptor], + }; + + let mut buf = Vec::new(); + descriptor_set + .encode(&mut buf) + .expect("Failed to encode FileDescriptorSet"); + buf + } + + /// Encode a `MapMessage` with the given `kv` entries. An empty `entries` + /// slice yields a message with the `kv` field entirely absent (no tag-1 + /// pairs), exercising the absent-map path. + fn create_map_test_message(entries: &[(&str, i32)]) -> Vec { + use prost::encoding::*; + + let mut buf = Vec::new(); + for (k, v) in entries { + let mut entry = Vec::new(); + // key (entry field 1, string) + encode_key(1, WireType::LengthDelimited, &mut entry); + encode_varint(k.len() as u64, &mut entry); + entry.extend_from_slice(k.as_bytes()); + // value (entry field 2, int32) + encode_key(2, WireType::Varint, &mut entry); + encode_varint(*v as u64, &mut entry); + // map field 1 (length-delimited entry) + encode_key(1, WireType::LengthDelimited, &mut buf); + encode_varint(entry.len() as u64, &mut buf); + buf.extend_from_slice(&entry); + } + buf + } + fn create_test_message(id: i32, name: &str, score: f64, active: bool) -> Vec { use prost::encoding::*; @@ -1843,6 +2202,39 @@ mod tests { buf } + fn create_repeated_test_message(id: i32, scores: &[i32]) -> Vec { + use prost::encoding::*; + + let mut buf = Vec::new(); + + encode_key(1, WireType::Varint, &mut buf); + encode_varint(id as u64, &mut buf); + + for score in scores { + encode_key(2, WireType::Varint, &mut buf); + encode_varint(*score as u64, &mut buf); + } + + buf + } + + fn create_empty_nested_test_message(name: &str) -> Vec { + use prost::encoding::*; + + let mut buf = Vec::new(); + + // name (field 1, string) —— present + encode_key(1, WireType::LengthDelimited, &mut buf); + encode_varint(name.len() as u64, &mut buf); + buf.extend_from_slice(name.as_bytes()); + + // address (field 2, message) —— present but length 0(空 sub-message) + encode_key(2, WireType::LengthDelimited, &mut buf); + encode_varint(0, &mut buf); + + buf + } + fn create_binary_array(messages: Vec>) -> BinaryArray { let mut builder = BinaryBuilder::new(); for msg in messages { @@ -2041,6 +2433,129 @@ mod tests { assert_eq!(city_array.value(1), "Los Angeles"); } + #[test] + fn test_parse_messages_with_repeated_field_all_tags_present() { + let descriptor_data = create_repeated_test_descriptor(); + let schema = Arc::new(Schema::new(vec![ + Field::new("serialized_kafka_records_partition", DataType::Int32, false), + Field::new("serialized_kafka_records_offset", DataType::Int64, false), + Field::new("serialized_kafka_records_timestamp", DataType::Int64, false), + Field::new("id", DataType::Int32, true), + Field::new( + "scores", + DataType::List(Arc::new(Field::new("scores", DataType::Int32, true))), + true, + ), + ])); + + let mut deserializer = PbDeserializer::new( + descriptor_data, + "RepeatedMessage", + schema, + &HashMap::new(), + &[], + ) + .expect("Failed to create deserializer"); + + // Key: fill every row with id + scores, so that seen_tags.count_ones() == + // total_handlers, triggering O3 to skip the ensure_size path (under the + // current bug, list row slots are not finalized). + let messages = create_binary_array(vec![ + create_repeated_test_message(1, &[10, 11]), + create_repeated_test_message(2, &[20, 21, 22]), + ]); + let partitions = create_partition_array(vec![0, 0]); + let offsets = create_offset_array(vec![100, 101]); + let timestamps = create_timestamp_array(vec![1000, 1001]); + + let batch = deserializer + .parse_messages_with_kafka_meta(&messages, &partitions, &offsets, ×tamps) + .expect("Failed to deserialize repeated message"); + + assert_eq!(batch.num_rows(), 2); + let scores = batch + .column(4) + .as_any() + .downcast_ref::() + .expect("Failed to downcast scores array to ListArray"); + assert_eq!(scores.len(), 2); + + let row0 = scores.value(0); + let row0_values = row0 + .as_any() + .downcast_ref::() + .expect("Failed to downcast row0 scores to Int32Array"); + assert_eq!(row0_values.values(), &[10, 11]); + + let row1 = scores.value(1); + let row1_values = row1 + .as_any() + .downcast_ref::() + .expect("Failed to downcast row1 scores to Int32Array"); + assert_eq!(row1_values.values(), &[20, 21, 22]); + } + + #[test] + fn test_parse_messages_with_empty_struct_message_all_tags_present() { + let descriptor_data = create_nested_test_descriptor(); + let schema = Arc::new(Schema::new(vec![ + Field::new("serialized_kafka_records_partition", DataType::Int32, false), + Field::new("serialized_kafka_records_offset", DataType::Int64, false), + Field::new("serialized_kafka_records_timestamp", DataType::Int64, false), + Field::new("name", DataType::Utf8, true), + Field::new("street", DataType::Utf8, true), + Field::new("city", DataType::Utf8, true), + ])); + + let mut nested_mapping = HashMap::new(); + nested_mapping.insert("street".to_string(), "address.street".to_string()); + nested_mapping.insert("city".to_string(), "address.city".to_string()); + + let mut deserializer = + PbDeserializer::new(descriptor_data, "Person", schema, &nested_mapping, &[]) + .expect("Failed to create deserializer"); + + // Both name and address tags are present (address being an empty sub-message), + // triggering the empty struct branch + the O3 all-fields-hit path. + let messages = create_binary_array(vec![ + create_empty_nested_test_message("Alice"), + create_empty_nested_test_message("Bob"), + ]); + let partitions = create_partition_array(vec![0, 0]); + let offsets = create_offset_array(vec![200, 201]); + let timestamps = create_timestamp_array(vec![2000, 2001]); + + let batch = deserializer + .parse_messages_with_kafka_meta(&messages, &partitions, &offsets, ×tamps) + .expect("Failed to deserialize empty nested message"); + + assert_eq!(batch.num_rows(), 2); + + let street = batch + .column(4) + .as_any() + .downcast_ref::() + .expect("Failed to downcast street array to StringArray"); + assert_eq!(street.len(), 2); + // C2: the empty sub-message pads children to align with the struct + // length. `ensure_output_array_builders_size` pads String children + // with a non-null default (""), consistent with how absent fields are + // already handled everywhere else — so street is non-null empty. + assert_eq!(street.null_count(), 0); + assert_eq!(street.value(0), ""); + assert_eq!(street.value(1), ""); + + let city = batch + .column(5) + .as_any() + .downcast_ref::() + .expect("Failed to downcast city array to StringArray"); + assert_eq!(city.len(), 2); + assert_eq!(city.null_count(), 0); + assert_eq!(city.value(0), ""); + assert_eq!(city.value(1), ""); + } + #[test] fn test_parse_messages_with_kafka_meta_empty() { let descriptor_data = create_test_descriptor(); @@ -2158,4 +2673,164 @@ mod tests { assert_eq!(id_array.value(2), 3); assert_eq!(id_array.value(3), 4); } + + /// Pin the row-alignment invariant for a top-level `DataType::Map` column. + /// + /// A top-level Map is structurally different from List: the per-row + /// offset/null slot is finalized only inside `ensure_size` (which runs + /// every row because `top_level_has_list_or_map` is true), while the + /// per-entry key/value pushes go to the child builders via + /// `decode_sub_message`. This test covers both the present (≥1 entry) + /// and absent (no `kv` tag) cases and asserts `map.len() == num_rows` + /// plus the per-row entry counts, so a regression where the map column + /// ever desyncs from the batch row count + /// (e.g. if `top_level_has_list_or_map` stopped matching `DataType::Map`) + /// fails loudly instead of silently corrupting offsets. + #[test] + fn test_parse_messages_with_top_level_map() { + let descriptor_data = create_map_test_descriptor(); + let schema = Arc::new(Schema::new(vec![ + Field::new("serialized_kafka_records_partition", DataType::Int32, false), + Field::new("serialized_kafka_records_offset", DataType::Int64, false), + Field::new("serialized_kafka_records_timestamp", DataType::Int64, false), + Field::new( + "kv", + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Utf8, true), + Field::new("value", DataType::Int32, true), + ])), + false, + )), + false, + ), + true, + ), + ])); + + let mut deserializer = + PbDeserializer::new(descriptor_data, "MapMessage", schema, &HashMap::new(), &[]) + .expect("Failed to create deserializer"); + + // row0: two entries; row1: map absent (no kv tag at all). + let messages = create_binary_array(vec![ + create_map_test_message(&[("a", 1), ("b", 2)]), + create_map_test_message(&[]), + ]); + let partitions = create_partition_array(vec![0, 1]); + let offsets = create_offset_array(vec![10, 20]); + let timestamps = create_timestamp_array(vec![100, 200]); + + let batch = deserializer + .parse_messages_with_kafka_meta(&messages, &partitions, &offsets, ×tamps) + .expect("Failed to deserialize map message"); + + assert_eq!(batch.num_rows(), 2); + let map_array = batch + .column(3) + .as_any() + .downcast_ref::() + .expect("Failed to downcast kv column to MapArray"); + + // Row-alignment invariant: the map column has exactly one slot per row. + assert_eq!(map_array.len(), 2); + + // row0: present with 2 entries → non-null, offsets span 2 entries. + assert!(!map_array.is_null(0)); + let row0 = map_array.value(0); + let row0_entries = row0 + .as_any() + .downcast_ref::() + .expect("map entries are a StructArray"); + assert_eq!(row0_entries.len(), 2); + let row0_keys = row0_entries + .column(0) + .as_any() + .downcast_ref::() + .expect("map keys are StringArray"); + let row0_values = row0_entries + .column(1) + .as_any() + .downcast_ref::() + .expect("map values are Int32Array"); + assert_eq!(row0_keys.value(0), "a"); + assert_eq!(row0_values.value(0), 1); + assert_eq!(row0_keys.value(1), "b"); + assert_eq!(row0_values.value(1), 2); + + // row1: absent map → ensure_size finalizes one non-null slot with + // 0 entries (current behavior). Pin it so an absent-vs-null change is + // conscious. + assert!(!map_array.is_null(1)); + let row1 = map_array.value(1); + let row1_entries = row1 + .as_any() + .downcast_ref::() + .expect("map entries are a StructArray"); + assert_eq!(row1_entries.len(), 0); + } + + /// Regression test for the #2320 boolean bug, in the specific + /// "top-level field absent from EVERY row of the batch" shape that the + /// dropped O10 short-circuit used to mis-handle. + /// + /// With O10 present, a column never touched in any row short-circuited to + /// `new_null_array`, emitting all-NULL — re-introducing #2320 for the + /// all-absent case (and similarly for int/string/float/binary). With O10 + /// removed, the column falls through to the cast path on a builder that + /// `ensure_size` filled with the proto3 default `false`, so it is + /// non-null all-false. This test pins that: if O10 (or an equivalent + /// all-null short-circuit) is ever re-introduced for a top-level field, + /// `null_count()` would be 2 and the test would fail. + #[test] + fn test_parse_messages_top_level_boolean_absent_in_all_rows() { + let descriptor_data = create_test_descriptor(); + // Schema exposes only `active` (field 4, bool). The messages below + // carry only `id` (field 1), so `active` (tag 4) is never decoded. + let schema = Arc::new(Schema::new(vec![ + Field::new("serialized_kafka_records_partition", DataType::Int32, false), + Field::new("serialized_kafka_records_offset", DataType::Int64, false), + Field::new("serialized_kafka_records_timestamp", DataType::Int64, false), + Field::new("active", DataType::Boolean, true), + ])); + + let mut deserializer = + PbDeserializer::new(descriptor_data, "TestMessage", schema, &HashMap::new(), &[]) + .expect("Failed to create deserializer"); + + // Each message encodes ONLY `id` (tag 1); `active` (tag 4) is absent + // in every row of the batch — the exact shape O10 mishandled. + let only_id_message = |id: i32| { + use prost::encoding::*; + let mut buf = Vec::new(); + encode_key(1, WireType::Varint, &mut buf); + encode_varint(id as u64, &mut buf); + buf + }; + let messages = create_binary_array(vec![only_id_message(1), only_id_message(2)]); + let partitions = create_partition_array(vec![0, 0]); + let offsets = create_offset_array(vec![100, 101]); + let timestamps = create_timestamp_array(vec![1000, 1001]); + + let batch = deserializer + .parse_messages_with_kafka_meta(&messages, &partitions, &offsets, ×tamps) + .expect("Failed to deserialize"); + + assert_eq!(batch.num_rows(), 2); + let active_array = batch + .column(3) + .as_any() + .downcast_ref::() + .expect("Failed to downcast active array to BooleanArray"); + + // All-absent boolean must be the proto3 default `false`, NON-null — + // NOT all-NULL (which is the #2320 regression this PR fixes). + assert_eq!(active_array.null_count(), 0); + assert!(!active_array.is_null(0)); + assert!(!active_array.is_null(1)); + assert!(!active_array.value(0)); + assert!(!active_array.value(1)); + } }