From 6bdcd32b62d3852247efe48a7a3a2af7d7fd922c Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Thu, 11 Jun 2026 17:35:52 -0400 Subject: [PATCH 1/6] Support dense grouped aggregate accumulation Signed-off-by: "Nicholas Gates" --- vortex-array/benches/aggregate_grouped.rs | 80 ++- vortex-array/src/aggregate_fn/accumulator.rs | 2 +- .../src/aggregate_fn/accumulator_grouped.rs | 478 ++++++++++-------- .../src/aggregate_fn/fns/count/mod.rs | 56 ++ vortex-array/src/aggregate_fn/fns/sum/mod.rs | 267 +++++++--- vortex-array/src/aggregate_fn/kernels.rs | 64 ++- vortex-array/src/aggregate_fn/vtable.rs | 29 ++ 7 files changed, 637 insertions(+), 339 deletions(-) diff --git a/vortex-array/benches/aggregate_grouped.rs b/vortex-array/benches/aggregate_grouped.rs index b067314c1d9..2d46a5cce8a 100644 --- a/vortex-array/benches/aggregate_grouped.rs +++ b/vortex-array/benches/aggregate_grouped.rs @@ -18,10 +18,8 @@ use vortex_array::aggregate_fn::EmptyOptions; use vortex_array::aggregate_fn::GroupedAccumulator; use vortex_array::aggregate_fn::fns::count::Count; use vortex_array::aggregate_fn::fns::sum::Sum; -use vortex_array::arrays::ListViewArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::VarBinViewArray; -use vortex_array::dtype::DType; use vortex_array::validity::Validity; use vortex_buffer::Buffer; @@ -45,44 +43,42 @@ fn total_element_count(group_sizes: &[usize]) -> usize { group_sizes.iter().sum() } -fn contiguous_list_view(elements: ArrayRef, group_sizes: &[usize]) -> ArrayRef { - let mut offset = 0usize; - let offsets: Buffer = group_sizes +struct DenseGroupedInput { + values: ArrayRef, + group_ids: Vec, + num_groups: usize, +} + +fn dense_grouped_input(values: ArrayRef, group_sizes: &[usize]) -> DenseGroupedInput { + assert_eq!(values.len(), total_element_count(group_sizes)); + + let group_ids = group_sizes .iter() - .map(|&size| { - let current_offset = offset; - offset += size; - current_offset as u32 - }) + .enumerate() + .flat_map(|(group_id, &size)| std::iter::repeat_n(group_id as u32, size)) .collect(); - let sizes: Buffer = group_sizes.iter().map(|&size| size as u32).collect(); - assert_eq!(elements.len(), total_element_count(group_sizes)); - - ListViewArray::try_new( - elements, - offsets.into_array(), - sizes.into_array(), - Validity::NonNullable, - ) - .unwrap() - .into_array() + DenseGroupedInput { + values, + group_ids, + num_groups: group_sizes.len(), + } } -fn i32_nullable_all_valid_input() -> ArrayRef { +fn i32_nullable_all_valid_input() -> DenseGroupedInput { let group_sizes = random_group_sizes(); let element_count = total_element_count(&group_sizes); let values: Buffer = (0..element_count) .map(|i| (i % 1024) as i32 - 512) .collect(); let validity = Validity::from_iter(std::iter::repeat_n(true, element_count)); - contiguous_list_view( + dense_grouped_input( PrimitiveArray::new(values, validity).into_array(), &group_sizes, ) } -fn i32_clustered_nulls_input() -> ArrayRef { +fn i32_clustered_nulls_input() -> DenseGroupedInput { let group_sizes = random_group_sizes(); let element_count = total_element_count(&group_sizes); let values = (0..element_count).map(|i| { @@ -92,26 +88,26 @@ fn i32_clustered_nulls_input() -> ArrayRef { Some((i % 1024) as i32 - 512) } }); - contiguous_list_view( + dense_grouped_input( PrimitiveArray::from_option_iter(values).into_array(), &group_sizes, ) } -fn f64_all_valid_input() -> ArrayRef { +fn f64_all_valid_input() -> DenseGroupedInput { let group_sizes = random_group_sizes(); let element_count = total_element_count(&group_sizes); let mut rng = StdRng::seed_from_u64(GROUP_SIZE_SEED); let values: Buffer = (0..element_count) .map(|_| rng.random_range(-1000.0..1000.0)) .collect(); - contiguous_list_view( + dense_grouped_input( PrimitiveArray::new(values, Validity::NonNullable).into_array(), &group_sizes, ) } -fn f64_clustered_nulls_input() -> ArrayRef { +fn f64_clustered_nulls_input() -> DenseGroupedInput { let group_sizes = random_group_sizes(); let element_count = total_element_count(&group_sizes); let mut rng = StdRng::seed_from_u64(GROUP_SIZE_SEED); @@ -122,40 +118,38 @@ fn f64_clustered_nulls_input() -> ArrayRef { Some(rng.random_range(-1000.0f64..1000.0)) } }); - contiguous_list_view( + dense_grouped_input( PrimitiveArray::from_option_iter(values).into_array(), &group_sizes, ) } -fn varbinview_input() -> ArrayRef { +fn varbinview_input() -> DenseGroupedInput { let group_sizes = random_group_sizes(); let element_count = total_element_count(&group_sizes); let values: Vec = (0..element_count) .map(|i| format!("value-{i:06}")) .collect(); - contiguous_list_view( + dense_grouped_input( VarBinViewArray::from_iter_str(values.iter().map(String::as_str)).into_array(), &group_sizes, ) } -fn list_element_dtype(list_view: &ArrayRef) -> DType { - match list_view.dtype() { - DType::List(element_dtype, _) => element_dtype.as_ref().clone(), - dtype => unreachable!("expected List dtype, got {dtype}"), - } -} - -fn grouped_accumulator(list_view: &ArrayRef, vtable: V) -> ArrayRef +fn grouped_accumulator(input: &DenseGroupedInput, vtable: V) -> ArrayRef where V: AggregateFnVTable + Clone, { let mut acc = - GroupedAccumulator::try_new(vtable, EmptyOptions, list_element_dtype(list_view)).unwrap(); - acc.accumulate_list(list_view, &mut LEGACY_SESSION.create_execution_ctx()) - .unwrap(); - divan::black_box(acc.finish().unwrap()) + GroupedAccumulator::try_new(vtable, EmptyOptions, input.values.dtype().clone()).unwrap(); + acc.accumulate( + &input.values, + &input.group_ids, + input.num_groups, + &mut LEGACY_SESSION.create_execution_ctx(), + ) + .unwrap(); + divan::black_box(acc.finish(input.num_groups).unwrap()) } #[divan::bench] diff --git a/vortex-array/src/aggregate_fn/accumulator.rs b/vortex-array/src/aggregate_fn/accumulator.rs index c89418e67a6..ab4e0ee26ba 100644 --- a/vortex-array/src/aggregate_fn/accumulator.rs +++ b/vortex-array/src/aggregate_fn/accumulator.rs @@ -172,7 +172,7 @@ impl DynAccumulator for Accumulator { } // 3. Iteratively check the registry against each intermediate encoding, executing one - // step between checks. Mirrors the loop in `GroupedAccumulator::accumulate_list_view`. + // step between checks. Mirrors the loop in `GroupedAccumulator::accumulate`. // Iteration 0 re-checks the initial encoding — a redundant HashMap miss, the price of // keeping the loop body uniform. Terminates on `AnyColumnar` (Canonical or Constant) // since the vtable's `accumulate(&Columnar)` handles both cases directly. diff --git a/vortex-array/src/aggregate_fn/accumulator_grouped.rs b/vortex-array/src/aggregate_fn/accumulator_grouped.rs index 4b94159127b..66da32b4085 100644 --- a/vortex-array/src/aggregate_fn/accumulator_grouped.rs +++ b/vortex-array/src/aggregate_fn/accumulator_grouped.rs @@ -1,19 +1,12 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use arrow_buffer::ArrowNativeType; use vortex_buffer::Buffer; -use vortex_error::VortexExpect; use vortex_error::VortexResult; -use vortex_error::vortex_bail; use vortex_error::vortex_ensure; use vortex_error::vortex_err; -use vortex_error::vortex_panic; -use vortex_mask::Mask; -use crate::AnyCanonical; use crate::ArrayRef; -use crate::Canonical; use crate::Columnar; use crate::ExecutionCtx; use crate::IntoArray; @@ -22,26 +15,21 @@ use crate::aggregate_fn::AggregateFn; use crate::aggregate_fn::AggregateFnRef; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::DynAccumulator; +use crate::aggregate_fn::kernels::GroupedAggregateKernelResult; use crate::aggregate_fn::session::AggregateFnSessionExt; -use crate::arrays::ChunkedArray; -use crate::arrays::FixedSizeListArray; -use crate::arrays::ListViewArray; -use crate::arrays::fixed_size_list::FixedSizeListArrayExt; -use crate::arrays::listview::ListViewArrayExt; use crate::builders::builder_with_capacity; -use crate::builtins::ArrayBuiltins; +use crate::columnar::AnyColumnar; use crate::dtype::DType; -use crate::dtype::IntegerPType; use crate::executor::max_iterations; -use crate::match_each_integer_ptype; +use crate::scalar::Scalar; /// Reference-counted type-erased grouped accumulator. pub type GroupedAccumulatorRef = Box; -/// An accumulator used for computing grouped aggregates. +/// An accumulator used for computing aggregates over dense group ids. /// -/// Note that the groups must be processed in order, and the accumulator does not support random -/// access to groups. +/// Group ids are dense `u32` slots in the range `0..num_groups`. The accumulator keeps one partial +/// state per slot, so ordered and unordered grouping only differ in how the caller assigns ids. pub struct GroupedAccumulator { /// The vtable of the aggregate function. vtable: V, @@ -55,8 +43,8 @@ pub struct GroupedAccumulator { return_dtype: DType, /// The DType of the partial accumulator state. partial_dtype: DType, - /// The accumulated state for prior batches of groups. - partials: Vec, + /// Dense per-group partial state. + partials: Vec, } impl GroupedAccumulator { @@ -84,249 +72,315 @@ impl GroupedAccumulator { dtype, return_dtype, partial_dtype, - partials: vec![], + partials: Vec::new(), }) } -} - -/// A trait object for type-erased grouped accumulators, used for dynamic dispatch when the aggregate -/// function is not known at compile time. -pub trait DynGroupedAccumulator: 'static + Send { - /// Accumulate a list of groups into the accumulator. - fn accumulate_list(&mut self, groups: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()>; - /// Finish the accumulation and return the partial aggregate results for all groups. - /// Resets the accumulator state for the next round of accumulation. - fn flush(&mut self) -> VortexResult; + fn ensure_groups(&mut self, num_groups: usize) -> VortexResult<()> { + vortex_ensure!( + num_groups <= (u32::MAX as usize) + 1, + "num_groups {} exceeds dense u32 group id capacity", + num_groups + ); - /// Finish the accumulation and return the final aggregate results for all groups. - /// Resets the accumulator state for the next round of accumulation. - fn finish(&mut self) -> VortexResult; -} + while self.partials.len() < num_groups { + self.partials + .push(self.vtable.empty_partial(&self.options, &self.dtype)?); + } + Ok(()) + } -impl DynGroupedAccumulator for GroupedAccumulator { - fn accumulate_list(&mut self, groups: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> { - let elements_dtype = match groups.dtype() { - DType::List(elem, _) => elem, - DType::FixedSizeList(elem, ..) => elem, - _ => vortex_bail!( - "Input DType mismatch: expected List or FixedSizeList, got {}", - groups.dtype() - ), - }; + fn validate_group_ids(&self, group_ids: &[u32], num_groups: usize) -> VortexResult<()> { vortex_ensure!( - elements_dtype.as_ref() == &self.dtype, - "Input DType mismatch: expected {}, got {}", - self.dtype, - elements_dtype + num_groups <= (u32::MAX as usize) + 1, + "num_groups {} exceeds dense u32 group id capacity", + num_groups ); - - // We first execute the groups until it is a ListView or FixedSizeList, since we only - // dispatch the aggregate kernel over the elements of these arrays. - let canonical = match groups.clone().execute::(ctx)? { - Columnar::Canonical(c) => c, - Columnar::Constant(c) => c.into_array().execute::(ctx)?, - }; - match canonical { - Canonical::List(groups) => self.accumulate_list_view(&groups, ctx), - Canonical::FixedSizeList(groups) => self.accumulate_fixed_size_list(&groups, ctx), - _ => vortex_panic!("We checked the DType above, so this should never happen"), + for &group_id in group_ids { + vortex_ensure!( + (group_id as usize) < num_groups, + "Group id {} out of range for {} groups", + group_id, + num_groups + ); } + Ok(()) } - fn flush(&mut self) -> VortexResult { - let states = std::mem::take(&mut self.partials); - Ok(ChunkedArray::try_new(states, self.partial_dtype.clone())?.into_array()) + fn accumulate_kernel_result( + &mut self, + result: GroupedAggregateKernelResult, + num_groups: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult<()> { + self.accumulate_partials(result.partials(), result.group_ids(), num_groups, ctx) } - fn finish(&mut self) -> VortexResult { - let states = self.flush()?; - let results = self.vtable.finalize(states)?; + fn accumulate_fallback( + &mut self, + batch: &ArrayRef, + group_ids: &[u32], + ctx: &mut ExecutionCtx, + ) -> VortexResult<()> { + let Some((&first, rest)) = group_ids.split_first() else { + return Ok(()); + }; + let mut first = first; + let mut last = first; + for &group_id in rest { + first = first.min(group_id); + last = last.max(group_id); + } - vortex_ensure!( - results.dtype() == &self.return_dtype, - "Return DType mismatch: expected {}, got {}", - self.return_dtype, - results.dtype() - ); + let first = first as usize; + let mut buckets = vec![Vec::new(); last as usize - first + 1]; + for (row_idx, &group_id) in group_ids.iter().enumerate() { + buckets[group_id as usize - first].push(row_idx as u64); + } - Ok(results) + for (offset, rows) in buckets.into_iter().enumerate() { + if rows.is_empty() { + continue; + } + + let group = first + offset; + if self.vtable.is_saturated(&self.partials[group]) { + continue; + } + + let taken = batch.clone().take(Buffer::from_iter(rows).into_array())?; + let mut accumulator = Accumulator::try_new( + self.vtable.clone(), + self.options.clone(), + self.dtype.clone(), + )?; + accumulator.accumulate(&taken, ctx)?; + let partial = accumulator.flush()?; + self.vtable + .combine_partials(&mut self.partials[group], partial)?; + } + Ok(()) } } -impl GroupedAccumulator { - fn accumulate_list_view( +/// A trait object for type-erased grouped accumulators, used for dynamic dispatch when the +/// aggregate function is not known at compile time. +pub trait DynGroupedAccumulator: 'static + Send { + /// Accumulate a values batch into dense group state. + fn accumulate( &mut self, - groups: &ListViewArray, + batch: &ArrayRef, + group_ids: &[u32], + num_groups: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult<()>; + + /// Fold columnar partial states into dense group state. + fn accumulate_partials( + &mut self, + partials: &ArrayRef, + group_ids: &[u32], + num_groups: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult<()>; + + /// Merge one group from another grouped accumulator into this accumulator. + fn merge_group( + &mut self, + into: u32, + other: &dyn DynGroupedAccumulator, + from: u32, + ) -> VortexResult<()>; + + /// Return this accumulator's partial dtype. + fn partial_dtype(&self) -> &DType; + + /// Read one group's current partial state. + fn partial_scalar(&self, group_id: u32) -> VortexResult; + + /// Finish the accumulation and return partial aggregate results for all groups. + /// + /// Resets the accumulator state for the next round of accumulation. + fn flush_partials(&mut self, num_groups: usize) -> VortexResult; + + /// Finish the accumulation and return final aggregate results for all groups. + /// + /// Resets the accumulator state for the next round of accumulation. + fn finish(&mut self, num_groups: usize) -> VortexResult; +} + +impl DynGroupedAccumulator for GroupedAccumulator { + fn accumulate( + &mut self, + batch: &ArrayRef, + group_ids: &[u32], + num_groups: usize, ctx: &mut ExecutionCtx, ) -> VortexResult<()> { - let mut elements = groups.elements().clone(); - let groups_validity = groups.validity()?; + vortex_ensure!( + batch.dtype() == &self.dtype, + "Input DType mismatch: expected {}, got {}", + self.dtype, + batch.dtype() + ); + vortex_ensure!( + batch.len() == group_ids.len(), + "Grouped aggregate input length mismatch: {} values, {} group ids", + batch.len(), + group_ids.len() + ); + + self.validate_group_ids(group_ids, num_groups)?; + self.ensure_groups(num_groups)?; + let session = ctx.session().clone(); + if let Some(kernel) = session + .aggregate_fns() + .find_grouped_kernel(batch.encoding_id(), self.aggregate_fn.id()) + && let Some(result) = + kernel.grouped_aggregate(&self.aggregate_fn, batch, group_ids, num_groups, ctx)? + { + return self.accumulate_kernel_result(result, num_groups, ctx); + } + + if self.vtable.try_accumulate_grouped( + &mut self.partials[..num_groups], + batch, + group_ids, + ctx, + )? { + return Ok(()); + } + + let input = batch.clone(); + let mut batch = batch.clone(); for _ in 0..max_iterations() { - if elements.is::() { + if batch.is::() { break; } - if let Some(result) = session + if let Some(kernel) = session .aggregate_fns() - .find_grouped_kernel(elements.encoding_id(), self.aggregate_fn.id()) - .and_then(|kernel| { - // SAFETY: we assume that elements execution is safe - let groups = unsafe { - ListViewArray::new_unchecked( - elements.clone(), - groups.offsets().clone(), - groups.sizes().clone(), - groups_validity.clone(), - ) - }; - kernel - .grouped_aggregate(&self.aggregate_fn, &groups) - .transpose() - }) - .transpose()? + .find_grouped_kernel(batch.encoding_id(), self.aggregate_fn.id()) + && let Some(result) = kernel.grouped_aggregate( + &self.aggregate_fn, + &batch, + group_ids, + num_groups, + ctx, + )? { - return self.push_result(result); + return self.accumulate_kernel_result(result, num_groups, ctx); } - // Execute one step and try again - elements = elements.execute(ctx)?; + batch = batch.execute(ctx)?; } - // Otherwise, we iterate the offsets and sizes and accumulate each group one by one. - let elements = elements.execute::(ctx)?.into_array(); - let offsets = groups.offsets(); - let sizes = groups.sizes().cast(offsets.dtype().clone())?; - let validity = groups_validity.execute_mask(offsets.len(), ctx)?; - - match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| { - let offsets = offsets.clone().execute::>(ctx)?; - let sizes = sizes.execute::>(ctx)?; - self.accumulate_list_view_typed( - &elements, - offsets.as_ref(), - sizes.as_ref(), - &validity, - ctx, - ) - }) + let columnar = batch.clone().execute::(ctx)?; + if self.vtable.accumulate_grouped( + &mut self.partials[..num_groups], + &columnar, + group_ids, + ctx, + )? { + return Ok(()); + } + + self.accumulate_fallback(&input, group_ids, ctx) } - fn accumulate_list_view_typed( + fn accumulate_partials( &mut self, - elements: &ArrayRef, - offsets: &[O], - sizes: &[O], - validity: &Mask, + partials: &ArrayRef, + group_ids: &[u32], + num_groups: usize, ctx: &mut ExecutionCtx, ) -> VortexResult<()> { - let mut accumulator = Accumulator::try_new( - self.vtable.clone(), - self.options.clone(), - self.dtype.clone(), - )?; - let mut states = builder_with_capacity(&self.partial_dtype, offsets.len()); - - // `validity` is the per-group list-view validity, so it is zipped element-wise with the - // offsets and sizes (one entry per group). - for ((offset, size), valid) in offsets.iter().zip(sizes.iter()).zip(validity.iter()) { - let offset = offset.to_usize().vortex_expect("Offset value is not usize"); - let size = size.to_usize().vortex_expect("Size value is not usize"); - - if valid { - let group = elements.slice(offset..offset + size)?; - accumulator.accumulate(&group, ctx)?; - states.append_scalar(&accumulator.flush()?)?; - } else { - states.append_null() - } - } + vortex_ensure!( + partials.dtype() == &self.partial_dtype, + "Partial DType mismatch: expected {}, got {}", + self.partial_dtype, + partials.dtype() + ); + vortex_ensure!( + partials.len() == group_ids.len(), + "Grouped aggregate partial length mismatch: {} partials, {} group ids", + partials.len(), + group_ids.len() + ); - self.push_result(states.finish()) + self.validate_group_ids(group_ids, num_groups)?; + self.ensure_groups(num_groups)?; + + for (row_idx, &group_id) in group_ids.iter().enumerate() { + let partial = partials.execute_scalar(row_idx, ctx)?; + self.vtable + .combine_partials(&mut self.partials[group_id as usize], partial)?; + } + Ok(()) } - fn accumulate_fixed_size_list( + fn merge_group( &mut self, - groups: &FixedSizeListArray, - ctx: &mut ExecutionCtx, + into: u32, + other: &dyn DynGroupedAccumulator, + from: u32, ) -> VortexResult<()> { - let mut elements = groups.elements().clone(); - let groups_validity = groups.validity()?; - let session = ctx.session().clone(); - - for _ in 0..64 { - if elements.is::() { - break; - } + vortex_ensure!( + other.partial_dtype() == &self.partial_dtype, + "Partial DType mismatch: expected {}, got {}", + self.partial_dtype, + other.partial_dtype() + ); + self.ensure_groups((into as usize) + 1)?; + let partial = other.partial_scalar(from)?; + self.vtable + .combine_partials(&mut self.partials[into as usize], partial) + } - if let Some(result) = session - .aggregate_fns() - .find_grouped_kernel(elements.encoding_id(), self.aggregate_fn.id()) - .and_then(|kernel| { - // SAFETY: we assume that elements execution is safe - let groups = unsafe { - FixedSizeListArray::new_unchecked( - elements.clone(), - groups.list_size(), - groups_validity.clone(), - groups.len(), - ) - }; - - kernel - .grouped_aggregate_fixed_size(&self.aggregate_fn, &groups) - .transpose() - }) - .transpose()? - { - return self.push_result(result); - } + fn partial_dtype(&self) -> &DType { + &self.partial_dtype + } - // Execute one step and try again - elements = elements.execute(ctx)?; + fn partial_scalar(&self, group_id: u32) -> VortexResult { + if let Some(partial) = self.partials.get(group_id as usize) { + self.vtable.to_scalar(partial) + } else { + let partial = self.vtable.empty_partial(&self.options, &self.dtype)?; + self.vtable.to_scalar(&partial) } + } - // Otherwise, we iterate the offsets and sizes and accumulate each group one by one. - let elements = elements.execute::(ctx)?.into_array(); - let validity = groups_validity.execute_mask(groups.len(), ctx)?; - - let mut accumulator = Accumulator::try_new( - self.vtable.clone(), - self.options.clone(), - self.dtype.clone(), - )?; - let mut states = builder_with_capacity(&self.partial_dtype, groups.len()); - - let mut offset = 0; - let size = groups - .list_size() - .to_usize() - .vortex_expect("List size is not usize"); - - for valid in validity.iter() { - if valid { - let group = elements.slice(offset..offset + size)?; - accumulator.accumulate(&group, ctx)?; - states.append_scalar(&accumulator.flush()?)?; - } else { - states.append_null() - } - offset += size; + fn flush_partials(&mut self, num_groups: usize) -> VortexResult { + vortex_ensure!( + num_groups >= self.partials.len(), + "Cannot flush {} groups after accumulating {} groups", + num_groups, + self.partials.len() + ); + self.ensure_groups(num_groups)?; + + let mut states = builder_with_capacity(&self.partial_dtype, num_groups); + for partial in &self.partials { + states.append_scalar(&self.vtable.to_scalar(partial)?)?; } + self.partials.clear(); - self.push_result(states.finish()) + Ok(states.finish()) } - fn push_result(&mut self, state: ArrayRef) -> VortexResult<()> { + fn finish(&mut self, num_groups: usize) -> VortexResult { + let states = self.flush_partials(num_groups)?; + let results = self.vtable.finalize(states)?; + vortex_ensure!( - state.dtype() == &self.partial_dtype, - "State DType mismatch: expected {}, got {}", - self.partial_dtype, - state.dtype() + results.dtype() == &self.return_dtype, + "Return DType mismatch: expected {}, got {}", + self.return_dtype, + results.dtype() ); - self.partials.push(state); - Ok(()) + + Ok(results) } } diff --git a/vortex-array/src/aggregate_fn/fns/count/mod.rs b/vortex-array/src/aggregate_fn/fns/count/mod.rs index e25c42e0845..53afa28d912 100644 --- a/vortex-array/src/aggregate_fn/fns/count/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/count/mod.rs @@ -82,6 +82,22 @@ impl AggregateFnVTable for Count { Ok(true) } + fn try_accumulate_grouped( + &self, + states: &mut [Self::Partial], + batch: &ArrayRef, + group_ids: &[u32], + ctx: &mut ExecutionCtx, + ) -> VortexResult { + let validity = batch.validity()?.execute_mask(batch.len(), ctx)?; + for (&group_id, valid) in group_ids.iter().zip(validity.iter()) { + if valid { + states[group_id as usize] += 1; + } + } + Ok(true) + } + fn accumulate( &self, _partial: &mut Self::Partial, @@ -114,11 +130,14 @@ mod tests { use crate::aggregate_fn::Accumulator; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::DynAccumulator; + use crate::aggregate_fn::DynGroupedAccumulator; use crate::aggregate_fn::EmptyOptions; + use crate::aggregate_fn::GroupedAccumulator; use crate::aggregate_fn::fns::count::Count; use crate::arrays::ChunkedArray; use crate::arrays::ConstantArray; use crate::arrays::PrimitiveArray; + use crate::assert_arrays_eq; use crate::dtype::DType; use crate::dtype::Nullability; use crate::dtype::PType; @@ -225,6 +244,43 @@ mod tests { Ok(()) } + #[test] + fn grouped_count_dense_ids() -> VortexResult<()> { + let values = + PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None, Some(6)]) + .into_array(); + let mut acc = GroupedAccumulator::try_new(Count, EmptyOptions, values.dtype().clone())?; + acc.accumulate( + &values, + &[0, 0, 1, 1, 2, 2], + 3, + &mut LEGACY_SESSION.create_execution_ctx(), + )?; + + let actual = acc.finish(3)?; + let expected = PrimitiveArray::from_iter([1u64, 2, 1]).into_array(); + assert_arrays_eq!(&actual, &expected); + Ok(()) + } + + #[test] + fn grouped_count_accumulate_partials_and_merge_group() -> VortexResult<()> { + let dtype = DType::Primitive(PType::I32, Nullability::Nullable); + let partials = PrimitiveArray::from_iter([2u64, 3, 5]).into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + + let mut left = GroupedAccumulator::try_new(Count, EmptyOptions, dtype.clone())?; + left.accumulate_partials(&partials, &[0, 1, 1], 2, &mut ctx)?; + + let mut right = GroupedAccumulator::try_new(Count, EmptyOptions, dtype)?; + right.merge_group(0, &left, 1)?; + + let actual = right.finish(1)?; + let expected = PrimitiveArray::from_iter([8u64]).into_array(); + assert_arrays_eq!(&actual, &expected); + Ok(()) + } + #[test] fn count_constant_non_null() -> VortexResult<()> { let array = ConstantArray::new(42i32, 10); diff --git a/vortex-array/src/aggregate_fn/fns/sum/mod.rs b/vortex-array/src/aggregate_fn/fns/sum/mod.rs index 24799570ff7..4e75a1390f3 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/mod.rs @@ -6,11 +6,15 @@ mod constant; mod decimal; mod primitive; +use num_traits::AsPrimitive; +use num_traits::ToPrimitive; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_err; use vortex_error::vortex_panic; +use vortex_mask::AllOr; +use vortex_mask::Mask; use self::bool::accumulate_bool; use self::constant::multiply_constant; @@ -25,14 +29,19 @@ use crate::aggregate_fn::AggregateFnId; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::DynAccumulator; use crate::aggregate_fn::EmptyOptions; +use crate::arrays::BoolArray; +use crate::arrays::PrimitiveArray; +use crate::arrays::bool::BoolArrayExt; use crate::dtype::DType; use crate::dtype::DecimalDType; use crate::dtype::MAX_PRECISION; +use crate::dtype::NativePType; use crate::dtype::Nullability; use crate::dtype::PType; use crate::expr::stats::Precision; use crate::expr::stats::Stat; use crate::expr::stats::StatsProvider; +use crate::match_each_native_ptype; use crate::scalar::DecimalValue; use crate::scalar::Scalar; @@ -253,6 +262,30 @@ impl AggregateFnVTable for Sum { Ok(()) } + fn accumulate_grouped( + &self, + partials: &mut [Self::Partial], + batch: &Columnar, + group_ids: &[u32], + ctx: &mut ExecutionCtx, + ) -> VortexResult { + match batch { + Columnar::Canonical(Canonical::Primitive(p)) => { + accumulate_grouped_primitive(partials, p, group_ids, ctx)?; + Ok(true) + } + Columnar::Canonical(Canonical::Bool(b)) => { + accumulate_grouped_bool(partials, b, group_ids, ctx)?; + Ok(true) + } + // Decimal and constants still use the universal grouped fallback. + Columnar::Canonical(Canonical::Decimal(_)) | Columnar::Constant(_) => Ok(false), + Columnar::Canonical(_) => { + vortex_bail!("Unsupported canonical type for sum: {}", batch.dtype()) + } + } + } + fn finalize(&self, partials: ArrayRef) -> VortexResult { Ok(partials) } @@ -299,6 +332,146 @@ fn make_zero_state(return_dtype: &DType) -> SumState { } } +fn for_each_valid_idx(validity: &Mask, len: usize, mut f: impl FnMut(usize)) { + match validity.indices() { + AllOr::All => { + for idx in 0..len { + f(idx); + } + } + AllOr::None => {} + AllOr::Some(indices) => { + for &idx in indices { + f(idx); + } + } + } +} + +fn accumulate_grouped_unsigned(partials: &mut [SumPartial], group_id: u32, value: u64) { + let partial = &mut partials[group_id as usize]; + let saturated = match partial.current.as_mut() { + None => return, + Some(SumState::Unsigned(acc)) => checked_add_u64(acc, value), + Some(_) => vortex_panic!("unsigned sum state with non-unsigned input"), + }; + if saturated { + partial.current = None; + } +} + +fn accumulate_grouped_signed(partials: &mut [SumPartial], group_id: u32, value: i64) { + let partial = &mut partials[group_id as usize]; + let saturated = match partial.current.as_mut() { + None => return, + Some(SumState::Signed(acc)) => checked_add_i64(acc, value), + Some(_) => vortex_panic!("signed sum state with non-signed input"), + }; + if saturated { + partial.current = None; + } +} + +fn accumulate_grouped_float(partials: &mut [SumPartial], group_id: u32, value: f64) { + if value.is_nan() { + return; + } + + match partials[group_id as usize].current.as_mut() { + None => {} + Some(SumState::Float(acc)) => *acc += value, + Some(_) => vortex_panic!("float sum state with non-float input"), + } +} + +fn accumulate_grouped_primitive( + partials: &mut [SumPartial], + primitive: &PrimitiveArray, + group_ids: &[u32], + ctx: &mut ExecutionCtx, +) -> VortexResult<()> { + let validity = primitive + .as_ref() + .validity()? + .execute_mask(primitive.as_ref().len(), ctx)?; + match_each_native_ptype!(primitive.ptype(), + unsigned: |T| { + accumulate_grouped_primitive_unsigned::(partials, primitive, group_ids, &validity); + Ok(()) + }, + signed: |T| { + accumulate_grouped_primitive_signed::(partials, primitive, group_ids, &validity); + Ok(()) + }, + floating: |T| { + accumulate_grouped_primitive_float::(partials, primitive, group_ids, &validity); + Ok(()) + } + ) +} + +fn accumulate_grouped_primitive_unsigned( + partials: &mut [SumPartial], + primitive: &PrimitiveArray, + group_ids: &[u32], + validity: &Mask, +) where + T: NativePType + AsPrimitive, +{ + let values = primitive.as_slice::(); + for_each_valid_idx(validity, values.len(), |idx| { + accumulate_grouped_unsigned(partials, group_ids[idx], values[idx].as_()); + }); +} + +fn accumulate_grouped_primitive_signed( + partials: &mut [SumPartial], + primitive: &PrimitiveArray, + group_ids: &[u32], + validity: &Mask, +) where + T: NativePType + AsPrimitive, +{ + let values = primitive.as_slice::(); + for_each_valid_idx(validity, values.len(), |idx| { + accumulate_grouped_signed(partials, group_ids[idx], values[idx].as_()); + }); +} + +fn accumulate_grouped_primitive_float( + partials: &mut [SumPartial], + primitive: &PrimitiveArray, + group_ids: &[u32], + validity: &Mask, +) where + T: NativePType + ToPrimitive, +{ + let values = primitive.as_slice::(); + for_each_valid_idx(validity, values.len(), |idx| { + let value = values[idx].to_f64().vortex_expect("float to f64"); + accumulate_grouped_float(partials, group_ids[idx], value); + }); +} + +fn accumulate_grouped_bool( + partials: &mut [SumPartial], + bools: &BoolArray, + group_ids: &[u32], + ctx: &mut ExecutionCtx, +) -> VortexResult<()> { + let validity = bools + .as_ref() + .validity()? + .execute_mask(bools.as_ref().len(), ctx)?; + let values = bools.to_bit_buffer(); + for_each_valid_idx(&validity, values.len(), |idx| { + if values.value(idx) { + accumulate_grouped_unsigned(partials, group_ids[idx], 1); + } + }); + Ok(()) +} + /// Checked add for u64, returning true if overflow occurred. #[inline(always)] fn checked_add_u64(acc: &mut u64, val: u64) -> bool { @@ -346,8 +519,6 @@ mod tests { use crate::arrays::ChunkedArray; use crate::arrays::ConstantArray; use crate::arrays::DecimalArray; - use crate::arrays::FixedSizeListArray; - use crate::arrays::ListViewArray; use crate::arrays::PrimitiveArray; use crate::assert_arrays_eq; use crate::dtype::DType; @@ -512,20 +683,26 @@ mod tests { // Grouped sum tests - fn run_grouped_sum(groups: &ArrayRef, elem_dtype: &DType) -> VortexResult { - let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype.clone())?; - acc.accumulate_list(groups, &mut LEGACY_SESSION.create_execution_ctx())?; - acc.finish() + fn run_grouped_sum( + values: &ArrayRef, + group_ids: &[u32], + num_groups: usize, + ) -> VortexResult { + let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, values.dtype().clone())?; + acc.accumulate( + values, + group_ids, + num_groups, + &mut LEGACY_SESSION.create_execution_ctx(), + )?; + acc.finish(num_groups) } #[test] - fn grouped_sum_fixed_size_list() -> VortexResult<()> { - let elements = + fn grouped_sum_dense_ids() -> VortexResult<()> { + let values = PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6], Validity::NonNullable).into_array(); - let groups = FixedSizeListArray::try_new(elements, 3, Validity::NonNullable, 2)?; - - let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?; + let result = run_grouped_sum(&values, &[0, 0, 0, 1, 1, 1], 2)?; let expected = PrimitiveArray::from_option_iter([Some(6i64), Some(15i64)]).into_array(); assert_arrays_eq!(&result, &expected); @@ -534,13 +711,10 @@ mod tests { #[test] fn grouped_sum_with_null_elements() -> VortexResult<()> { - let elements = + let values = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, Some(5), Some(6)]) .into_array(); - let groups = FixedSizeListArray::try_new(elements, 3, Validity::NonNullable, 2)?; - - let elem_dtype = DType::Primitive(PType::I32, Nullable); - let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?; + let result = run_grouped_sum(&values, &[0, 0, 0, 1, 1, 1], 2)?; let expected = PrimitiveArray::from_option_iter([Some(4i64), Some(11i64)]).into_array(); assert_arrays_eq!(&result, &expected); @@ -548,30 +722,22 @@ mod tests { } #[test] - fn grouped_sum_with_null_group() -> VortexResult<()> { - let elements = - PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6, 7, 8, 9], Validity::NonNullable) - .into_array(); - let validity = Validity::from_iter([true, false, true]); - let groups = FixedSizeListArray::try_new(elements, 3, validity, 3)?; - - let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?; + fn grouped_sum_empty_group() -> VortexResult<()> { + let values = + PrimitiveArray::new(buffer![1i32, 2, 3, 7, 8, 9], Validity::NonNullable).into_array(); + let result = run_grouped_sum(&values, &[0, 0, 0, 2, 2, 2], 3)?; let expected = - PrimitiveArray::from_option_iter([Some(6i64), None, Some(24i64)]).into_array(); + PrimitiveArray::from_option_iter([Some(6i64), Some(0i64), Some(24i64)]).into_array(); assert_arrays_eq!(&result, &expected); Ok(()) } #[test] fn grouped_sum_all_null_elements_in_group() -> VortexResult<()> { - let elements = + let values = PrimitiveArray::from_option_iter([None::, None, Some(3), Some(4)]).into_array(); - let groups = FixedSizeListArray::try_new(elements, 2, Validity::NonNullable, 2)?; - - let elem_dtype = DType::Primitive(PType::I32, Nullable); - let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?; + let result = run_grouped_sum(&values, &[0, 0, 1, 1], 2)?; let expected = PrimitiveArray::from_option_iter([Some(0i64), Some(7i64)]).into_array(); assert_arrays_eq!(&result, &expected); @@ -580,12 +746,8 @@ mod tests { #[test] fn grouped_sum_bool() -> VortexResult<()> { - let elements: BoolArray = [true, false, true, true, true, true].into_iter().collect(); - let groups = - FixedSizeListArray::try_new(elements.into_array(), 3, Validity::NonNullable, 2)?; - - let elem_dtype = DType::Bool(Nullability::NonNullable); - let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?; + let values: BoolArray = [true, false, true, true, true, true].into_iter().collect(); + let result = run_grouped_sum(&values.into_array(), &[0, 0, 0, 1, 1, 1], 2)?; let expected = PrimitiveArray::from_option_iter([Some(2u64), Some(3u64)]).into_array(); assert_arrays_eq!(&result, &expected); @@ -598,19 +760,17 @@ mod tests { let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable); let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype)?; - let elements1 = + let values1 = PrimitiveArray::new(buffer![1i32, 2, 3, 4], Validity::NonNullable).into_array(); - let groups1 = FixedSizeListArray::try_new(elements1, 2, Validity::NonNullable, 2)?; - acc.accumulate_list(&groups1.into_array(), &mut ctx)?; - let result1 = acc.finish()?; + acc.accumulate(&values1, &[0, 0, 1, 1], 2, &mut ctx)?; + let result1 = acc.finish(2)?; let expected1 = PrimitiveArray::from_option_iter([Some(3i64), Some(7i64)]).into_array(); assert_arrays_eq!(&result1, &expected1); - let elements2 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array(); - let groups2 = FixedSizeListArray::try_new(elements2, 2, Validity::NonNullable, 1)?; - acc.accumulate_list(&groups2.into_array(), &mut ctx)?; - let result2 = acc.finish()?; + let values2 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array(); + acc.accumulate(&values2, &[0, 0], 1, &mut ctx)?; + let result2 = acc.finish(1)?; let expected2 = PrimitiveArray::from_option_iter([Some(30i64)]).into_array(); assert_arrays_eq!(&result2, &expected2); @@ -618,20 +778,13 @@ mod tests { } #[test] - fn grouped_sum_listview_out_of_order_offsets_with_null_group() -> VortexResult<()> { - let elements = + fn grouped_sum_out_of_order_group_ids() -> VortexResult<()> { + let values = PrimitiveArray::new(buffer![100i32, 200, 300], Validity::NonNullable).into_array(); - let offsets = PrimitiveArray::new(buffer![2i32, 0, 1], Validity::NonNullable).into_array(); - let sizes = PrimitiveArray::new(buffer![1i32, 1, 1], Validity::NonNullable).into_array(); - let validity = Validity::from_iter([true, false, true]); - let groups = ListViewArray::try_new(elements, offsets, sizes, validity)?.into_array(); - - let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let result = run_grouped_sum(&groups, &elem_dtype)?; + let result = run_grouped_sum(&values, &[2, 0, 1], 3)?; - // group 0 -> elements[2..3] = 300; group 1 -> null; group 2 -> elements[1..2] = 200. let expected = - PrimitiveArray::from_option_iter([Some(300i64), None, Some(200i64)]).into_array(); + PrimitiveArray::from_option_iter([Some(200i64), Some(300), Some(100)]).into_array(); assert_arrays_eq!(&result, &expected); Ok(()) } diff --git a/vortex-array/src/aggregate_fn/kernels.rs b/vortex-array/src/aggregate_fn/kernels.rs index d806b18d84d..23ad6a934e5 100644 --- a/vortex-array/src/aggregate_fn/kernels.rs +++ b/vortex-array/src/aggregate_fn/kernels.rs @@ -6,13 +6,12 @@ use std::fmt::Debug; +use vortex_buffer::Buffer; use vortex_error::VortexResult; use crate::ArrayRef; use crate::ExecutionCtx; use crate::aggregate_fn::AggregateFnRef; -use crate::arrays::FixedSizeListArray; -use crate::arrays::ListViewArray; use crate::scalar::Scalar; /// A pluggable kernel for an aggregate function. @@ -28,36 +27,49 @@ pub trait DynAggregateKernel: 'static + Send + Sync + Debug { ) -> VortexResult>; } -/// A pluggable kernel for batch aggregation of many groups. +/// Partial grouped aggregate output produced by an encoding-specific grouped kernel. /// -/// The kernel is matched on the encoding of the _elements_ array, which is the inner array of the -/// provided `ListViewArray`. This is more pragmatic than having every kernel match on the outer -/// list encoding and having to deal with the possibility of multiple list encodings. +/// `group_ids` is parallel to `partials`: each row in `partials` is a partial state for the +/// corresponding dense group id. The grouped accumulator merges this batch through +/// `accumulate_partials`. +#[derive(Clone, Debug)] +pub struct GroupedAggregateKernelResult { + group_ids: Buffer, + partials: ArrayRef, +} + +impl GroupedAggregateKernelResult { + pub fn new(group_ids: Buffer, partials: ArrayRef) -> Self { + Self { + group_ids, + partials, + } + } + + pub fn group_ids(&self) -> &[u32] { + self.group_ids.as_ref() + } + + pub fn partials(&self) -> &ArrayRef { + &self.partials + } +} + +/// A pluggable kernel for batch aggregation of many groups. /// -/// Each element of the list array represents a group and the result of the grouped aggregate -/// should be an array of the same length, where each element is the aggregate state of the -/// corresponding group. +/// The kernel is matched on the encoding of the values array. It receives the same dense group ids +/// that the caller passed to the grouped accumulator and may aggregate directly in the encoded +/// domain. /// /// Return `Ok(None)` if the kernel cannot be applied to the given aggregate function. pub trait DynGroupedAggregateKernel: 'static + Send + Sync + Debug { - /// Aggregate each group in the provided `ListViewArray` and return an array of the - /// aggregate states. + /// Aggregate values into a partial-state batch keyed by dense group id. fn grouped_aggregate( &self, aggregate_fn: &AggregateFnRef, - groups: &ListViewArray, - ) -> VortexResult>; - - /// Aggregate each group in the provided `FixedSizeListArray` and return an array of the - /// aggregate states. - fn grouped_aggregate_fixed_size( - &self, - aggregate_fn: &AggregateFnRef, - groups: &FixedSizeListArray, - ) -> VortexResult> { - // TODO(ngates): we could automatically delegate to `grouped_aggregate` if SequenceArray - // was in the vortex-array crate - let _ = (aggregate_fn, groups); - Ok(None) - } + batch: &ArrayRef, + group_ids: &[u32], + num_groups: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult>; } diff --git a/vortex-array/src/aggregate_fn/vtable.rs b/vortex-array/src/aggregate_fn/vtable.rs index 28b91d45166..b6c0915f2e7 100644 --- a/vortex-array/src/aggregate_fn/vtable.rs +++ b/vortex-array/src/aggregate_fn/vtable.rs @@ -146,6 +146,35 @@ pub trait AggregateFnVTable: 'static + Sized + Clone + Send + Sync { ctx: &mut ExecutionCtx, ) -> VortexResult<()>; + /// Try to accumulate a raw values batch into dense per-group states before decompression. + /// + /// `group_ids` is parallel to `batch` and contains dense ids in `0..states.len()`. Returns + /// `true` when the batch was fully handled. + fn try_accumulate_grouped( + &self, + _states: &mut [Self::Partial], + _batch: &ArrayRef, + _group_ids: &[u32], + _ctx: &mut ExecutionCtx, + ) -> VortexResult { + Ok(false) + } + + /// Accumulate a canonical values batch into dense per-group states. + /// + /// `group_ids` is parallel to `batch` and contains dense ids in `0..states.len()`. Returns + /// `true` when the batch was fully handled. The provided default preserves universal + /// correctness through [`GroupedAccumulator`]'s fallback. + fn accumulate_grouped( + &self, + _states: &mut [Self::Partial], + _batch: &Columnar, + _group_ids: &[u32], + _ctx: &mut ExecutionCtx, + ) -> VortexResult { + Ok(false) + } + /// Finalize an array of accumulator states into an array of aggregate results. /// /// The provides `states` array has dtype as specified by `state_dtype`, the result array From 9bd157fe5670f24592c01f7cb9fe35b9058a1df6 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Thu, 11 Jun 2026 17:42:21 -0400 Subject: [PATCH 2/6] Clarify dense grouped aggregate ids Signed-off-by: "Nicholas Gates" --- .../src/aggregate_fn/accumulator_grouped.rs | 12 ++++++++++-- vortex-array/src/aggregate_fn/fns/count/mod.rs | 10 ++++++++++ vortex-array/src/aggregate_fn/kernels.rs | 13 +++++++------ vortex-array/src/aggregate_fn/vtable.rs | 10 ++++++---- 4 files changed, 33 insertions(+), 12 deletions(-) diff --git a/vortex-array/src/aggregate_fn/accumulator_grouped.rs b/vortex-array/src/aggregate_fn/accumulator_grouped.rs index 66da32b4085..990a420eecb 100644 --- a/vortex-array/src/aggregate_fn/accumulator_grouped.rs +++ b/vortex-array/src/aggregate_fn/accumulator_grouped.rs @@ -28,8 +28,10 @@ pub type GroupedAccumulatorRef = Box; /// An accumulator used for computing aggregates over dense group ids. /// -/// Group ids are dense `u32` slots in the range `0..num_groups`. The accumulator keeps one partial -/// state per slot, so ordered and unordered grouping only differ in how the caller assigns ids. +/// Group ids are caller-assigned `u32` ordinals in the dense range `0..num_groups`. Input batches +/// may repeat, omit, and reorder those ids, but every id must identify a state slot rather than a +/// raw group key. The accumulator keeps one partial state per slot, so ordered and unordered +/// grouping only differ in how the caller assigns ids. pub struct GroupedAccumulator { /// The vtable of the aggregate function. vtable: V, @@ -167,6 +169,9 @@ impl GroupedAccumulator { /// aggregate function is not known at compile time. pub trait DynGroupedAccumulator: 'static + Send { /// Accumulate a values batch into dense group state. + /// + /// `group_ids` is parallel to `batch`. Each id must be a caller-assigned group ordinal in + /// `0..num_groups`; ids may repeat, appear out of order, or be absent from a given batch. fn accumulate( &mut self, batch: &ArrayRef, @@ -176,6 +181,9 @@ pub trait DynGroupedAccumulator: 'static + Send { ) -> VortexResult<()>; /// Fold columnar partial states into dense group state. + /// + /// `group_ids` is parallel to `partials` and follows the same dense ordinal contract as + /// [`Self::accumulate`]. fn accumulate_partials( &mut self, partials: &ArrayRef, diff --git a/vortex-array/src/aggregate_fn/fns/count/mod.rs b/vortex-array/src/aggregate_fn/fns/count/mod.rs index 53afa28d912..07395211ca8 100644 --- a/vortex-array/src/aggregate_fn/fns/count/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/count/mod.rs @@ -263,6 +263,16 @@ mod tests { Ok(()) } + #[test] + fn grouped_count_rejects_out_of_range_group_id() -> VortexResult<()> { + let values = PrimitiveArray::new(buffer![1i32, 2], Validity::NonNullable).into_array(); + let mut acc = GroupedAccumulator::try_new(Count, EmptyOptions, values.dtype().clone())?; + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + + assert!(acc.accumulate(&values, &[0, 2], 2, &mut ctx).is_err()); + Ok(()) + } + #[test] fn grouped_count_accumulate_partials_and_merge_group() -> VortexResult<()> { let dtype = DType::Primitive(PType::I32, Nullability::Nullable); diff --git a/vortex-array/src/aggregate_fn/kernels.rs b/vortex-array/src/aggregate_fn/kernels.rs index 23ad6a934e5..91248091437 100644 --- a/vortex-array/src/aggregate_fn/kernels.rs +++ b/vortex-array/src/aggregate_fn/kernels.rs @@ -30,8 +30,9 @@ pub trait DynAggregateKernel: 'static + Send + Sync + Debug { /// Partial grouped aggregate output produced by an encoding-specific grouped kernel. /// /// `group_ids` is parallel to `partials`: each row in `partials` is a partial state for the -/// corresponding dense group id. The grouped accumulator merges this batch through -/// `accumulate_partials`. +/// corresponding dense group ordinal. The ids may repeat, omit, and reorder groups, but must be +/// valid slots in the accumulator's `0..num_groups` range. The grouped accumulator merges this +/// batch through `accumulate_partials`. #[derive(Clone, Debug)] pub struct GroupedAggregateKernelResult { group_ids: Buffer, @@ -57,13 +58,13 @@ impl GroupedAggregateKernelResult { /// A pluggable kernel for batch aggregation of many groups. /// -/// The kernel is matched on the encoding of the values array. It receives the same dense group ids -/// that the caller passed to the grouped accumulator and may aggregate directly in the encoded -/// domain. +/// The kernel is matched on the encoding of the values array. It receives the same dense group +/// ordinals that the caller passed to the grouped accumulator and may aggregate directly in the +/// encoded domain. /// /// Return `Ok(None)` if the kernel cannot be applied to the given aggregate function. pub trait DynGroupedAggregateKernel: 'static + Send + Sync + Debug { - /// Aggregate values into a partial-state batch keyed by dense group id. + /// Aggregate values into a partial-state batch keyed by dense group ordinal. fn grouped_aggregate( &self, aggregate_fn: &AggregateFnRef, diff --git a/vortex-array/src/aggregate_fn/vtable.rs b/vortex-array/src/aggregate_fn/vtable.rs index b6c0915f2e7..e30f41f012e 100644 --- a/vortex-array/src/aggregate_fn/vtable.rs +++ b/vortex-array/src/aggregate_fn/vtable.rs @@ -148,8 +148,9 @@ pub trait AggregateFnVTable: 'static + Sized + Clone + Send + Sync { /// Try to accumulate a raw values batch into dense per-group states before decompression. /// - /// `group_ids` is parallel to `batch` and contains dense ids in `0..states.len()`. Returns - /// `true` when the batch was fully handled. + /// `group_ids` is parallel to `batch` and contains caller-assigned dense ordinals in + /// `0..states.len()`. Ids may repeat, appear out of order, or be absent from the batch. + /// Returns `true` when the batch was fully handled. fn try_accumulate_grouped( &self, _states: &mut [Self::Partial], @@ -162,8 +163,9 @@ pub trait AggregateFnVTable: 'static + Sized + Clone + Send + Sync { /// Accumulate a canonical values batch into dense per-group states. /// - /// `group_ids` is parallel to `batch` and contains dense ids in `0..states.len()`. Returns - /// `true` when the batch was fully handled. The provided default preserves universal + /// `group_ids` is parallel to `batch` and contains caller-assigned dense ordinals in + /// `0..states.len()`. Ids may repeat, appear out of order, or be absent from the batch. + /// Returns `true` when the batch was fully handled. The provided default preserves universal /// correctness through [`GroupedAccumulator`]'s fallback. fn accumulate_grouped( &self, From 50701b29f24915144a50e4fe6d88942b5d0808cc Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Thu, 11 Jun 2026 18:04:41 -0400 Subject: [PATCH 3/6] Fix grouped aggregate CI checks Signed-off-by: "Nicholas Gates" --- AGENTS.md | 8 +++++++ .../src/aggregate_fn/accumulator_grouped.rs | 21 ++++++++++--------- vortex-array/src/aggregate_fn/vtable.rs | 2 +- 3 files changed, 20 insertions(+), 11 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index e5c3d0cc13b..2a1ad73df22 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -113,6 +113,14 @@ cargo +nightly fmt --all cargo clippy --all-targets --all-features ``` +Do not push Rust code changes before running the applicable lint command above. If the change adds +or edits Rustdoc on public APIs, also run the CI docs command so broken intra-doc links are caught +locally: + +```bash +RUSTDOCFLAGS="-D warnings" cargo doc --profile ci --no-deps +``` + Notes: - For `.github/` changes, follow `.github/AGENTS.md` and run diff --git a/vortex-array/src/aggregate_fn/accumulator_grouped.rs b/vortex-array/src/aggregate_fn/accumulator_grouped.rs index 990a420eecb..938ee2b0dd6 100644 --- a/vortex-array/src/aggregate_fn/accumulator_grouped.rs +++ b/vortex-array/src/aggregate_fn/accumulator_grouped.rs @@ -79,11 +79,7 @@ impl GroupedAccumulator { } fn ensure_groups(&mut self, num_groups: usize) -> VortexResult<()> { - vortex_ensure!( - num_groups <= (u32::MAX as usize) + 1, - "num_groups {} exceeds dense u32 group id capacity", - num_groups - ); + validate_num_groups(num_groups)?; while self.partials.len() < num_groups { self.partials @@ -93,11 +89,7 @@ impl GroupedAccumulator { } fn validate_group_ids(&self, group_ids: &[u32], num_groups: usize) -> VortexResult<()> { - vortex_ensure!( - num_groups <= (u32::MAX as usize) + 1, - "num_groups {} exceeds dense u32 group id capacity", - num_groups - ); + validate_num_groups(num_groups)?; for &group_id in group_ids { vortex_ensure!( (group_id as usize) < num_groups, @@ -165,6 +157,15 @@ impl GroupedAccumulator { } } +fn validate_num_groups(num_groups: usize) -> VortexResult<()> { + vortex_ensure!( + num_groups == 0 || u32::try_from(num_groups - 1).is_ok(), + "num_groups {} exceeds dense u32 group id capacity", + num_groups + ); + Ok(()) +} + /// A trait object for type-erased grouped accumulators, used for dynamic dispatch when the /// aggregate function is not known at compile time. pub trait DynGroupedAccumulator: 'static + Send { diff --git a/vortex-array/src/aggregate_fn/vtable.rs b/vortex-array/src/aggregate_fn/vtable.rs index e30f41f012e..24c2113e64a 100644 --- a/vortex-array/src/aggregate_fn/vtable.rs +++ b/vortex-array/src/aggregate_fn/vtable.rs @@ -166,7 +166,7 @@ pub trait AggregateFnVTable: 'static + Sized + Clone + Send + Sync { /// `group_ids` is parallel to `batch` and contains caller-assigned dense ordinals in /// `0..states.len()`. Ids may repeat, appear out of order, or be absent from the batch. /// Returns `true` when the batch was fully handled. The provided default preserves universal - /// correctness through [`GroupedAccumulator`]'s fallback. + /// correctness through [`crate::aggregate_fn::GroupedAccumulator`]'s fallback. fn accumulate_grouped( &self, _states: &mut [Self::Partial], From adae76e3122b5c71387eaef1de187b07694bb6e1 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Thu, 11 Jun 2026 18:05:45 -0400 Subject: [PATCH 4/6] DCO Remediation Commit for Nicholas Gates I, Nicholas Gates , hereby add my Signed-off-by to this commit: 6bdcd32b62d3852247efe48a7a3a2af7d7fd922c I, Nicholas Gates , hereby add my Signed-off-by to this commit: 9bd157fe5670f24592c01f7cb9fe35b9058a1df6 I, Nicholas Gates , hereby add my Signed-off-by to this commit: 50701b29f24915144a50e4fe6d88942b5d0808cc Signed-off-by: Nicholas Gates --- AGENTS.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/AGENTS.md b/AGENTS.md index 2a1ad73df22..759008d730b 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -198,5 +198,8 @@ you ran and call out any checks you could not run. All commits must be signed off by the committers in this form: ```text -Signed-off-by: "COMMITTER" +Signed-off-by: COMMITTER ``` + +Do not wrap the committer name in quotes; the DCO check expects the exact unquoted name/email +pair from the commit author. From c61b7bbabb015e7518a4138d85c34623d4200435 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Fri, 12 Jun 2026 14:26:49 -0400 Subject: [PATCH 5/6] Restore dense grouped aggregate test coverage Signed-off-by: Nicholas Gates --- .../src/aggregate_fn/fns/count/mod.rs | 53 ++++++++++++++++--- vortex-array/src/aggregate_fn/fns/sum/mod.rs | 33 ++++++++++++ 2 files changed, 78 insertions(+), 8 deletions(-) diff --git a/vortex-array/src/aggregate_fn/fns/count/mod.rs b/vortex-array/src/aggregate_fn/fns/count/mod.rs index 1ce3588a6ef..60896476e91 100644 --- a/vortex-array/src/aggregate_fn/fns/count/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/count/mod.rs @@ -132,6 +132,7 @@ mod tests { use crate::arrays::ChunkedArray; use crate::arrays::ConstantArray; use crate::arrays::PrimitiveArray; + use crate::arrays::VarBinViewArray; use crate::assert_arrays_eq; use crate::dtype::DType; use crate::dtype::Nullability; @@ -239,25 +240,61 @@ mod tests { Ok(()) } + fn run_grouped_count( + values: &ArrayRef, + group_ids: &[u32], + num_groups: usize, + ) -> VortexResult { + let mut acc = GroupedAccumulator::try_new(Count, EmptyOptions, values.dtype().clone())?; + acc.accumulate( + values, + group_ids, + num_groups, + &mut LEGACY_SESSION.create_execution_ctx(), + )?; + acc.finish(num_groups) + } + #[test] fn grouped_count_dense_ids() -> VortexResult<()> { let values = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None, Some(6)]) .into_array(); - let mut acc = GroupedAccumulator::try_new(Count, EmptyOptions, values.dtype().clone())?; - acc.accumulate( - &values, - &[0, 0, 1, 1, 2, 2], - 3, - &mut LEGACY_SESSION.create_execution_ctx(), - )?; + let actual = run_grouped_count(&values, &[0, 0, 1, 1, 2, 2], 3)?; - let actual = acc.finish(3)?; let expected = PrimitiveArray::from_iter([1u64, 2, 1]).into_array(); assert_arrays_eq!(&actual, &expected); Ok(()) } + #[test] + fn grouped_count_omitted_group() -> VortexResult<()> { + let values = + PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6], Validity::NonNullable).into_array(); + let actual = run_grouped_count(&values, &[0, 0, 1, 2, 2, 2], 4)?; + + let expected = PrimitiveArray::from_iter([2u64, 1, 3, 0]).into_array(); + assert_arrays_eq!(&actual, &expected); + Ok(()) + } + + #[test] + fn grouped_count_varbinview_with_nulls() -> VortexResult<()> { + let values = VarBinViewArray::from_iter_nullable_str([ + Some("a"), + None, + Some("bbb"), + None, + Some("cc"), + ]) + .into_array(); + let actual = run_grouped_count(&values, &[0, 0, 1, 1, 2], 3)?; + + let expected = PrimitiveArray::from_iter([1u64, 1, 1]).into_array(); + assert_arrays_eq!(&actual, &expected); + Ok(()) + } + #[test] fn grouped_count_rejects_out_of_range_group_id() -> VortexResult<()> { let values = PrimitiveArray::new(buffer![1i32, 2], Validity::NonNullable).into_array(); diff --git a/vortex-array/src/aggregate_fn/fns/sum/mod.rs b/vortex-array/src/aggregate_fn/fns/sum/mod.rs index e26c07c1f1a..8626e810399 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/mod.rs @@ -641,6 +641,39 @@ mod tests { Ok(()) } + #[test] + fn grouped_sum_overflow_group_is_null() -> VortexResult<()> { + let values = + PrimitiveArray::new(buffer![i64::MAX, 1, 2, 3], Validity::NonNullable).into_array(); + let result = run_grouped_sum(&values, &[0, 0, 1, 1], 2)?; + + let expected = PrimitiveArray::from_option_iter([None, Some(5i64)]).into_array(); + assert_arrays_eq!(&result, &expected); + Ok(()) + } + + #[test] + fn grouped_sum_float_nan_and_inf() -> VortexResult<()> { + let values = PrimitiveArray::new( + buffer![1.0f64, f64::NAN, 2.0, f64::INFINITY, f64::NEG_INFINITY, 4.0], + Validity::NonNullable, + ) + .into_array(); + let actual = run_grouped_sum(&values, &[0, 0, 0, 1, 1, 1], 2)?; + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + + let g0 = actual.execute_scalar(0, &mut ctx)?; + assert_eq!(g0.as_primitive().typed_value::(), Some(3.0)); + + let g1 = actual.execute_scalar(1, &mut ctx)?; + let g1_value = g1 + .as_primitive() + .typed_value::() + .vortex_expect("group sum should be non-null"); + assert!(g1_value.is_nan()); + Ok(()) + } + // Chunked array tests #[test] From 2ef64b2397d508460bda2f6d27a22b4835f20a3a Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Fri, 12 Jun 2026 16:46:40 -0400 Subject: [PATCH 6/6] Optimize dense grouped aggregates Signed-off-by: Nicholas Gates --- .../src/aggregate_fn/accumulator_grouped.rs | 14 ++ .../src/aggregate_fn/fns/count/mod.rs | 12 ++ .../src/aggregate_fn/fns/sum/grouped.rs | 142 +++++++++++++++++- vortex-array/src/aggregate_fn/fns/sum/mod.rs | 89 +++++++++++ vortex-array/src/aggregate_fn/vtable.rs | 11 ++ 5 files changed, 265 insertions(+), 3 deletions(-) diff --git a/vortex-array/src/aggregate_fn/accumulator_grouped.rs b/vortex-array/src/aggregate_fn/accumulator_grouped.rs index bf771609ac8..7a614ceed63 100644 --- a/vortex-array/src/aggregate_fn/accumulator_grouped.rs +++ b/vortex-array/src/aggregate_fn/accumulator_grouped.rs @@ -385,6 +385,20 @@ impl DynGroupedAccumulator for GroupedAccumulator { ); self.ensure_groups(num_groups)?; + if let Some(states) = self + .vtable + .partials_to_array(&self.partials, &self.partial_dtype)? + { + vortex_ensure!( + states.dtype() == &self.partial_dtype, + "Partial array DType mismatch: expected {}, got {}", + self.partial_dtype, + states.dtype() + ); + self.partials.clear(); + return Ok(states); + } + let mut states = builder_with_capacity(&self.partial_dtype, num_groups); for partial in &self.partials { states.append_scalar(&self.vtable.to_scalar(partial)?)?; diff --git a/vortex-array/src/aggregate_fn/fns/count/mod.rs b/vortex-array/src/aggregate_fn/fns/count/mod.rs index 60896476e91..e53a378b5a9 100644 --- a/vortex-array/src/aggregate_fn/fns/count/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/count/mod.rs @@ -8,9 +8,11 @@ use vortex_error::VortexResult; use crate::ArrayRef; use crate::Columnar; use crate::ExecutionCtx; +use crate::IntoArray; use crate::aggregate_fn::AggregateFnId; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::EmptyOptions; +use crate::arrays::PrimitiveArray; use crate::dtype::DType; use crate::dtype::Nullability; use crate::dtype::PType; @@ -64,6 +66,16 @@ impl AggregateFnVTable for Count { Ok(Scalar::primitive(*partial, Nullability::NonNullable)) } + fn partials_to_array( + &self, + partials: &[Self::Partial], + _partial_dtype: &DType, + ) -> VortexResult> { + Ok(Some( + PrimitiveArray::from_iter(partials.iter().copied()).into_array(), + )) + } + fn reset(&self, partial: &mut Self::Partial) { *partial = 0; } diff --git a/vortex-array/src/aggregate_fn/fns/sum/grouped.rs b/vortex-array/src/aggregate_fn/fns/sum/grouped.rs index 432c48b4cf4..81304f1eb9f 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/grouped.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/grouped.rs @@ -13,6 +13,9 @@ use super::SumPartial; use super::SumState; use super::checked_add_i64; use super::checked_add_u64; +use super::primitive::sum_float_all; +use super::primitive::sum_signed_all; +use super::primitive::sum_unsigned_all; use crate::ExecutionCtx; use crate::arrays::BoolArray; use crate::arrays::PrimitiveArray; @@ -20,6 +23,8 @@ use crate::arrays::bool::BoolArrayExt; use crate::dtype::NativePType; use crate::match_each_native_ptype; +const MIN_AVG_RUN_LENGTH_FOR_GROUPED_SUM_RUNS: usize = 4; + fn for_each_valid_idx(validity: &Mask, len: usize, mut f: impl FnMut(usize)) { match validity.indices() { AllOr::All => { @@ -36,6 +41,41 @@ fn for_each_valid_idx(validity: &Mask, len: usize, mut f: impl FnMut(usize)) { } } +fn should_accumulate_group_runs(group_ids: &[u32]) -> bool { + let Some((&first, rest)) = group_ids.split_first() else { + return false; + }; + + let mut run_count = 1usize; + let mut group_id = first; + for &next_group_id in rest { + if next_group_id != group_id { + run_count += 1; + group_id = next_group_id; + } + } + + run_count * MIN_AVG_RUN_LENGTH_FOR_GROUPED_SUM_RUNS <= group_ids.len() +} + +fn for_each_group_run(group_ids: &[u32], mut f: impl FnMut(u32, usize, usize)) { + let Some((&first, rest)) = group_ids.split_first() else { + return; + }; + + let mut group_id = first; + let mut start = 0usize; + for (idx, &next_group_id) in rest.iter().enumerate() { + let idx = idx + 1; + if next_group_id != group_id { + f(group_id, start, idx); + group_id = next_group_id; + start = idx; + } + } + f(group_id, start, group_ids.len()); +} + fn accumulate_grouped_unsigned(partials: &mut [SumPartial], group_id: u32, value: u64) { let partial = &mut partials[group_id as usize]; let saturated = match partial.current.as_mut() { @@ -48,6 +88,21 @@ fn accumulate_grouped_unsigned(partials: &mut [SumPartial], group_id: u32, value } } +fn accumulate_grouped_unsigned_run(partials: &mut [SumPartial], group_id: u32, values: &[T]) +where + T: NativePType + AsPrimitive, +{ + let partial = &mut partials[group_id as usize]; + let saturated = match partial.current.as_mut() { + None => return, + Some(SumState::Unsigned(acc)) => sum_unsigned_all(acc, values), + Some(_) => vortex_panic!("unsigned sum state with non-unsigned input"), + }; + if saturated { + partial.current = None; + } +} + fn accumulate_grouped_signed(partials: &mut [SumPartial], group_id: u32, value: i64) { let partial = &mut partials[group_id as usize]; let saturated = match partial.current.as_mut() { @@ -60,6 +115,21 @@ fn accumulate_grouped_signed(partials: &mut [SumPartial], group_id: u32, value: } } +fn accumulate_grouped_signed_run(partials: &mut [SumPartial], group_id: u32, values: &[T]) +where + T: NativePType + AsPrimitive, +{ + let partial = &mut partials[group_id as usize]; + let saturated = match partial.current.as_mut() { + None => return, + Some(SumState::Signed(acc)) => sum_signed_all(acc, values), + Some(_) => vortex_panic!("signed sum state with non-signed input"), + }; + if saturated { + partial.current = None; + } +} + fn accumulate_grouped_float(partials: &mut [SumPartial], group_id: u32, value: f64) { if value.is_nan() { return; @@ -72,6 +142,18 @@ fn accumulate_grouped_float(partials: &mut [SumPartial], group_id: u32, value: f } } +fn accumulate_grouped_float_run( + partials: &mut [SumPartial], + group_id: u32, + values: &[T], +) { + match partials[group_id as usize].current.as_mut() { + None => {} + Some(SumState::Float(acc)) => sum_float_all(acc, values), + Some(_) => vortex_panic!("float sum state with non-float input"), + } +} + pub(super) fn accumulate_grouped_primitive( partials: &mut [SumPartial], primitive: &PrimitiveArray, @@ -82,17 +164,32 @@ pub(super) fn accumulate_grouped_primitive( .as_ref() .validity()? .execute_mask(primitive.as_ref().len(), ctx)?; + let use_runs = + matches!(validity.slices(), AllOr::All) && should_accumulate_group_runs(group_ids); + match_each_native_ptype!(primitive.ptype(), unsigned: |T| { - accumulate_grouped_primitive_unsigned::(partials, primitive, group_ids, &validity); + if use_runs { + accumulate_grouped_primitive_unsigned_runs::(partials, primitive, group_ids); + } else { + accumulate_grouped_primitive_unsigned::(partials, primitive, group_ids, &validity); + } Ok(()) }, signed: |T| { - accumulate_grouped_primitive_signed::(partials, primitive, group_ids, &validity); + if use_runs { + accumulate_grouped_primitive_signed_runs::(partials, primitive, group_ids); + } else { + accumulate_grouped_primitive_signed::(partials, primitive, group_ids, &validity); + } Ok(()) }, floating: |T| { - accumulate_grouped_primitive_float::(partials, primitive, group_ids, &validity); + if use_runs { + accumulate_grouped_primitive_float_runs::(partials, primitive, group_ids); + } else { + accumulate_grouped_primitive_float::(partials, primitive, group_ids, &validity); + } Ok(()) } ) @@ -112,6 +209,19 @@ fn accumulate_grouped_primitive_unsigned( }); } +fn accumulate_grouped_primitive_unsigned_runs( + partials: &mut [SumPartial], + primitive: &PrimitiveArray, + group_ids: &[u32], +) where + T: NativePType + AsPrimitive, +{ + let values = primitive.as_slice::(); + for_each_group_run(group_ids, |group_id, start, end| { + accumulate_grouped_unsigned_run(partials, group_id, &values[start..end]); + }); +} + fn accumulate_grouped_primitive_signed( partials: &mut [SumPartial], primitive: &PrimitiveArray, @@ -126,6 +236,19 @@ fn accumulate_grouped_primitive_signed( }); } +fn accumulate_grouped_primitive_signed_runs( + partials: &mut [SumPartial], + primitive: &PrimitiveArray, + group_ids: &[u32], +) where + T: NativePType + AsPrimitive, +{ + let values = primitive.as_slice::(); + for_each_group_run(group_ids, |group_id, start, end| { + accumulate_grouped_signed_run(partials, group_id, &values[start..end]); + }); +} + fn accumulate_grouped_primitive_float( partials: &mut [SumPartial], primitive: &PrimitiveArray, @@ -141,6 +264,19 @@ fn accumulate_grouped_primitive_float( }); } +fn accumulate_grouped_primitive_float_runs( + partials: &mut [SumPartial], + primitive: &PrimitiveArray, + group_ids: &[u32], +) where + T: NativePType, +{ + let values = primitive.as_slice::(); + for_each_group_run(group_ids, |group_id, start, end| { + accumulate_grouped_float_run(partials, group_id, &values[start..end]); + }); +} + pub(super) fn accumulate_grouped_bool( partials: &mut [SumPartial], bools: &BoolArray, diff --git a/vortex-array/src/aggregate_fn/fns/sum/mod.rs b/vortex-array/src/aggregate_fn/fns/sum/mod.rs index 8626e810399..eff487d55e8 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/mod.rs @@ -7,6 +7,7 @@ mod decimal; mod grouped; mod primitive; +use vortex_buffer::Buffer; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; @@ -21,11 +22,13 @@ use crate::ArrayRef; use crate::Canonical; use crate::Columnar; use crate::ExecutionCtx; +use crate::IntoArray; use crate::aggregate_fn::Accumulator; use crate::aggregate_fn::AggregateFnId; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::DynAccumulator; use crate::aggregate_fn::EmptyOptions; +use crate::arrays::PrimitiveArray; use crate::dtype::DType; use crate::dtype::DecimalDType; use crate::dtype::MAX_PRECISION; @@ -36,6 +39,7 @@ use crate::expr::stats::Stat; use crate::expr::stats::StatsProvider; use crate::scalar::DecimalValue; use crate::scalar::Scalar; +use crate::validity::Validity; /// Return the sum of an array. /// @@ -201,6 +205,29 @@ impl AggregateFnVTable for Sum { }) } + fn partials_to_array( + &self, + partials: &[Self::Partial], + partial_dtype: &DType, + ) -> VortexResult> { + Ok(match partial_dtype { + DType::Primitive(PType::U64, _) => Some(sum_primitive_partials_to_array( + partials, + unsigned_sum_state_value, + )), + DType::Primitive(PType::I64, _) => Some(sum_primitive_partials_to_array( + partials, + signed_sum_state_value, + )), + DType::Primitive(PType::F64, _) => Some(sum_primitive_partials_to_array( + partials, + float_sum_state_value, + )), + DType::Decimal(..) => None, + _ => vortex_bail!("Unsupported sum partial dtype: {}", partial_dtype), + }) + } + fn reset(&self, partial: &mut Self::Partial) { partial.current = Some(make_zero_state(&partial.return_dtype)); } @@ -309,6 +336,54 @@ pub enum SumState { }, } +fn sum_primitive_partials_to_array( + partials: &[SumPartial], + value_from_state: fn(&SumState) -> T, +) -> ArrayRef +where + T: crate::dtype::NativePType, +{ + if partials.iter().all(|partial| partial.current.is_some()) { + let values = Buffer::from_iter(partials.iter().map(|partial| { + value_from_state( + partial + .current + .as_ref() + .vortex_expect("checked non-null partial"), + ) + })); + return PrimitiveArray::new(values, Validity::AllValid).into_array(); + } + + PrimitiveArray::from_option_iter( + partials + .iter() + .map(|partial| partial.current.as_ref().map(value_from_state)), + ) + .into_array() +} + +fn unsigned_sum_state_value(state: &SumState) -> u64 { + match state { + SumState::Unsigned(v) => *v, + _ => vortex_panic!("unsigned sum state with non-unsigned partial dtype"), + } +} + +fn signed_sum_state_value(state: &SumState) -> i64 { + match state { + SumState::Signed(v) => *v, + _ => vortex_panic!("signed sum state with non-signed partial dtype"), + } +} + +fn float_sum_state_value(state: &SumState) -> f64 { + match state { + SumState::Float(v) => *v, + _ => vortex_panic!("float sum state with non-float partial dtype"), + } +} + fn make_zero_state(return_dtype: &DType) -> SumState { match return_dtype { DType::Primitive(ptype, _) => match ptype { @@ -641,6 +716,20 @@ mod tests { Ok(()) } + #[test] + fn grouped_sum_contiguous_group_runs() -> VortexResult<()> { + let values = PrimitiveArray::new( + buffer![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], + Validity::NonNullable, + ) + .into_array(); + let result = run_grouped_sum(&values, &[0, 0, 0, 0, 1, 1, 1, 1], 2)?; + + let expected = PrimitiveArray::from_option_iter([Some(10.0f64), Some(26.0)]).into_array(); + assert_arrays_eq!(&result, &expected); + Ok(()) + } + #[test] fn grouped_sum_overflow_group_is_null() -> VortexResult<()> { let values = diff --git a/vortex-array/src/aggregate_fn/vtable.rs b/vortex-array/src/aggregate_fn/vtable.rs index 24c2113e64a..09eab6c5a9c 100644 --- a/vortex-array/src/aggregate_fn/vtable.rs +++ b/vortex-array/src/aggregate_fn/vtable.rs @@ -115,6 +115,17 @@ pub trait AggregateFnVTable: 'static + Sized + Clone + Send + Sync { /// options and input dtype used to construct the state. fn to_scalar(&self, partial: &Self::Partial) -> VortexResult; + /// Try to convert dense partial states directly into a partial-state array. + /// + /// Returning `Ok(None)` falls back to scalarizing each partial with [`Self::to_scalar`]. + fn partials_to_array( + &self, + _partials: &[Self::Partial], + _partial_dtype: &DType, + ) -> VortexResult> { + Ok(None) + } + /// Reset the state of the accumulator to an empty group. fn reset(&self, partial: &mut Self::Partial);