diff --git a/vortex-array/src/scalar_fn/fns/binary/numeric.rs b/vortex-array/src/scalar_fn/fns/binary/numeric.rs index 2ef58fb5209..cbad0b8cc85 100644 --- a/vortex-array/src/scalar_fn/fns/binary/numeric.rs +++ b/vortex-array/src/scalar_fn/fns/binary/numeric.rs @@ -1,75 +1,811 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use vortex_buffer::Buffer; +use vortex_buffer::BufferMut; +use vortex_error::VortexError; use vortex_error::VortexResult; +use vortex_error::vortex_bail; +use vortex_error::vortex_err; +use vortex_mask::Mask; use crate::ArrayRef; +use crate::ExecutionCtx; use crate::IntoArray; use crate::arrays::Constant; use crate::arrays::ConstantArray; -use crate::arrow::Datum; -use crate::arrow::from_arrow_columnar; -use crate::executor::ExecutionCtx; +use crate::arrays::PrimitiveArray; +use crate::builtins::ArrayBuiltins; +use crate::dtype::DType; +use crate::dtype::NativePType; +use crate::dtype::PType; +use crate::dtype::half::f16; +use crate::match_each_native_ptype; use crate::scalar::NumericOperator; +use crate::scalar::Scalar; +use crate::validity::Validity; + +struct CheckedAdd; + +struct CheckedSub; + +struct CheckedMul; + +struct CheckedDiv; + +trait CheckedPrimitiveOp { + const ERROR: &'static str; +} + +trait CheckedPrimitiveBinary: CheckedPrimitiveOp { + fn checked(lhs: T, rhs: T) -> Option; +} + +trait CheckedPrimitiveBinaryAll: + CheckedPrimitiveBinary + + CheckedPrimitiveBinary + + CheckedPrimitiveBinary + + CheckedPrimitiveBinary + + CheckedPrimitiveBinary + + CheckedPrimitiveBinary + + CheckedPrimitiveBinary + + CheckedPrimitiveBinary + + CheckedPrimitiveBinary + + CheckedPrimitiveBinary + + CheckedPrimitiveBinary +{ +} + +impl CheckedPrimitiveBinaryAll for Op where + Op: CheckedPrimitiveBinary + + CheckedPrimitiveBinary + + CheckedPrimitiveBinary + + CheckedPrimitiveBinary + + CheckedPrimitiveBinary + + CheckedPrimitiveBinary + + CheckedPrimitiveBinary + + CheckedPrimitiveBinary + + CheckedPrimitiveBinary + + CheckedPrimitiveBinary + + CheckedPrimitiveBinary +{ +} + +impl CheckedPrimitiveOp for CheckedAdd { + const ERROR: &'static str = "integer overflow in checked add"; +} + +impl CheckedPrimitiveOp for CheckedSub { + const ERROR: &'static str = "integer overflow in checked sub"; +} + +impl CheckedPrimitiveOp for CheckedMul { + const ERROR: &'static str = "integer overflow in checked mul"; +} + +impl CheckedPrimitiveOp for CheckedDiv { + const ERROR: &'static str = "integer division by zero or overflow in checked div"; +} + +impl CheckedPrimitiveBinary for CheckedAdd { + #[inline(always)] + fn checked(lhs: T, rhs: T) -> Option { + lhs.checked_add(rhs) + } +} + +impl CheckedPrimitiveBinary for CheckedSub { + #[inline(always)] + fn checked(lhs: T, rhs: T) -> Option { + lhs.checked_sub(rhs) + } +} + +impl CheckedPrimitiveBinary for CheckedMul { + #[inline(always)] + fn checked(lhs: T, rhs: T) -> Option { + lhs.checked_mul(rhs) + } +} + +impl CheckedPrimitiveBinary for CheckedDiv { + #[inline(always)] + fn checked(lhs: T, rhs: T) -> Option { + lhs.checked_div(rhs) + } +} /// Execute a numeric operation between two arrays. -/// -/// This is the entry point for numeric operations from the binary expression. -/// Handles constant-constant directly, otherwise falls back to Arrow. pub(crate) fn execute_numeric( lhs: &ArrayRef, rhs: &ArrayRef, op: NumericOperator, ctx: &mut ExecutionCtx, ) -> VortexResult { - if let Some(result) = constant_numeric(lhs, rhs, op)? { - return Ok(result); + match op { + NumericOperator::Add => execute_checked_numeric::(lhs, rhs, ctx), + NumericOperator::Sub => execute_checked_numeric::(lhs, rhs, ctx), + NumericOperator::Mul => execute_checked_numeric::(lhs, rhs, ctx), + NumericOperator::Div => execute_checked_numeric::(lhs, rhs, ctx), } - arrow_numeric(lhs, rhs, op, ctx) } -/// Implementation of numeric operations using the Arrow crate. -pub(crate) fn arrow_numeric( +fn execute_checked_numeric( lhs: &ArrayRef, rhs: &ArrayRef, - operator: NumericOperator, ctx: &mut ExecutionCtx, -) -> VortexResult { - let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable(); +) -> VortexResult +where + Op: CheckedPrimitiveBinaryAll, +{ + let ptype = PType::try_from(lhs.dtype())?; + if !lhs.dtype().eq_ignore_nullability(rhs.dtype()) { + vortex_bail!( + "numeric operator requires matching primitive types, got {} and {}", + lhs.dtype(), + rhs.dtype() + ); + } + + match_each_native_ptype!(ptype, |T| { execute_checked_typed::(lhs, rhs, ctx) }) +} + +fn execute_checked_typed( + lhs: &ArrayRef, + rhs: &ArrayRef, + ctx: &mut ExecutionCtx, +) -> VortexResult +where + T: NativePType, + Op: CheckedPrimitiveBinary, + Scalar: From, + Scalar: From>, +{ + let result_dtype = lhs + .dtype() + .with_nullability(lhs.dtype().nullability() | rhs.dtype().nullability()); + let lhs = PrimitiveOperand::::try_new(lhs, ctx)?; + let rhs = PrimitiveOperand::::try_new(rhs, ctx)?; let len = lhs.len(); + if len != rhs.len() { + vortex_bail!( + "numeric operator requires equal lengths, got {} and {}", + len, + rhs.len() + ); + } - let left = Datum::try_new(lhs, ctx)?; - let right = Datum::try_new_with_target_datatype(rhs, left.data_type(), ctx)?; + let validity = lhs.validity().and(rhs.validity())?; + let valid_rows = ValidRows::from_validity(&validity, len, ctx)?; + if valid_rows.is_none() { + return primitive_result_array::(Buffer::::zeroed(len), validity, &result_dtype); + } - let array = match operator { - NumericOperator::Add => arrow_arith::numeric::add(&left, &right)?, - NumericOperator::Sub => arrow_arith::numeric::sub(&left, &right)?, - NumericOperator::Mul => arrow_arith::numeric::mul(&left, &right)?, - NumericOperator::Div => arrow_arith::numeric::div(&left, &right)?, + let values = match (&lhs, &rhs) { + (PrimitiveOperand::Array(lhs), PrimitiveOperand::Array(rhs)) => { + checked_array_array::(lhs.values(), rhs.values(), &valid_rows)? + } + (PrimitiveOperand::Array(lhs), PrimitiveOperand::Constant { value: rhs, .. }) => { + checked_array_constant::(lhs.values(), *rhs, &valid_rows)? + } + (PrimitiveOperand::Constant { value: lhs, .. }, PrimitiveOperand::Array(rhs)) => { + checked_constant_array::(*lhs, rhs.values(), &valid_rows)? + } + ( + PrimitiveOperand::Constant { value: lhs, .. }, + PrimitiveOperand::Constant { value: rhs, .. }, + ) => { + let value = Op::checked(*lhs, *rhs).ok_or_else(|| numeric_error::())?; + return Ok(constant_result_array(value, len, &result_dtype)); + } + (PrimitiveOperand::Null(_), _) | (_, PrimitiveOperand::Null(_)) => Buffer::::zeroed(len), }; - from_arrow_columnar(array.as_ref(), len, nullable, ctx) + primitive_result_array::(values, validity, &result_dtype) } -fn constant_numeric( - lhs: &ArrayRef, - rhs: &ArrayRef, - op: NumericOperator, -) -> VortexResult> { - let (Some(lhs), Some(rhs)) = (lhs.as_opt::(), rhs.as_opt::()) else { - return Ok(None); - }; +fn primitive_result_array( + values: Buffer, + validity: Validity, + dtype: &DType, +) -> VortexResult { + let array = PrimitiveArray::new(values, validity).into_array(); + if array.dtype() == dtype { + return Ok(array); + } + array.cast(dtype.clone()) +} - let Some(result) = lhs - .scalar() - .as_primitive() - .checked_binary_numeric(&rhs.scalar().as_primitive(), op) - else { - // Overflow detected — fall through to arrow_numeric which uses wrapping arithmetic. - return Ok(None); - }; +fn constant_result_array(value: T, len: usize, dtype: &DType) -> ArrayRef +where + T: NativePType, + Scalar: From + From>, +{ + if dtype.is_nullable() { + ConstantArray::new(Some(value), len).into_array() + } else { + ConstantArray::new(value, len).into_array() + } +} + +enum PrimitiveOperand { + Array(TypedPrimitive), + Constant { + value: T, + len: usize, + validity: Validity, + }, + Null(usize), +} + +impl PrimitiveOperand { + fn try_new(array: &ArrayRef, ctx: &mut ExecutionCtx) -> VortexResult { + if let Some(constant) = array.as_opt::() { + return Ok( + match constant.scalar().as_primitive().try_typed_value::()? { + Some(value) => Self::Constant { + value, + len: array.len(), + validity: if constant.scalar().dtype().is_nullable() { + Validity::AllValid + } else { + Validity::NonNullable + }, + }, + None => Self::Null(array.len()), + }, + ); + } + + Ok(Self::Array(TypedPrimitive::new( + array.clone().execute::(ctx)?, + )?)) + } + + fn len(&self) -> usize { + match self { + Self::Array(array) => array.values().len(), + Self::Constant { len, .. } | Self::Null(len) => *len, + } + } - Ok(Some(ConstantArray::new(result, lhs.len()).into_array())) + fn validity(&self) -> Validity { + match self { + Self::Array(array) => array.validity(), + Self::Constant { validity, .. } => validity.clone(), + Self::Null(_) => Validity::AllInvalid, + } + } +} + +struct TypedPrimitive { + values: Buffer, + validity: Validity, +} + +impl TypedPrimitive { + fn new(array: PrimitiveArray) -> VortexResult { + let validity = array.validity()?; + let values = array.into_buffer::(); + Ok(Self { values, validity }) + } + + fn values(&self) -> &[T] { + &self.values + } + + fn validity(&self) -> Validity { + self.validity.clone() + } +} + +enum ValidRows { + All, + Some(Mask), + None, +} + +impl ValidRows { + fn from_validity( + validity: &Validity, + len: usize, + ctx: &mut ExecutionCtx, + ) -> VortexResult { + let mask = validity.execute_mask(len, ctx)?; + Ok(if mask.all_true() { + Self::All + } else if mask.all_false() { + Self::None + } else { + Self::Some(mask) + }) + } + + fn is_none(&self) -> bool { + matches!(self, Self::None) + } +} + +fn checked_array_array( + lhs: &[T], + rhs: &[T], + valid_rows: &ValidRows, +) -> VortexResult> +where + T: NativePType, + Op: CheckedPrimitiveBinary, +{ + debug_assert_eq!(lhs.len(), rhs.len()); + + match valid_rows { + ValidRows::All => checked_array_array_all_valid::(lhs, rhs), + ValidRows::Some(mask) => checked_array_array_masked::(lhs, rhs, mask), + ValidRows::None => Ok(Buffer::::zeroed(lhs.len())), + } +} + +fn checked_array_constant( + lhs: &[T], + rhs: T, + valid_rows: &ValidRows, +) -> VortexResult> +where + T: NativePType, + Op: CheckedPrimitiveBinary, +{ + match valid_rows { + ValidRows::All => checked_array_constant_all_valid::(lhs, rhs), + ValidRows::Some(mask) => checked_array_constant_masked::(lhs, rhs, mask), + ValidRows::None => Ok(Buffer::::zeroed(lhs.len())), + } +} + +fn checked_constant_array( + lhs: T, + rhs: &[T], + valid_rows: &ValidRows, +) -> VortexResult> +where + T: NativePType, + Op: CheckedPrimitiveBinary, +{ + match valid_rows { + ValidRows::All => checked_constant_array_all_valid::(lhs, rhs), + ValidRows::Some(mask) => checked_constant_array_masked::(lhs, rhs, mask), + ValidRows::None => Ok(Buffer::::zeroed(rhs.len())), + } +} + +fn checked_array_array_all_valid(lhs: &[T], rhs: &[T]) -> VortexResult> +where + T: NativePType, + Op: CheckedPrimitiveBinary, +{ + let mut failed = false; + let mut values = BufferMut::::zeroed(lhs.len()); + for ((dst, &lhs), &rhs) in values.iter_mut().zip(lhs).zip(rhs) { + let checked = Op::checked(lhs, rhs); + let invalid = checked.is_none(); + *dst = checked.unwrap_or_default(); + failed |= invalid; + } + check_numeric_error::(failed)?; + Ok(values.freeze()) +} + +fn checked_array_array_masked( + lhs: &[T], + rhs: &[T], + valid_rows: &Mask, +) -> VortexResult> +where + T: NativePType, + Op: CheckedPrimitiveBinary, +{ + let mut failed = false; + let mut values = BufferMut::::zeroed(lhs.len()); + for (((dst, &lhs), &rhs), valid) in values.iter_mut().zip(lhs).zip(rhs).zip(valid_rows.iter()) { + let checked = Op::checked(lhs, rhs); + let invalid = checked.is_none(); + *dst = checked.unwrap_or_default(); + failed |= invalid & valid; + } + check_numeric_error::(failed)?; + Ok(values.freeze()) +} + +fn checked_array_constant_all_valid(lhs: &[T], rhs: T) -> VortexResult> +where + T: NativePType, + Op: CheckedPrimitiveBinary, +{ + let mut failed = false; + let mut values = BufferMut::::zeroed(lhs.len()); + for (dst, &lhs) in values.iter_mut().zip(lhs) { + let checked = Op::checked(lhs, rhs); + let invalid = checked.is_none(); + *dst = checked.unwrap_or_default(); + failed |= invalid; + } + check_numeric_error::(failed)?; + Ok(values.freeze()) +} + +fn checked_array_constant_masked( + lhs: &[T], + rhs: T, + valid_rows: &Mask, +) -> VortexResult> +where + T: NativePType, + Op: CheckedPrimitiveBinary, +{ + let mut failed = false; + let mut values = BufferMut::::zeroed(lhs.len()); + for ((dst, &lhs), valid) in values.iter_mut().zip(lhs).zip(valid_rows.iter()) { + let checked = Op::checked(lhs, rhs); + let invalid = checked.is_none(); + *dst = checked.unwrap_or_default(); + failed |= invalid & valid; + } + check_numeric_error::(failed)?; + Ok(values.freeze()) +} + +fn checked_constant_array_all_valid(lhs: T, rhs: &[T]) -> VortexResult> +where + T: NativePType, + Op: CheckedPrimitiveBinary, +{ + let mut failed = false; + let mut values = BufferMut::::zeroed(rhs.len()); + for (dst, &rhs) in values.iter_mut().zip(rhs) { + let checked = Op::checked(lhs, rhs); + let invalid = checked.is_none(); + *dst = checked.unwrap_or_default(); + failed |= invalid; + } + check_numeric_error::(failed)?; + Ok(values.freeze()) +} + +fn checked_constant_array_masked( + lhs: T, + rhs: &[T], + valid_rows: &Mask, +) -> VortexResult> +where + T: NativePType, + Op: CheckedPrimitiveBinary, +{ + let mut failed = false; + let mut values = BufferMut::::zeroed(rhs.len()); + for ((dst, &rhs), valid) in values.iter_mut().zip(rhs).zip(valid_rows.iter()) { + let checked = Op::checked(lhs, rhs); + let invalid = checked.is_none(); + *dst = checked.unwrap_or_default(); + failed |= invalid & valid; + } + check_numeric_error::(failed)?; + Ok(values.freeze()) +} + +trait CheckedArithmetic: NativePType { + fn checked_add(self, rhs: Self) -> Option; + fn checked_sub(self, rhs: Self) -> Option; + fn checked_mul(self, rhs: Self) -> Option; + fn checked_div(self, rhs: Self) -> Option; +} + +impl CheckedArithmetic for u8 { + #[inline(always)] + fn checked_add(self, rhs: Self) -> Option { + let (value, overflow) = self.overflowing_add(rhs); + (!overflow).then_some(value) + } + + #[inline(always)] + fn checked_sub(self, rhs: Self) -> Option { + let (value, overflow) = self.overflowing_sub(rhs); + (!overflow).then_some(value) + } + + #[inline(always)] + #[allow(clippy::cast_possible_truncation)] + fn checked_mul(self, rhs: Self) -> Option { + let product = (self as u16) * (rhs as u16); + (product <= u8::MAX as u16).then_some(product as Self) + } + + #[inline(always)] + fn checked_div(self, rhs: Self) -> Option { + let invalid = rhs == 0; + let divisor = if invalid { 1 } else { rhs }; + (!invalid).then_some(self.wrapping_div(divisor)) + } +} + +impl CheckedArithmetic for u16 { + #[inline(always)] + fn checked_add(self, rhs: Self) -> Option { + let (value, overflow) = self.overflowing_add(rhs); + (!overflow).then_some(value) + } + + #[inline(always)] + fn checked_sub(self, rhs: Self) -> Option { + let (value, overflow) = self.overflowing_sub(rhs); + (!overflow).then_some(value) + } + + #[inline(always)] + #[allow(clippy::cast_possible_truncation)] + fn checked_mul(self, rhs: Self) -> Option { + let product = (self as u32) * (rhs as u32); + (product <= u16::MAX as u32).then_some(product as Self) + } + + #[inline(always)] + fn checked_div(self, rhs: Self) -> Option { + let invalid = rhs == 0; + let divisor = if invalid { 1 } else { rhs }; + (!invalid).then_some(self.wrapping_div(divisor)) + } +} + +impl CheckedArithmetic for u32 { + #[inline(always)] + fn checked_add(self, rhs: Self) -> Option { + let (value, overflow) = self.overflowing_add(rhs); + (!overflow).then_some(value) + } + + #[inline(always)] + fn checked_sub(self, rhs: Self) -> Option { + let (value, overflow) = self.overflowing_sub(rhs); + (!overflow).then_some(value) + } + + #[inline(always)] + #[allow(clippy::cast_possible_truncation)] + fn checked_mul(self, rhs: Self) -> Option { + let product = (self as u64) * (rhs as u64); + (product <= u32::MAX as u64).then_some(product as Self) + } + + #[inline(always)] + fn checked_div(self, rhs: Self) -> Option { + let invalid = rhs == 0; + let divisor = if invalid { 1 } else { rhs }; + (!invalid).then_some(self.wrapping_div(divisor)) + } +} + +impl CheckedArithmetic for u64 { + #[inline(always)] + fn checked_add(self, rhs: Self) -> Option { + let (value, overflow) = self.overflowing_add(rhs); + (!overflow).then_some(value) + } + + #[inline(always)] + fn checked_sub(self, rhs: Self) -> Option { + let (value, overflow) = self.overflowing_sub(rhs); + (!overflow).then_some(value) + } + + #[inline(always)] + fn checked_mul(self, rhs: Self) -> Option { + let (value, overflow) = self.overflowing_mul(rhs); + (!overflow).then_some(value) + } + + #[inline(always)] + fn checked_div(self, rhs: Self) -> Option { + let invalid = rhs == 0; + let divisor = if invalid { 1 } else { rhs }; + (!invalid).then_some(self.wrapping_div(divisor)) + } +} + +impl CheckedArithmetic for i8 { + #[inline(always)] + fn checked_add(self, rhs: Self) -> Option { + let value = self.wrapping_add(rhs); + let overflow = ((self ^ value) & (rhs ^ value)) < 0; + (!overflow).then_some(value) + } + + #[inline(always)] + fn checked_sub(self, rhs: Self) -> Option { + let value = self.wrapping_sub(rhs); + let overflow = ((self ^ rhs) & (self ^ value)) < 0; + (!overflow).then_some(value) + } + + #[inline(always)] + #[allow(clippy::cast_possible_truncation)] + fn checked_mul(self, rhs: Self) -> Option { + let product = (self as i16) * (rhs as i16); + (product >= i8::MIN as i16 && product <= i8::MAX as i16).then_some(product as Self) + } + + #[inline(always)] + fn checked_div(self, rhs: Self) -> Option { + let div_by_zero = rhs == 0; + let overflow = self == i8::MIN && rhs == -1; + let divisor = if div_by_zero { 1 } else { rhs }; + (!(div_by_zero | overflow)).then_some(self.wrapping_div(divisor)) + } +} + +impl CheckedArithmetic for i16 { + #[inline(always)] + fn checked_add(self, rhs: Self) -> Option { + let value = self.wrapping_add(rhs); + let overflow = ((self ^ value) & (rhs ^ value)) < 0; + (!overflow).then_some(value) + } + + #[inline(always)] + fn checked_sub(self, rhs: Self) -> Option { + let value = self.wrapping_sub(rhs); + let overflow = ((self ^ rhs) & (self ^ value)) < 0; + (!overflow).then_some(value) + } + + #[inline(always)] + #[allow(clippy::cast_possible_truncation)] + fn checked_mul(self, rhs: Self) -> Option { + let product = (self as i32) * (rhs as i32); + (product >= i16::MIN as i32 && product <= i16::MAX as i32).then_some(product as Self) + } + + #[inline(always)] + fn checked_div(self, rhs: Self) -> Option { + let div_by_zero = rhs == 0; + let overflow = self == i16::MIN && rhs == -1; + let divisor = if div_by_zero { 1 } else { rhs }; + (!(div_by_zero | overflow)).then_some(self.wrapping_div(divisor)) + } +} + +impl CheckedArithmetic for i32 { + #[inline(always)] + fn checked_add(self, rhs: Self) -> Option { + let value = self.wrapping_add(rhs); + let overflow = ((self ^ value) & (rhs ^ value)) < 0; + (!overflow).then_some(value) + } + + #[inline(always)] + fn checked_sub(self, rhs: Self) -> Option { + let value = self.wrapping_sub(rhs); + let overflow = ((self ^ rhs) & (self ^ value)) < 0; + (!overflow).then_some(value) + } + + #[inline(always)] + #[allow(clippy::cast_possible_truncation)] + fn checked_mul(self, rhs: Self) -> Option { + let product = (self as i64) * (rhs as i64); + (product >= i32::MIN as i64 && product <= i32::MAX as i64).then_some(product as Self) + } + + #[inline(always)] + fn checked_div(self, rhs: Self) -> Option { + let div_by_zero = rhs == 0; + let overflow = self == i32::MIN && rhs == -1; + let divisor = if div_by_zero { 1 } else { rhs }; + (!(div_by_zero | overflow)).then_some(self.wrapping_div(divisor)) + } +} + +impl CheckedArithmetic for i64 { + #[inline(always)] + fn checked_add(self, rhs: Self) -> Option { + let value = self.wrapping_add(rhs); + let overflow = ((self ^ value) & (rhs ^ value)) < 0; + (!overflow).then_some(value) + } + + #[inline(always)] + fn checked_sub(self, rhs: Self) -> Option { + let value = self.wrapping_sub(rhs); + let overflow = ((self ^ rhs) & (self ^ value)) < 0; + (!overflow).then_some(value) + } + + #[inline(always)] + fn checked_mul(self, rhs: Self) -> Option { + let (value, overflow) = self.overflowing_mul(rhs); + (!overflow).then_some(value) + } + + #[inline(always)] + fn checked_div(self, rhs: Self) -> Option { + let div_by_zero = rhs == 0; + let overflow = self == i64::MIN && rhs == -1; + let divisor = if div_by_zero { 1 } else { rhs }; + (!(div_by_zero | overflow)).then_some(self.wrapping_div(divisor)) + } +} + +impl CheckedArithmetic for f16 { + #[inline(always)] + fn checked_add(self, rhs: Self) -> Option { + Some(self + rhs) + } + + #[inline(always)] + fn checked_sub(self, rhs: Self) -> Option { + Some(self - rhs) + } + + #[inline(always)] + fn checked_mul(self, rhs: Self) -> Option { + Some(self * rhs) + } + + #[inline(always)] + fn checked_div(self, rhs: Self) -> Option { + Some(self / rhs) + } +} + +impl CheckedArithmetic for f32 { + #[inline(always)] + fn checked_add(self, rhs: Self) -> Option { + Some(self + rhs) + } + + #[inline(always)] + fn checked_sub(self, rhs: Self) -> Option { + Some(self - rhs) + } + + #[inline(always)] + fn checked_mul(self, rhs: Self) -> Option { + Some(self * rhs) + } + + #[inline(always)] + fn checked_div(self, rhs: Self) -> Option { + Some(self / rhs) + } +} + +impl CheckedArithmetic for f64 { + #[inline(always)] + fn checked_add(self, rhs: Self) -> Option { + Some(self + rhs) + } + + #[inline(always)] + fn checked_sub(self, rhs: Self) -> Option { + Some(self - rhs) + } + + #[inline(always)] + fn checked_mul(self, rhs: Self) -> Option { + Some(self * rhs) + } + + #[inline(always)] + fn checked_div(self, rhs: Self) -> Option { + Some(self / rhs) + } +} + +fn check_numeric_error(failed: bool) -> VortexResult<()> { + if failed { + return Err(numeric_error::()); + } + Ok(()) +} + +fn numeric_error() -> VortexError { + vortex_err!(InvalidArgument: "{}", Op::ERROR) } #[cfg(test)] @@ -82,12 +818,13 @@ mod test { use crate::LEGACY_SESSION; use crate::RecursiveCanonical; use crate::VortexSessionExecute; + use crate::arrays::ConstantArray; use crate::arrays::PrimitiveArray; use crate::assert_arrays_eq; use crate::builtins::ArrayBuiltins; use crate::scalar::Scalar; - use crate::scalar_fn::fns::binary::numeric::ConstantArray; use crate::scalar_fn::fns::operators::Operator; + use crate::validity::Validity; fn sub_scalar(array: &ArrayRef, scalar: impl Into) -> VortexResult { array @@ -138,4 +875,65 @@ mod test { let _results = sub_scalar(&values, 1.0f32).unwrap(); let _results = sub_scalar(&values, f32::MAX).unwrap(); } + + #[test] + fn test_integer_overflow_errors() { + let values = buffer![u8::MAX].into_array(); + let result = values + .binary( + ConstantArray::new(1u8, values.len()).into_array(), + Operator::Add, + ) + .and_then(|a| a.execute::(&mut LEGACY_SESSION.create_execution_ctx())); + + assert!(result.is_err()); + } + + #[test] + fn test_integer_divide_by_zero_errors() { + let values = buffer![1i32].into_array(); + let result = values + .binary( + ConstantArray::new(0i32, values.len()).into_array(), + Operator::Div, + ) + .and_then(|a| a.execute::(&mut LEGACY_SESSION.create_execution_ctx())); + + assert!(result.is_err()); + } + + #[test] + fn test_integer_errors_ignore_null_lanes() { + let values = PrimitiveArray::new(buffer![u8::MAX, 1], Validity::from_iter([false, true])) + .into_array(); + let result = values + .binary( + ConstantArray::new(1u8, values.len()).into_array(), + Operator::Add, + ) + .and_then(|a| { + a.execute::(&mut LEGACY_SESSION.create_execution_ctx()) + }) + .map(|a| a.0.into_array()) + .unwrap(); + + assert_arrays_eq!(result, PrimitiveArray::from_option_iter([None, Some(2u8)])); + } + + #[test] + fn test_present_nullable_constant_preserves_nullable_output() { + let values = buffer![1u8, 2].into_array(); + let result = values + .binary( + ConstantArray::new(Some(1u8), values.len()).into_array(), + Operator::Add, + ) + .and_then(|a| a.execute::(&mut LEGACY_SESSION.create_execution_ctx())) + .unwrap(); + + assert_arrays_eq!( + result, + PrimitiveArray::from_option_iter([Some(2u8), Some(3)]) + ); + } }