diff --git a/AGENTS.md b/AGENTS.md index e5c3d0cc13b..759008d730b 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -113,6 +113,14 @@ cargo +nightly fmt --all cargo clippy --all-targets --all-features ``` +Do not push Rust code changes before running the applicable lint command above. If the change adds +or edits Rustdoc on public APIs, also run the CI docs command so broken intra-doc links are caught +locally: + +```bash +RUSTDOCFLAGS="-D warnings" cargo doc --profile ci --no-deps +``` + Notes: - For `.github/` changes, follow `.github/AGENTS.md` and run @@ -190,5 +198,8 @@ you ran and call out any checks you could not run. All commits must be signed off by the committers in this form: ```text -Signed-off-by: "COMMITTER" +Signed-off-by: COMMITTER ``` + +Do not wrap the committer name in quotes; the DCO check expects the exact unquoted name/email +pair from the commit author. diff --git a/vortex-array/benches/aggregate_grouped.rs b/vortex-array/benches/aggregate_grouped.rs index b067314c1d9..2d46a5cce8a 100644 --- a/vortex-array/benches/aggregate_grouped.rs +++ b/vortex-array/benches/aggregate_grouped.rs @@ -18,10 +18,8 @@ use vortex_array::aggregate_fn::EmptyOptions; use vortex_array::aggregate_fn::GroupedAccumulator; use vortex_array::aggregate_fn::fns::count::Count; use vortex_array::aggregate_fn::fns::sum::Sum; -use vortex_array::arrays::ListViewArray; use vortex_array::arrays::PrimitiveArray; use vortex_array::arrays::VarBinViewArray; -use vortex_array::dtype::DType; use vortex_array::validity::Validity; use vortex_buffer::Buffer; @@ -45,44 +43,42 @@ fn total_element_count(group_sizes: &[usize]) -> usize { group_sizes.iter().sum() } -fn contiguous_list_view(elements: ArrayRef, group_sizes: &[usize]) -> ArrayRef { - let mut offset = 0usize; - let offsets: Buffer = group_sizes +struct DenseGroupedInput { + values: ArrayRef, + group_ids: Vec, + num_groups: usize, +} + +fn dense_grouped_input(values: ArrayRef, group_sizes: &[usize]) -> DenseGroupedInput { + assert_eq!(values.len(), total_element_count(group_sizes)); + + let group_ids = group_sizes .iter() - .map(|&size| { - let current_offset = offset; - offset += size; - current_offset as u32 - }) + .enumerate() + .flat_map(|(group_id, &size)| std::iter::repeat_n(group_id as u32, size)) .collect(); - let sizes: Buffer = group_sizes.iter().map(|&size| size as u32).collect(); - assert_eq!(elements.len(), total_element_count(group_sizes)); - - ListViewArray::try_new( - elements, - offsets.into_array(), - sizes.into_array(), - Validity::NonNullable, - ) - .unwrap() - .into_array() + DenseGroupedInput { + values, + group_ids, + num_groups: group_sizes.len(), + } } -fn i32_nullable_all_valid_input() -> ArrayRef { +fn i32_nullable_all_valid_input() -> DenseGroupedInput { let group_sizes = random_group_sizes(); let element_count = total_element_count(&group_sizes); let values: Buffer = (0..element_count) .map(|i| (i % 1024) as i32 - 512) .collect(); let validity = Validity::from_iter(std::iter::repeat_n(true, element_count)); - contiguous_list_view( + dense_grouped_input( PrimitiveArray::new(values, validity).into_array(), &group_sizes, ) } -fn i32_clustered_nulls_input() -> ArrayRef { +fn i32_clustered_nulls_input() -> DenseGroupedInput { let group_sizes = random_group_sizes(); let element_count = total_element_count(&group_sizes); let values = (0..element_count).map(|i| { @@ -92,26 +88,26 @@ fn i32_clustered_nulls_input() -> ArrayRef { Some((i % 1024) as i32 - 512) } }); - contiguous_list_view( + dense_grouped_input( PrimitiveArray::from_option_iter(values).into_array(), &group_sizes, ) } -fn f64_all_valid_input() -> ArrayRef { +fn f64_all_valid_input() -> DenseGroupedInput { let group_sizes = random_group_sizes(); let element_count = total_element_count(&group_sizes); let mut rng = StdRng::seed_from_u64(GROUP_SIZE_SEED); let values: Buffer = (0..element_count) .map(|_| rng.random_range(-1000.0..1000.0)) .collect(); - contiguous_list_view( + dense_grouped_input( PrimitiveArray::new(values, Validity::NonNullable).into_array(), &group_sizes, ) } -fn f64_clustered_nulls_input() -> ArrayRef { +fn f64_clustered_nulls_input() -> DenseGroupedInput { let group_sizes = random_group_sizes(); let element_count = total_element_count(&group_sizes); let mut rng = StdRng::seed_from_u64(GROUP_SIZE_SEED); @@ -122,40 +118,38 @@ fn f64_clustered_nulls_input() -> ArrayRef { Some(rng.random_range(-1000.0f64..1000.0)) } }); - contiguous_list_view( + dense_grouped_input( PrimitiveArray::from_option_iter(values).into_array(), &group_sizes, ) } -fn varbinview_input() -> ArrayRef { +fn varbinview_input() -> DenseGroupedInput { let group_sizes = random_group_sizes(); let element_count = total_element_count(&group_sizes); let values: Vec = (0..element_count) .map(|i| format!("value-{i:06}")) .collect(); - contiguous_list_view( + dense_grouped_input( VarBinViewArray::from_iter_str(values.iter().map(String::as_str)).into_array(), &group_sizes, ) } -fn list_element_dtype(list_view: &ArrayRef) -> DType { - match list_view.dtype() { - DType::List(element_dtype, _) => element_dtype.as_ref().clone(), - dtype => unreachable!("expected List dtype, got {dtype}"), - } -} - -fn grouped_accumulator(list_view: &ArrayRef, vtable: V) -> ArrayRef +fn grouped_accumulator(input: &DenseGroupedInput, vtable: V) -> ArrayRef where V: AggregateFnVTable + Clone, { let mut acc = - GroupedAccumulator::try_new(vtable, EmptyOptions, list_element_dtype(list_view)).unwrap(); - acc.accumulate_list(list_view, &mut LEGACY_SESSION.create_execution_ctx()) - .unwrap(); - divan::black_box(acc.finish().unwrap()) + GroupedAccumulator::try_new(vtable, EmptyOptions, input.values.dtype().clone()).unwrap(); + acc.accumulate( + &input.values, + &input.group_ids, + input.num_groups, + &mut LEGACY_SESSION.create_execution_ctx(), + ) + .unwrap(); + divan::black_box(acc.finish(input.num_groups).unwrap()) } #[divan::bench] diff --git a/vortex-array/src/aggregate_fn/accumulator.rs b/vortex-array/src/aggregate_fn/accumulator.rs index c89418e67a6..ab4e0ee26ba 100644 --- a/vortex-array/src/aggregate_fn/accumulator.rs +++ b/vortex-array/src/aggregate_fn/accumulator.rs @@ -172,7 +172,7 @@ impl DynAccumulator for Accumulator { } // 3. Iteratively check the registry against each intermediate encoding, executing one - // step between checks. Mirrors the loop in `GroupedAccumulator::accumulate_list_view`. + // step between checks. Mirrors the loop in `GroupedAccumulator::accumulate`. // Iteration 0 re-checks the initial encoding — a redundant HashMap miss, the price of // keeping the loop body uniform. Terminates on `AnyColumnar` (Canonical or Constant) // since the vtable's `accumulate(&Columnar)` handles both cases directly. diff --git a/vortex-array/src/aggregate_fn/accumulator_grouped.rs b/vortex-array/src/aggregate_fn/accumulator_grouped.rs index b87c04ee204..7a614ceed63 100644 --- a/vortex-array/src/aggregate_fn/accumulator_grouped.rs +++ b/vortex-array/src/aggregate_fn/accumulator_grouped.rs @@ -1,18 +1,12 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use arrow_buffer::ArrowNativeType; use vortex_buffer::Buffer; -use vortex_error::VortexExpect; use vortex_error::VortexResult; -use vortex_error::vortex_bail; use vortex_error::vortex_ensure; use vortex_error::vortex_err; -use vortex_error::vortex_panic; -use vortex_mask::Mask; use crate::ArrayRef; -use crate::Canonical; use crate::Columnar; use crate::ExecutionCtx; use crate::IntoArray; @@ -21,164 +15,23 @@ use crate::aggregate_fn::AggregateFn; use crate::aggregate_fn::AggregateFnRef; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::DynAccumulator; +use crate::aggregate_fn::kernels::GroupedAggregateKernelResult; use crate::aggregate_fn::session::AggregateFnSessionExt; -use crate::arrays::ChunkedArray; -use crate::arrays::FixedSizeListArray; -use crate::arrays::ListViewArray; -use crate::arrays::fixed_size_list::FixedSizeListArrayExt; -use crate::arrays::listview::ListViewArrayExt; use crate::builders::builder_with_capacity; -use crate::builtins::ArrayBuiltins; use crate::columnar::AnyColumnar; use crate::dtype::DType; use crate::executor::max_iterations; -use crate::match_each_integer_ptype; +use crate::scalar::Scalar; /// Reference-counted type-erased grouped accumulator. pub type GroupedAccumulatorRef = Box; -/// A batch of grouped values to aggregate. +/// An accumulator used for computing aggregates over dense group ids. /// -/// Each outer list value is one group, and the inner element array is shared by all groups. -/// Aggregate implementations can inspect the concrete grouped representation directly, or ask for -/// derived ranges when their algorithm is expressed in terms of `(offset, size)` pairs. -pub enum GroupedArray { - /// Groups represented as a list-view array with per-group offsets and sizes. - ListView(ListViewArray), - /// Groups represented as a fixed-size list array. - FixedSizeList(FixedSizeListArray), -} - -impl From for GroupedArray { - fn from(groups: ListViewArray) -> Self { - Self::ListView(groups) - } -} - -impl From for GroupedArray { - fn from(groups: FixedSizeListArray) -> Self { - Self::FixedSizeList(groups) - } -} - -impl GroupedArray { - /// The inner element array shared by all groups. - pub fn elements(&self) -> &ArrayRef { - match self { - Self::ListView(groups) => groups.elements(), - Self::FixedSizeList(groups) => groups.elements(), - } - } - - /// Return the `(offset, size)` ranges describing each group in `elements`. - pub fn group_ranges(&self, ctx: &mut ExecutionCtx) -> VortexResult { - match self { - Self::ListView(groups) => list_view_group_ranges(groups, ctx), - Self::FixedSizeList(groups) => Ok(fixed_size_list_group_ranges(groups)), - } - } - - /// Return the per-group validity mask. - pub fn group_validity(&self, ctx: &mut ExecutionCtx) -> VortexResult { - match self { - Self::ListView(groups) => groups.validity()?.execute_mask(groups.len(), ctx), - Self::FixedSizeList(groups) => groups.validity()?.execute_mask(groups.len(), ctx), - } - } - - /// The number of groups in this batch. - pub fn len(&self) -> usize { - match self { - Self::ListView(groups) => groups.len(), - Self::FixedSizeList(groups) => groups.len(), - } - } - - /// Returns true when this batch contains no groups. - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Returns true when every group is valid. - pub fn all_groups_valid(&self, ctx: &mut ExecutionCtx) -> VortexResult { - Ok(self.group_validity(ctx)?.all_true()) - } - - unsafe fn with_elements_unchecked(&self, elements: ArrayRef) -> VortexResult { - Ok(match self { - Self::ListView(groups) => unsafe { - ListViewArray::new_unchecked( - elements, - groups.offsets().clone(), - groups.sizes().clone(), - groups.validity()?, - ) - } - .into(), - Self::FixedSizeList(groups) => unsafe { - FixedSizeListArray::new_unchecked( - elements, - groups.list_size(), - groups.validity()?, - groups.len(), - ) - } - .into(), - }) - } -} - -/// The physical ranges of a grouped array. -pub enum GroupRanges { - /// Explicit ranges extracted from a list-view array. - ListView { - /// The `(offset, size)` ranges. - ranges: Vec<(usize, usize)>, - }, - /// Uniform ranges derived from a fixed-size list array. - FixedSizeList { - /// The number of groups. - len: usize, - /// The number of elements in each group. - size: usize, - }, -} - -impl GroupRanges { - /// The number of groups described by these ranges. - pub fn len(&self) -> usize { - match self { - Self::ListView { ranges } => ranges.len(), - Self::FixedSizeList { len, .. } => *len, - } - } - - /// Returns true when there are no groups. - pub fn is_empty(&self) -> bool { - self.len() == 0 - } - - /// Return the `(offset, size)` range for the group at `index`. - fn range(&self, index: usize) -> (usize, usize) { - match self { - Self::ListView { ranges } => ranges[index], - Self::FixedSizeList { len, size } => { - assert!(index < *len, "range index out of bounds"); - (index * size, *size) - } - } - } - - /// Iterate over all `(offset, size)` group ranges. - pub fn iter(&self) -> impl Iterator + '_ { - (0..self.len()).map(|index| self.range(index)) - } -} - -/// An accumulator used for computing grouped aggregates. -/// -/// Note that the groups must be processed in order, and the accumulator does not support random -/// access to groups. +/// Group ids are caller-assigned `u32` ordinals in the dense range `0..num_groups`. Input batches +/// may repeat, omit, and reorder those ids, but every id must identify a state slot rather than a +/// raw group key. The accumulator keeps one partial state per slot, so ordered and unordered +/// grouping only differ in how the caller assigns ids. pub struct GroupedAccumulator { /// The vtable of the aggregate function. vtable: V, @@ -192,8 +45,8 @@ pub struct GroupedAccumulator { return_dtype: DType, /// The DType of the partial accumulator state. partial_dtype: DType, - /// The accumulated state for prior batches of groups. - partials: Vec, + /// Dense per-group partial state. + partials: Vec, } impl GroupedAccumulator { @@ -221,199 +74,351 @@ impl GroupedAccumulator { dtype, return_dtype, partial_dtype, - partials: vec![], + partials: Vec::new(), }) } + + fn ensure_groups(&mut self, num_groups: usize) -> VortexResult<()> { + validate_num_groups(num_groups)?; + + while self.partials.len() < num_groups { + self.partials + .push(self.vtable.empty_partial(&self.options, &self.dtype)?); + } + Ok(()) + } + + fn validate_group_ids(&self, group_ids: &[u32], num_groups: usize) -> VortexResult<()> { + validate_num_groups(num_groups)?; + for &group_id in group_ids { + vortex_ensure!( + (group_id as usize) < num_groups, + "Group id {} out of range for {} groups", + group_id, + num_groups + ); + } + Ok(()) + } + + fn accumulate_kernel_result( + &mut self, + result: GroupedAggregateKernelResult, + num_groups: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult<()> { + self.accumulate_partials(result.partials(), result.group_ids(), num_groups, ctx) + } + + fn try_accumulate_kernel( + &mut self, + batch: &ArrayRef, + group_ids: &[u32], + num_groups: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult { + let session = ctx.session().clone(); + + if let Some(kernel) = session + .aggregate_fns() + .find_grouped_encoding_kernel(batch.encoding_id(), self.aggregate_fn.id()) + && let Some(result) = + kernel.grouped_aggregate(&self.aggregate_fn, batch, group_ids, num_groups, ctx)? + { + self.accumulate_kernel_result(result, num_groups, ctx)?; + return Ok(true); + } + + if let Some(kernel) = session + .aggregate_fns() + .find_grouped_kernel(self.aggregate_fn.id()) + && let Some(result) = + kernel.grouped_aggregate(&self.aggregate_fn, batch, group_ids, num_groups, ctx)? + { + self.accumulate_kernel_result(result, num_groups, ctx)?; + return Ok(true); + } + + Ok(false) + } + + fn accumulate_fallback( + &mut self, + batch: &ArrayRef, + group_ids: &[u32], + ctx: &mut ExecutionCtx, + ) -> VortexResult<()> { + let Some((&first, rest)) = group_ids.split_first() else { + return Ok(()); + }; + let mut first = first; + let mut last = first; + for &group_id in rest { + first = first.min(group_id); + last = last.max(group_id); + } + + let first = first as usize; + let mut buckets = vec![Vec::new(); last as usize - first + 1]; + for (row_idx, &group_id) in group_ids.iter().enumerate() { + buckets[group_id as usize - first].push(row_idx as u64); + } + + for (offset, rows) in buckets.into_iter().enumerate() { + if rows.is_empty() { + continue; + } + + let group = first + offset; + if self.vtable.is_saturated(&self.partials[group]) { + continue; + } + + let taken = batch.clone().take(Buffer::from_iter(rows).into_array())?; + let mut accumulator = Accumulator::try_new( + self.vtable.clone(), + self.options.clone(), + self.dtype.clone(), + )?; + accumulator.accumulate(&taken, ctx)?; + let partial = accumulator.flush()?; + self.vtable + .combine_partials(&mut self.partials[group], partial)?; + } + Ok(()) + } +} + +fn validate_num_groups(num_groups: usize) -> VortexResult<()> { + vortex_ensure!( + num_groups == 0 || u32::try_from(num_groups - 1).is_ok(), + "num_groups {} exceeds dense u32 group id capacity", + num_groups + ); + Ok(()) } -/// A trait object for type-erased grouped accumulators, used for dynamic dispatch when the aggregate -/// function is not known at compile time. +/// A trait object for type-erased grouped accumulators, used for dynamic dispatch when the +/// aggregate function is not known at compile time. pub trait DynGroupedAccumulator: 'static + Send { - /// Accumulate a list of groups into the accumulator. - fn accumulate_list(&mut self, groups: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()>; + /// Accumulate a values batch into dense group state. + /// + /// `group_ids` is parallel to `batch`. Each id must be a caller-assigned group ordinal in + /// `0..num_groups`; ids may repeat, appear out of order, or be absent from a given batch. + fn accumulate( + &mut self, + batch: &ArrayRef, + group_ids: &[u32], + num_groups: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult<()>; + + /// Fold columnar partial states into dense group state. + /// + /// `group_ids` is parallel to `partials` and follows the same dense ordinal contract as + /// [`Self::accumulate`]. + fn accumulate_partials( + &mut self, + partials: &ArrayRef, + group_ids: &[u32], + num_groups: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult<()>; + + /// Merge one group from another grouped accumulator into this accumulator. + fn merge_group( + &mut self, + into: u32, + other: &dyn DynGroupedAccumulator, + from: u32, + ) -> VortexResult<()>; + + /// Return this accumulator's partial dtype. + fn partial_dtype(&self) -> &DType; - /// Finish the accumulation and return the partial aggregate results for all groups. + /// Read one group's current partial state. + fn partial_scalar(&self, group_id: u32) -> VortexResult; + + /// Finish the accumulation and return partial aggregate results for all groups. + /// /// Resets the accumulator state for the next round of accumulation. - fn flush(&mut self) -> VortexResult; + fn flush_partials(&mut self, num_groups: usize) -> VortexResult; - /// Finish the accumulation and return the final aggregate results for all groups. + /// Finish the accumulation and return final aggregate results for all groups. + /// /// Resets the accumulator state for the next round of accumulation. - fn finish(&mut self) -> VortexResult; + fn finish(&mut self, num_groups: usize) -> VortexResult; } impl DynGroupedAccumulator for GroupedAccumulator { - fn accumulate_list(&mut self, groups: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<()> { - let elements_dtype = match groups.dtype() { - DType::List(elem, _) => elem, - DType::FixedSizeList(elem, ..) => elem, - _ => vortex_bail!( - "Input DType mismatch: expected List or FixedSizeList, got {}", - groups.dtype() - ), - }; + fn accumulate( + &mut self, + batch: &ArrayRef, + group_ids: &[u32], + num_groups: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult<()> { vortex_ensure!( - elements_dtype.as_ref() == &self.dtype, + batch.dtype() == &self.dtype, "Input DType mismatch: expected {}, got {}", self.dtype, - elements_dtype + batch.dtype() ); - - // We first execute the groups until it is a ListView or FixedSizeList, since we only - // dispatch the aggregate kernel over the elements of these arrays. - let canonical = match groups.clone().execute::(ctx)? { - Columnar::Canonical(c) => c, - Columnar::Constant(c) => c.into_array().execute::(ctx)?, - }; - match canonical { - Canonical::List(groups) => self.accumulate_grouped_array(groups.into(), ctx), - Canonical::FixedSizeList(groups) => self.accumulate_grouped_array(groups.into(), ctx), - _ => vortex_panic!("We checked the DType above, so this should never happen"), - } - } - - fn flush(&mut self) -> VortexResult { - let states = std::mem::take(&mut self.partials); - Ok(ChunkedArray::try_new(states, self.partial_dtype.clone())?.into_array()) - } - - fn finish(&mut self) -> VortexResult { - let states = self.flush()?; - let results = self.vtable.finalize(states)?; - vortex_ensure!( - results.dtype() == &self.return_dtype, - "Return DType mismatch: expected {}, got {}", - self.return_dtype, - results.dtype() + batch.len() == group_ids.len(), + "Grouped aggregate input length mismatch: {} values, {} group ids", + batch.len(), + group_ids.len() ); - Ok(results) - } -} + self.validate_group_ids(group_ids, num_groups)?; + self.ensure_groups(num_groups)?; -impl GroupedAccumulator { - fn accumulate_grouped_array( - &mut self, - groups: GroupedArray, - ctx: &mut ExecutionCtx, - ) -> VortexResult<()> { - let mut elements = groups.elements().clone(); - let session = ctx.session().clone(); + if self.try_accumulate_kernel(batch, group_ids, num_groups, ctx)? { + return Ok(()); + } - for _ in 0..max_iterations() { - // Try a registered grouped kernel for the current element encoding. - if let Some(kernel) = session - .aggregate_fns() - .find_grouped_encoding_kernel(elements.encoding_id(), self.aggregate_fn.id()) - { - // SAFETY: we assume that elements execution is safe - let kernel_groups = unsafe { groups.with_elements_unchecked(elements.clone())? }; - if let Some(result) = - kernel.grouped_aggregate(&self.aggregate_fn, &kernel_groups, ctx)? - { - return self.push_result(result); - } - } + if self.vtable.try_accumulate_grouped( + &mut self.partials[..num_groups], + batch, + group_ids, + ctx, + )? { + return Ok(()); + } - // Try a grouped kernel for the current aggregate regardless of element encoding. - if let Some(kernel) = session - .aggregate_fns() - .find_grouped_kernel(self.aggregate_fn.id()) - { - // SAFETY: we preserve the grouped shape and validity while replacing the - // elements with another representation of the same logical array. - let kernel_groups = unsafe { groups.with_elements_unchecked(elements.clone())? }; - if let Some(result) = - kernel.grouped_aggregate(&self.aggregate_fn, &kernel_groups, ctx)? - { - return self.push_result(result); - } + let input = batch.clone(); + let mut batch = batch.clone(); + for _ in 0..max_iterations() { + if batch.is::() { + break; } - if elements.is::() { - break; + if self.try_accumulate_kernel(&batch, group_ids, num_groups, ctx)? { + return Ok(()); } - // Execute one step and try again - elements = elements.execute(ctx)?; + batch = batch.execute(ctx)?; } - let elements = elements.execute::(ctx)?.into_array(); - // SAFETY: we preserve the grouped shape and validity while replacing the elements with an - // executed form of the same logical array. - let grouped = unsafe { groups.with_elements_unchecked(elements)? }; + let columnar = batch.clone().execute::(ctx)?; + if self.vtable.accumulate_grouped( + &mut self.partials[..num_groups], + &columnar, + group_ids, + ctx, + )? { + return Ok(()); + } - // Otherwise, we iterate the offsets and sizes and accumulate each group one by one. - self.accumulate_grouped_fallback(&grouped, ctx) + self.accumulate_fallback(&input, group_ids, ctx) } - fn accumulate_grouped_fallback( + fn accumulate_partials( &mut self, - grouped: &GroupedArray, + partials: &ArrayRef, + group_ids: &[u32], + num_groups: usize, ctx: &mut ExecutionCtx, ) -> VortexResult<()> { - let mut accumulator = Accumulator::try_new( - self.vtable.clone(), - self.options.clone(), - self.dtype.clone(), - )?; - let mut states = builder_with_capacity(&self.partial_dtype, grouped.len()); - let group_ranges = grouped.group_ranges(ctx)?; - let group_validity = grouped.group_validity(ctx)?; - - for ((offset, size), valid) in group_ranges.iter().zip(group_validity.iter()) { - if valid { - let group = grouped.elements().slice(offset..offset + size)?; - accumulator.accumulate(&group, ctx)?; - states.append_scalar(&accumulator.flush()?)?; - } else { - states.append_null() - } - } + vortex_ensure!( + partials.dtype() == &self.partial_dtype, + "Partial DType mismatch: expected {}, got {}", + self.partial_dtype, + partials.dtype() + ); + vortex_ensure!( + partials.len() == group_ids.len(), + "Grouped aggregate partial length mismatch: {} partials, {} group ids", + partials.len(), + group_ids.len() + ); + + self.validate_group_ids(group_ids, num_groups)?; + self.ensure_groups(num_groups)?; - self.push_result(states.finish()) + for (row_idx, &group_id) in group_ids.iter().enumerate() { + let partial = partials.execute_scalar(row_idx, ctx)?; + self.vtable + .combine_partials(&mut self.partials[group_id as usize], partial)?; + } + Ok(()) } - fn push_result(&mut self, state: ArrayRef) -> VortexResult<()> { + fn merge_group( + &mut self, + into: u32, + other: &dyn DynGroupedAccumulator, + from: u32, + ) -> VortexResult<()> { vortex_ensure!( - state.dtype() == &self.partial_dtype, - "State DType mismatch: expected {}, got {}", + other.partial_dtype() == &self.partial_dtype, + "Partial DType mismatch: expected {}, got {}", self.partial_dtype, - state.dtype() + other.partial_dtype() ); - self.partials.push(state); - Ok(()) + self.ensure_groups((into as usize) + 1)?; + let partial = other.partial_scalar(from)?; + self.vtable + .combine_partials(&mut self.partials[into as usize], partial) } -} -fn list_view_group_ranges( - groups: &ListViewArray, - ctx: &mut ExecutionCtx, -) -> VortexResult { - let offsets = groups.offsets(); - let sizes = groups.sizes().cast(offsets.dtype().clone())?; - - let ranges = match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| { - let offsets = offsets.clone().execute::>(ctx)?; - let sizes = sizes.execute::>(ctx)?; - offsets - .as_ref() - .iter() - .zip(sizes.as_ref().iter()) - .map(|(offset, size)| { - ( - offset.to_usize().vortex_expect("Offset value is not usize"), - size.to_usize().vortex_expect("Size value is not usize"), - ) - }) - .collect::>() - }); - - Ok(GroupRanges::ListView { ranges }) -} -fn fixed_size_list_group_ranges(groups: &FixedSizeListArray) -> GroupRanges { - GroupRanges::FixedSizeList { - len: groups.len(), - size: groups.list_size() as usize, + fn partial_dtype(&self) -> &DType { + &self.partial_dtype + } + + fn partial_scalar(&self, group_id: u32) -> VortexResult { + if let Some(partial) = self.partials.get(group_id as usize) { + self.vtable.to_scalar(partial) + } else { + let partial = self.vtable.empty_partial(&self.options, &self.dtype)?; + self.vtable.to_scalar(&partial) + } + } + + fn flush_partials(&mut self, num_groups: usize) -> VortexResult { + vortex_ensure!( + num_groups >= self.partials.len(), + "Cannot flush {} groups after accumulating {} groups", + num_groups, + self.partials.len() + ); + self.ensure_groups(num_groups)?; + + if let Some(states) = self + .vtable + .partials_to_array(&self.partials, &self.partial_dtype)? + { + vortex_ensure!( + states.dtype() == &self.partial_dtype, + "Partial array DType mismatch: expected {}, got {}", + self.partial_dtype, + states.dtype() + ); + self.partials.clear(); + return Ok(states); + } + + let mut states = builder_with_capacity(&self.partial_dtype, num_groups); + for partial in &self.partials { + states.append_scalar(&self.vtable.to_scalar(partial)?)?; + } + self.partials.clear(); + + Ok(states.finish()) + } + + fn finish(&mut self, num_groups: usize) -> VortexResult { + let states = self.flush_partials(num_groups)?; + let results = self.vtable.finalize(states)?; + + vortex_ensure!( + results.dtype() == &self.return_dtype, + "Return DType mismatch: expected {}, got {}", + self.return_dtype, + results.dtype() + ); + + Ok(results) } } diff --git a/vortex-array/src/aggregate_fn/fns/count/grouped.rs b/vortex-array/src/aggregate_fn/fns/count/grouped.rs index fb94489dde0..03e2b1b49ae 100644 --- a/vortex-array/src/aggregate_fn/fns/count/grouped.rs +++ b/vortex-array/src/aggregate_fn/fns/count/grouped.rs @@ -1,216 +1,22 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors -use vortex_buffer::Buffer; use vortex_error::VortexResult; -use vortex_mask::Mask; -use super::Count; use crate::ArrayRef; use crate::ExecutionCtx; -use crate::IntoArray; -use crate::aggregate_fn::AggregateFnRef; -use crate::aggregate_fn::GroupRanges; -use crate::aggregate_fn::GroupedArray; -use crate::aggregate_fn::kernels::DynGroupedAggregateKernel; -use crate::arrays::PrimitiveArray; -use crate::validity::Validity; -/// Encoding-independent grouped [`Count`] kernel. -#[derive(Debug)] -pub(crate) struct CountGroupedKernel; - -impl DynGroupedAggregateKernel for CountGroupedKernel { - fn grouped_aggregate( - &self, - aggregate_fn: &AggregateFnRef, - groups: &GroupedArray, - ctx: &mut ExecutionCtx, - ) -> VortexResult> { - if !aggregate_fn.is::() { - return Ok(None); - } - try_grouped_count(groups, ctx) - } -} - -/// Count each valid group from the element validity mask. -/// -/// The [`Count`] partial dtype is non-nullable `U64`, so a null outer group cannot be represented -/// as a partial state. If any outer group is invalid, this returns `Ok(None)` and lets the caller -/// use the existing fallback behavior. -pub(super) fn try_grouped_count( - groups: &GroupedArray, - ctx: &mut ExecutionCtx, -) -> VortexResult> { - if !groups.all_groups_valid(ctx)? { - return Ok(None); - } - let group_ranges = groups.group_ranges(ctx)?; - - Ok(Some(grouped_count(groups.elements(), &group_ranges, ctx)?)) -} - -/// Count the valid elements of each group described by `group_ranges` (element `(offset, size)` -/// pairs) into a non-nullable `U64` array, one entry per group. -fn grouped_count( - elements: &ArrayRef, - group_ranges: &GroupRanges, +pub(super) fn try_accumulate_grouped( + states: &mut [u64], + batch: &ArrayRef, + group_ids: &[u32], ctx: &mut ExecutionCtx, -) -> VortexResult { - let elem_mask = elements.validity()?.execute_mask(elements.len(), ctx)?; - - let counts: Buffer = if elem_mask.all_true() { - group_ranges.iter().map(|(_, size)| size as u64).collect() - } else { - group_ranges - .iter() - .map(|(offset, size)| valid_count(&elem_mask, offset, size) as u64) - .collect() - }; - - Ok(PrimitiveArray::new(counts, Validity::NonNullable).into_array()) -} - -/// Number of valid elements in the `[offset, offset + size)` range of the element mask. -fn valid_count(elem_mask: &Mask, offset: usize, size: usize) -> usize { - elem_mask.slice(offset..offset + size).true_count() -} - -#[cfg(test)] -mod tests { - #![allow(clippy::cast_possible_truncation)] - - use vortex_buffer::Buffer; - use vortex_buffer::buffer; - use vortex_error::VortexResult; - - use crate::ArrayRef; - use crate::IntoArray; - use crate::LEGACY_SESSION; - use crate::VortexSessionExecute; - use crate::aggregate_fn::DynGroupedAccumulator; - use crate::aggregate_fn::EmptyOptions; - use crate::aggregate_fn::GroupedAccumulator; - use crate::aggregate_fn::fns::count::Count; - use crate::arrays::FixedSizeListArray; - use crate::arrays::ListViewArray; - use crate::arrays::PrimitiveArray; - use crate::arrays::VarBinViewArray; - use crate::assert_arrays_eq; - use crate::dtype::DType; - use crate::dtype::Nullability::NonNullable; - use crate::dtype::Nullability::Nullable; - use crate::dtype::PType; - use crate::validity::Validity; - - /// Run a grouped count through the accumulator. - fn grouped_count_actual(groups: &ArrayRef, elem_dtype: &DType) -> VortexResult { - let mut acc = GroupedAccumulator::try_new(Count, EmptyOptions, elem_dtype.clone())?; - acc.accumulate_list(groups, &mut LEGACY_SESSION.create_execution_ctx())?; - acc.finish() - } - - /// Reference valid-counts (non-nullable `U64`), one per group. - fn grouped_count_reference( - elements: &ArrayRef, - ranges: &[(usize, usize)], - ) -> VortexResult { - let mut ctx = LEGACY_SESSION.create_execution_ctx(); - let counts: Buffer = ranges - .iter() - .map(|&(offset, size)| { - Ok(elements - .slice(offset..offset + size)? - .valid_count(&mut ctx)? as u64) - }) - .collect::>()?; - Ok(PrimitiveArray::new(counts, Validity::NonNullable).into_array()) - } - - fn listview(elements: ArrayRef, ranges: &[(usize, usize)]) -> VortexResult { - let offsets = PrimitiveArray::from_iter(ranges.iter().map(|&(o, _)| o as i32)); - let sizes = PrimitiveArray::from_iter(ranges.iter().map(|&(_, s)| s as i32)); - Ok(ListViewArray::try_new( - elements, - offsets.into_array(), - sizes.into_array(), - Validity::NonNullable, - )? - .into_array()) - } - - #[test] - fn listview_counts_all_valid() -> VortexResult<()> { - let elements = - PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6], Validity::NonNullable).into_array(); - let elem_dtype = DType::Primitive(PType::I32, NonNullable); - let ranges = [(0, 2), (2, 1), (3, 3), (6, 0)]; - - let groups = listview(elements.clone(), &ranges)?; - let actual = grouped_count_actual(&groups, &elem_dtype)?; - let expected = grouped_count_reference(&elements, &ranges)?; - - let direct = - PrimitiveArray::new(buffer![2u64, 1, 3, 0], Validity::NonNullable).into_array(); - assert_arrays_eq!(&actual, &direct); - assert_arrays_eq!(&actual, &expected); - Ok(()) - } - - #[test] - fn listview_counts_with_nulls() -> VortexResult<()> { - let elements = - PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, None, Some(9)]) - .into_array(); - let elem_dtype = DType::Primitive(PType::I32, Nullable); - let ranges = [(0, 3), (3, 2), (5, 1)]; - - let groups = listview(elements.clone(), &ranges)?; - let actual = grouped_count_actual(&groups, &elem_dtype)?; - let expected = grouped_count_reference(&elements, &ranges)?; - - // Group 0: {1, null, 3} -> 2. Group 1: {null, null} -> 0. Group 2: {9} -> 1. - let direct = PrimitiveArray::new(buffer![2u64, 0, 1], Validity::NonNullable).into_array(); - assert_arrays_eq!(&actual, &direct); - assert_arrays_eq!(&actual, &expected); - Ok(()) - } - - #[test] - fn listview_counts_varbinview_with_nulls() -> VortexResult<()> { - let elements = VarBinViewArray::from_iter_nullable_str([ - Some("a"), - None, - Some("bbb"), - None, - Some("cc"), - ]) - .into_array(); - let elem_dtype = elements.dtype().clone(); - let ranges = [(0, 2), (2, 2), (4, 1)]; - - let groups = listview(elements.clone(), &ranges)?; - let actual = grouped_count_actual(&groups, &elem_dtype)?; - let expected = grouped_count_reference(&elements, &ranges)?; - - let direct = PrimitiveArray::new(buffer![1u64, 1, 1], Validity::NonNullable).into_array(); - assert_arrays_eq!(&actual, &direct); - assert_arrays_eq!(&actual, &expected); - Ok(()) - } - - #[test] - fn fixed_size_counts_with_nulls() -> VortexResult<()> { - let elements = - PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4)]).into_array(); - let elem_dtype = DType::Primitive(PType::I32, Nullable); - let groups = - FixedSizeListArray::try_new(elements, 2, Validity::NonNullable, 2)?.into_array(); - - let actual = grouped_count_actual(&groups, &elem_dtype)?; - let direct = PrimitiveArray::new(buffer![1u64, 2], Validity::NonNullable).into_array(); - assert_arrays_eq!(&actual, &direct); - Ok(()) +) -> VortexResult { + let validity = batch.validity()?.execute_mask(batch.len(), ctx)?; + for (&group_id, valid) in group_ids.iter().zip(validity.iter()) { + if valid { + states[group_id as usize] += 1; + } } + Ok(true) } diff --git a/vortex-array/src/aggregate_fn/fns/count/mod.rs b/vortex-array/src/aggregate_fn/fns/count/mod.rs index 1fe984fb099..e53a378b5a9 100644 --- a/vortex-array/src/aggregate_fn/fns/count/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/count/mod.rs @@ -2,16 +2,17 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors mod grouped; -pub(crate) use grouped::CountGroupedKernel; use vortex_error::VortexExpect; use vortex_error::VortexResult; use crate::ArrayRef; use crate::Columnar; use crate::ExecutionCtx; +use crate::IntoArray; use crate::aggregate_fn::AggregateFnId; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::EmptyOptions; +use crate::arrays::PrimitiveArray; use crate::dtype::DType; use crate::dtype::Nullability; use crate::dtype::PType; @@ -65,6 +66,16 @@ impl AggregateFnVTable for Count { Ok(Scalar::primitive(*partial, Nullability::NonNullable)) } + fn partials_to_array( + &self, + partials: &[Self::Partial], + _partial_dtype: &DType, + ) -> VortexResult> { + Ok(Some( + PrimitiveArray::from_iter(partials.iter().copied()).into_array(), + )) + } + fn reset(&self, partial: &mut Self::Partial) { *partial = 0; } @@ -84,6 +95,16 @@ impl AggregateFnVTable for Count { Ok(true) } + fn try_accumulate_grouped( + &self, + states: &mut [Self::Partial], + batch: &ArrayRef, + group_ids: &[u32], + ctx: &mut ExecutionCtx, + ) -> VortexResult { + grouped::try_accumulate_grouped(states, batch, group_ids, ctx) + } + fn accumulate( &self, _partial: &mut Self::Partial, @@ -116,11 +137,15 @@ mod tests { use crate::aggregate_fn::Accumulator; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::DynAccumulator; + use crate::aggregate_fn::DynGroupedAccumulator; use crate::aggregate_fn::EmptyOptions; + use crate::aggregate_fn::GroupedAccumulator; use crate::aggregate_fn::fns::count::Count; use crate::arrays::ChunkedArray; use crate::arrays::ConstantArray; use crate::arrays::PrimitiveArray; + use crate::arrays::VarBinViewArray; + use crate::assert_arrays_eq; use crate::dtype::DType; use crate::dtype::Nullability; use crate::dtype::PType; @@ -227,6 +252,89 @@ mod tests { Ok(()) } + fn run_grouped_count( + values: &ArrayRef, + group_ids: &[u32], + num_groups: usize, + ) -> VortexResult { + let mut acc = GroupedAccumulator::try_new(Count, EmptyOptions, values.dtype().clone())?; + acc.accumulate( + values, + group_ids, + num_groups, + &mut LEGACY_SESSION.create_execution_ctx(), + )?; + acc.finish(num_groups) + } + + #[test] + fn grouped_count_dense_ids() -> VortexResult<()> { + let values = + PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), Some(4), None, Some(6)]) + .into_array(); + let actual = run_grouped_count(&values, &[0, 0, 1, 1, 2, 2], 3)?; + + let expected = PrimitiveArray::from_iter([1u64, 2, 1]).into_array(); + assert_arrays_eq!(&actual, &expected); + Ok(()) + } + + #[test] + fn grouped_count_omitted_group() -> VortexResult<()> { + let values = + PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6], Validity::NonNullable).into_array(); + let actual = run_grouped_count(&values, &[0, 0, 1, 2, 2, 2], 4)?; + + let expected = PrimitiveArray::from_iter([2u64, 1, 3, 0]).into_array(); + assert_arrays_eq!(&actual, &expected); + Ok(()) + } + + #[test] + fn grouped_count_varbinview_with_nulls() -> VortexResult<()> { + let values = VarBinViewArray::from_iter_nullable_str([ + Some("a"), + None, + Some("bbb"), + None, + Some("cc"), + ]) + .into_array(); + let actual = run_grouped_count(&values, &[0, 0, 1, 1, 2], 3)?; + + let expected = PrimitiveArray::from_iter([1u64, 1, 1]).into_array(); + assert_arrays_eq!(&actual, &expected); + Ok(()) + } + + #[test] + fn grouped_count_rejects_out_of_range_group_id() -> VortexResult<()> { + let values = PrimitiveArray::new(buffer![1i32, 2], Validity::NonNullable).into_array(); + let mut acc = GroupedAccumulator::try_new(Count, EmptyOptions, values.dtype().clone())?; + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + + assert!(acc.accumulate(&values, &[0, 2], 2, &mut ctx).is_err()); + Ok(()) + } + + #[test] + fn grouped_count_accumulate_partials_and_merge_group() -> VortexResult<()> { + let dtype = DType::Primitive(PType::I32, Nullability::Nullable); + let partials = PrimitiveArray::from_iter([2u64, 3, 5]).into_array(); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + + let mut left = GroupedAccumulator::try_new(Count, EmptyOptions, dtype.clone())?; + left.accumulate_partials(&partials, &[0, 1, 1], 2, &mut ctx)?; + + let mut right = GroupedAccumulator::try_new(Count, EmptyOptions, dtype)?; + right.merge_group(0, &left, 1)?; + + let actual = right.finish(1)?; + let expected = PrimitiveArray::from_iter([8u64]).into_array(); + assert_arrays_eq!(&actual, &expected); + Ok(()) + } + #[test] fn count_constant_non_null() -> VortexResult<()> { let array = ConstantArray::new(42i32, 10); diff --git a/vortex-array/src/aggregate_fn/fns/sum/grouped.rs b/vortex-array/src/aggregate_fn/fns/sum/grouped.rs index 6f00cce7fdb..81304f1eb9f 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/grouped.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/grouped.rs @@ -1,367 +1,297 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use num_traits::AsPrimitive; +use num_traits::ToPrimitive; +use vortex_error::VortexExpect; use vortex_error::VortexResult; +use vortex_error::vortex_panic; use vortex_mask::AllOr; use vortex_mask::Mask; -use super::Sum; +use super::SumPartial; +use super::SumState; +use super::checked_add_i64; +use super::checked_add_u64; use super::primitive::sum_float_all; use super::primitive::sum_signed_all; use super::primitive::sum_unsigned_all; -use crate::ArrayRef; use crate::ExecutionCtx; -use crate::IntoArray; -use crate::aggregate_fn::AggregateFnRef; -use crate::aggregate_fn::GroupRanges; -use crate::aggregate_fn::GroupedArray; -use crate::aggregate_fn::kernels::DynGroupedAggregateKernel; -use crate::arrays::Primitive; +use crate::arrays::BoolArray; use crate::arrays::PrimitiveArray; +use crate::arrays::bool::BoolArrayExt; use crate::dtype::NativePType; use crate::match_each_native_ptype; -/// Encoding-specific grouped [`Sum`] kernel for primitive element arrays. -#[derive(Debug)] -pub(crate) struct PrimitiveGroupedSumEncodingKernel; +const MIN_AVG_RUN_LENGTH_FOR_GROUPED_SUM_RUNS: usize = 4; -impl DynGroupedAggregateKernel for PrimitiveGroupedSumEncodingKernel { - fn grouped_aggregate( - &self, - aggregate_fn: &AggregateFnRef, - groups: &GroupedArray, - ctx: &mut ExecutionCtx, - ) -> VortexResult> { - if !aggregate_fn.is::() { - return Ok(None); +fn for_each_valid_idx(validity: &Mask, len: usize, mut f: impl FnMut(usize)) { + match validity.indices() { + AllOr::All => { + for idx in 0..len { + f(idx); + } + } + AllOr::None => {} + AllOr::Some(indices) => { + for &idx in indices { + f(idx); + } } - try_grouped_sum(groups, ctx) - } -} - -/// Grouped [`Sum`] implementation for canonical primitive elements. -/// -/// Reuses the scalar primitive-sum reductions ([`sum_unsigned_all`]/[`sum_signed_all`]/ -/// [`sum_float_all`]) so the per-group semantics match scalar `sum` exactly (overflow saturates to -/// a null sum, NaNs are skipped). The element validity mask is materialized once and sliced per -/// group, rather than the per-group accumulator setup of the generic fallback path. -pub(super) fn try_grouped_sum( - groups: &GroupedArray, - ctx: &mut ExecutionCtx, -) -> VortexResult> { - if !groups.elements().is::() { - return Ok(None); } - let elements = groups.elements().clone().downcast::(); - let group_ranges = groups.group_ranges(ctx)?; - let group_validity = groups.group_validity(ctx)?; - - Ok(Some(grouped_sum( - &elements, - &group_ranges, - &group_validity, - ctx, - )?)) } -/// Sum each group described by `group_ranges` (element `(offset, size)` pairs), one sum per group. -fn grouped_sum( - elements: &PrimitiveArray, - group_ranges: &GroupRanges, - group_validity: &Mask, - ctx: &mut ExecutionCtx, -) -> VortexResult { - let elem_mask = elements - .as_ref() - .validity()? - .execute_mask(elements.as_ref().len(), ctx)?; - let all_valid = matches!(elem_mask.slices(), AllOr::All); - - let result = match_each_native_ptype!(elements.ptype(), - unsigned: |T| { - let values = elements.as_slice::(); - collect_sums::(values, group_ranges, group_validity, &elem_mask, all_valid, - sum_unsigned_all) - }, - signed: |T| { - let values = elements.as_slice::(); - collect_sums::(values, group_ranges, group_validity, &elem_mask, all_valid, - sum_signed_all) - }, - floating: |T| { - let values = elements.as_slice::(); - collect_sums::(values, group_ranges, group_validity, &elem_mask, all_valid, - |acc, slice| { sum_float_all(acc, slice); false }) +fn should_accumulate_group_runs(group_ids: &[u32]) -> bool { + let Some((&first, rest)) = group_ids.split_first() else { + return false; + }; + + let mut run_count = 1usize; + let mut group_id = first; + for &next_group_id in rest { + if next_group_id != group_id { + run_count += 1; + group_id = next_group_id; } - ); + } - Ok(result.into_array()) + run_count * MIN_AVG_RUN_LENGTH_FOR_GROUPED_SUM_RUNS <= group_ids.len() } -/// Reduce each group's element slice into a nullable sum. A group is null when the group -/// itself is invalid, or when summing it overflows (`sum_run` returns `true`). -fn collect_sums( - values: &[T], - group_ranges: &GroupRanges, - group_validity: &Mask, - elem_mask: &Mask, - all_valid: bool, - sum_run: impl Fn(&mut A, &[T]) -> bool, -) -> PrimitiveArray { - let sums = group_ranges.iter().enumerate().map(|(i, (offset, size))| { - if !group_validity.value(i) { - return None; +fn for_each_group_run(group_ids: &[u32], mut f: impl FnMut(u32, usize, usize)) { + let Some((&first, rest)) = group_ids.split_first() else { + return; + }; + + let mut group_id = first; + let mut start = 0usize; + for (idx, &next_group_id) in rest.iter().enumerate() { + let idx = idx + 1; + if next_group_id != group_id { + f(group_id, start, idx); + group_id = next_group_id; + start = idx; } - let mut acc = A::default(); - let overflow = if all_valid { - sum_run(&mut acc, &values[offset..offset + size]) - } else { - sum_masked_group(&mut acc, values, offset, size, elem_mask, &sum_run) - }; - (!overflow).then_some(acc) - }); - PrimitiveArray::from_option_iter(sums) + } + f(group_id, start, group_ids.len()); } -/// Sum the valid elements of a single group, using the contiguous valid runs of the element mask -/// intersected with the group's `[offset, offset + size)` range. -fn sum_masked_group( - acc: &mut A, - values: &[T], - offset: usize, - size: usize, - elem_mask: &Mask, - sum_run: &impl Fn(&mut A, &[T]) -> bool, -) -> bool { - match elem_mask.slice(offset..offset + size).slices() { - AllOr::All => sum_run(acc, &values[offset..offset + size]), - AllOr::None => false, - AllOr::Some(runs) => { - for &(start, end) in runs { - if sum_run(acc, &values[offset + start..offset + end]) { - return true; - } - } - false - } +fn accumulate_grouped_unsigned(partials: &mut [SumPartial], group_id: u32, value: u64) { + let partial = &mut partials[group_id as usize]; + let saturated = match partial.current.as_mut() { + None => return, + Some(SumState::Unsigned(acc)) => checked_add_u64(acc, value), + Some(_) => vortex_panic!("unsigned sum state with non-unsigned input"), + }; + if saturated { + partial.current = None; } } -#[cfg(test)] -mod tests { - #![allow(clippy::cast_possible_truncation)] - - use vortex_buffer::buffer; - use vortex_error::VortexResult; - - use crate::ArrayRef; - use crate::IntoArray; - use crate::LEGACY_SESSION; - use crate::VortexSessionExecute; - use crate::aggregate_fn::DynGroupedAccumulator; - use crate::aggregate_fn::EmptyOptions; - use crate::aggregate_fn::GroupedAccumulator; - use crate::aggregate_fn::fns::sum::Sum; - use crate::aggregate_fn::fns::sum::sum; - use crate::arrays::FixedSizeListArray; - use crate::arrays::ListViewArray; - use crate::arrays::PrimitiveArray; - use crate::assert_arrays_eq; - use crate::builders::builder_with_capacity; - use crate::dtype::DType; - use crate::dtype::Nullability::NonNullable; - use crate::dtype::Nullability::Nullable; - use crate::dtype::PType; - use crate::validity::Validity; - - /// Run a grouped sum through the accumulator. - fn grouped_sum_actual(groups: &ArrayRef, elem_dtype: &DType) -> VortexResult { - let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype.clone())?; - acc.accumulate_list(groups, &mut LEGACY_SESSION.create_execution_ctx())?; - acc.finish() +fn accumulate_grouped_unsigned_run(partials: &mut [SumPartial], group_id: u32, values: &[T]) +where + T: NativePType + AsPrimitive, +{ + let partial = &mut partials[group_id as usize]; + let saturated = match partial.current.as_mut() { + None => return, + Some(SumState::Unsigned(acc)) => sum_unsigned_all(acc, values), + Some(_) => vortex_panic!("unsigned sum state with non-unsigned input"), + }; + if saturated { + partial.current = None; } +} - /// Reference sums computed exactly like the generic slow path: per-group scalar [`sum`] for - /// valid groups, a null sum for invalid groups. - fn grouped_sum_reference( - elements: &ArrayRef, - ranges: &[(usize, usize)], - group_valid: &[bool], - elem_dtype: &DType, - ) -> VortexResult { - use crate::aggregate_fn::AggregateFnVTable; - - let mut ctx = LEGACY_SESSION.create_execution_ctx(); - let sum_dtype = Sum - .partial_dtype(&EmptyOptions, elem_dtype) - .expect("sum partial dtype"); - let mut builder = builder_with_capacity(&sum_dtype, ranges.len()); - for (i, &(offset, size)) in ranges.iter().enumerate() { - if group_valid[i] { - let slice = elements.slice(offset..offset + size)?; - builder.append_scalar(&sum(&slice, &mut ctx)?)?; - } else { - builder.append_null(); - } - } - Ok(builder.finish()) +fn accumulate_grouped_signed(partials: &mut [SumPartial], group_id: u32, value: i64) { + let partial = &mut partials[group_id as usize]; + let saturated = match partial.current.as_mut() { + None => return, + Some(SumState::Signed(acc)) => checked_add_i64(acc, value), + Some(_) => vortex_panic!("signed sum state with non-signed input"), + }; + if saturated { + partial.current = None; } +} - fn offsets_sizes(ranges: &[(usize, usize)]) -> (ArrayRef, ArrayRef) { - let offsets = PrimitiveArray::from_iter(ranges.iter().map(|&(o, _)| o as i32)); - let sizes = PrimitiveArray::from_iter(ranges.iter().map(|&(_, s)| s as i32)); - (offsets.into_array(), sizes.into_array()) +fn accumulate_grouped_signed_run(partials: &mut [SumPartial], group_id: u32, values: &[T]) +where + T: NativePType + AsPrimitive, +{ + let partial = &mut partials[group_id as usize]; + let saturated = match partial.current.as_mut() { + None => return, + Some(SumState::Signed(acc)) => sum_signed_all(acc, values), + Some(_) => vortex_panic!("signed sum state with non-signed input"), + }; + if saturated { + partial.current = None; } +} - fn listview( - elements: ArrayRef, - ranges: &[(usize, usize)], - group_valid: &[bool], - ) -> VortexResult { - let (offsets, sizes) = offsets_sizes(ranges); - let validity = if group_valid.iter().all(|&v| v) { - Validity::NonNullable - } else { - Validity::from_iter(group_valid.iter().copied()) - }; - Ok(ListViewArray::try_new(elements, offsets, sizes, validity)?.into_array()) +fn accumulate_grouped_float(partials: &mut [SumPartial], group_id: u32, value: f64) { + if value.is_nan() { + return; } - #[test] - fn listview_matches_reference_unsigned() -> VortexResult<()> { - let elements = - PrimitiveArray::new(buffer![1u32, 2, 3, 4, 5, 6], Validity::NonNullable).into_array(); - let elem_dtype = DType::Primitive(PType::U32, NonNullable); - let ranges = [(0, 2), (2, 1), (3, 3)]; - let valid = [true, true, true]; - - let groups = listview(elements.clone(), &ranges, &valid)?; - let actual = grouped_sum_actual(&groups, &elem_dtype)?; - let expected = grouped_sum_reference(&elements, &ranges, &valid, &elem_dtype)?; - - // Unsigned input sums to U64. - let direct = PrimitiveArray::from_option_iter([Some(3u64), Some(3u64), Some(15u64)]); - assert_arrays_eq!(&actual, &direct.into_array()); - assert_arrays_eq!(&actual, &expected); - Ok(()) + match partials[group_id as usize].current.as_mut() { + None => {} + Some(SumState::Float(acc)) => *acc += value, + Some(_) => vortex_panic!("float sum state with non-float input"), } +} - #[test] - fn listview_out_of_order_offsets_with_null_group() -> VortexResult<()> { - // Offsets are not in group order and a group is null: the group validity must be indexed by - // group index, not by element offset. - let elements = - PrimitiveArray::new(buffer![10i32, 20, 30, 40, 50, 60], Validity::NonNullable) - .into_array(); - let elem_dtype = DType::Primitive(PType::I32, NonNullable); - let ranges = [(4, 2), (0, 2), (2, 2)]; - let valid = [true, false, true]; - - let groups = listview(elements.clone(), &ranges, &valid)?; - let actual = grouped_sum_actual(&groups, &elem_dtype)?; - let expected = grouped_sum_reference(&elements, &ranges, &valid, &elem_dtype)?; - - let direct = PrimitiveArray::from_option_iter([Some(110i64), None, Some(70i64)]); - assert_arrays_eq!(&actual, &direct.into_array()); - assert_arrays_eq!(&actual, &expected); - Ok(()) +fn accumulate_grouped_float_run( + partials: &mut [SumPartial], + group_id: u32, + values: &[T], +) { + match partials[group_id as usize].current.as_mut() { + None => {} + Some(SumState::Float(acc)) => sum_float_all(acc, values), + Some(_) => vortex_panic!("float sum state with non-float input"), } +} - #[test] - fn listview_interior_and_full_nulls() -> VortexResult<()> { - // Group 1 has an interior null, group 2 is entirely null, group 3 is empty. - let elements = - PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, None, Some(9)]) - .into_array(); - let elem_dtype = DType::Primitive(PType::I32, Nullable); - let ranges = [(0, 3), (3, 2), (5, 0), (5, 1)]; - let valid = [true, true, true, true]; - - let groups = listview(elements.clone(), &ranges, &valid)?; - let actual = grouped_sum_actual(&groups, &elem_dtype)?; - let expected = grouped_sum_reference(&elements, &ranges, &valid, &elem_dtype)?; - - let direct = - PrimitiveArray::from_option_iter([Some(4i64), Some(0i64), Some(0i64), Some(9i64)]); - assert_arrays_eq!(&actual, &direct.into_array()); - assert_arrays_eq!(&actual, &expected); - Ok(()) - } +pub(super) fn accumulate_grouped_primitive( + partials: &mut [SumPartial], + primitive: &PrimitiveArray, + group_ids: &[u32], + ctx: &mut ExecutionCtx, +) -> VortexResult<()> { + let validity = primitive + .as_ref() + .validity()? + .execute_mask(primitive.as_ref().len(), ctx)?; + let use_runs = + matches!(validity.slices(), AllOr::All) && should_accumulate_group_runs(group_ids); - #[test] - fn listview_overflow_group_is_null() -> VortexResult<()> { - let elements = - PrimitiveArray::new(buffer![i64::MAX, 1, 2, 3], Validity::NonNullable).into_array(); - let elem_dtype = DType::Primitive(PType::I64, NonNullable); - let ranges = [(0, 2), (2, 2)]; - let valid = [true, true]; + match_each_native_ptype!(primitive.ptype(), + unsigned: |T| { + if use_runs { + accumulate_grouped_primitive_unsigned_runs::(partials, primitive, group_ids); + } else { + accumulate_grouped_primitive_unsigned::(partials, primitive, group_ids, &validity); + } + Ok(()) + }, + signed: |T| { + if use_runs { + accumulate_grouped_primitive_signed_runs::(partials, primitive, group_ids); + } else { + accumulate_grouped_primitive_signed::(partials, primitive, group_ids, &validity); + } + Ok(()) + }, + floating: |T| { + if use_runs { + accumulate_grouped_primitive_float_runs::(partials, primitive, group_ids); + } else { + accumulate_grouped_primitive_float::(partials, primitive, group_ids, &validity); + } + Ok(()) + } + ) +} - let groups = listview(elements.clone(), &ranges, &valid)?; - let actual = grouped_sum_actual(&groups, &elem_dtype)?; - let expected = grouped_sum_reference(&elements, &ranges, &valid, &elem_dtype)?; +fn accumulate_grouped_primitive_unsigned( + partials: &mut [SumPartial], + primitive: &PrimitiveArray, + group_ids: &[u32], + validity: &Mask, +) where + T: NativePType + AsPrimitive, +{ + let values = primitive.as_slice::(); + for_each_valid_idx(validity, values.len(), |idx| { + accumulate_grouped_unsigned(partials, group_ids[idx], values[idx].as_()); + }); +} - // First group overflows -> null sum; second group sums normally. - let direct = PrimitiveArray::from_option_iter([None, Some(5i64)]); - assert_arrays_eq!(&actual, &direct.into_array()); - assert_arrays_eq!(&actual, &expected); - Ok(()) - } +fn accumulate_grouped_primitive_unsigned_runs( + partials: &mut [SumPartial], + primitive: &PrimitiveArray, + group_ids: &[u32], +) where + T: NativePType + AsPrimitive, +{ + let values = primitive.as_slice::(); + for_each_group_run(group_ids, |group_id, start, end| { + accumulate_grouped_unsigned_run(partials, group_id, &values[start..end]); + }); +} - #[test] - fn listview_float_nan_and_inf() -> VortexResult<()> { - let elements = PrimitiveArray::new( - buffer![1.0f64, f64::NAN, 2.0, f64::INFINITY, f64::NEG_INFINITY, 4.0], - Validity::NonNullable, - ) - .into_array(); - let elem_dtype = DType::Primitive(PType::F64, NonNullable); - let ranges = [(0, 3), (3, 3)]; - let valid = [true, true]; +fn accumulate_grouped_primitive_signed( + partials: &mut [SumPartial], + primitive: &PrimitiveArray, + group_ids: &[u32], + validity: &Mask, +) where + T: NativePType + AsPrimitive, +{ + let values = primitive.as_slice::(); + for_each_valid_idx(validity, values.len(), |idx| { + accumulate_grouped_signed(partials, group_ids[idx], values[idx].as_()); + }); +} - let groups = listview(elements.clone(), &ranges, &valid)?; - let actual = grouped_sum_actual(&groups, &elem_dtype)?; +fn accumulate_grouped_primitive_signed_runs( + partials: &mut [SumPartial], + primitive: &PrimitiveArray, + group_ids: &[u32], +) where + T: NativePType + AsPrimitive, +{ + let values = primitive.as_slice::(); + for_each_group_run(group_ids, |group_id, start, end| { + accumulate_grouped_signed_run(partials, group_id, &values[start..end]); + }); +} - // Group 0: NaN skipped -> 3.0. Group 1: INF + -INF = NaN. (Avoid array equality here since - // NaN != NaN; compare element scalars against the reference path instead.) - let mut ctx = LEGACY_SESSION.create_execution_ctx(); - let expected = grouped_sum_reference(&elements, &ranges, &valid, &elem_dtype)?; - let g0 = actual.execute_scalar(0, &mut ctx)?; - assert_eq!(g0.as_primitive().typed_value::(), Some(3.0)); - assert_eq!( - g0.as_primitive().typed_value::(), - expected - .execute_scalar(0, &mut ctx)? - .as_primitive() - .typed_value::() - ); - let g1 = actual.execute_scalar(1, &mut ctx)?; - assert!(g1.as_primitive().typed_value::().unwrap().is_nan()); - assert!( - expected - .execute_scalar(1, &mut ctx)? - .as_primitive() - .typed_value::() - .unwrap() - .is_nan() - ); - Ok(()) - } +fn accumulate_grouped_primitive_float( + partials: &mut [SumPartial], + primitive: &PrimitiveArray, + group_ids: &[u32], + validity: &Mask, +) where + T: NativePType + ToPrimitive, +{ + let values = primitive.as_slice::(); + for_each_valid_idx(validity, values.len(), |idx| { + let value = values[idx].to_f64().vortex_expect("float to f64"); + accumulate_grouped_float(partials, group_ids[idx], value); + }); +} - #[test] - fn fixed_size_overflow_and_nan() -> VortexResult<()> { - // FixedSize path: first group overflows -> null sum, second sums normally. - let elements = - PrimitiveArray::new(buffer![i64::MAX, 1, 2, 3], Validity::NonNullable).into_array(); - let elem_dtype = DType::Primitive(PType::I64, NonNullable); - let groups = FixedSizeListArray::try_new(elements.clone(), 2, Validity::NonNullable, 2)? - .into_array(); +fn accumulate_grouped_primitive_float_runs( + partials: &mut [SumPartial], + primitive: &PrimitiveArray, + group_ids: &[u32], +) where + T: NativePType, +{ + let values = primitive.as_slice::(); + for_each_group_run(group_ids, |group_id, start, end| { + accumulate_grouped_float_run(partials, group_id, &values[start..end]); + }); +} - let actual = grouped_sum_actual(&groups, &elem_dtype)?; - let expected = - grouped_sum_reference(&elements, &[(0, 2), (2, 2)], &[true, true], &elem_dtype)?; - let direct = PrimitiveArray::from_option_iter([None, Some(5i64)]); - assert_arrays_eq!(&actual, &direct.into_array()); - assert_arrays_eq!(&actual, &expected); - Ok(()) - } +pub(super) fn accumulate_grouped_bool( + partials: &mut [SumPartial], + bools: &BoolArray, + group_ids: &[u32], + ctx: &mut ExecutionCtx, +) -> VortexResult<()> { + let validity = bools + .as_ref() + .validity()? + .execute_mask(bools.as_ref().len(), ctx)?; + let values = bools.to_bit_buffer(); + for_each_valid_idx(&validity, values.len(), |idx| { + if values.value(idx) { + accumulate_grouped_unsigned(partials, group_ids[idx], 1); + } + }); + Ok(()) } diff --git a/vortex-array/src/aggregate_fn/fns/sum/mod.rs b/vortex-array/src/aggregate_fn/fns/sum/mod.rs index 9d525bec742..eff487d55e8 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/mod.rs @@ -6,7 +6,8 @@ mod constant; mod decimal; mod grouped; mod primitive; -pub(crate) use grouped::PrimitiveGroupedSumEncodingKernel; + +use vortex_buffer::Buffer; use vortex_error::VortexExpect; use vortex_error::VortexResult; use vortex_error::vortex_bail; @@ -21,11 +22,13 @@ use crate::ArrayRef; use crate::Canonical; use crate::Columnar; use crate::ExecutionCtx; +use crate::IntoArray; use crate::aggregate_fn::Accumulator; use crate::aggregate_fn::AggregateFnId; use crate::aggregate_fn::AggregateFnVTable; use crate::aggregate_fn::DynAccumulator; use crate::aggregate_fn::EmptyOptions; +use crate::arrays::PrimitiveArray; use crate::dtype::DType; use crate::dtype::DecimalDType; use crate::dtype::MAX_PRECISION; @@ -36,6 +39,7 @@ use crate::expr::stats::Stat; use crate::expr::stats::StatsProvider; use crate::scalar::DecimalValue; use crate::scalar::Scalar; +use crate::validity::Validity; /// Return the sum of an array. /// @@ -201,6 +205,29 @@ impl AggregateFnVTable for Sum { }) } + fn partials_to_array( + &self, + partials: &[Self::Partial], + partial_dtype: &DType, + ) -> VortexResult> { + Ok(match partial_dtype { + DType::Primitive(PType::U64, _) => Some(sum_primitive_partials_to_array( + partials, + unsigned_sum_state_value, + )), + DType::Primitive(PType::I64, _) => Some(sum_primitive_partials_to_array( + partials, + signed_sum_state_value, + )), + DType::Primitive(PType::F64, _) => Some(sum_primitive_partials_to_array( + partials, + float_sum_state_value, + )), + DType::Decimal(..) => None, + _ => vortex_bail!("Unsupported sum partial dtype: {}", partial_dtype), + }) + } + fn reset(&self, partial: &mut Self::Partial) { partial.current = Some(make_zero_state(&partial.return_dtype)); } @@ -254,6 +281,30 @@ impl AggregateFnVTable for Sum { Ok(()) } + fn accumulate_grouped( + &self, + partials: &mut [Self::Partial], + batch: &Columnar, + group_ids: &[u32], + ctx: &mut ExecutionCtx, + ) -> VortexResult { + match batch { + Columnar::Canonical(Canonical::Primitive(p)) => { + grouped::accumulate_grouped_primitive(partials, p, group_ids, ctx)?; + Ok(true) + } + Columnar::Canonical(Canonical::Bool(b)) => { + grouped::accumulate_grouped_bool(partials, b, group_ids, ctx)?; + Ok(true) + } + // Decimal and constants still use the universal grouped fallback. + Columnar::Canonical(Canonical::Decimal(_)) | Columnar::Constant(_) => Ok(false), + Columnar::Canonical(_) => { + vortex_bail!("Unsupported canonical type for sum: {}", batch.dtype()) + } + } + } + fn finalize(&self, partials: ArrayRef) -> VortexResult { Ok(partials) } @@ -285,6 +336,54 @@ pub enum SumState { }, } +fn sum_primitive_partials_to_array( + partials: &[SumPartial], + value_from_state: fn(&SumState) -> T, +) -> ArrayRef +where + T: crate::dtype::NativePType, +{ + if partials.iter().all(|partial| partial.current.is_some()) { + let values = Buffer::from_iter(partials.iter().map(|partial| { + value_from_state( + partial + .current + .as_ref() + .vortex_expect("checked non-null partial"), + ) + })); + return PrimitiveArray::new(values, Validity::AllValid).into_array(); + } + + PrimitiveArray::from_option_iter( + partials + .iter() + .map(|partial| partial.current.as_ref().map(value_from_state)), + ) + .into_array() +} + +fn unsigned_sum_state_value(state: &SumState) -> u64 { + match state { + SumState::Unsigned(v) => *v, + _ => vortex_panic!("unsigned sum state with non-unsigned partial dtype"), + } +} + +fn signed_sum_state_value(state: &SumState) -> i64 { + match state { + SumState::Signed(v) => *v, + _ => vortex_panic!("signed sum state with non-signed partial dtype"), + } +} + +fn float_sum_state_value(state: &SumState) -> f64 { + match state { + SumState::Float(v) => *v, + _ => vortex_panic!("float sum state with non-float partial dtype"), + } +} + fn make_zero_state(return_dtype: &DType) -> SumState { match return_dtype { DType::Primitive(ptype, _) => match ptype { @@ -347,8 +446,6 @@ mod tests { use crate::arrays::ChunkedArray; use crate::arrays::ConstantArray; use crate::arrays::DecimalArray; - use crate::arrays::FixedSizeListArray; - use crate::arrays::ListViewArray; use crate::arrays::PrimitiveArray; use crate::assert_arrays_eq; use crate::dtype::DType; @@ -513,20 +610,26 @@ mod tests { // Grouped sum tests - fn run_grouped_sum(groups: &ArrayRef, elem_dtype: &DType) -> VortexResult { - let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype.clone())?; - acc.accumulate_list(groups, &mut LEGACY_SESSION.create_execution_ctx())?; - acc.finish() + fn run_grouped_sum( + values: &ArrayRef, + group_ids: &[u32], + num_groups: usize, + ) -> VortexResult { + let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, values.dtype().clone())?; + acc.accumulate( + values, + group_ids, + num_groups, + &mut LEGACY_SESSION.create_execution_ctx(), + )?; + acc.finish(num_groups) } #[test] - fn grouped_sum_fixed_size_list() -> VortexResult<()> { - let elements = + fn grouped_sum_dense_ids() -> VortexResult<()> { + let values = PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6], Validity::NonNullable).into_array(); - let groups = FixedSizeListArray::try_new(elements, 3, Validity::NonNullable, 2)?; - - let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?; + let result = run_grouped_sum(&values, &[0, 0, 0, 1, 1, 1], 2)?; let expected = PrimitiveArray::from_option_iter([Some(6i64), Some(15i64)]).into_array(); assert_arrays_eq!(&result, &expected); @@ -535,13 +638,10 @@ mod tests { #[test] fn grouped_sum_with_null_elements() -> VortexResult<()> { - let elements = + let values = PrimitiveArray::from_option_iter([Some(1i32), None, Some(3), None, Some(5), Some(6)]) .into_array(); - let groups = FixedSizeListArray::try_new(elements, 3, Validity::NonNullable, 2)?; - - let elem_dtype = DType::Primitive(PType::I32, Nullable); - let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?; + let result = run_grouped_sum(&values, &[0, 0, 0, 1, 1, 1], 2)?; let expected = PrimitiveArray::from_option_iter([Some(4i64), Some(11i64)]).into_array(); assert_arrays_eq!(&result, &expected); @@ -549,30 +649,22 @@ mod tests { } #[test] - fn grouped_sum_with_null_group() -> VortexResult<()> { - let elements = - PrimitiveArray::new(buffer![1i32, 2, 3, 4, 5, 6, 7, 8, 9], Validity::NonNullable) - .into_array(); - let validity = Validity::from_iter([true, false, true]); - let groups = FixedSizeListArray::try_new(elements, 3, validity, 3)?; - - let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?; + fn grouped_sum_empty_group() -> VortexResult<()> { + let values = + PrimitiveArray::new(buffer![1i32, 2, 3, 7, 8, 9], Validity::NonNullable).into_array(); + let result = run_grouped_sum(&values, &[0, 0, 0, 2, 2, 2], 3)?; let expected = - PrimitiveArray::from_option_iter([Some(6i64), None, Some(24i64)]).into_array(); + PrimitiveArray::from_option_iter([Some(6i64), Some(0i64), Some(24i64)]).into_array(); assert_arrays_eq!(&result, &expected); Ok(()) } #[test] fn grouped_sum_all_null_elements_in_group() -> VortexResult<()> { - let elements = + let values = PrimitiveArray::from_option_iter([None::, None, Some(3), Some(4)]).into_array(); - let groups = FixedSizeListArray::try_new(elements, 2, Validity::NonNullable, 2)?; - - let elem_dtype = DType::Primitive(PType::I32, Nullable); - let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?; + let result = run_grouped_sum(&values, &[0, 0, 1, 1], 2)?; let expected = PrimitiveArray::from_option_iter([Some(0i64), Some(7i64)]).into_array(); assert_arrays_eq!(&result, &expected); @@ -581,12 +673,8 @@ mod tests { #[test] fn grouped_sum_bool() -> VortexResult<()> { - let elements: BoolArray = [true, false, true, true, true, true].into_iter().collect(); - let groups = - FixedSizeListArray::try_new(elements.into_array(), 3, Validity::NonNullable, 2)?; - - let elem_dtype = DType::Bool(Nullability::NonNullable); - let result = run_grouped_sum(&groups.into_array(), &elem_dtype)?; + let values: BoolArray = [true, false, true, true, true, true].into_iter().collect(); + let result = run_grouped_sum(&values.into_array(), &[0, 0, 0, 1, 1, 1], 2)?; let expected = PrimitiveArray::from_option_iter([Some(2u64), Some(3u64)]).into_array(); assert_arrays_eq!(&result, &expected); @@ -599,19 +687,17 @@ mod tests { let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable); let mut acc = GroupedAccumulator::try_new(Sum, EmptyOptions, elem_dtype)?; - let elements1 = + let values1 = PrimitiveArray::new(buffer![1i32, 2, 3, 4], Validity::NonNullable).into_array(); - let groups1 = FixedSizeListArray::try_new(elements1, 2, Validity::NonNullable, 2)?; - acc.accumulate_list(&groups1.into_array(), &mut ctx)?; - let result1 = acc.finish()?; + acc.accumulate(&values1, &[0, 0, 1, 1], 2, &mut ctx)?; + let result1 = acc.finish(2)?; let expected1 = PrimitiveArray::from_option_iter([Some(3i64), Some(7i64)]).into_array(); assert_arrays_eq!(&result1, &expected1); - let elements2 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array(); - let groups2 = FixedSizeListArray::try_new(elements2, 2, Validity::NonNullable, 1)?; - acc.accumulate_list(&groups2.into_array(), &mut ctx)?; - let result2 = acc.finish()?; + let values2 = PrimitiveArray::new(buffer![10i32, 20], Validity::NonNullable).into_array(); + acc.accumulate(&values2, &[0, 0], 1, &mut ctx)?; + let result2 = acc.finish(1)?; let expected2 = PrimitiveArray::from_option_iter([Some(30i64)]).into_array(); assert_arrays_eq!(&result2, &expected2); @@ -619,24 +705,64 @@ mod tests { } #[test] - fn grouped_sum_listview_out_of_order_offsets_with_null_group() -> VortexResult<()> { - let elements = + fn grouped_sum_out_of_order_group_ids() -> VortexResult<()> { + let values = PrimitiveArray::new(buffer![100i32, 200, 300], Validity::NonNullable).into_array(); - let offsets = PrimitiveArray::new(buffer![2i32, 0, 1], Validity::NonNullable).into_array(); - let sizes = PrimitiveArray::new(buffer![1i32, 1, 1], Validity::NonNullable).into_array(); - let validity = Validity::from_iter([true, false, true]); - let groups = ListViewArray::try_new(elements, offsets, sizes, validity)?.into_array(); + let result = run_grouped_sum(&values, &[2, 0, 1], 3)?; - let elem_dtype = DType::Primitive(PType::I32, Nullability::NonNullable); - let result = run_grouped_sum(&groups, &elem_dtype)?; - - // group 0 -> elements[2..3] = 300; group 1 -> null; group 2 -> elements[1..2] = 200. let expected = - PrimitiveArray::from_option_iter([Some(300i64), None, Some(200i64)]).into_array(); + PrimitiveArray::from_option_iter([Some(200i64), Some(300), Some(100)]).into_array(); + assert_arrays_eq!(&result, &expected); + Ok(()) + } + + #[test] + fn grouped_sum_contiguous_group_runs() -> VortexResult<()> { + let values = PrimitiveArray::new( + buffer![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], + Validity::NonNullable, + ) + .into_array(); + let result = run_grouped_sum(&values, &[0, 0, 0, 0, 1, 1, 1, 1], 2)?; + + let expected = PrimitiveArray::from_option_iter([Some(10.0f64), Some(26.0)]).into_array(); + assert_arrays_eq!(&result, &expected); + Ok(()) + } + + #[test] + fn grouped_sum_overflow_group_is_null() -> VortexResult<()> { + let values = + PrimitiveArray::new(buffer![i64::MAX, 1, 2, 3], Validity::NonNullable).into_array(); + let result = run_grouped_sum(&values, &[0, 0, 1, 1], 2)?; + + let expected = PrimitiveArray::from_option_iter([None, Some(5i64)]).into_array(); assert_arrays_eq!(&result, &expected); Ok(()) } + #[test] + fn grouped_sum_float_nan_and_inf() -> VortexResult<()> { + let values = PrimitiveArray::new( + buffer![1.0f64, f64::NAN, 2.0, f64::INFINITY, f64::NEG_INFINITY, 4.0], + Validity::NonNullable, + ) + .into_array(); + let actual = run_grouped_sum(&values, &[0, 0, 0, 1, 1, 1], 2)?; + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + + let g0 = actual.execute_scalar(0, &mut ctx)?; + assert_eq!(g0.as_primitive().typed_value::(), Some(3.0)); + + let g1 = actual.execute_scalar(1, &mut ctx)?; + let g1_value = g1 + .as_primitive() + .typed_value::() + .vortex_expect("group sum should be non-null"); + assert!(g1_value.is_nan()); + Ok(()) + } + // Chunked array tests #[test] diff --git a/vortex-array/src/aggregate_fn/kernels.rs b/vortex-array/src/aggregate_fn/kernels.rs index c5af0902cbb..e0b1d42e41e 100644 --- a/vortex-array/src/aggregate_fn/kernels.rs +++ b/vortex-array/src/aggregate_fn/kernels.rs @@ -6,12 +6,12 @@ use std::fmt::Debug; +use vortex_buffer::Buffer; use vortex_error::VortexResult; use crate::ArrayRef; use crate::ExecutionCtx; use crate::aggregate_fn::AggregateFnRef; -use crate::aggregate_fn::GroupedArray; use crate::scalar::Scalar; /// A pluggable kernel for an aggregate function. @@ -27,26 +27,53 @@ pub trait DynAggregateKernel: 'static + Send + Sync + Debug { ) -> VortexResult>; } +/// Partial grouped aggregate output produced by an encoding-specific grouped kernel. +/// +/// `group_ids` is parallel to `partials`: each row in `partials` is a partial state for the +/// corresponding dense group ordinal. The ids may repeat, omit, and reorder groups, but must be +/// valid slots in the accumulator's `0..num_groups` range. The grouped accumulator merges this +/// batch through `accumulate_partials`. +#[derive(Clone, Debug)] +pub struct GroupedAggregateKernelResult { + group_ids: Buffer, + partials: ArrayRef, +} + +impl GroupedAggregateKernelResult { + pub fn new(group_ids: Buffer, partials: ArrayRef) -> Self { + Self { + group_ids, + partials, + } + } + + pub fn group_ids(&self) -> &[u32] { + self.group_ids.as_ref() + } + + pub fn partials(&self) -> &ArrayRef { + &self.partials + } +} + /// A pluggable kernel for batch aggregation of many groups. /// -/// A kernel can be registered either for an aggregate function regardless of the element encoding, -/// or for a specific aggregate function and element encoding. Element-encoding kernels are matched -/// on the inner array of the provided grouped array, not on the outer list encoding. This is more -/// pragmatic than having every kernel match on the outer list encoding and having to deal with the -/// possibility of multiple list encodings. +/// A grouped kernel can be registered for an aggregate function regardless of input encoding, or +/// for a specific aggregate function and array encoding. Encoding-specific kernels are matched on +/// the values array, not on a pre-grouped list wrapper. /// -/// Each value in the grouped array represents a group and the result of the grouped aggregate -/// should be an array of the same length, where each element is the aggregate state of the -/// corresponding group. +/// Kernels receive the same dense group ordinals that the caller passed to the grouped accumulator +/// and may aggregate directly in the encoded domain. /// /// Return `Ok(None)` if the kernel cannot be applied to the given aggregate function. pub trait DynGroupedAggregateKernel: 'static + Send + Sync + Debug { - /// Aggregate each group in the provided grouped array and return an array of the aggregate - /// states. + /// Aggregate values into a partial-state batch keyed by dense group ordinal. fn grouped_aggregate( &self, aggregate_fn: &AggregateFnRef, - groups: &GroupedArray, + batch: &ArrayRef, + group_ids: &[u32], + num_groups: usize, ctx: &mut ExecutionCtx, - ) -> VortexResult>; + ) -> VortexResult>; } diff --git a/vortex-array/src/aggregate_fn/session.rs b/vortex-array/src/aggregate_fn/session.rs index c6d7542a687..78b139bf36f 100644 --- a/vortex-array/src/aggregate_fn/session.rs +++ b/vortex-array/src/aggregate_fn/session.rs @@ -18,8 +18,6 @@ use crate::aggregate_fn::fns::all_non_null::AllNonNull; use crate::aggregate_fn::fns::all_null::AllNull; use crate::aggregate_fn::fns::bounded_max::BoundedMax; use crate::aggregate_fn::fns::bounded_min::BoundedMin; -use crate::aggregate_fn::fns::count::Count; -use crate::aggregate_fn::fns::count::CountGroupedKernel; use crate::aggregate_fn::fns::first::First; use crate::aggregate_fn::fns::is_constant::IsConstant; use crate::aggregate_fn::fns::is_sorted::IsSorted; @@ -29,7 +27,6 @@ use crate::aggregate_fn::fns::min::Min; use crate::aggregate_fn::fns::min_max::MinMax; use crate::aggregate_fn::fns::nan_count::NanCount; use crate::aggregate_fn::fns::null_count::NullCount; -use crate::aggregate_fn::fns::sum::PrimitiveGroupedSumEncodingKernel; use crate::aggregate_fn::fns::sum::Sum; use crate::aggregate_fn::fns::uncompressed_size_in_bytes::UncompressedSizeInBytes; use crate::aggregate_fn::kernels::DynAggregateKernel; @@ -39,7 +36,6 @@ use crate::array::ArrayId; use crate::array::VTable; use crate::arrays::Chunked; use crate::arrays::Dict; -use crate::arrays::Primitive; use crate::arrays::chunked::compute::aggregate::ChunkedArrayAggregate; use crate::arrays::dict::compute::is_constant::DictIsConstantKernel; use crate::arrays::dict::compute::is_sorted::DictIsSortedKernel; @@ -108,14 +104,6 @@ impl Default for AggregateFnSession { this.register_aggregate_kernel(Dict.id(), Some(IsConstant.id()), &DictIsConstantKernel); this.register_aggregate_kernel(Dict.id(), Some(IsSorted.id()), &DictIsSortedKernel); - // Register the built-in grouped aggregate kernels. - this.register_grouped_kernel(Count.id(), &CountGroupedKernel); - this.register_grouped_encoding_kernel( - Primitive.id(), - Sum.id(), - &PrimitiveGroupedSumEncodingKernel, - ); - this } } diff --git a/vortex-array/src/aggregate_fn/vtable.rs b/vortex-array/src/aggregate_fn/vtable.rs index 28b91d45166..09eab6c5a9c 100644 --- a/vortex-array/src/aggregate_fn/vtable.rs +++ b/vortex-array/src/aggregate_fn/vtable.rs @@ -115,6 +115,17 @@ pub trait AggregateFnVTable: 'static + Sized + Clone + Send + Sync { /// options and input dtype used to construct the state. fn to_scalar(&self, partial: &Self::Partial) -> VortexResult; + /// Try to convert dense partial states directly into a partial-state array. + /// + /// Returning `Ok(None)` falls back to scalarizing each partial with [`Self::to_scalar`]. + fn partials_to_array( + &self, + _partials: &[Self::Partial], + _partial_dtype: &DType, + ) -> VortexResult> { + Ok(None) + } + /// Reset the state of the accumulator to an empty group. fn reset(&self, partial: &mut Self::Partial); @@ -146,6 +157,37 @@ pub trait AggregateFnVTable: 'static + Sized + Clone + Send + Sync { ctx: &mut ExecutionCtx, ) -> VortexResult<()>; + /// Try to accumulate a raw values batch into dense per-group states before decompression. + /// + /// `group_ids` is parallel to `batch` and contains caller-assigned dense ordinals in + /// `0..states.len()`. Ids may repeat, appear out of order, or be absent from the batch. + /// Returns `true` when the batch was fully handled. + fn try_accumulate_grouped( + &self, + _states: &mut [Self::Partial], + _batch: &ArrayRef, + _group_ids: &[u32], + _ctx: &mut ExecutionCtx, + ) -> VortexResult { + Ok(false) + } + + /// Accumulate a canonical values batch into dense per-group states. + /// + /// `group_ids` is parallel to `batch` and contains caller-assigned dense ordinals in + /// `0..states.len()`. Ids may repeat, appear out of order, or be absent from the batch. + /// Returns `true` when the batch was fully handled. The provided default preserves universal + /// correctness through [`crate::aggregate_fn::GroupedAccumulator`]'s fallback. + fn accumulate_grouped( + &self, + _states: &mut [Self::Partial], + _batch: &Columnar, + _group_ids: &[u32], + _ctx: &mut ExecutionCtx, + ) -> VortexResult { + Ok(false) + } + /// Finalize an array of accumulator states into an array of aggregate results. /// /// The provides `states` array has dtype as specified by `state_dtype`, the result array