Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 89 additions & 9 deletions vortex-array/src/aggregate_fn/fns/mean/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use vortex_error::vortex_bail;

use crate::ArrayRef;
use crate::ExecutionCtx;
use crate::IntoArray;
use crate::aggregate_fn::Accumulator;
use crate::aggregate_fn::AggregateFnId;
use crate::aggregate_fn::AggregateFnVTable;
Expand All @@ -17,6 +18,7 @@ use crate::aggregate_fn::combined::CombinedOptions;
use crate::aggregate_fn::combined::PairOptions;
use crate::aggregate_fn::fns::count::Count;
use crate::aggregate_fn::fns::sum::Sum;
use crate::arrays::ConstantArray;
use crate::builtins::ArrayBuiltins;
use crate::dtype::DType;
use crate::dtype::Nullability;
Expand Down Expand Up @@ -88,8 +90,19 @@ impl BinaryCombined for Mean {
_ => DType::Primitive(PType::F64, Nullability::Nullable),
};
let sum_cast = sum.cast(target.clone())?;
let count_cast = count.cast(target)?;
sum_cast.binary(count_cast, Operator::Div)
let count_cast = count.cast(target.clone())?;
let mean = sum_cast.binary(count_cast.clone(), Operator::Div)?;
// Nulls are skipped during accumulation, so an all-null group has a count of zero and
// the division produces 0/0 = NaN. The mean of an empty group is null (as in SQL), so
// mask out zero-count entries. This matches `finalize_scalar`.
let non_empty = count_cast
.binary(
ConstantArray::new(Scalar::zero_value(&target), count.len()).into_array(),
Operator::NotEq,
)?
// A null count means a null group; keep it masked out.
.fill_null(false)?;
mean.mask(non_empty)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

}

