diff --git a/vortex-array/src/aggregate_fn/fns/mean/mod.rs b/vortex-array/src/aggregate_fn/fns/mean/mod.rs index 17fb2616d44..ccd1192eb44 100644 --- a/vortex-array/src/aggregate_fn/fns/mean/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/mean/mod.rs @@ -6,6 +6,7 @@ use vortex_error::vortex_bail; use crate::ArrayRef; use crate::ExecutionCtx; +use crate::IntoArray; use crate::aggregate_fn::Accumulator; use crate::aggregate_fn::AggregateFnId; use crate::aggregate_fn::AggregateFnVTable; @@ -17,6 +18,7 @@ use crate::aggregate_fn::combined::CombinedOptions; use crate::aggregate_fn::combined::PairOptions; use crate::aggregate_fn::fns::count::Count; use crate::aggregate_fn::fns::sum::Sum; +use crate::arrays::ConstantArray; use crate::builtins::ArrayBuiltins; use crate::dtype::DType; use crate::dtype::Nullability; @@ -88,8 +90,19 @@ impl BinaryCombined for Mean { _ => DType::Primitive(PType::F64, Nullability::Nullable), }; let sum_cast = sum.cast(target.clone())?; - let count_cast = count.cast(target)?; - sum_cast.binary(count_cast, Operator::Div) + let count_cast = count.cast(target.clone())?; + let mean = sum_cast.binary(count_cast.clone(), Operator::Div)?; + // Nulls are skipped during accumulation, so an all-null group has a count of zero and + // the division produces 0/0 = NaN. The mean of an empty group is null (as in SQL), so + // mask out zero-count entries. This matches `finalize_scalar`. + let non_empty = count_cast + .binary( + ConstantArray::new(Scalar::zero_value(&target), count.len()).into_array(), + Operator::NotEq, + )? + // A null count means a null group; keep it masked out. + .fill_null(false)?; + mean.mask(non_empty) } fn finalize_scalar(&self, left_scalar: Scalar, right_scalar: Scalar) -> VortexResult { @@ -104,9 +117,7 @@ impl BinaryCombined for Mean { let sum = sum_cast.as_primitive().typed_value::(); let count = count_cast.as_primitive().typed_value::(); let value = match (sum, count) { - (None, _) | (_, None) => return Ok(Scalar::null(target)), // Sum overflowed - // A count of zero yields 0/0 = NaN, matching the array `finalize` path: nulls are - // skipped during accumulation, so an all-null input is an empty mean, not null. + (None, _) | (_, None) | (_, Some(0.0)) => return Ok(Scalar::null(target)), // Sum overflowed (Some(s), Some(c)) => s / c, }; Ok(Scalar::primitive(value, Nullability::Nullable)) @@ -164,12 +175,13 @@ mod tests { use vortex_error::VortexResult; use super::*; - use crate::IntoArray; use crate::LEGACY_SESSION; use crate::VortexSessionExecute; + use crate::aggregate_fn::DynGroupedAccumulator; + use crate::aggregate_fn::GroupedAccumulator; use crate::arrays::BoolArray; use crate::arrays::ChunkedArray; - use crate::arrays::ConstantArray; + use crate::arrays::FixedSizeListArray; use crate::arrays::PrimitiveArray; use crate::validity::Validity; @@ -232,11 +244,79 @@ mod tests { } #[test] - fn mean_all_null_returns_nan() -> VortexResult<()> { + fn mean_all_null_returns_null() -> VortexResult<()> { let array = PrimitiveArray::from_option_iter::([None, None, None]).into_array(); let mut ctx = LEGACY_SESSION.create_execution_ctx(); let result = mean(&array, &mut ctx)?; - assert!(result.as_primitive().as_::().is_some_and(f64::is_nan)); + assert_eq!(result.as_primitive().as_::(), None); + Ok(()) + } + + fn mean_cases() -> Vec<(Vec>, Option)> { + vec![ + (vec![Some(f64::NAN), Some(1.0), None], Some(f64::NAN)), + (vec![Some(f64::NAN), Some(1.0), Some(1.0)], Some(f64::NAN)), + (vec![None, None, Some(f64::NAN)], Some(f64::NAN)), + (vec![Some(f64::NAN), Some(1.0), Some(1.0)], Some(f64::NAN)), + (vec![None, None, None], None), + (vec![Some(1.0), Some(2.0), Some(3.0)], Some(2.0)), + ] + } + + fn assert_mean(actual: Option, expected: Option, case: usize) { + match expected { + Some(e) if e.is_nan() => assert!( + actual.is_some_and(f64::is_nan), + "case {case}: expected NaN, got {actual:?}" + ), + _ => assert_eq!(actual, expected, "case {case}"), + } + } + + #[test] + fn mean_via_combined_partials() -> VortexResult<()> { + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + for (case, (group, expected)) in mean_cases().into_iter().enumerate() { + let mut acc = Accumulator::try_new( + Mean::combined(), + PairOptions(EmptyOptions, EmptyOptions), + DType::Primitive(PType::F64, Nullability::Nullable), + )?; + // Two batches per group so the result goes through partial combination and + // `finalize_scalar`. + let (head, tail) = group.split_at(2); + let head = PrimitiveArray::from_option_iter(head.iter().copied()).into_array(); + let tail = PrimitiveArray::from_option_iter(tail.iter().copied()).into_array(); + acc.accumulate(&head, &mut ctx)?; + acc.accumulate(&tail, &mut ctx)?; + let result = acc.finish()?; + assert_mean(result.as_primitive().as_::(), expected, case); + } + Ok(()) + } + + #[test] + fn mean_via_grouped_finalize() -> VortexResult<()> { + let cases = mean_cases(); + let elements = PrimitiveArray::from_option_iter( + cases.iter().flat_map(|(group, _)| group.iter().copied()), + ) + .into_array(); + let groups = FixedSizeListArray::try_new(elements, 3, Validity::NonNullable, cases.len())?; + + let mut acc = GroupedAccumulator::try_new( + Mean::combined(), + PairOptions(EmptyOptions, EmptyOptions), + DType::Primitive(PType::F64, Nullability::Nullable), + )?; + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + acc.accumulate_list(&groups.into_array(), &mut ctx)?; + let result = acc.finish()?; + + for (case, (_, expected)) in cases.into_iter().enumerate() { + let actual = result.execute_scalar(case, &mut ctx)?; + assert_mean(actual.as_primitive().as_::(), expected, case); + } Ok(()) } diff --git a/vortex-array/src/aggregate_fn/fns/sum/grouped.rs b/vortex-array/src/aggregate_fn/fns/sum/grouped.rs index 6f00cce7fdb..666da929d89 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/grouped.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/grouped.rs @@ -43,7 +43,7 @@ impl DynGroupedAggregateKernel for PrimitiveGroupedSumEncodingKernel { /// /// Reuses the scalar primitive-sum reductions ([`sum_unsigned_all`]/[`sum_signed_all`]/ /// [`sum_float_all`]) so the per-group semantics match scalar `sum` exactly (overflow saturates to -/// a null sum, NaNs are skipped). The element validity mask is materialized once and sliced per +/// a null sum, NaNs propagate). The element validity mask is materialized once and sliced per /// group, rather than the per-group accumulator setup of the generic fallback path. pub(super) fn try_grouped_sum( groups: &GroupedArray, @@ -321,18 +321,19 @@ mod tests { let groups = listview(elements.clone(), &ranges, &valid)?; let actual = grouped_sum_actual(&groups, &elem_dtype)?; - // Group 0: NaN skipped -> 3.0. Group 1: INF + -INF = NaN. (Avoid array equality here since - // NaN != NaN; compare element scalars against the reference path instead.) + // Group 0: NaN propagates -> NaN. Group 1: INF + -INF = NaN. (Avoid array equality here + // since NaN != NaN; compare element scalars against the reference path instead.) let mut ctx = LEGACY_SESSION.create_execution_ctx(); let expected = grouped_sum_reference(&elements, &ranges, &valid, &elem_dtype)?; let g0 = actual.execute_scalar(0, &mut ctx)?; - assert_eq!(g0.as_primitive().typed_value::(), Some(3.0)); - assert_eq!( - g0.as_primitive().typed_value::(), + assert!(g0.as_primitive().typed_value::().unwrap().is_nan()); + assert!( expected .execute_scalar(0, &mut ctx)? .as_primitive() .typed_value::() + .unwrap() + .is_nan() ); let g1 = actual.execute_scalar(1, &mut ctx)?; assert!(g1.as_primitive().typed_value::().unwrap().is_nan()); diff --git a/vortex-array/src/aggregate_fn/fns/sum/mod.rs b/vortex-array/src/aggregate_fn/fns/sum/mod.rs index 9d525bec742..03f490c05c8 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/mod.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/mod.rs @@ -64,6 +64,7 @@ pub fn sum(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { /// /// If the sum overflows, a null scalar will be returned. /// If the array is all-invalid, the sum will be zero. +/// Float NaN values propagate: any NaN in the input makes the sum NaN. #[derive(Clone, Debug)] pub struct Sum; diff --git a/vortex-array/src/aggregate_fn/fns/sum/primitive.rs b/vortex-array/src/aggregate_fn/fns/sum/primitive.rs index df7d929d896..2d8a927daf5 100644 --- a/vortex-array/src/aggregate_fn/fns/sum/primitive.rs +++ b/vortex-array/src/aggregate_fn/fns/sum/primitive.rs @@ -58,13 +58,11 @@ fn accumulate_primitive_all(inner: &mut SumState, p: &PrimitiveArray) -> VortexR } } -/// Sum the non-NaN values of a float slice into an `f64` accumulator. NaNs are skipped to match the -/// scalar `sum` semantics. Floats cannot overflow the accumulator, so this never reports saturation. +/// Sum a float slice into an `f64` accumulator. NaN values propagate into the sum, as in IEEE 754 +/// addition. Floats cannot overflow the accumulator, so this never reports saturation. pub(super) fn sum_float_all(acc: &mut f64, slice: &[T]) { for &v in slice { - if !v.is_nan() { - *acc += ToPrimitive::to_f64(&v).vortex_expect("float to f64"); - } + *acc += ToPrimitive::to_f64(&v).vortex_expect("float to f64"); } } @@ -297,7 +295,7 @@ mod tests { ) .into_array(); let result = sum(&arr, &mut LEGACY_SESSION.create_execution_ctx())?; - assert_eq!(result.as_primitive().typed_value::(), Some(6.0)); + assert!(result.as_primitive().typed_value::().unwrap().is_nan()); Ok(()) } @@ -306,7 +304,7 @@ mod tests { let arr = PrimitiveArray::new(buffer![1.0f32, f32::NAN, 4.0], Validity::NonNullable).into_array(); let result = sum(&arr, &mut LEGACY_SESSION.create_execution_ctx())?; - assert_eq!(result.as_primitive().typed_value::(), Some(5.0)); + assert!(result.as_primitive().typed_value::().unwrap().is_nan()); Ok(()) } @@ -315,7 +313,7 @@ mod tests { let arr = PrimitiveArray::from_option_iter([Some(1.0f64), None, Some(f64::NAN), Some(3.0)]) .into_array(); let result = sum(&arr, &mut LEGACY_SESSION.create_execution_ctx())?; - assert_eq!(result.as_primitive().typed_value::(), Some(4.0)); + assert!(result.as_primitive().typed_value::().unwrap().is_nan()); Ok(()) } @@ -324,7 +322,7 @@ mod tests { let arr = PrimitiveArray::new(buffer![f64::NAN, f64::NAN], Validity::NonNullable).into_array(); let result = sum(&arr, &mut LEGACY_SESSION.create_execution_ctx())?; - assert_eq!(result.as_primitive().typed_value::(), Some(0.0)); + assert!(result.as_primitive().typed_value::().unwrap().is_nan()); Ok(()) }