From 2c37139dc805334a461d7b4d0dff26824af73f8c Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Wed, 10 Jun 2026 07:13:16 -0700 Subject: [PATCH 1/2] vortex-array: decimal arithmetic dtypes and widening rescale cast Binary::return_dtype derives Hive-style result types for decimal + - * / (mirroring arrow-arith's execution-time rules, with precision saturation at the physical width), instead of erroring at plan time. The decimal cast kernel gains widening rescale so expression coercion can align decimal operand types. Needed by vortex-engine for TPC-H Q1/Q6 decimal expressions. Co-Authored-By: Claude Fable 5 Signed-off-by: Nicholas Gates --- .../src/arrays/decimal/compute/cast.rs | 111 ++++++++++++++++-- vortex-array/src/scalar_fn/fns/binary/mod.rs | 92 +++++++++++++++ 2 files changed, 194 insertions(+), 9 deletions(-) diff --git a/vortex-array/src/arrays/decimal/compute/cast.rs b/vortex-array/src/arrays/decimal/compute/cast.rs index 432313d3cb6..7ed4d287fbb 100644 --- a/vortex-array/src/arrays/decimal/compute/cast.rs +++ b/vortex-array/src/arrays/decimal/compute/cast.rs @@ -77,8 +77,8 @@ impl CastKernel for Decimal { ); }; - // Scale changes are not yet supported - if from_decimal_dtype.scale() != to_decimal_dtype.scale() { + // Narrowing the scale (dropping fractional digits) is not supported. + if from_decimal_dtype.scale() > to_decimal_dtype.scale() { vortex_bail!( "Casting decimal with scale {} to scale {} not yet implemented", from_decimal_dtype.scale(), @@ -86,8 +86,12 @@ impl CastKernel for Decimal { ); } - // Downcasting precision is not yet supported - if to_decimal_dtype.precision() < from_decimal_dtype.precision() { + // The target must retain at least the source's integer digits. + let from_integer_digits = + i16::from(from_decimal_dtype.precision()) - i16::from(from_decimal_dtype.scale()); + let to_integer_digits = + i16::from(to_decimal_dtype.precision()) - i16::from(to_decimal_dtype.scale()); + if to_integer_digits < from_integer_digits { vortex_bail!( "Downcasting decimal from precision {} to {} not yet implemented", from_decimal_dtype.precision(), @@ -105,6 +109,12 @@ impl CastKernel for Decimal { .validity()? .cast_nullability(*to_nullability, array.len(), ctx)?; + // Widening the scale multiplies unscaled values by a power of ten. + if from_decimal_dtype.scale() < to_decimal_dtype.scale() { + let rescaled = rescale_decimal_values(array, *to_decimal_dtype, new_validity)?; + return Ok(Some(rescaled.into_array())); + } + // If the target needs a wider physical type, upcast the values let target_values_type = DecimalType::smallest_decimal_value_type(to_decimal_dtype); let array = if target_values_type > array.values_type() { @@ -128,6 +138,73 @@ impl CastKernel for Decimal { } } +/// Rescale a DecimalArray to a wider scale (e.g. `(16,2)` → `(31,4)`), +/// multiplying unscaled values by the corresponding power of ten. The +/// result is stored at the width the target precision requires. +fn rescale_decimal_values( + array: ArrayView<'_, Decimal>, + to: crate::dtype::DecimalDType, + validity: crate::validity::Validity, +) -> VortexResult { + let from = array.decimal_dtype(); + let scale_up = u32::try_from(to.scale() - from.scale()) + .map_err(|_| vortex_error::vortex_err!("rescale requires a widening scale"))?; + let factor = 10i128 + .checked_pow(scale_up) + .ok_or_else(|| vortex_error::vortex_err!("rescale factor overflows i128"))?; + + // Gather unscaled values as i128 (i256 sources are unsupported). + let values: Vec = match array.values_type() { + DecimalType::I8 => array + .buffer::() + .iter() + .map(|&v| i128::from(v)) + .collect(), + DecimalType::I16 => array + .buffer::() + .iter() + .map(|&v| i128::from(v)) + .collect(), + DecimalType::I32 => array + .buffer::() + .iter() + .map(|&v| i128::from(v)) + .collect(), + DecimalType::I64 => array + .buffer::() + .iter() + .map(|&v| i128::from(v)) + .collect(), + DecimalType::I128 => array.buffer::().iter().copied().collect(), + DecimalType::I256 => vortex_bail!("rescaling i256 decimals is not supported"), + }; + + let rescaled = values + .into_iter() + .map(|v| { + v.checked_mul(factor) + .ok_or_else(|| vortex_error::vortex_err!("decimal rescale overflows i128")) + }) + .collect::>>()?; + + match DecimalType::smallest_decimal_value_type(&to) { + DecimalType::I256 => vortex_bail!("rescaling into i256 decimals is not supported"), + DecimalType::I128 => Ok(DecimalArray::new(Buffer::from_iter(rescaled), to, validity)), + // Narrow storage targets: the values fit by the precision check. + DecimalType::I64 | DecimalType::I32 | DecimalType::I16 | DecimalType::I8 => { + let narrowed = rescaled + .into_iter() + .map(|v| { + i64::try_from(v).map_err(|_| { + vortex_error::vortex_err!("rescaled decimal exceeds target width") + }) + }) + .collect::>>()?; + Ok(DecimalArray::new(Buffer::from_iter(narrowed), to, validity)) + } + } +} + /// Upcast a DecimalArray to a wider physical representation (e.g., i32 -> i64) while keeping /// the same precision and scale. /// @@ -262,19 +339,35 @@ mod tests { } #[test] - fn cast_different_scale_fails() { + fn cast_widening_scale_rescales() { + let array = DecimalArray::new( + buffer![100i32, -250], + DecimalDType::new(10, 2), + Validity::NonNullable, + ); + + // 1.00 and -2.50 at scale 2 become 1.000 and -2.500 at scale 3. + let wider = DType::Decimal(DecimalDType::new(15, 3), Nullability::NonNullable); + #[expect(deprecated)] + let casted = array.into_array().cast(wider.clone()).unwrap().to_decimal(); + assert_eq!(casted.dtype(), &wider); + assert_eq!(casted.buffer::().as_ref(), &[1000i64, -2500]); + } + + #[test] + fn cast_narrowing_scale_fails() { let array = DecimalArray::new( buffer![100i32], DecimalDType::new(10, 2), Validity::NonNullable, ); - // Try to cast to different scale - not supported - let different_dtype = DType::Decimal(DecimalDType::new(15, 3), Nullability::NonNullable); + // Dropping fractional digits is not supported. + let narrower = DType::Decimal(DecimalDType::new(15, 1), Nullability::NonNullable); #[expect(deprecated)] let result = array .into_array() - .cast(different_dtype) + .cast(narrower) .and_then(|a| a.to_canonical().map(|c| c.into_array())); assert!(result.is_err()); @@ -282,7 +375,7 @@ mod tests { result .unwrap_err() .to_string() - .contains("Casting decimal with scale 2 to scale 3 not yet implemented") + .contains("Casting decimal with scale 2 to scale 1 not yet implemented") ); } diff --git a/vortex-array/src/scalar_fn/fns/binary/mod.rs b/vortex-array/src/scalar_fn/fns/binary/mod.rs index 1c860cb75b5..1b0a62a4ac3 100644 --- a/vortex-array/src/scalar_fn/fns/binary/mod.rs +++ b/vortex-array/src/scalar_fn/fns/binary/mod.rs @@ -17,6 +17,7 @@ use vortex_session::registry::CachedId; use crate::ArrayRef; use crate::ExecutionCtx; use crate::dtype::DType; +use crate::dtype::DecimalDType; use crate::expr::StatsCatalog; use crate::expr::and; use crate::expr::and_collect; @@ -46,6 +47,47 @@ pub(crate) use numeric::*; use crate::scalar::NumericOperator; +/// Output decimal type of an arithmetic `operator` over two operands that +/// have already been coerced to the same decimal type. +/// +/// Mirrors the Hive-style rules `arrow-arith` applies at execution time +/// (see `arrow_arith::numeric::decimal_op`), including precision saturation +/// at the physical width's maximum: vortex lowers precisions `<= 38` to +/// Arrow `Decimal128` and wider decimals to `Decimal256`. +fn decimal_arithmetic_dtype( + operator: Operator, + operand: DecimalDType, +) -> VortexResult { + let p = u16::from(operand.precision()); + let s = i16::from(operand.scale()); + let (max_precision, max_scale): (u16, i16) = if p <= 38 { (38, 38) } else { (76, 76) }; + let (precision, scale) = match operator { + // scale = max(s, s); precision = max(p - s, p - s) + scale + 1 + Operator::Add | Operator::Sub => ((p + 1).min(max_precision), s), + // scale = s + s; precision = p + p + 1 + Operator::Mul => { + let scale = s + s; + if scale > max_scale { + vortex_bail!( + "output scale of {operand} {operator} {operand} exceeds the maximum scale \ + {max_scale}" + ); + } + ((p + p + 1).min(max_precision), scale) + } + // scale = min(s + 4, max); precision = p - s + s + scale + Operator::Div => { + let scale = (s + 4).min(max_scale); + (((p + scale.unsigned_abs()).min(max_precision)), scale) + } + _ => vortex_bail!("operator {operator} is not arithmetic"), + }; + Ok(DecimalDType::new( + u8::try_from(precision).unwrap_or(u8::MAX), + i8::try_from(scale).unwrap_or(i8::MAX), + )) +} + #[derive(Clone)] pub struct Binary; @@ -122,6 +164,15 @@ impl ScalarFnVTable for Binary { if lhs.is_primitive() && lhs.eq_ignore_nullability(rhs) { return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability())); } + if let (DType::Decimal(l, _), DType::Decimal(r, _)) = (lhs, rhs) + && l == r + { + let result = decimal_arithmetic_dtype(*operator, *l)?; + return Ok(DType::Decimal( + result, + lhs.nullability() | rhs.nullability(), + )); + } vortex_bail!( "incompatible types for arithmetic operation: {} {}", lhs, @@ -332,6 +383,47 @@ mod tests { use crate::expr::or_collect; use crate::expr::test_harness; use crate::scalar::Scalar; + + /// The decimal arithmetic dtypes derived at plan time must match what + /// arrow produces at execution time (see `decimal_arithmetic_dtype`). + #[test] + fn decimal_arithmetic_dtype_matches_execution() -> VortexResult<()> { + use vortex_buffer::buffer; + + use crate::Canonical; + use crate::IntoArray; + use crate::arrays::DecimalArray; + use crate::dtype::DecimalDType; + use crate::scalar::DecimalValue; + use crate::scalar_fn::ScalarFnVTableExt; + use crate::validity::Validity; + + let dec = DecimalDType::new(15, 2); + let values = + DecimalArray::new(buffer![100i128, 250, 1099], dec, Validity::NonNullable).into_array(); + let rhs = lit(Scalar::decimal( + DecimalValue::I128(50), + dec, + Nullability::NonNullable, + )); + for op in [Operator::Add, Operator::Sub, Operator::Mul, Operator::Div] { + let expr = Binary.try_new_expr(op, [crate::expr::root(), rhs.clone()])?; + let derived = expr.return_dtype(values.dtype())?; + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let executed = values + .clone() + .apply(&expr)? + .execute::(&mut ctx)? + .into_array(); + assert_eq!( + executed.dtype(), + &derived, + "derived dtype diverges from execution for {op}" + ); + } + Ok(()) + } + #[test] fn and_collect_balanced() { let values = vec![lit(1), lit(2), lit(3), lit(4), lit(5)]; From 9257307ccdc7e2fd27917f56f2baa14899ec548e Mon Sep 17 00:00:00 2001 From: Nicholas Gates Date: Fri, 12 Jun 2026 14:09:39 -0400 Subject: [PATCH 2/2] vortex-array: preserve decimal arithmetic input dtypes Signed-off-by: "Nicholas Gates" --- .../src/arrays/decimal/compute/cast.rs | 106 ++++++++++-------- vortex-array/src/expr/transform/coerce.rs | 27 +++++ vortex-array/src/scalar_fn/fns/binary/mod.rs | 94 +++++++++++----- .../src/scalar_fn/fns/binary/numeric.rs | 37 +++++- 4 files changed, 183 insertions(+), 81 deletions(-) diff --git a/vortex-array/src/arrays/decimal/compute/cast.rs b/vortex-array/src/arrays/decimal/compute/cast.rs index 7ed4d287fbb..a78aa5d57b2 100644 --- a/vortex-array/src/arrays/decimal/compute/cast.rs +++ b/vortex-array/src/arrays/decimal/compute/cast.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use num_traits::CheckedMul; use vortex_buffer::Buffer; use vortex_error::VortexExpect; use vortex_error::VortexResult; @@ -153,56 +154,42 @@ fn rescale_decimal_values( .checked_pow(scale_up) .ok_or_else(|| vortex_error::vortex_err!("rescale factor overflows i128"))?; - // Gather unscaled values as i128 (i256 sources are unsupported). - let values: Vec = match array.values_type() { - DecimalType::I8 => array - .buffer::() - .iter() - .map(|&v| i128::from(v)) - .collect(), - DecimalType::I16 => array - .buffer::() - .iter() - .map(|&v| i128::from(v)) - .collect(), - DecimalType::I32 => array - .buffer::() - .iter() - .map(|&v| i128::from(v)) - .collect(), - DecimalType::I64 => array - .buffer::() - .iter() - .map(|&v| i128::from(v)) - .collect(), - DecimalType::I128 => array.buffer::().iter().copied().collect(), - DecimalType::I256 => vortex_bail!("rescaling i256 decimals is not supported"), - }; - - let rescaled = values - .into_iter() - .map(|v| { - v.checked_mul(factor) - .ok_or_else(|| vortex_error::vortex_err!("decimal rescale overflows i128")) - }) - .collect::>>()?; - - match DecimalType::smallest_decimal_value_type(&to) { - DecimalType::I256 => vortex_bail!("rescaling into i256 decimals is not supported"), - DecimalType::I128 => Ok(DecimalArray::new(Buffer::from_iter(rescaled), to, validity)), - // Narrow storage targets: the values fit by the precision check. - DecimalType::I64 | DecimalType::I32 | DecimalType::I16 | DecimalType::I8 => { - let narrowed = rescaled - .into_iter() - .map(|v| { - i64::try_from(v).map_err(|_| { - vortex_error::vortex_err!("rescaled decimal exceeds target width") - }) - }) - .collect::>>()?; - Ok(DecimalArray::new(Buffer::from_iter(narrowed), to, validity)) - } + let from_values_type = array.values_type(); + if from_values_type == DecimalType::I256 { + vortex_bail!("rescaling i256 decimals is not supported"); + } + + let to_values_type = DecimalType::smallest_decimal_value_type(&to); + if to_values_type == DecimalType::I256 { + vortex_bail!("rescaling into i256 decimals is not supported"); } + + match_each_decimal_value_type!(from_values_type, |F| { + let from_buffer = array.buffer::(); + match_each_decimal_value_type!(to_values_type, |T| { + let to_buffer = rescale_decimal_buffer::(from_buffer, factor)?; + Ok(DecimalArray::new(to_buffer, to, validity)) + }) + }) +} + +fn rescale_decimal_buffer(from: Buffer, factor: i128) -> VortexResult> +where + F: NativeDecimalType, + T: NativeDecimalType + CheckedMul, +{ + let factor = ::from(factor) + .ok_or_else(|| vortex_error::vortex_err!("decimal rescale factor exceeds target width"))?; + + from.iter() + .map(|&v| { + let v = ::from(v).ok_or_else(|| { + vortex_error::vortex_err!("decimal rescale input exceeds target width") + })?; + CheckedMul::checked_mul(&v, &factor) + .ok_or_else(|| vortex_error::vortex_err!("decimal rescale overflows target width")) + }) + .collect() } /// Upcast a DecimalArray to a wider physical representation (e.g., i32 -> i64) while keeping @@ -354,6 +341,27 @@ mod tests { assert_eq!(casted.buffer::().as_ref(), &[1000i64, -2500]); } + #[test] + fn cast_widening_scale_uses_target_width() { + let array = DecimalArray::new( + buffer![9i8, -8], + DecimalDType::new(1, 0), + Validity::NonNullable, + ); + + let wider_scale = DType::Decimal(DecimalDType::new(2, 1), Nullability::NonNullable); + #[expect(deprecated)] + let casted = array + .into_array() + .cast(wider_scale.clone()) + .unwrap() + .to_decimal(); + + assert_eq!(casted.dtype(), &wider_scale); + assert_eq!(casted.values_type(), DecimalType::I8); + assert_eq!(casted.buffer::().as_ref(), &[90i8, -80]); + } + #[test] fn cast_narrowing_scale_fails() { let array = DecimalArray::new( diff --git a/vortex-array/src/expr/transform/coerce.rs b/vortex-array/src/expr/transform/coerce.rs index 1b6e9acd661..39d4a29be89 100644 --- a/vortex-array/src/expr/transform/coerce.rs +++ b/vortex-array/src/expr/transform/coerce.rs @@ -70,6 +70,7 @@ mod tests { use vortex_error::VortexResult; use crate::dtype::DType; + use crate::dtype::DecimalDType; use crate::dtype::Nullability::NonNullable; use crate::dtype::PType; use crate::dtype::StructFields; @@ -153,6 +154,32 @@ mod tests { Ok(()) } + #[test] + fn mixed_decimal_arithmetic_preserves_input_types() -> VortexResult<()> { + let lhs = DecimalDType::new(10, 2); + let rhs = DecimalDType::new(5, 1); + let scope = DType::Struct( + StructFields::new( + ["a", "b"].into(), + vec![ + DType::Decimal(lhs, NonNullable), + DType::Decimal(rhs, NonNullable), + ], + ), + NonNullable, + ); + let expr = Binary.new_expr(Operator::Add, [col("a"), col("b")]); + let coerced = coerce_expression(expr, &scope)?; + + assert!(!coerced.child(0).is::()); + assert!(!coerced.child(1).is::()); + assert_eq!( + coerced.return_dtype(&scope)?, + DType::Decimal(DecimalDType::new(11, 2), NonNullable) + ); + Ok(()) + } + #[test] fn boolean_operators_no_coercion() -> VortexResult<()> { let scope = DType::Struct( diff --git a/vortex-array/src/scalar_fn/fns/binary/mod.rs b/vortex-array/src/scalar_fn/fns/binary/mod.rs index 1b0a62a4ac3..0f66ea61da7 100644 --- a/vortex-array/src/scalar_fn/fns/binary/mod.rs +++ b/vortex-array/src/scalar_fn/fns/binary/mod.rs @@ -10,6 +10,7 @@ pub use boolean::or_kleene; use prost::Message; use vortex_error::VortexResult; use vortex_error::vortex_bail; +use vortex_error::vortex_err; use vortex_proto::expr as pb; use vortex_session::VortexSession; use vortex_session::registry::CachedId; @@ -18,6 +19,8 @@ use crate::ArrayRef; use crate::ExecutionCtx; use crate::dtype::DType; use crate::dtype::DecimalDType; +use crate::dtype::NativeDecimalType; +use crate::dtype::i256; use crate::expr::StatsCatalog; use crate::expr::and; use crate::expr::and_collect; @@ -47,45 +50,62 @@ pub(crate) use numeric::*; use crate::scalar::NumericOperator; -/// Output decimal type of an arithmetic `operator` over two operands that -/// have already been coerced to the same decimal type. +/// Output decimal type of an arithmetic `operator` over two decimal operands. /// /// Mirrors the Hive-style rules `arrow-arith` applies at execution time /// (see `arrow_arith::numeric::decimal_op`), including precision saturation -/// at the physical width's maximum: vortex lowers precisions `<= 38` to -/// Arrow `Decimal128` and wider decimals to `Decimal256`. +/// at the physical width's maximum: vortex lowers precisions +/// `<= i128::MAX_PRECISION` to Arrow `Decimal128` and wider decimals to +/// `Decimal256`. fn decimal_arithmetic_dtype( operator: Operator, - operand: DecimalDType, + lhs: DecimalDType, + rhs: DecimalDType, ) -> VortexResult { - let p = u16::from(operand.precision()); - let s = i16::from(operand.scale()); - let (max_precision, max_scale): (u16, i16) = if p <= 38 { (38, 38) } else { (76, 76) }; + let p1 = i16::from(lhs.precision()); + let s1 = i16::from(lhs.scale()); + let p2 = i16::from(rhs.precision()); + let s2 = i16::from(rhs.scale()); + let (max_precision, max_scale) = + if lhs.precision() <= i128::MAX_PRECISION && rhs.precision() <= i128::MAX_PRECISION { + (i16::from(i128::MAX_PRECISION), i16::from(i128::MAX_SCALE)) + } else { + (i16::from(i256::MAX_PRECISION), i16::from(i256::MAX_SCALE)) + }; let (precision, scale) = match operator { - // scale = max(s, s); precision = max(p - s, p - s) + scale + 1 - Operator::Add | Operator::Sub => ((p + 1).min(max_precision), s), - // scale = s + s; precision = p + p + 1 + // scale = max(s1, s2); precision = scale + max(p1 - s1, p2 - s2) + 1 + Operator::Add | Operator::Sub => { + let scale = s1.max(s2); + ( + (scale + (p1 - s1).max(p2 - s2) + 1).min(max_precision), + scale, + ) + } + // scale = s1 + s2; precision = p1 + p2 + 1 Operator::Mul => { - let scale = s + s; + let scale = s1 + s2; if scale > max_scale { vortex_bail!( - "output scale of {operand} {operator} {operand} exceeds the maximum scale \ + "output scale of {lhs} {operator} {rhs} exceeds the maximum scale \ {max_scale}" ); } - ((p + p + 1).min(max_precision), scale) + ((p1 + p2 + 1).min(max_precision), scale) } - // scale = min(s + 4, max); precision = p - s + s + scale + // scale = min(s1 + 4, max); precision = p1 - s1 + s2 + scale Operator::Div => { - let scale = (s + 4).min(max_scale); - (((p + scale.unsigned_abs()).min(max_precision)), scale) + let scale = (s1 + 4).min(max_scale); + let mul_pow = scale - s1 + s2; + ((p1 + mul_pow).clamp(1, max_precision), scale) } _ => vortex_bail!("operator {operator} is not arithmetic"), }; - Ok(DecimalDType::new( - u8::try_from(precision).unwrap_or(u8::MAX), - i8::try_from(scale).unwrap_or(i8::MAX), - )) + let precision = u8::try_from(precision) + .map_err(|_| vortex_err!("decimal arithmetic precision exceeds supported range"))?; + let scale = i8::try_from(scale) + .map_err(|_| vortex_err!("decimal arithmetic scale exceeds supported range"))?; + + DecimalDType::try_new(precision, scale) } #[derive(Clone)] @@ -145,6 +165,11 @@ impl ScalarFnVTable for Binary { fn coerce_args(&self, operator: &Self::Options, args: &[DType]) -> VortexResult> { let lhs = &args[0]; let rhs = &args[1]; + if operator.is_arithmetic() + && matches!((lhs, rhs), (DType::Decimal(..), DType::Decimal(..))) + { + return Ok(args.to_vec()); + } if operator.is_arithmetic() || operator.is_comparison() { let supertype = lhs.least_supertype(rhs).ok_or_else(|| { vortex_error::vortex_err!("No common supertype for {} and {}", lhs, rhs) @@ -164,10 +189,8 @@ impl ScalarFnVTable for Binary { if lhs.is_primitive() && lhs.eq_ignore_nullability(rhs) { return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability())); } - if let (DType::Decimal(l, _), DType::Decimal(r, _)) = (lhs, rhs) - && l == r - { - let result = decimal_arithmetic_dtype(*operator, *l)?; + if let (DType::Decimal(l, _), DType::Decimal(r, _)) = (lhs, rhs) { + let result = decimal_arithmetic_dtype(*operator, *l, *r)?; return Ok(DType::Decimal( result, lhs.nullability() | rhs.nullability(), @@ -398,17 +421,28 @@ mod tests { use crate::scalar_fn::ScalarFnVTableExt; use crate::validity::Validity; - let dec = DecimalDType::new(15, 2); - let values = - DecimalArray::new(buffer![100i128, 250, 1099], dec, Validity::NonNullable).into_array(); + let lhs_dec = DecimalDType::new(10, 2); + let rhs_dec = DecimalDType::new(5, 1); + let values = DecimalArray::new(buffer![100i128, 250, 1099], lhs_dec, Validity::NonNullable) + .into_array(); let rhs = lit(Scalar::decimal( DecimalValue::I128(50), - dec, + rhs_dec, Nullability::NonNullable, )); - for op in [Operator::Add, Operator::Sub, Operator::Mul, Operator::Div] { + for (op, expected) in [ + (Operator::Add, DecimalDType::new(11, 2)), + (Operator::Sub, DecimalDType::new(11, 2)), + (Operator::Mul, DecimalDType::new(16, 3)), + (Operator::Div, DecimalDType::new(15, 6)), + ] { let expr = Binary.try_new_expr(op, [crate::expr::root(), rhs.clone()])?; let derived = expr.return_dtype(values.dtype())?; + assert_eq!( + derived, + DType::Decimal(expected, Nullability::NonNullable), + "unexpected derived dtype for {op}" + ); let mut ctx = LEGACY_SESSION.create_execution_ctx(); let executed = values .clone() diff --git a/vortex-array/src/scalar_fn/fns/binary/numeric.rs b/vortex-array/src/scalar_fn/fns/binary/numeric.rs index 2ef58fb5209..a1c6d4cfd05 100644 --- a/vortex-array/src/scalar_fn/fns/binary/numeric.rs +++ b/vortex-array/src/scalar_fn/fns/binary/numeric.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use arrow_schema::DataType; use vortex_error::VortexResult; use crate::ArrayRef; @@ -9,6 +10,8 @@ use crate::arrays::Constant; use crate::arrays::ConstantArray; use crate::arrow::Datum; use crate::arrow::from_arrow_columnar; +use crate::dtype::DType; +use crate::dtype::NativeDecimalType; use crate::executor::ExecutionCtx; use crate::scalar::NumericOperator; @@ -38,8 +41,18 @@ pub(crate) fn arrow_numeric( let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable(); let len = lhs.len(); - let left = Datum::try_new(lhs, ctx)?; - let right = Datum::try_new_with_target_datatype(rhs, left.data_type(), ctx)?; + let (left, right) = if let Some((lhs_data_type, rhs_data_type)) = + decimal_arrow_data_types(lhs.dtype(), rhs.dtype()) + { + ( + Datum::try_new_with_target_datatype(lhs, &lhs_data_type, ctx)?, + Datum::try_new_with_target_datatype(rhs, &rhs_data_type, ctx)?, + ) + } else { + let left = Datum::try_new(lhs, ctx)?; + let right = Datum::try_new_with_target_datatype(rhs, left.data_type(), ctx)?; + (left, right) + }; let array = match operator { NumericOperator::Add => arrow_arith::numeric::add(&left, &right)?, @@ -51,6 +64,26 @@ pub(crate) fn arrow_numeric( from_arrow_columnar(array.as_ref(), len, nullable, ctx) } +fn decimal_arrow_data_types(lhs: &DType, rhs: &DType) -> Option<(DataType, DataType)> { + let (DType::Decimal(lhs_decimal, _), DType::Decimal(rhs_decimal, _)) = (lhs, rhs) else { + return None; + }; + + let use_decimal256 = lhs_decimal.precision() > i128::MAX_PRECISION + || rhs_decimal.precision() > i128::MAX_PRECISION; + if use_decimal256 { + Some(( + DataType::Decimal256(lhs_decimal.precision(), lhs_decimal.scale()), + DataType::Decimal256(rhs_decimal.precision(), rhs_decimal.scale()), + )) + } else { + Some(( + DataType::Decimal128(lhs_decimal.precision(), lhs_decimal.scale()), + DataType::Decimal128(rhs_decimal.precision(), rhs_decimal.scale()), + )) + } +} + fn constant_numeric( lhs: &ArrayRef, rhs: &ArrayRef,