fn finalize_scalar(&self, left_scalar: Scalar, right_scalar: Scalar) -> VortexResult<Scalar> {
Expand All @@ -104,9 +117,7 @@ impl BinaryCombined for Mean {
let sum = sum_cast.as_primitive().typed_value::<f64>();
let count = count_cast.as_primitive().typed_value::<f64>();
let value = match (sum, count) {
(None, _) | (_, None) => return Ok(Scalar::null(target)), // Sum overflowed
// A count of zero yields 0/0 = NaN, matching the array `finalize` path: nulls are
// skipped during accumulation, so an all-null input is an empty mean, not null.
(None, _) | (_, None) | (_, Some(0.0)) => return Ok(Scalar::null(target)), // Sum overflowed
(Some(s), Some(c)) => s / c,
};
Ok(Scalar::primitive(value, Nullability::Nullable))
Expand Down Expand Up @@ -164,12 +175,13 @@ mod tests {
use vortex_error::VortexResult;

use super::*;
use crate::IntoArray;
use crate::LEGACY_SESSION;
use crate::VortexSessionExecute;
use crate::aggregate_fn::DynGroupedAccumulator;
use crate::aggregate_fn::GroupedAccumulator;
use crate::arrays::BoolArray;
use crate::arrays::ChunkedArray;
use crate::arrays::ConstantArray;
use crate::arrays::FixedSizeListArray;
use crate::arrays::PrimitiveArray;
use crate::validity::Validity;

Expand Down Expand Up @@ -232,11 +244,79 @@ mod tests {
}

#[test]
fn mean_all_null_returns_nan() -> VortexResult<()> {
fn mean_all_null_returns_null() -> VortexResult<()> {
let array = PrimitiveArray::from_option_iter::<f64, _>([None, None, None]).into_array();
let mut ctx = LEGACY_SESSION.create_execution_ctx();
let result = mean(&array, &mut ctx)?;
assert!(result.as_primitive().as_::<f64>().is_some_and(f64::is_nan));
assert_eq!(result.as_primitive().as_::<f64>(), None);
Ok(())
}

fn mean_cases() -> Vec<(Vec<Option<f64>>, Option<f64>)> {
vec![
(vec![Some(f64::NAN), Some(1.0), None], Some(f64::NAN)),
(vec![Some(f64::NAN), Some(1.0), Some(1.0)], Some(f64::NAN)),
(vec![None, None, Some(f64::NAN)], Some(f64::NAN)),
(vec![Some(f64::NAN), Some(1.0), Some(1.0)], Some(f64::NAN)),
(vec![None, None, None], None),
(vec![Some(1.0), Some(2.0), Some(3.0)], Some(2.0)),
]
}

fn assert_mean(actual: Option<f64>, expected: Option<f64>, case: usize) {
match expected {
Some(e) if e.is_nan() => assert!(
actual.is_some_and(f64::is_nan),
"case {case}: expected NaN, got {actual:?}"
),
_ => assert_eq!(actual, expected, "case {case}"),
}
}

#[test]
fn mean_via_combined_partials() -> VortexResult<()> {
let mut ctx = LEGACY_SESSION.create_execution_ctx();
for (case, (group, expected)) in mean_cases().into_iter().enumerate() {
let mut acc = Accumulator::try_new(
Mean::combined(),
PairOptions(EmptyOptions, EmptyOptions),
DType::Primitive(PType::F64, Nullability::Nullable),
)?;
// Two batches per group so the result goes through partial combination and
// `finalize_scalar`.
let (head, tail) = group.split_at(2);
let head = PrimitiveArray::from_option_iter(head.iter().copied()).into_array();
let tail = PrimitiveArray::from_option_iter(tail.iter().copied()).into_array();
acc.accumulate(&head, &mut ctx)?;
acc.accumulate(&tail, &mut ctx)?;
let result = acc.finish()?;
assert_mean(result.as_primitive().as_::<f64>(), expected, case);
}
Ok(())
}

#[test]
fn mean_via_grouped_finalize() -> VortexResult<()> {
let cases = mean_cases();
let elements = PrimitiveArray::from_option_iter(
cases.iter().flat_map(|(group, _)| group.iter().copied()),
)
.into_array();
let groups = FixedSizeListArray::try_new(elements, 3, Validity::NonNullable, cases.len())?;

let mut acc = GroupedAccumulator::try_new(
Mean::combined(),
PairOptions(EmptyOptions, EmptyOptions),
DType::Primitive(PType::F64, Nullability::Nullable),
)?;
let mut ctx = LEGACY_SESSION.create_execution_ctx();
acc.accumulate_list(&groups.into_array(), &mut ctx)?;
let result = acc.finish()?;

for (case, (_, expected)) in cases.into_iter().enumerate() {
let actual = result.execute_scalar(case, &mut ctx)?;
assert_mean(actual.as_primitive().as_::<f64>(), expected, case);
}
Ok(())
}

Expand Down
13 changes: 7 additions & 6 deletions vortex-array/src/aggregate_fn/fns/sum/grouped.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ impl DynGroupedAggregateKernel for PrimitiveGroupedSumEncodingKernel {
///
/// Reuses the scalar primitive-sum reductions ([`sum_unsigned_all`]/[`sum_signed_all`]/
/// [`sum_float_all`]) so the per-group semantics match scalar `sum` exactly (overflow saturates to
/// a null sum, NaNs are skipped). The element validity mask is materialized once and sliced per
/// a null sum, NaNs propagate). The element validity mask is materialized once and sliced per
/// group, rather than the per-group accumulator setup of the generic fallback path.
pub(super) fn try_grouped_sum(
groups: &GroupedArray,
Expand Down Expand Up @@ -321,18 +321,19 @@ mod tests {
let groups = listview(elements.clone(), &ranges, &valid)?;
let actual = grouped_sum_actual(&groups, &elem_dtype)?;

// Group 0: NaN skipped -> 3.0. Group 1: INF + -INF = NaN. (Avoid array equality here since
// NaN != NaN; compare element scalars against the reference path instead.)
// Group 0: NaN propagates -> NaN. Group 1: INF + -INF = NaN. (Avoid array equality here
// since NaN != NaN; compare element scalars against the reference path instead.)
let mut ctx = LEGACY_SESSION.create_execution_ctx();
let expected = grouped_sum_reference(&elements, &ranges, &valid, &elem_dtype)?;
let g0 = actual.execute_scalar(0, &mut ctx)?;
assert_eq!(g0.as_primitive().typed_value::<f64>(), Some(3.0));
assert_eq!(
g0.as_primitive().typed_value::<f64>(),
assert!(g0.as_primitive().typed_value::<f64>().unwrap().is_nan());
assert!(
expected
.execute_scalar(0, &mut ctx)?
.as_primitive()
.typed_value::<f64>()
.unwrap()
.is_nan()
);
let g1 = actual.execute_scalar(1, &mut ctx)?;
assert!(g1.as_primitive().typed_value::<f64>().unwrap().is_nan());
Expand Down
1 change: 1 addition & 0 deletions vortex-array/src/aggregate_fn/fns/sum/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ pub fn sum(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult<Scalar> {
///
/// If the sum overflows, a null scalar will be returned.
/// If the array is all-invalid, the sum will be zero.
/// Float NaN values propagate: any NaN in the input makes the sum NaN.
#[derive(Clone, Debug)]
pub struct Sum;

Expand Down
16 changes: 7 additions & 9 deletions vortex-array/src/aggregate_fn/fns/sum/primitive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,11 @@ fn accumulate_primitive_all(inner: &mut SumState, p: &PrimitiveArray) -> VortexR
}
}

/// Sum the non-NaN values of a float slice into an `f64` accumulator. NaNs are skipped to match the
/// scalar `sum` semantics. Floats cannot overflow the accumulator, so this never reports saturation.
/// Sum a float slice into an `f64` accumulator. NaN values propagate into the sum, as in IEEE 754
/// addition. Floats cannot overflow the accumulator, so this never reports saturation.
pub(super) fn sum_float_all<T: NativePType>(acc: &mut f64, slice: &[T]) {
for &v in slice {
if !v.is_nan() {
*acc += ToPrimitive::to_f64(&v).vortex_expect("float to f64");
}
*acc += ToPrimitive::to_f64(&v).vortex_expect("float to f64");
}
}

Expand Down Expand Up @@ -297,7 +295,7 @@ mod tests {
)
.into_array();
let result = sum(&arr, &mut LEGACY_SESSION.create_execution_ctx())?;
assert_eq!(result.as_primitive().typed_value::<f64>(), Some(6.0));
assert!(result.as_primitive().typed_value::<f64>().unwrap().is_nan());
Ok(())
}

Expand All @@ -306,7 +304,7 @@ mod tests {
let arr =
PrimitiveArray::new(buffer![1.0f32, f32::NAN, 4.0], Validity::NonNullable).into_array();
let result = sum(&arr, &mut LEGACY_SESSION.create_execution_ctx())?;
assert_eq!(result.as_primitive().typed_value::<f64>(), Some(5.0));
assert!(result.as_primitive().typed_value::<f64>().unwrap().is_nan());
Ok(())
}

Expand All @@ -315,7 +313,7 @@ mod tests {
let arr = PrimitiveArray::from_option_iter([Some(1.0f64), None, Some(f64::NAN), Some(3.0)])
.into_array();
let result = sum(&arr, &mut LEGACY_SESSION.create_execution_ctx())?;
assert_eq!(result.as_primitive().typed_value::<f64>(), Some(4.0));
assert!(result.as_primitive().typed_value::<f64>().unwrap().is_nan());
Ok(())
}

Expand All @@ -324,7 +322,7 @@ mod tests {
let arr =
PrimitiveArray::new(buffer![f64::NAN, f64::NAN], Validity::NonNullable).into_array();
let result = sum(&arr, &mut LEGACY_SESSION.create_execution_ctx())?;
assert_eq!(result.as_primitive().typed_value::<f64>(), Some(0.0));
assert!(result.as_primitive().typed_value::<f64>().unwrap().is_nan());
Ok(())
}

Expand Down
Loading