From e5a7072c3895dad2c9327f3397cdc34c3340b84c Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Fri, 12 Jun 2026 15:58:41 -0400 Subject: [PATCH 1/3] Add binary operator benchmarks Signed-off-by: Nicholas Gates --- vortex-array/Cargo.toml | 4 + vortex-array/benches/binary_ops.rs | 154 +++++++++++++++++++++++++++++ 2 files changed, 158 insertions(+) create mode 100644 vortex-array/benches/binary_ops.rs diff --git a/vortex-array/Cargo.toml b/vortex-array/Cargo.toml index 4ec8a83575e..de3bb416985 100644 --- a/vortex-array/Cargo.toml +++ b/vortex-array/Cargo.toml @@ -125,6 +125,10 @@ harness = false name = "compare" harness = false +[[bench]] +name = "binary_ops" +harness = false + [[bench]] name = "interleave" harness = false diff --git a/vortex-array/benches/binary_ops.rs b/vortex-array/benches/binary_ops.rs new file mode 100644 index 00000000000..a5e3f9a5594 --- /dev/null +++ b/vortex-array/benches/binary_ops.rs @@ -0,0 +1,154 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +#![expect(clippy::unwrap_used)] +#![expect( + clippy::cast_possible_truncation, + reason = "benchmark fixtures use indices that fit in the chosen widths" +)] + +use std::sync::LazyLock; + +use divan::Bencher; +use divan::counter::ItemsCount; +use vortex_array::ArrayRef; +use vortex_array::Executable; +use vortex_array::IntoArray; +use vortex_array::VortexSessionExecute; +use vortex_array::arrays::BoolArray; +use vortex_array::arrays::ConstantArray; +use vortex_array::arrays::PrimitiveArray; +use vortex_array::builtins::ArrayBuiltins; +use vortex_array::scalar_fn::fns::operators::Operator; +use vortex_array::session::ArraySession; +use vortex_session::VortexSession; + +fn main() { + divan::main(); +} + +static SESSION: LazyLock = + LazyLock::new(|| VortexSession::empty().with::()); + +const LEN: usize = 65_536; + +#[divan::bench] +fn add_i64_nonnull(bencher: Bencher) { + let lhs = primitive_nonnull(0).into_array(); + let rhs = primitive_nonnull(1_000_000).into_array(); + + bench_primitive(bencher, lhs, rhs, Operator::Add); +} + +#[divan::bench] +fn add_i64_nullable(bencher: Bencher) { + let lhs = primitive_nullable(0, 7).into_array(); + let rhs = primitive_nullable(1_000_000, 5).into_array(); + + bench_primitive(bencher, lhs, rhs, Operator::Add); +} + +#[divan::bench] +fn mul_i64_nonnull(bencher: Bencher) { + let lhs = primitive_small_nonnull(1).into_array(); + let rhs = primitive_small_nonnull(17).into_array(); + + bench_primitive(bencher, lhs, rhs, Operator::Mul); +} + +#[divan::bench] +fn div_i64_nonnull(bencher: Bencher) { + let lhs = primitive_nonnull(1_000_000).into_array(); + let rhs = primitive_nonzero().into_array(); + + bench_primitive(bencher, lhs, rhs, Operator::Div); +} + +#[divan::bench] +fn sub_i64_constant(bencher: Bencher) { + let lhs = primitive_nonnull(0).into_array(); + let rhs = ConstantArray::new(37i64, LEN).into_array(); + + bench_primitive(bencher, lhs, rhs, Operator::Sub); +} + +#[divan::bench] +fn eq_i64_constant(bencher: Bencher) { + let lhs = primitive_nonnull(0).into_array(); + let rhs = ConstantArray::new(1024i64, LEN).into_array(); + + bench_bool(bencher, lhs, rhs, Operator::Eq); +} + +#[divan::bench] +fn lt_i64_nullable(bencher: Bencher) { + let lhs = primitive_nullable(0, 7).into_array(); + let rhs = primitive_nullable(1_000_000, 5).into_array(); + + bench_bool(bencher, lhs, rhs, Operator::Lt); +} + +#[divan::bench] +fn and_bool_nullable(bencher: Bencher) { + let lhs = bool_nullable(2, 7).into_array(); + let rhs = bool_nullable(3, 5).into_array(); + + bench_bool(bencher, lhs, rhs, Operator::And); +} + +#[divan::bench] +fn or_bool_constant(bencher: Bencher) { + let lhs = bool_nullable(2, 7).into_array(); + let rhs = ConstantArray::new(true, LEN).into_array(); + + bench_bool(bencher, lhs, rhs, Operator::Or); +} + +fn bench_primitive(bencher: Bencher, lhs: ArrayRef, rhs: ArrayRef, operator: Operator) { + bench_binary::(bencher, lhs, rhs, operator); +} + +fn bench_bool(bencher: Bencher, lhs: ArrayRef, rhs: ArrayRef, operator: Operator) { + bench_binary::(bencher, lhs, rhs, operator); +} + +fn bench_binary( + bencher: Bencher, + lhs: ArrayRef, + rhs: ArrayRef, + operator: Operator, +) { + let mut ctx = SESSION.create_execution_ctx(); + + bencher.counter(ItemsCount::new(LEN)).bench_local(|| { + lhs.clone() + .binary(rhs.clone(), operator) + .unwrap() + .execute::(&mut ctx) + .unwrap() + }); +} + +fn primitive_nonnull(base: i64) -> PrimitiveArray { + PrimitiveArray::from_iter((0..LEN as i64).map(|i| base + i)) +} + +fn primitive_small_nonnull(offset: i64) -> PrimitiveArray { + PrimitiveArray::from_iter((0..LEN as i64).map(|i| ((i + offset) % 1024) + 1)) +} + +fn primitive_nonzero() -> PrimitiveArray { + PrimitiveArray::from_iter((0..LEN as i64).map(|i| (i % 255) + 1)) +} + +fn primitive_nullable(base: i64, null_every: usize) -> PrimitiveArray { + PrimitiveArray::from_option_iter( + (0..LEN as i64).map(|i| (!(i as usize).is_multiple_of(null_every)).then_some(base + i)), + ) +} + +fn bool_nullable(true_every: usize, null_every: usize) -> BoolArray { + BoolArray::from_iter( + (0..LEN).map(|i| (!i.is_multiple_of(null_every)).then_some(i.is_multiple_of(true_every))), + ) +} From ddeb791b2a8a8bb555d526a11c33e8c475fddbc0 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Fri, 12 Jun 2026 16:58:16 -0400 Subject: [PATCH 2/3] Reimplement arithmetic scalar functions natively Signed-off-by: "Nicholas Gates" --- .../src/scalar_fn/fns/binary/numeric.rs | 946 +++++++++++++++++- 1 file changed, 923 insertions(+), 23 deletions(-) diff --git a/vortex-array/src/scalar_fn/fns/binary/numeric.rs b/vortex-array/src/scalar_fn/fns/binary/numeric.rs index 2ef58fb5209..bb622a03abc 100644 --- a/vortex-array/src/scalar_fn/fns/binary/numeric.rs +++ b/vortex-array/src/scalar_fn/fns/binary/numeric.rs @@ -1,21 +1,31 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use vortex_buffer::Buffer; +use vortex_buffer::BufferMut; use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_err; +use vortex_mask::Mask; use crate::ArrayRef; +use crate::ExecutionCtx; use crate::IntoArray; use crate::arrays::Constant; use crate::arrays::ConstantArray; -use crate::arrow::Datum; -use crate::arrow::from_arrow_columnar; -use crate::executor::ExecutionCtx; +use crate::arrays::PrimitiveArray; +use crate::dtype::NativePType; +use crate::dtype::PType; +use crate::dtype::half::f16; +use crate::match_each_native_ptype; use crate::scalar::NumericOperator; +use crate::validity::Validity; /// Execute a numeric operation between two arrays. /// -/// This is the entry point for numeric operations from the binary expression. -/// Handles constant-constant directly, otherwise falls back to Arrow. +/// This is the entry point for numeric operations from the binary expression. The implementation +/// keeps constants scalar, canonicalizes non-constant inputs to primitive buffers, and accumulates +/// integer arithmetic failures before returning a single operation-level error. pub(crate) fn execute_numeric( lhs: &ArrayRef, rhs: &ArrayRef, @@ -25,30 +35,71 @@ pub(crate) fn execute_numeric( if let Some(result) = constant_numeric(lhs, rhs, op)? { return Ok(result); } - arrow_numeric(lhs, rhs, op, ctx) + + native_numeric(lhs, rhs, op, ctx) } -/// Implementation of numeric operations using the Arrow crate. -pub(crate) fn arrow_numeric( +fn native_numeric( lhs: &ArrayRef, rhs: &ArrayRef, - operator: NumericOperator, + op: NumericOperator, ctx: &mut ExecutionCtx, ) -> VortexResult { - let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable(); + let ptype = PType::try_from(lhs.dtype())?; + if !lhs.dtype().eq_ignore_nullability(rhs.dtype()) { + vortex_bail!( + "numeric operator {} requires matching primitive types, got {} and {}", + op, + lhs.dtype(), + rhs.dtype() + ); + } + + match_each_native_ptype!(ptype, |T| { execute_numeric_typed::(lhs, rhs, op, ctx) }) +} + +fn execute_numeric_typed( + lhs: &ArrayRef, + rhs: &ArrayRef, + op: NumericOperator, + ctx: &mut ExecutionCtx, +) -> VortexResult { + let lhs = NumericOperand::::try_new(lhs, ctx)?; + let rhs = NumericOperand::::try_new(rhs, ctx)?; let len = lhs.len(); + if len != rhs.len() { + vortex_bail!( + "numeric operator {} requires equal lengths, got {} and {}", + op, + len, + rhs.len() + ); + } - let left = Datum::try_new(lhs, ctx)?; - let right = Datum::try_new_with_target_datatype(rhs, left.data_type(), ctx)?; + let validity = lhs.validity().and(rhs.validity())?; + let valid_rows = ValidRows::from_validity(&validity, len, ctx)?; + if valid_rows.is_none() { + return Ok(PrimitiveArray::new(Buffer::::zeroed(len), validity).into_array()); + } - let array = match operator { - NumericOperator::Add => arrow_arith::numeric::add(&left, &right)?, - NumericOperator::Sub => arrow_arith::numeric::sub(&left, &right)?, - NumericOperator::Mul => arrow_arith::numeric::mul(&left, &right)?, - NumericOperator::Div => arrow_arith::numeric::div(&left, &right)?, + let values = match (&lhs, &rhs) { + (NumericOperand::Array(lhs), NumericOperand::Array(rhs)) => { + T::apply_array_array(lhs.values(), rhs.values(), op, &valid_rows)? + } + (NumericOperand::Array(lhs), NumericOperand::Constant { value: rhs, .. }) => { + T::apply_array_constant(lhs.values(), *rhs, op, &valid_rows)? + } + (NumericOperand::Constant { value: lhs, .. }, NumericOperand::Array(rhs)) => { + T::apply_constant_array(*lhs, rhs.values(), op, &valid_rows)? + } + ( + NumericOperand::Constant { value: lhs, .. }, + NumericOperand::Constant { value: rhs, .. }, + ) => BufferMut::full(T::apply_scalar(*lhs, *rhs, op)?, len).freeze(), + (NumericOperand::Null(_), _) | (_, NumericOperand::Null(_)) => Buffer::::zeroed(len), }; - from_arrow_columnar(array.as_ref(), len, nullable, ctx) + Ok(PrimitiveArray::new(values, validity).into_array()) } fn constant_numeric( @@ -60,18 +111,805 @@ fn constant_numeric( return Ok(None); }; - let Some(result) = lhs + let result = lhs .scalar() .as_primitive() .checked_binary_numeric(&rhs.scalar().as_primitive(), op) - else { - // Overflow detected — fall through to arrow_numeric which uses wrapping arithmetic. - return Ok(None); - }; + .ok_or_else(|| numeric_error(op))?; Ok(Some(ConstantArray::new(result, lhs.len()).into_array())) } +enum NumericOperand { + Array(TypedPrimitive), + Constant { + value: T, + len: usize, + validity: Validity, + }, + Null(usize), +} + +impl NumericOperand { + fn try_new(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { + if let Some(constant) = array.as_opt::() { + return Ok( + match constant.scalar().as_primitive().try_typed_value::()? { + Some(value) => Self::Constant { + value, + len: array.len(), + validity: if constant.scalar().dtype().is_nullable() { + Validity::AllValid + } else { + Validity::NonNullable + }, + }, + None => Self::Null(array.len()), + }, + ); + } + + Ok(Self::Array(TypedPrimitive::new( + array.clone().execute::(ctx)?, + )?)) + } + + fn len(&self) -> usize { + match self { + Self::Array(array) => array.values().len(), + Self::Constant { len, .. } | Self::Null(len) => *len, + } + } + + fn validity(&self) -> Validity { + match self { + Self::Array(array) => array.validity(), + Self::Constant { validity, .. } => validity.clone(), + Self::Null(_) => Validity::AllInvalid, + } + } +} + +struct TypedPrimitive { + values: Buffer, + validity: Validity, +} + +impl TypedPrimitive { + fn new(array: PrimitiveArray) -> VortexResult { + let validity = array.validity()?; + let values = array.into_buffer::(); + Ok(Self { values, validity }) + } + + fn values(&self) -> &[T] { + &self.values + } + + fn validity(&self) -> Validity { + self.validity.clone() + } +} + +enum ValidRows { + All, + Some(Mask), + None, +} + +impl ValidRows { + fn from_validity( + validity: &Validity, + len: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult { + let mask = validity.execute_mask(len, ctx)?; + Ok(if mask.all_true() { + Self::All + } else if mask.all_false() { + Self::None + } else { + Self::Some(mask) + }) + } + + fn is_none(&self) -> bool { + matches!(self, Self::None) + } +} + +trait NativeNumeric: NativePType + Sized { + fn apply_array_array( + lhs: &[Self], + rhs: &[Self], + op: NumericOperator, + valid_rows: &ValidRows, + ) -> VortexResult>; + + fn apply_array_constant( + lhs: &[Self], + rhs: Self, + op: NumericOperator, + valid_rows: &ValidRows, + ) -> VortexResult>; + + fn apply_constant_array( + lhs: Self, + rhs: &[Self], + op: NumericOperator, + valid_rows: &ValidRows, + ) -> VortexResult>; + + fn apply_scalar(lhs: Self, rhs: Self, op: NumericOperator) -> VortexResult; +} + +trait OverflowingInteger: NativePType { + fn overflowing_add(self, rhs: Self) -> (Self, bool); + fn overflowing_sub(self, rhs: Self) -> (Self, bool); + fn overflowing_mul(self, rhs: Self) -> (Self, bool); + fn overflowing_div(self, rhs: Self) -> (Self, bool); +} + +trait IntegerOp { + fn apply(lhs: T, rhs: T) -> (T, bool); +} + +struct AddOp; +struct SubOp; +struct MulOp; +struct DivOp; + +impl IntegerOp for AddOp { + #[inline(always)] + fn apply(lhs: T, rhs: T) -> (T, bool) { + lhs.overflowing_add(rhs) + } +} + +impl IntegerOp for SubOp { + #[inline(always)] + fn apply(lhs: T, rhs: T) -> (T, bool) { + lhs.overflowing_sub(rhs) + } +} + +impl IntegerOp for MulOp { + #[inline(always)] + fn apply(lhs: T, rhs: T) -> (T, bool) { + lhs.overflowing_mul(rhs) + } +} + +impl IntegerOp for DivOp { + #[inline(always)] + fn apply(lhs: T, rhs: T) -> (T, bool) { + lhs.overflowing_div(rhs) + } +} + +trait FloatingOp { + fn apply(lhs: T, rhs: T) -> T; +} + +impl FloatingOp for AddOp +where + T: NativePType + std::ops::Add, +{ + #[inline(always)] + fn apply(lhs: T, rhs: T) -> T { + lhs + rhs + } +} + +impl FloatingOp for SubOp +where + T: NativePType + std::ops::Sub, +{ + #[inline(always)] + fn apply(lhs: T, rhs: T) -> T { + lhs - rhs + } +} + +impl FloatingOp for MulOp +where + T: NativePType + std::ops::Mul, +{ + #[inline(always)] + fn apply(lhs: T, rhs: T) -> T { + lhs * rhs + } +} + +impl FloatingOp for DivOp +where + T: NativePType + std::ops::Div, +{ + #[inline(always)] + fn apply(lhs: T, rhs: T) -> T { + lhs / rhs + } +} + +macro_rules! impl_integer_numeric { + ($($ty:ty),* $(,)?) => { + $( + impl NativeNumeric for $ty { + fn apply_array_array( + lhs: &[Self], + rhs: &[Self], + op: NumericOperator, + valid_rows: &ValidRows, + ) -> VortexResult> { + integer_array_array(lhs, rhs, op, valid_rows) + } + + fn apply_array_constant( + lhs: &[Self], + rhs: Self, + op: NumericOperator, + valid_rows: &ValidRows, + ) -> VortexResult> { + integer_array_constant(lhs, rhs, op, valid_rows) + } + + fn apply_constant_array( + lhs: Self, + rhs: &[Self], + op: NumericOperator, + valid_rows: &ValidRows, + ) -> VortexResult> { + integer_constant_array(lhs, rhs, op, valid_rows) + } + + fn apply_scalar(lhs: Self, rhs: Self, op: NumericOperator) -> VortexResult { + integer_scalar(lhs, rhs, op) + } + } + )* + }; +} + +macro_rules! impl_signed_integer_div { + ($($ty:ty),* $(,)?) => { + $( + impl OverflowingInteger for $ty { + #[inline(always)] + fn overflowing_add(self, rhs: Self) -> (Self, bool) { + let result = self.wrapping_add(rhs); + let overflow = ((self ^ result) & (rhs ^ result)) < 0; + (result, overflow) + } + + #[inline(always)] + fn overflowing_sub(self, rhs: Self) -> (Self, bool) { + let result = self.wrapping_sub(rhs); + let overflow = ((self ^ rhs) & (self ^ result)) < 0; + (result, overflow) + } + + #[inline(always)] + fn overflowing_mul(self, rhs: Self) -> (Self, bool) { + self.overflowing_mul(rhs) + } + + #[inline(always)] + fn overflowing_div(self, rhs: Self) -> (Self, bool) { + let div_by_zero = rhs == 0; + let overflow = self == <$ty>::MIN && rhs == -1; + let divisor = if div_by_zero { 1 } else { rhs }; + (self.wrapping_div(divisor), div_by_zero | overflow) + } + } + )* + }; +} + +macro_rules! impl_signed_widening_integer_div { + ($($ty:ty => $wide:ty),* $(,)?) => { + $( + impl OverflowingInteger for $ty { + #[inline(always)] + fn overflowing_add(self, rhs: Self) -> (Self, bool) { + let result = self.wrapping_add(rhs); + let overflow = ((self ^ result) & (rhs ^ result)) < 0; + (result, overflow) + } + + #[inline(always)] + fn overflowing_sub(self, rhs: Self) -> (Self, bool) { + let result = self.wrapping_sub(rhs); + let overflow = ((self ^ rhs) & (self ^ result)) < 0; + (result, overflow) + } + + #[inline(always)] + #[allow(clippy::cast_possible_truncation)] + fn overflowing_mul(self, rhs: Self) -> (Self, bool) { + let product = (self as $wide) * (rhs as $wide); + ( + product as Self, + product < <$ty>::MIN as $wide || product > <$ty>::MAX as $wide, + ) + } + + #[inline(always)] + fn overflowing_div(self, rhs: Self) -> (Self, bool) { + let div_by_zero = rhs == 0; + let overflow = self == <$ty>::MIN && rhs == -1; + let divisor = if div_by_zero { 1 } else { rhs }; + (self.wrapping_div(divisor), div_by_zero | overflow) + } + } + )* + }; +} + +macro_rules! impl_unsigned_integer_div { + ($($ty:ty),* $(,)?) => { + $( + impl OverflowingInteger for $ty { + #[inline(always)] + fn overflowing_add(self, rhs: Self) -> (Self, bool) { + self.overflowing_add(rhs) + } + + #[inline(always)] + fn overflowing_sub(self, rhs: Self) -> (Self, bool) { + self.overflowing_sub(rhs) + } + + #[inline(always)] + fn overflowing_mul(self, rhs: Self) -> (Self, bool) { + self.overflowing_mul(rhs) + } + + #[inline(always)] + fn overflowing_div(self, rhs: Self) -> (Self, bool) { + let div_by_zero = rhs == 0; + let divisor = if div_by_zero { 1 } else { rhs }; + (self.wrapping_div(divisor), div_by_zero) + } + } + )* + }; +} + +macro_rules! impl_unsigned_widening_integer_div { + ($($ty:ty => $wide:ty),* $(,)?) => { + $( + impl OverflowingInteger for $ty { + #[inline(always)] + fn overflowing_add(self, rhs: Self) -> (Self, bool) { + self.overflowing_add(rhs) + } + + #[inline(always)] + fn overflowing_sub(self, rhs: Self) -> (Self, bool) { + self.overflowing_sub(rhs) + } + + #[inline(always)] + #[allow(clippy::cast_possible_truncation)] + fn overflowing_mul(self, rhs: Self) -> (Self, bool) { + let product = (self as $wide) * (rhs as $wide); + (product as Self, product > <$ty>::MAX as $wide) + } + + #[inline(always)] + fn overflowing_div(self, rhs: Self) -> (Self, bool) { + let div_by_zero = rhs == 0; + let divisor = if div_by_zero { 1 } else { rhs }; + (self.wrapping_div(divisor), div_by_zero) + } + } + )* + }; +} + +macro_rules! impl_floating_numeric { + ($($ty:ty),* $(,)?) => { + $( + impl NativeNumeric for $ty { + fn apply_array_array( + lhs: &[Self], + rhs: &[Self], + op: NumericOperator, + _valid_rows: &ValidRows, + ) -> VortexResult> { + Ok(floating_array_array(lhs, rhs, op)) + } + + fn apply_array_constant( + lhs: &[Self], + rhs: Self, + op: NumericOperator, + _valid_rows: &ValidRows, + ) -> VortexResult> { + Ok(floating_array_constant(lhs, rhs, op)) + } + + fn apply_constant_array( + lhs: Self, + rhs: &[Self], + op: NumericOperator, + _valid_rows: &ValidRows, + ) -> VortexResult> { + Ok(floating_constant_array(lhs, rhs, op)) + } + + fn apply_scalar(lhs: Self, rhs: Self, op: NumericOperator) -> VortexResult { + Ok(floating_scalar(lhs, rhs, op)) + } + } + )* + }; +} + +impl_unsigned_widening_integer_div!(u8 => u16, u16 => u32, u32 => u64); +impl_unsigned_integer_div!(u64); +impl_signed_widening_integer_div!(i8 => i16, i16 => i32, i32 => i64); +impl_signed_integer_div!(i64); +impl_integer_numeric!(u8, u16, u32, u64, i8, i16, i32, i64); +impl_floating_numeric!(f16, f32, f64); + +fn integer_array_array( + lhs: &[T], + rhs: &[T], + op: NumericOperator, + valid_rows: &ValidRows, +) -> VortexResult> { + match op { + NumericOperator::Add => integer_array_array_op::(lhs, rhs, op, valid_rows), + NumericOperator::Sub => integer_array_array_op::(lhs, rhs, op, valid_rows), + NumericOperator::Mul => integer_array_array_op::(lhs, rhs, op, valid_rows), + NumericOperator::Div => integer_array_array_op::(lhs, rhs, op, valid_rows), + } +} + +fn integer_array_constant( + lhs: &[T], + rhs: T, + op: NumericOperator, + valid_rows: &ValidRows, +) -> VortexResult> { + match op { + NumericOperator::Add => integer_array_constant_op::(lhs, rhs, op, valid_rows), + NumericOperator::Sub => integer_array_constant_op::(lhs, rhs, op, valid_rows), + NumericOperator::Mul => integer_array_constant_op::(lhs, rhs, op, valid_rows), + NumericOperator::Div => integer_array_constant_op::(lhs, rhs, op, valid_rows), + } +} + +fn integer_constant_array( + lhs: T, + rhs: &[T], + op: NumericOperator, + valid_rows: &ValidRows, +) -> VortexResult> { + match op { + NumericOperator::Add => integer_constant_array_op::(lhs, rhs, op, valid_rows), + NumericOperator::Sub => integer_constant_array_op::(lhs, rhs, op, valid_rows), + NumericOperator::Mul => integer_constant_array_op::(lhs, rhs, op, valid_rows), + NumericOperator::Div => integer_constant_array_op::(lhs, rhs, op, valid_rows), + } +} + +fn integer_scalar(lhs: T, rhs: T, op: NumericOperator) -> VortexResult { + match op { + NumericOperator::Add => integer_scalar_op::(lhs, rhs, op), + NumericOperator::Sub => integer_scalar_op::(lhs, rhs, op), + NumericOperator::Mul => integer_scalar_op::(lhs, rhs, op), + NumericOperator::Div => integer_scalar_op::(lhs, rhs, op), + } +} + +fn integer_array_array_op( + lhs: &[T], + rhs: &[T], + op: NumericOperator, + valid_rows: &ValidRows, +) -> VortexResult> +where + T: OverflowingInteger, + Op: IntegerOp, +{ + debug_assert_eq!(lhs.len(), rhs.len()); + + match valid_rows { + ValidRows::All => integer_array_array_all_valid::(lhs, rhs, op), + ValidRows::Some(mask) => integer_array_array_masked::(lhs, rhs, op, mask), + ValidRows::None => Ok(Buffer::::zeroed(lhs.len())), + } +} + +fn integer_array_constant_op( + lhs: &[T], + rhs: T, + op: NumericOperator, + valid_rows: &ValidRows, +) -> VortexResult> +where + T: OverflowingInteger, + Op: IntegerOp, +{ + match valid_rows { + ValidRows::All => integer_array_constant_all_valid::(lhs, rhs, op), + ValidRows::Some(mask) => integer_array_constant_masked::(lhs, rhs, op, mask), + ValidRows::None => Ok(Buffer::::zeroed(lhs.len())), + } +} + +fn integer_constant_array_op( + lhs: T, + rhs: &[T], + op: NumericOperator, + valid_rows: &ValidRows, +) -> VortexResult> +where + T: OverflowingInteger, + Op: IntegerOp, +{ + match valid_rows { + ValidRows::All => integer_constant_array_all_valid::(lhs, rhs, op), + ValidRows::Some(mask) => integer_constant_array_masked::(lhs, rhs, op, mask), + ValidRows::None => Ok(Buffer::::zeroed(rhs.len())), + } +} + +fn integer_scalar_op(lhs: T, rhs: T, op: NumericOperator) -> VortexResult +where + T: OverflowingInteger, + Op: IntegerOp, +{ + let (value, failed) = Op::apply(lhs, rhs); + check_numeric_error(op, failed)?; + Ok(value) +} + +fn integer_array_array_all_valid( + lhs: &[T], + rhs: &[T], + op: NumericOperator, +) -> VortexResult> +where + T: OverflowingInteger, + Op: IntegerOp, +{ + let mut failed = false; + let mut values = BufferMut::::zeroed(lhs.len()); + for ((dst, &lhs), &rhs) in values.iter_mut().zip(lhs).zip(rhs) { + let (value, error) = Op::apply(lhs, rhs); + *dst = value; + failed |= error; + } + check_numeric_error(op, failed)?; + Ok(values.freeze()) +} + +fn integer_array_array_masked( + lhs: &[T], + rhs: &[T], + op: NumericOperator, + valid_rows: &Mask, +) -> VortexResult> +where + T: OverflowingInteger, + Op: IntegerOp, +{ + let mut failed = false; + let mut values = BufferMut::::zeroed(lhs.len()); + for (((dst, &lhs), &rhs), valid) in values.iter_mut().zip(lhs).zip(rhs).zip(valid_rows.iter()) { + let (value, error) = Op::apply(lhs, rhs); + *dst = value; + failed |= error & valid; + } + check_numeric_error(op, failed)?; + Ok(values.freeze()) +} + +fn integer_array_constant_all_valid( + lhs: &[T], + rhs: T, + op: NumericOperator, +) -> VortexResult> +where + T: OverflowingInteger, + Op: IntegerOp, +{ + let mut failed = false; + let mut values = BufferMut::::zeroed(lhs.len()); + for (dst, &lhs) in values.iter_mut().zip(lhs) { + let (value, error) = Op::apply(lhs, rhs); + *dst = value; + failed |= error; + } + check_numeric_error(op, failed)?; + Ok(values.freeze()) +} + +fn integer_array_constant_masked( + lhs: &[T], + rhs: T, + op: NumericOperator, + valid_rows: &Mask, +) -> VortexResult> +where + T: OverflowingInteger, + Op: IntegerOp, +{ + let mut failed = false; + let mut values = BufferMut::::zeroed(lhs.len()); + for ((dst, &lhs), valid) in values.iter_mut().zip(lhs).zip(valid_rows.iter()) { + let (value, error) = Op::apply(lhs, rhs); + *dst = value; + failed |= error & valid; + } + check_numeric_error(op, failed)?; + Ok(values.freeze()) +} + +fn integer_constant_array_all_valid( + lhs: T, + rhs: &[T], + op: NumericOperator, +) -> VortexResult> +where + T: OverflowingInteger, + Op: IntegerOp, +{ + let mut failed = false; + let mut values = BufferMut::::zeroed(rhs.len()); + for (dst, &rhs) in values.iter_mut().zip(rhs) { + let (value, error) = Op::apply(lhs, rhs); + *dst = value; + failed |= error; + } + check_numeric_error(op, failed)?; + Ok(values.freeze()) +} + +fn integer_constant_array_masked( + lhs: T, + rhs: &[T], + op: NumericOperator, + valid_rows: &Mask, +) -> VortexResult> +where + T: OverflowingInteger, + Op: IntegerOp, +{ + let mut failed = false; + let mut values = BufferMut::::zeroed(rhs.len()); + for ((dst, &rhs), valid) in values.iter_mut().zip(rhs).zip(valid_rows.iter()) { + let (value, error) = Op::apply(lhs, rhs); + *dst = value; + failed |= error & valid; + } + check_numeric_error(op, failed)?; + Ok(values.freeze()) +} + +fn floating_array_array(lhs: &[T], rhs: &[T], op: NumericOperator) -> Buffer +where + T: NativePType + + std::ops::Add + + std::ops::Sub + + std::ops::Mul + + std::ops::Div, +{ + match op { + NumericOperator::Add => floating_array_array_op::(lhs, rhs), + NumericOperator::Sub => floating_array_array_op::(lhs, rhs), + NumericOperator::Mul => floating_array_array_op::(lhs, rhs), + NumericOperator::Div => floating_array_array_op::(lhs, rhs), + } +} + +fn floating_array_constant(lhs: &[T], rhs: T, op: NumericOperator) -> Buffer +where + T: NativePType + + std::ops::Add + + std::ops::Sub + + std::ops::Mul + + std::ops::Div, +{ + match op { + NumericOperator::Add => floating_array_constant_op::(lhs, rhs), + NumericOperator::Sub => floating_array_constant_op::(lhs, rhs), + NumericOperator::Mul => floating_array_constant_op::(lhs, rhs), + NumericOperator::Div => floating_array_constant_op::(lhs, rhs), + } +} + +fn floating_constant_array(lhs: T, rhs: &[T], op: NumericOperator) -> Buffer +where + T: NativePType + + std::ops::Add + + std::ops::Sub + + std::ops::Mul + + std::ops::Div, +{ + match op { + NumericOperator::Add => floating_constant_array_op::(lhs, rhs), + NumericOperator::Sub => floating_constant_array_op::(lhs, rhs), + NumericOperator::Mul => floating_constant_array_op::(lhs, rhs), + NumericOperator::Div => floating_constant_array_op::(lhs, rhs), + } +} + +fn floating_scalar(lhs: T, rhs: T, op: NumericOperator) -> T +where + T: NativePType + + std::ops::Add + + std::ops::Sub + + std::ops::Mul + + std::ops::Div, +{ + match op { + NumericOperator::Add => >::apply(lhs, rhs), + NumericOperator::Sub => >::apply(lhs, rhs), + NumericOperator::Mul => >::apply(lhs, rhs), + NumericOperator::Div => >::apply(lhs, rhs), + } +} + +fn floating_array_array_op(lhs: &[T], rhs: &[T]) -> Buffer +where + T: NativePType, + Op: FloatingOp, +{ + debug_assert_eq!(lhs.len(), rhs.len()); + + let mut values = BufferMut::::zeroed(lhs.len()); + for ((dst, &lhs), &rhs) in values.iter_mut().zip(lhs).zip(rhs) { + *dst = Op::apply(lhs, rhs); + } + values.freeze() +} + +fn floating_array_constant_op(lhs: &[T], rhs: T) -> Buffer +where + T: NativePType, + Op: FloatingOp, +{ + let mut values = BufferMut::::zeroed(lhs.len()); + for (dst, &lhs) in values.iter_mut().zip(lhs) { + *dst = Op::apply(lhs, rhs); + } + values.freeze() +} + +fn floating_constant_array_op(lhs: T, rhs: &[T]) -> Buffer +where + T: NativePType, + Op: FloatingOp, +{ + let mut values = BufferMut::::zeroed(rhs.len()); + for (dst, &rhs) in values.iter_mut().zip(rhs) { + *dst = Op::apply(lhs, rhs); + } + values.freeze() +} + +fn check_numeric_error(op: NumericOperator, failed: bool) -> VortexResult<()> { + if failed { + return Err(numeric_error(op)); + } + Ok(()) +} + +fn numeric_error(op: NumericOperator) -> vortex_error::VortexError { + match op { + NumericOperator::Add | NumericOperator::Sub | NumericOperator::Mul => { + vortex_err!(InvalidArgument: "integer overflow in numeric {} operation", op) + } + NumericOperator::Div => { + vortex_err!(InvalidArgument: "integer division by zero or overflow in numeric / operation") + } + } +} + #[cfg(test)] mod test { use vortex_buffer::buffer; @@ -88,6 +926,7 @@ mod test { use crate::scalar::Scalar; use crate::scalar_fn::fns::binary::numeric::ConstantArray; use crate::scalar_fn::fns::operators::Operator; + use crate::validity::Validity; fn sub_scalar(array: &ArrayRef, scalar: impl Into) -> VortexResult { array @@ -138,4 +977,65 @@ mod test { let _results = sub_scalar(&values, 1.0f32).unwrap(); let _results = sub_scalar(&values, f32::MAX).unwrap(); } + + #[test] + fn test_integer_overflow_errors() { + let values = buffer![u8::MAX].into_array(); + let result = values + .binary( + ConstantArray::new(1u8, values.len()).into_array(), + Operator::Add, + ) + .and_then(|a| a.execute::(&mut LEGACY_SESSION.create_execution_ctx())); + + assert!(result.is_err()); + } + + #[test] + fn test_integer_divide_by_zero_errors() { + let values = buffer![1i32].into_array(); + let result = values + .binary( + ConstantArray::new(0i32, values.len()).into_array(), + Operator::Div, + ) + .and_then(|a| a.execute::(&mut LEGACY_SESSION.create_execution_ctx())); + + assert!(result.is_err()); + } + + #[test] + fn test_integer_errors_ignore_null_lanes() { + let values = PrimitiveArray::new(buffer![u8::MAX, 1], Validity::from_iter([false, true])) + .into_array(); + let result = values + .binary( + ConstantArray::new(1u8, values.len()).into_array(), + Operator::Add, + ) + .and_then(|a| { + a.execute::(&mut LEGACY_SESSION.create_execution_ctx()) + }) + .map(|a| a.0.into_array()) + .unwrap(); + + assert_arrays_eq!(result, PrimitiveArray::from_option_iter([None, Some(2u8)])); + } + + #[test] + fn test_present_nullable_constant_preserves_nullable_output() { + let values = buffer![1u8, 2].into_array(); + let result = values + .binary( + ConstantArray::new(Some(1u8), values.len()).into_array(), + Operator::Add, + ) + .and_then(|a| a.execute::(&mut LEGACY_SESSION.create_execution_ctx())) + .unwrap(); + + assert_arrays_eq!( + result, + PrimitiveArray::from_option_iter([Some(2u8), Some(3)]) + ); + } } From ffb50d5dfc5f3e177bcabe3877b2bdfb89239509 Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Fri, 12 Jun 2026 18:40:20 -0400 Subject: [PATCH 3/3] Refactor binary arithmetic lifting Signed-off-by: Nicholas Gates --- .../src/scalar_fn/fns/binary/numeric.rs | 1124 ++++++++--------- 1 file changed, 511 insertions(+), 613 deletions(-) diff --git a/vortex-array/src/scalar_fn/fns/binary/numeric.rs b/vortex-array/src/scalar_fn/fns/binary/numeric.rs index bb622a03abc..cbad0b8cc85 100644 --- a/vortex-array/src/scalar_fn/fns/binary/numeric.rs +++ b/vortex-array/src/scalar_fn/fns/binary/numeric.rs @@ -3,6 +3,7 @@ use vortex_buffer::Buffer; use vortex_buffer::BufferMut; +use vortex_error::VortexError; use vortex_error::VortexResult; use vortex_error::vortex_bail; use vortex_error::vortex_err; @@ -14,63 +15,161 @@ use crate::IntoArray; use crate::arrays::Constant; use crate::arrays::ConstantArray; use crate::arrays::PrimitiveArray; +use crate::builtins::ArrayBuiltins; +use crate::dtype::DType; use crate::dtype::NativePType; use crate::dtype::PType; use crate::dtype::half::f16; use crate::match_each_native_ptype; use crate::scalar::NumericOperator; +use crate::scalar::Scalar; use crate::validity::Validity; +struct CheckedAdd; + +struct CheckedSub; + +struct CheckedMul; + +struct CheckedDiv; + +trait CheckedPrimitiveOp { + const ERROR: &'static str; +} + +trait CheckedPrimitiveBinary: CheckedPrimitiveOp { + fn checked(lhs: T, rhs: T) -> Option; +} + +trait CheckedPrimitiveBinaryAll: + CheckedPrimitiveBinary + + CheckedPrimitiveBinary + + CheckedPrimitiveBinary + + CheckedPrimitiveBinary + + CheckedPrimitiveBinary + + CheckedPrimitiveBinary + + CheckedPrimitiveBinary + + CheckedPrimitiveBinary + + CheckedPrimitiveBinary + + CheckedPrimitiveBinary + + CheckedPrimitiveBinary +{ +} + +impl CheckedPrimitiveBinaryAll for Op where + Op: CheckedPrimitiveBinary + + CheckedPrimitiveBinary + + CheckedPrimitiveBinary + + CheckedPrimitiveBinary + + CheckedPrimitiveBinary + + CheckedPrimitiveBinary + + CheckedPrimitiveBinary + + CheckedPrimitiveBinary + + CheckedPrimitiveBinary + + CheckedPrimitiveBinary + + CheckedPrimitiveBinary +{ +} + +impl CheckedPrimitiveOp for CheckedAdd { + const ERROR: &'static str = "integer overflow in checked add"; +} + +impl CheckedPrimitiveOp for CheckedSub { + const ERROR: &'static str = "integer overflow in checked sub"; +} + +impl CheckedPrimitiveOp for CheckedMul { + const ERROR: &'static str = "integer overflow in checked mul"; +} + +impl CheckedPrimitiveOp for CheckedDiv { + const ERROR: &'static str = "integer division by zero or overflow in checked div"; +} + +impl CheckedPrimitiveBinary for CheckedAdd { + #[inline(always)] + fn checked(lhs: T, rhs: T) -> Option { + lhs.checked_add(rhs) + } +} + +impl CheckedPrimitiveBinary for CheckedSub { + #[inline(always)] + fn checked(lhs: T, rhs: T) -> Option { + lhs.checked_sub(rhs) + } +} + +impl CheckedPrimitiveBinary for CheckedMul { + #[inline(always)] + fn checked(lhs: T, rhs: T) -> Option { + lhs.checked_mul(rhs) + } +} + +impl CheckedPrimitiveBinary for CheckedDiv { + #[inline(always)] + fn checked(lhs: T, rhs: T) -> Option { + lhs.checked_div(rhs) + } +} + /// Execute a numeric operation between two arrays. -/// -/// This is the entry point for numeric operations from the binary expression. The implementation -/// keeps constants scalar, canonicalizes non-constant inputs to primitive buffers, and accumulates -/// integer arithmetic failures before returning a single operation-level error. pub(crate) fn execute_numeric( lhs: &ArrayRef, rhs: &ArrayRef, op: NumericOperator, ctx: &mut ExecutionCtx, ) -> VortexResult { - if let Some(result) = constant_numeric(lhs, rhs, op)? { - return Ok(result); + match op { + NumericOperator::Add => execute_checked_numeric::(lhs, rhs, ctx), + NumericOperator::Sub => execute_checked_numeric::(lhs, rhs, ctx), + NumericOperator::Mul => execute_checked_numeric::(lhs, rhs, ctx), + NumericOperator::Div => execute_checked_numeric::(lhs, rhs, ctx), } - - native_numeric(lhs, rhs, op, ctx) } -fn native_numeric( +fn execute_checked_numeric( lhs: &ArrayRef, rhs: &ArrayRef, - op: NumericOperator, ctx: &mut ExecutionCtx, -) -> VortexResult { +) -> VortexResult +where + Op: CheckedPrimitiveBinaryAll, +{ let ptype = PType::try_from(lhs.dtype())?; if !lhs.dtype().eq_ignore_nullability(rhs.dtype()) { vortex_bail!( - "numeric operator {} requires matching primitive types, got {} and {}", - op, + "numeric operator requires matching primitive types, got {} and {}", lhs.dtype(), rhs.dtype() ); } - match_each_native_ptype!(ptype, |T| { execute_numeric_typed::(lhs, rhs, op, ctx) }) + match_each_native_ptype!(ptype, |T| { execute_checked_typed::(lhs, rhs, ctx) }) } -fn execute_numeric_typed( +fn execute_checked_typed( lhs: &ArrayRef, rhs: &ArrayRef, - op: NumericOperator, ctx: &mut ExecutionCtx, -) -> VortexResult { - let lhs = NumericOperand::::try_new(lhs, ctx)?; - let rhs = NumericOperand::::try_new(rhs, ctx)?; +) -> VortexResult +where + T: NativePType, + Op: CheckedPrimitiveBinary, + Scalar: From, + Scalar: From>, +{ + let result_dtype = lhs + .dtype() + .with_nullability(lhs.dtype().nullability() | rhs.dtype().nullability()); + let lhs = PrimitiveOperand::::try_new(lhs, ctx)?; + let rhs = PrimitiveOperand::::try_new(rhs, ctx)?; let len = lhs.len(); if len != rhs.len() { vortex_bail!( - "numeric operator {} requires equal lengths, got {} and {}", - op, + "numeric operator requires equal lengths, got {} and {}", len, rhs.len() ); @@ -79,48 +178,57 @@ fn execute_numeric_typed( let validity = lhs.validity().and(rhs.validity())?; let valid_rows = ValidRows::from_validity(&validity, len, ctx)?; if valid_rows.is_none() { - return Ok(PrimitiveArray::new(Buffer::::zeroed(len), validity).into_array()); + return primitive_result_array::(Buffer::::zeroed(len), validity, &result_dtype); } let values = match (&lhs, &rhs) { - (NumericOperand::Array(lhs), NumericOperand::Array(rhs)) => { - T::apply_array_array(lhs.values(), rhs.values(), op, &valid_rows)? + (PrimitiveOperand::Array(lhs), PrimitiveOperand::Array(rhs)) => { + checked_array_array::(lhs.values(), rhs.values(), &valid_rows)? } - (NumericOperand::Array(lhs), NumericOperand::Constant { value: rhs, .. }) => { - T::apply_array_constant(lhs.values(), *rhs, op, &valid_rows)? + (PrimitiveOperand::Array(lhs), PrimitiveOperand::Constant { value: rhs, .. }) => { + checked_array_constant::(lhs.values(), *rhs, &valid_rows)? } - (NumericOperand::Constant { value: lhs, .. }, NumericOperand::Array(rhs)) => { - T::apply_constant_array(*lhs, rhs.values(), op, &valid_rows)? + (PrimitiveOperand::Constant { value: lhs, .. }, PrimitiveOperand::Array(rhs)) => { + checked_constant_array::(*lhs, rhs.values(), &valid_rows)? } ( - NumericOperand::Constant { value: lhs, .. }, - NumericOperand::Constant { value: rhs, .. }, - ) => BufferMut::full(T::apply_scalar(*lhs, *rhs, op)?, len).freeze(), - (NumericOperand::Null(_), _) | (_, NumericOperand::Null(_)) => Buffer::::zeroed(len), + PrimitiveOperand::Constant { value: lhs, .. }, + PrimitiveOperand::Constant { value: rhs, .. }, + ) => { + let value = Op::checked(*lhs, *rhs).ok_or_else(|| numeric_error::())?; + return Ok(constant_result_array(value, len, &result_dtype)); + } + (PrimitiveOperand::Null(_), _) | (_, PrimitiveOperand::Null(_)) => Buffer::::zeroed(len), }; - Ok(PrimitiveArray::new(values, validity).into_array()) + primitive_result_array::(values, validity, &result_dtype) } -fn constant_numeric( - lhs: &ArrayRef, - rhs: &ArrayRef, - op: NumericOperator, -) -> VortexResult> { - let (Some(lhs), Some(rhs)) = (lhs.as_opt::(), rhs.as_opt::()) else { - return Ok(None); - }; - - let result = lhs - .scalar() - .as_primitive() - .checked_binary_numeric(&rhs.scalar().as_primitive(), op) - .ok_or_else(|| numeric_error(op))?; +fn primitive_result_array( + values: Buffer, + validity: Validity, + dtype: &DType, +) -> VortexResult { + let array = PrimitiveArray::new(values, validity).into_array(); + if array.dtype() == dtype { + return Ok(array); + } + array.cast(dtype.clone()) +} - Ok(Some(ConstantArray::new(result, lhs.len()).into_array())) +fn constant_result_array(value: T, len: usize, dtype: &DType) -> ArrayRef +where + T: NativePType, + Scalar: From + From>, +{ + if dtype.is_nullable() { + ConstantArray::new(Some(value), len).into_array() + } else { + ConstantArray::new(value, len).into_array() + } } -enum NumericOperand { +enum PrimitiveOperand { Array(TypedPrimitive), Constant { value: T, @@ -130,7 +238,7 @@ enum NumericOperand { Null(usize), } -impl NumericOperand { +impl PrimitiveOperand { fn try_new(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { if let Some(constant) = array.as_opt::() { return Ok( @@ -218,696 +326,486 @@ impl ValidRows { } } -trait NativeNumeric: NativePType + Sized { - fn apply_array_array( - lhs: &[Self], - rhs: &[Self], - op: NumericOperator, - valid_rows: &ValidRows, - ) -> VortexResult>; - - fn apply_array_constant( - lhs: &[Self], - rhs: Self, - op: NumericOperator, - valid_rows: &ValidRows, - ) -> VortexResult>; - - fn apply_constant_array( - lhs: Self, - rhs: &[Self], - op: NumericOperator, - valid_rows: &ValidRows, - ) -> VortexResult>; - - fn apply_scalar(lhs: Self, rhs: Self, op: NumericOperator) -> VortexResult; -} - -trait OverflowingInteger: NativePType { - fn overflowing_add(self, rhs: Self) -> (Self, bool); - fn overflowing_sub(self, rhs: Self) -> (Self, bool); - fn overflowing_mul(self, rhs: Self) -> (Self, bool); - fn overflowing_div(self, rhs: Self) -> (Self, bool); -} - -trait IntegerOp { - fn apply(lhs: T, rhs: T) -> (T, bool); -} - -struct AddOp; -struct SubOp; -struct MulOp; -struct DivOp; - -impl IntegerOp for AddOp { - #[inline(always)] - fn apply(lhs: T, rhs: T) -> (T, bool) { - lhs.overflowing_add(rhs) - } -} - -impl IntegerOp for SubOp { - #[inline(always)] - fn apply(lhs: T, rhs: T) -> (T, bool) { - lhs.overflowing_sub(rhs) - } -} - -impl IntegerOp for MulOp { - #[inline(always)] - fn apply(lhs: T, rhs: T) -> (T, bool) { - lhs.overflowing_mul(rhs) - } -} - -impl IntegerOp for DivOp { - #[inline(always)] - fn apply(lhs: T, rhs: T) -> (T, bool) { - lhs.overflowing_div(rhs) - } -} - -trait FloatingOp { - fn apply(lhs: T, rhs: T) -> T; -} - -impl FloatingOp for AddOp -where - T: NativePType + std::ops::Add, -{ - #[inline(always)] - fn apply(lhs: T, rhs: T) -> T { - lhs + rhs - } -} - -impl FloatingOp for SubOp -where - T: NativePType + std::ops::Sub, -{ - #[inline(always)] - fn apply(lhs: T, rhs: T) -> T { - lhs - rhs - } -} - -impl FloatingOp for MulOp -where - T: NativePType + std::ops::Mul, -{ - #[inline(always)] - fn apply(lhs: T, rhs: T) -> T { - lhs * rhs - } -} - -impl FloatingOp for DivOp -where - T: NativePType + std::ops::Div, -{ - #[inline(always)] - fn apply(lhs: T, rhs: T) -> T { - lhs / rhs - } -} - -macro_rules! impl_integer_numeric { - ($($ty:ty),* $(,)?) => { - $( - impl NativeNumeric for $ty { - fn apply_array_array( - lhs: &[Self], - rhs: &[Self], - op: NumericOperator, - valid_rows: &ValidRows, - ) -> VortexResult> { - integer_array_array(lhs, rhs, op, valid_rows) - } - - fn apply_array_constant( - lhs: &[Self], - rhs: Self, - op: NumericOperator, - valid_rows: &ValidRows, - ) -> VortexResult> { - integer_array_constant(lhs, rhs, op, valid_rows) - } - - fn apply_constant_array( - lhs: Self, - rhs: &[Self], - op: NumericOperator, - valid_rows: &ValidRows, - ) -> VortexResult> { - integer_constant_array(lhs, rhs, op, valid_rows) - } - - fn apply_scalar(lhs: Self, rhs: Self, op: NumericOperator) -> VortexResult { - integer_scalar(lhs, rhs, op) - } - } - )* - }; -} - -macro_rules! impl_signed_integer_div { - ($($ty:ty),* $(,)?) => { - $( - impl OverflowingInteger for $ty { - #[inline(always)] - fn overflowing_add(self, rhs: Self) -> (Self, bool) { - let result = self.wrapping_add(rhs); - let overflow = ((self ^ result) & (rhs ^ result)) < 0; - (result, overflow) - } - - #[inline(always)] - fn overflowing_sub(self, rhs: Self) -> (Self, bool) { - let result = self.wrapping_sub(rhs); - let overflow = ((self ^ rhs) & (self ^ result)) < 0; - (result, overflow) - } - - #[inline(always)] - fn overflowing_mul(self, rhs: Self) -> (Self, bool) { - self.overflowing_mul(rhs) - } - - #[inline(always)] - fn overflowing_div(self, rhs: Self) -> (Self, bool) { - let div_by_zero = rhs == 0; - let overflow = self == <$ty>::MIN && rhs == -1; - let divisor = if div_by_zero { 1 } else { rhs }; - (self.wrapping_div(divisor), div_by_zero | overflow) - } - } - )* - }; -} - -macro_rules! impl_signed_widening_integer_div { - ($($ty:ty => $wide:ty),* $(,)?) => { - $( - impl OverflowingInteger for $ty { - #[inline(always)] - fn overflowing_add(self, rhs: Self) -> (Self, bool) { - let result = self.wrapping_add(rhs); - let overflow = ((self ^ result) & (rhs ^ result)) < 0; - (result, overflow) - } - - #[inline(always)] - fn overflowing_sub(self, rhs: Self) -> (Self, bool) { - let result = self.wrapping_sub(rhs); - let overflow = ((self ^ rhs) & (self ^ result)) < 0; - (result, overflow) - } - - #[inline(always)] - #[allow(clippy::cast_possible_truncation)] - fn overflowing_mul(self, rhs: Self) -> (Self, bool) { - let product = (self as $wide) * (rhs as $wide); - ( - product as Self, - product < <$ty>::MIN as $wide || product > <$ty>::MAX as $wide, - ) - } - - #[inline(always)] - fn overflowing_div(self, rhs: Self) -> (Self, bool) { - let div_by_zero = rhs == 0; - let overflow = self == <$ty>::MIN && rhs == -1; - let divisor = if div_by_zero { 1 } else { rhs }; - (self.wrapping_div(divisor), div_by_zero | overflow) - } - } - )* - }; -} - -macro_rules! impl_unsigned_integer_div { - ($($ty:ty),* $(,)?) => { - $( - impl OverflowingInteger for $ty { - #[inline(always)] - fn overflowing_add(self, rhs: Self) -> (Self, bool) { - self.overflowing_add(rhs) - } - - #[inline(always)] - fn overflowing_sub(self, rhs: Self) -> (Self, bool) { - self.overflowing_sub(rhs) - } - - #[inline(always)] - fn overflowing_mul(self, rhs: Self) -> (Self, bool) { - self.overflowing_mul(rhs) - } - - #[inline(always)] - fn overflowing_div(self, rhs: Self) -> (Self, bool) { - let div_by_zero = rhs == 0; - let divisor = if div_by_zero { 1 } else { rhs }; - (self.wrapping_div(divisor), div_by_zero) - } - } - )* - }; -} - -macro_rules! impl_unsigned_widening_integer_div { - ($($ty:ty => $wide:ty),* $(,)?) => { - $( - impl OverflowingInteger for $ty { - #[inline(always)] - fn overflowing_add(self, rhs: Self) -> (Self, bool) { - self.overflowing_add(rhs) - } - - #[inline(always)] - fn overflowing_sub(self, rhs: Self) -> (Self, bool) { - self.overflowing_sub(rhs) - } - - #[inline(always)] - #[allow(clippy::cast_possible_truncation)] - fn overflowing_mul(self, rhs: Self) -> (Self, bool) { - let product = (self as $wide) * (rhs as $wide); - (product as Self, product > <$ty>::MAX as $wide) - } - - #[inline(always)] - fn overflowing_div(self, rhs: Self) -> (Self, bool) { - let div_by_zero = rhs == 0; - let divisor = if div_by_zero { 1 } else { rhs }; - (self.wrapping_div(divisor), div_by_zero) - } - } - )* - }; -} - -macro_rules! impl_floating_numeric { - ($($ty:ty),* $(,)?) => { - $( - impl NativeNumeric for $ty { - fn apply_array_array( - lhs: &[Self], - rhs: &[Self], - op: NumericOperator, - _valid_rows: &ValidRows, - ) -> VortexResult> { - Ok(floating_array_array(lhs, rhs, op)) - } - - fn apply_array_constant( - lhs: &[Self], - rhs: Self, - op: NumericOperator, - _valid_rows: &ValidRows, - ) -> VortexResult> { - Ok(floating_array_constant(lhs, rhs, op)) - } - - fn apply_constant_array( - lhs: Self, - rhs: &[Self], - op: NumericOperator, - _valid_rows: &ValidRows, - ) -> VortexResult> { - Ok(floating_constant_array(lhs, rhs, op)) - } - - fn apply_scalar(lhs: Self, rhs: Self, op: NumericOperator) -> VortexResult { - Ok(floating_scalar(lhs, rhs, op)) - } - } - )* - }; -} - -impl_unsigned_widening_integer_div!(u8 => u16, u16 => u32, u32 => u64); -impl_unsigned_integer_div!(u64); -impl_signed_widening_integer_div!(i8 => i16, i16 => i32, i32 => i64); -impl_signed_integer_div!(i64); -impl_integer_numeric!(u8, u16, u32, u64, i8, i16, i32, i64); -impl_floating_numeric!(f16, f32, f64); - -fn integer_array_array( +fn checked_array_array( lhs: &[T], rhs: &[T], - op: NumericOperator, - valid_rows: &ValidRows, -) -> VortexResult> { - match op { - NumericOperator::Add => integer_array_array_op::(lhs, rhs, op, valid_rows), - NumericOperator::Sub => integer_array_array_op::(lhs, rhs, op, valid_rows), - NumericOperator::Mul => integer_array_array_op::(lhs, rhs, op, valid_rows), - NumericOperator::Div => integer_array_array_op::(lhs, rhs, op, valid_rows), - } -} - -fn integer_array_constant( - lhs: &[T], - rhs: T, - op: NumericOperator, - valid_rows: &ValidRows, -) -> VortexResult> { - match op { - NumericOperator::Add => integer_array_constant_op::(lhs, rhs, op, valid_rows), - NumericOperator::Sub => integer_array_constant_op::(lhs, rhs, op, valid_rows), - NumericOperator::Mul => integer_array_constant_op::(lhs, rhs, op, valid_rows), - NumericOperator::Div => integer_array_constant_op::(lhs, rhs, op, valid_rows), - } -} - -fn integer_constant_array( - lhs: T, - rhs: &[T], - op: NumericOperator, - valid_rows: &ValidRows, -) -> VortexResult> { - match op { - NumericOperator::Add => integer_constant_array_op::(lhs, rhs, op, valid_rows), - NumericOperator::Sub => integer_constant_array_op::(lhs, rhs, op, valid_rows), - NumericOperator::Mul => integer_constant_array_op::(lhs, rhs, op, valid_rows), - NumericOperator::Div => integer_constant_array_op::(lhs, rhs, op, valid_rows), - } -} - -fn integer_scalar(lhs: T, rhs: T, op: NumericOperator) -> VortexResult { - match op { - NumericOperator::Add => integer_scalar_op::(lhs, rhs, op), - NumericOperator::Sub => integer_scalar_op::(lhs, rhs, op), - NumericOperator::Mul => integer_scalar_op::(lhs, rhs, op), - NumericOperator::Div => integer_scalar_op::(lhs, rhs, op), - } -} - -fn integer_array_array_op( - lhs: &[T], - rhs: &[T], - op: NumericOperator, valid_rows: &ValidRows, ) -> VortexResult> where - T: OverflowingInteger, - Op: IntegerOp, + T: NativePType, + Op: CheckedPrimitiveBinary, { debug_assert_eq!(lhs.len(), rhs.len()); match valid_rows { - ValidRows::All => integer_array_array_all_valid::(lhs, rhs, op), - ValidRows::Some(mask) => integer_array_array_masked::(lhs, rhs, op, mask), + ValidRows::All => checked_array_array_all_valid::(lhs, rhs), + ValidRows::Some(mask) => checked_array_array_masked::(lhs, rhs, mask), ValidRows::None => Ok(Buffer::::zeroed(lhs.len())), } } -fn integer_array_constant_op( +fn checked_array_constant( lhs: &[T], rhs: T, - op: NumericOperator, valid_rows: &ValidRows, ) -> VortexResult> where - T: OverflowingInteger, - Op: IntegerOp, + T: NativePType, + Op: CheckedPrimitiveBinary, { match valid_rows { - ValidRows::All => integer_array_constant_all_valid::(lhs, rhs, op), - ValidRows::Some(mask) => integer_array_constant_masked::(lhs, rhs, op, mask), + ValidRows::All => checked_array_constant_all_valid::(lhs, rhs), + ValidRows::Some(mask) => checked_array_constant_masked::(lhs, rhs, mask), ValidRows::None => Ok(Buffer::::zeroed(lhs.len())), } } -fn integer_constant_array_op( +fn checked_constant_array( lhs: T, rhs: &[T], - op: NumericOperator, valid_rows: &ValidRows, ) -> VortexResult> where - T: OverflowingInteger, - Op: IntegerOp, + T: NativePType, + Op: CheckedPrimitiveBinary, { match valid_rows { - ValidRows::All => integer_constant_array_all_valid::(lhs, rhs, op), - ValidRows::Some(mask) => integer_constant_array_masked::(lhs, rhs, op, mask), + ValidRows::All => checked_constant_array_all_valid::(lhs, rhs), + ValidRows::Some(mask) => checked_constant_array_masked::(lhs, rhs, mask), ValidRows::None => Ok(Buffer::::zeroed(rhs.len())), } } -fn integer_scalar_op(lhs: T, rhs: T, op: NumericOperator) -> VortexResult -where - T: OverflowingInteger, - Op: IntegerOp, -{ - let (value, failed) = Op::apply(lhs, rhs); - check_numeric_error(op, failed)?; - Ok(value) -} - -fn integer_array_array_all_valid( - lhs: &[T], - rhs: &[T], - op: NumericOperator, -) -> VortexResult> +fn checked_array_array_all_valid(lhs: &[T], rhs: &[T]) -> VortexResult> where - T: OverflowingInteger, - Op: IntegerOp, + T: NativePType, + Op: CheckedPrimitiveBinary, { let mut failed = false; let mut values = BufferMut::::zeroed(lhs.len()); for ((dst, &lhs), &rhs) in values.iter_mut().zip(lhs).zip(rhs) { - let (value, error) = Op::apply(lhs, rhs); - *dst = value; - failed |= error; + let checked = Op::checked(lhs, rhs); + let invalid = checked.is_none(); + *dst = checked.unwrap_or_default(); + failed |= invalid; } - check_numeric_error(op, failed)?; + check_numeric_error::(failed)?; Ok(values.freeze()) } -fn integer_array_array_masked( +fn checked_array_array_masked( lhs: &[T], rhs: &[T], - op: NumericOperator, valid_rows: &Mask, ) -> VortexResult> where - T: OverflowingInteger, - Op: IntegerOp, + T: NativePType, + Op: CheckedPrimitiveBinary, { let mut failed = false; let mut values = BufferMut::::zeroed(lhs.len()); for (((dst, &lhs), &rhs), valid) in values.iter_mut().zip(lhs).zip(rhs).zip(valid_rows.iter()) { - let (value, error) = Op::apply(lhs, rhs); - *dst = value; - failed |= error & valid; + let checked = Op::checked(lhs, rhs); + let invalid = checked.is_none(); + *dst = checked.unwrap_or_default(); + failed |= invalid & valid; } - check_numeric_error(op, failed)?; + check_numeric_error::(failed)?; Ok(values.freeze()) } -fn integer_array_constant_all_valid( - lhs: &[T], - rhs: T, - op: NumericOperator, -) -> VortexResult> +fn checked_array_constant_all_valid(lhs: &[T], rhs: T) -> VortexResult> where - T: OverflowingInteger, - Op: IntegerOp, + T: NativePType, + Op: CheckedPrimitiveBinary, { let mut failed = false; let mut values = BufferMut::::zeroed(lhs.len()); for (dst, &lhs) in values.iter_mut().zip(lhs) { - let (value, error) = Op::apply(lhs, rhs); - *dst = value; - failed |= error; + let checked = Op::checked(lhs, rhs); + let invalid = checked.is_none(); + *dst = checked.unwrap_or_default(); + failed |= invalid; } - check_numeric_error(op, failed)?; + check_numeric_error::(failed)?; Ok(values.freeze()) } -fn integer_array_constant_masked( +fn checked_array_constant_masked( lhs: &[T], rhs: T, - op: NumericOperator, valid_rows: &Mask, ) -> VortexResult> where - T: OverflowingInteger, - Op: IntegerOp, + T: NativePType, + Op: CheckedPrimitiveBinary, { let mut failed = false; let mut values = BufferMut::::zeroed(lhs.len()); for ((dst, &lhs), valid) in values.iter_mut().zip(lhs).zip(valid_rows.iter()) { - let (value, error) = Op::apply(lhs, rhs); - *dst = value; - failed |= error & valid; + let checked = Op::checked(lhs, rhs); + let invalid = checked.is_none(); + *dst = checked.unwrap_or_default(); + failed |= invalid & valid; } - check_numeric_error(op, failed)?; + check_numeric_error::(failed)?; Ok(values.freeze()) } -fn integer_constant_array_all_valid( - lhs: T, - rhs: &[T], - op: NumericOperator, -) -> VortexResult> +fn checked_constant_array_all_valid(lhs: T, rhs: &[T]) -> VortexResult> where - T: OverflowingInteger, - Op: IntegerOp, + T: NativePType, + Op: CheckedPrimitiveBinary, { let mut failed = false; let mut values = BufferMut::::zeroed(rhs.len()); for (dst, &rhs) in values.iter_mut().zip(rhs) { - let (value, error) = Op::apply(lhs, rhs); - *dst = value; - failed |= error; + let checked = Op::checked(lhs, rhs); + let invalid = checked.is_none(); + *dst = checked.unwrap_or_default(); + failed |= invalid; } - check_numeric_error(op, failed)?; + check_numeric_error::(failed)?; Ok(values.freeze()) } -fn integer_constant_array_masked( +fn checked_constant_array_masked( lhs: T, rhs: &[T], - op: NumericOperator, valid_rows: &Mask, ) -> VortexResult> where - T: OverflowingInteger, - Op: IntegerOp, + T: NativePType, + Op: CheckedPrimitiveBinary, { let mut failed = false; let mut values = BufferMut::::zeroed(rhs.len()); for ((dst, &rhs), valid) in values.iter_mut().zip(rhs).zip(valid_rows.iter()) { - let (value, error) = Op::apply(lhs, rhs); - *dst = value; - failed |= error & valid; + let checked = Op::checked(lhs, rhs); + let invalid = checked.is_none(); + *dst = checked.unwrap_or_default(); + failed |= invalid & valid; } - check_numeric_error(op, failed)?; + check_numeric_error::(failed)?; Ok(values.freeze()) } -fn floating_array_array(lhs: &[T], rhs: &[T], op: NumericOperator) -> Buffer -where - T: NativePType - + std::ops::Add - + std::ops::Sub - + std::ops::Mul - + std::ops::Div, -{ - match op { - NumericOperator::Add => floating_array_array_op::(lhs, rhs), - NumericOperator::Sub => floating_array_array_op::(lhs, rhs), - NumericOperator::Mul => floating_array_array_op::(lhs, rhs), - NumericOperator::Div => floating_array_array_op::(lhs, rhs), +trait CheckedArithmetic: NativePType { + fn checked_add(self, rhs: Self) -> Option; + fn checked_sub(self, rhs: Self) -> Option; + fn checked_mul(self, rhs: Self) -> Option; + fn checked_div(self, rhs: Self) -> Option; +} + +impl CheckedArithmetic for u8 { + #[inline(always)] + fn checked_add(self, rhs: Self) -> Option { + let (value, overflow) = self.overflowing_add(rhs); + (!overflow).then_some(value) + } + + #[inline(always)] + fn checked_sub(self, rhs: Self) -> Option { + let (value, overflow) = self.overflowing_sub(rhs); + (!overflow).then_some(value) + } + + #[inline(always)] + #[allow(clippy::cast_possible_truncation)] + fn checked_mul(self, rhs: Self) -> Option { + let product = (self as u16) * (rhs as u16); + (product <= u8::MAX as u16).then_some(product as Self) + } + + #[inline(always)] + fn checked_div(self, rhs: Self) -> Option { + let invalid = rhs == 0; + let divisor = if invalid { 1 } else { rhs }; + (!invalid).then_some(self.wrapping_div(divisor)) } } -fn floating_array_constant(lhs: &[T], rhs: T, op: NumericOperator) -> Buffer -where - T: NativePType - + std::ops::Add - + std::ops::Sub - + std::ops::Mul - + std::ops::Div, -{ - match op { - NumericOperator::Add => floating_array_constant_op::(lhs, rhs), - NumericOperator::Sub => floating_array_constant_op::(lhs, rhs), - NumericOperator::Mul => floating_array_constant_op::(lhs, rhs), - NumericOperator::Div => floating_array_constant_op::(lhs, rhs), +impl CheckedArithmetic for u16 { + #[inline(always)] + fn checked_add(self, rhs: Self) -> Option { + let (value, overflow) = self.overflowing_add(rhs); + (!overflow).then_some(value) + } + + #[inline(always)] + fn checked_sub(self, rhs: Self) -> Option { + let (value, overflow) = self.overflowing_sub(rhs); + (!overflow).then_some(value) + } + + #[inline(always)] + #[allow(clippy::cast_possible_truncation)] + fn checked_mul(self, rhs: Self) -> Option { + let product = (self as u32) * (rhs as u32); + (product <= u16::MAX as u32).then_some(product as Self) + } + + #[inline(always)] + fn checked_div(self, rhs: Self) -> Option { + let invalid = rhs == 0; + let divisor = if invalid { 1 } else { rhs }; + (!invalid).then_some(self.wrapping_div(divisor)) } } -fn floating_constant_array(lhs: T, rhs: &[T], op: NumericOperator) -> Buffer -where - T: NativePType - + std::ops::Add - + std::ops::Sub - + std::ops::Mul - + std::ops::Div, -{ - match op { - NumericOperator::Add => floating_constant_array_op::(lhs, rhs), - NumericOperator::Sub => floating_constant_array_op::(lhs, rhs), - NumericOperator::Mul => floating_constant_array_op::(lhs, rhs), - NumericOperator::Div => floating_constant_array_op::(lhs, rhs), +impl CheckedArithmetic for u32 { + #[inline(always)] + fn checked_add(self, rhs: Self) -> Option { + let (value, overflow) = self.overflowing_add(rhs); + (!overflow).then_some(value) + } + + #[inline(always)] + fn checked_sub(self, rhs: Self) -> Option { + let (value, overflow) = self.overflowing_sub(rhs); + (!overflow).then_some(value) + } + + #[inline(always)] + #[allow(clippy::cast_possible_truncation)] + fn checked_mul(self, rhs: Self) -> Option { + let product = (self as u64) * (rhs as u64); + (product <= u32::MAX as u64).then_some(product as Self) + } + + #[inline(always)] + fn checked_div(self, rhs: Self) -> Option { + let invalid = rhs == 0; + let divisor = if invalid { 1 } else { rhs }; + (!invalid).then_some(self.wrapping_div(divisor)) } } -fn floating_scalar(lhs: T, rhs: T, op: NumericOperator) -> T -where - T: NativePType - + std::ops::Add - + std::ops::Sub - + std::ops::Mul - + std::ops::Div, -{ - match op { - NumericOperator::Add => >::apply(lhs, rhs), - NumericOperator::Sub => >::apply(lhs, rhs), - NumericOperator::Mul => >::apply(lhs, rhs), - NumericOperator::Div => >::apply(lhs, rhs), +impl CheckedArithmetic for u64 { + #[inline(always)] + fn checked_add(self, rhs: Self) -> Option { + let (value, overflow) = self.overflowing_add(rhs); + (!overflow).then_some(value) + } + + #[inline(always)] + fn checked_sub(self, rhs: Self) -> Option { + let (value, overflow) = self.overflowing_sub(rhs); + (!overflow).then_some(value) + } + + #[inline(always)] + fn checked_mul(self, rhs: Self) -> Option { + let (value, overflow) = self.overflowing_mul(rhs); + (!overflow).then_some(value) + } + + #[inline(always)] + fn checked_div(self, rhs: Self) -> Option { + let invalid = rhs == 0; + let divisor = if invalid { 1 } else { rhs }; + (!invalid).then_some(self.wrapping_div(divisor)) } } -fn floating_array_array_op(lhs: &[T], rhs: &[T]) -> Buffer -where - T: NativePType, - Op: FloatingOp, -{ - debug_assert_eq!(lhs.len(), rhs.len()); +impl CheckedArithmetic for i8 { + #[inline(always)] + fn checked_add(self, rhs: Self) -> Option { + let value = self.wrapping_add(rhs); + let overflow = ((self ^ value) & (rhs ^ value)) < 0; + (!overflow).then_some(value) + } - let mut values = BufferMut::::zeroed(lhs.len()); - for ((dst, &lhs), &rhs) in values.iter_mut().zip(lhs).zip(rhs) { - *dst = Op::apply(lhs, rhs); + #[inline(always)] + fn checked_sub(self, rhs: Self) -> Option { + let value = self.wrapping_sub(rhs); + let overflow = ((self ^ rhs) & (self ^ value)) < 0; + (!overflow).then_some(value) + } + + #[inline(always)] + #[allow(clippy::cast_possible_truncation)] + fn checked_mul(self, rhs: Self) -> Option { + let product = (self as i16) * (rhs as i16); + (product >= i8::MIN as i16 && product <= i8::MAX as i16).then_some(product as Self) + } + + #[inline(always)] + fn checked_div(self, rhs: Self) -> Option { + let div_by_zero = rhs == 0; + let overflow = self == i8::MIN && rhs == -1; + let divisor = if div_by_zero { 1 } else { rhs }; + (!(div_by_zero | overflow)).then_some(self.wrapping_div(divisor)) } - values.freeze() } -fn floating_array_constant_op(lhs: &[T], rhs: T) -> Buffer -where - T: NativePType, - Op: FloatingOp, -{ - let mut values = BufferMut::::zeroed(lhs.len()); - for (dst, &lhs) in values.iter_mut().zip(lhs) { - *dst = Op::apply(lhs, rhs); +impl CheckedArithmetic for i16 { + #[inline(always)] + fn checked_add(self, rhs: Self) -> Option { + let value = self.wrapping_add(rhs); + let overflow = ((self ^ value) & (rhs ^ value)) < 0; + (!overflow).then_some(value) + } + + #[inline(always)] + fn checked_sub(self, rhs: Self) -> Option { + let value = self.wrapping_sub(rhs); + let overflow = ((self ^ rhs) & (self ^ value)) < 0; + (!overflow).then_some(value) + } + + #[inline(always)] + #[allow(clippy::cast_possible_truncation)] + fn checked_mul(self, rhs: Self) -> Option { + let product = (self as i32) * (rhs as i32); + (product >= i16::MIN as i32 && product <= i16::MAX as i32).then_some(product as Self) + } + + #[inline(always)] + fn checked_div(self, rhs: Self) -> Option { + let div_by_zero = rhs == 0; + let overflow = self == i16::MIN && rhs == -1; + let divisor = if div_by_zero { 1 } else { rhs }; + (!(div_by_zero | overflow)).then_some(self.wrapping_div(divisor)) } - values.freeze() } -fn floating_constant_array_op(lhs: T, rhs: &[T]) -> Buffer -where - T: NativePType, - Op: FloatingOp, -{ - let mut values = BufferMut::::zeroed(rhs.len()); - for (dst, &rhs) in values.iter_mut().zip(rhs) { - *dst = Op::apply(lhs, rhs); +impl CheckedArithmetic for i32 { + #[inline(always)] + fn checked_add(self, rhs: Self) -> Option { + let value = self.wrapping_add(rhs); + let overflow = ((self ^ value) & (rhs ^ value)) < 0; + (!overflow).then_some(value) + } + + #[inline(always)] + fn checked_sub(self, rhs: Self) -> Option { + let value = self.wrapping_sub(rhs); + let overflow = ((self ^ rhs) & (self ^ value)) < 0; + (!overflow).then_some(value) + } + + #[inline(always)] + #[allow(clippy::cast_possible_truncation)] + fn checked_mul(self, rhs: Self) -> Option { + let product = (self as i64) * (rhs as i64); + (product >= i32::MIN as i64 && product <= i32::MAX as i64).then_some(product as Self) + } + + #[inline(always)] + fn checked_div(self, rhs: Self) -> Option { + let div_by_zero = rhs == 0; + let overflow = self == i32::MIN && rhs == -1; + let divisor = if div_by_zero { 1 } else { rhs }; + (!(div_by_zero | overflow)).then_some(self.wrapping_div(divisor)) + } +} + +impl CheckedArithmetic for i64 { + #[inline(always)] + fn checked_add(self, rhs: Self) -> Option { + let value = self.wrapping_add(rhs); + let overflow = ((self ^ value) & (rhs ^ value)) < 0; + (!overflow).then_some(value) + } + + #[inline(always)] + fn checked_sub(self, rhs: Self) -> Option { + let value = self.wrapping_sub(rhs); + let overflow = ((self ^ rhs) & (self ^ value)) < 0; + (!overflow).then_some(value) + } + + #[inline(always)] + fn checked_mul(self, rhs: Self) -> Option { + let (value, overflow) = self.overflowing_mul(rhs); + (!overflow).then_some(value) + } + + #[inline(always)] + fn checked_div(self, rhs: Self) -> Option { + let div_by_zero = rhs == 0; + let overflow = self == i64::MIN && rhs == -1; + let divisor = if div_by_zero { 1 } else { rhs }; + (!(div_by_zero | overflow)).then_some(self.wrapping_div(divisor)) } - values.freeze() } -fn check_numeric_error(op: NumericOperator, failed: bool) -> VortexResult<()> { +impl CheckedArithmetic for f16 { + #[inline(always)] + fn checked_add(self, rhs: Self) -> Option { + Some(self + rhs) + } + + #[inline(always)] + fn checked_sub(self, rhs: Self) -> Option { + Some(self - rhs) + } + + #[inline(always)] + fn checked_mul(self, rhs: Self) -> Option { + Some(self * rhs) + } + + #[inline(always)] + fn checked_div(self, rhs: Self) -> Option { + Some(self / rhs) + } +} + +impl CheckedArithmetic for f32 { + #[inline(always)] + fn checked_add(self, rhs: Self) -> Option { + Some(self + rhs) + } + + #[inline(always)] + fn checked_sub(self, rhs: Self) -> Option { + Some(self - rhs) + } + + #[inline(always)] + fn checked_mul(self, rhs: Self) -> Option { + Some(self * rhs) + } + + #[inline(always)] + fn checked_div(self, rhs: Self) -> Option { + Some(self / rhs) + } +} + +impl CheckedArithmetic for f64 { + #[inline(always)] + fn checked_add(self, rhs: Self) -> Option { + Some(self + rhs) + } + + #[inline(always)] + fn checked_sub(self, rhs: Self) -> Option { + Some(self - rhs) + } + + #[inline(always)] + fn checked_mul(self, rhs: Self) -> Option { + Some(self * rhs) + } + + #[inline(always)] + fn checked_div(self, rhs: Self) -> Option { + Some(self / rhs) + } +} + +fn check_numeric_error(failed: bool) -> VortexResult<()> { if failed { - return Err(numeric_error(op)); + return Err(numeric_error::()); } Ok(()) } -fn numeric_error(op: NumericOperator) -> vortex_error::VortexError { - match op { - NumericOperator::Add | NumericOperator::Sub | NumericOperator::Mul => { - vortex_err!(InvalidArgument: "integer overflow in numeric {} operation", op) - } - NumericOperator::Div => { - vortex_err!(InvalidArgument: "integer division by zero or overflow in numeric / operation") - } - } +fn numeric_error() -> VortexError { + vortex_err!(InvalidArgument: "{}", Op::ERROR) } #[cfg(test)] @@ -920,11 +818,11 @@ mod test { use crate::LEGACY_SESSION; use crate::RecursiveCanonical; use crate::VortexSessionExecute; + use crate::arrays::ConstantArray; use crate::arrays::PrimitiveArray; use crate::assert_arrays_eq; use crate::builtins::ArrayBuiltins; use crate::scalar::Scalar; - use crate::scalar_fn::fns::binary::numeric::ConstantArray; use crate::scalar_fn::fns::operators::Operator; use crate::validity::Validity;