diff --git a/vortex-array/src/arrays/decimal/compute/cast.rs b/vortex-array/src/arrays/decimal/compute/cast.rs index 432313d3cb6..a78aa5d57b2 100644 --- a/vortex-array/src/arrays/decimal/compute/cast.rs +++ b/vortex-array/src/arrays/decimal/compute/cast.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use num_traits::CheckedMul; use vortex_buffer::Buffer; use vortex_error::VortexExpect; use vortex_error::VortexResult; @@ -77,8 +78,8 @@ impl CastKernel for Decimal { ); }; - // Scale changes are not yet supported - if from_decimal_dtype.scale() != to_decimal_dtype.scale() { + // Narrowing the scale (dropping fractional digits) is not supported. + if from_decimal_dtype.scale() > to_decimal_dtype.scale() { vortex_bail!( "Casting decimal with scale {} to scale {} not yet implemented", from_decimal_dtype.scale(), @@ -86,8 +87,12 @@ impl CastKernel for Decimal { ); } - // Downcasting precision is not yet supported - if to_decimal_dtype.precision() < from_decimal_dtype.precision() { + // The target must retain at least the source's integer digits. + let from_integer_digits = + i16::from(from_decimal_dtype.precision()) - i16::from(from_decimal_dtype.scale()); + let to_integer_digits = + i16::from(to_decimal_dtype.precision()) - i16::from(to_decimal_dtype.scale()); + if to_integer_digits < from_integer_digits { vortex_bail!( "Downcasting decimal from precision {} to {} not yet implemented", from_decimal_dtype.precision(), @@ -105,6 +110,12 @@ impl CastKernel for Decimal { .validity()? .cast_nullability(*to_nullability, array.len(), ctx)?; + // Widening the scale multiplies unscaled values by a power of ten. + if from_decimal_dtype.scale() < to_decimal_dtype.scale() { + let rescaled = rescale_decimal_values(array, *to_decimal_dtype, new_validity)?; + return Ok(Some(rescaled.into_array())); + } + // If the target needs a wider physical type, upcast the values let target_values_type = DecimalType::smallest_decimal_value_type(to_decimal_dtype); let array = if target_values_type > array.values_type() { @@ -128,6 +139,59 @@ impl CastKernel for Decimal { } } +/// Rescale a DecimalArray to a wider scale (e.g. `(16,2)` → `(31,4)`), +/// multiplying unscaled values by the corresponding power of ten. The +/// result is stored at the width the target precision requires. +fn rescale_decimal_values( + array: ArrayView<'_, Decimal>, + to: crate::dtype::DecimalDType, + validity: crate::validity::Validity, +) -> VortexResult { + let from = array.decimal_dtype(); + let scale_up = u32::try_from(to.scale() - from.scale()) + .map_err(|_| vortex_error::vortex_err!("rescale requires a widening scale"))?; + let factor = 10i128 + .checked_pow(scale_up) + .ok_or_else(|| vortex_error::vortex_err!("rescale factor overflows i128"))?; + + let from_values_type = array.values_type(); + if from_values_type == DecimalType::I256 { + vortex_bail!("rescaling i256 decimals is not supported"); + } + + let to_values_type = DecimalType::smallest_decimal_value_type(&to); + if to_values_type == DecimalType::I256 { + vortex_bail!("rescaling into i256 decimals is not supported"); + } + + match_each_decimal_value_type!(from_values_type, |F| { + let from_buffer = array.buffer::(); + match_each_decimal_value_type!(to_values_type, |T| { + let to_buffer = rescale_decimal_buffer::(from_buffer, factor)?; + Ok(DecimalArray::new(to_buffer, to, validity)) + }) + }) +} + +fn rescale_decimal_buffer(from: Buffer, factor: i128) -> VortexResult> +where + F: NativeDecimalType, + T: NativeDecimalType + CheckedMul, +{ + let factor = ::from(factor) + .ok_or_else(|| vortex_error::vortex_err!("decimal rescale factor exceeds target width"))?; + + from.iter() + .map(|&v| { + let v = ::from(v).ok_or_else(|| { + vortex_error::vortex_err!("decimal rescale input exceeds target width") + })?; + CheckedMul::checked_mul(&v, &factor) + .ok_or_else(|| vortex_error::vortex_err!("decimal rescale overflows target width")) + }) + .collect() +} + /// Upcast a DecimalArray to a wider physical representation (e.g., i32 -> i64) while keeping /// the same precision and scale. /// @@ -262,19 +326,56 @@ mod tests { } #[test] - fn cast_different_scale_fails() { + fn cast_widening_scale_rescales() { + let array = DecimalArray::new( + buffer![100i32, -250], + DecimalDType::new(10, 2), + Validity::NonNullable, + ); + + // 1.00 and -2.50 at scale 2 become 1.000 and -2.500 at scale 3. + let wider = DType::Decimal(DecimalDType::new(15, 3), Nullability::NonNullable); + #[expect(deprecated)] + let casted = array.into_array().cast(wider.clone()).unwrap().to_decimal(); + assert_eq!(casted.dtype(), &wider); + assert_eq!(casted.buffer::().as_ref(), &[1000i64, -2500]); + } + + #[test] + fn cast_widening_scale_uses_target_width() { + let array = DecimalArray::new( + buffer![9i8, -8], + DecimalDType::new(1, 0), + Validity::NonNullable, + ); + + let wider_scale = DType::Decimal(DecimalDType::new(2, 1), Nullability::NonNullable); + #[expect(deprecated)] + let casted = array + .into_array() + .cast(wider_scale.clone()) + .unwrap() + .to_decimal(); + + assert_eq!(casted.dtype(), &wider_scale); + assert_eq!(casted.values_type(), DecimalType::I8); + assert_eq!(casted.buffer::().as_ref(), &[90i8, -80]); + } + + #[test] + fn cast_narrowing_scale_fails() { let array = DecimalArray::new( buffer![100i32], DecimalDType::new(10, 2), Validity::NonNullable, ); - // Try to cast to different scale - not supported - let different_dtype = DType::Decimal(DecimalDType::new(15, 3), Nullability::NonNullable); + // Dropping fractional digits is not supported. + let narrower = DType::Decimal(DecimalDType::new(15, 1), Nullability::NonNullable); #[expect(deprecated)] let result = array .into_array() - .cast(different_dtype) + .cast(narrower) .and_then(|a| a.to_canonical().map(|c| c.into_array())); assert!(result.is_err()); @@ -282,7 +383,7 @@ mod tests { result .unwrap_err() .to_string() - .contains("Casting decimal with scale 2 to scale 3 not yet implemented") + .contains("Casting decimal with scale 2 to scale 1 not yet implemented") ); } diff --git a/vortex-array/src/expr/transform/coerce.rs b/vortex-array/src/expr/transform/coerce.rs index 1b6e9acd661..39d4a29be89 100644 --- a/vortex-array/src/expr/transform/coerce.rs +++ b/vortex-array/src/expr/transform/coerce.rs @@ -70,6 +70,7 @@ mod tests { use vortex_error::VortexResult; use crate::dtype::DType; + use crate::dtype::DecimalDType; use crate::dtype::Nullability::NonNullable; use crate::dtype::PType; use crate::dtype::StructFields; @@ -153,6 +154,32 @@ mod tests { Ok(()) } + #[test] + fn mixed_decimal_arithmetic_preserves_input_types() -> VortexResult<()> { + let lhs = DecimalDType::new(10, 2); + let rhs = DecimalDType::new(5, 1); + let scope = DType::Struct( + StructFields::new( + ["a", "b"].into(), + vec![ + DType::Decimal(lhs, NonNullable), + DType::Decimal(rhs, NonNullable), + ], + ), + NonNullable, + ); + let expr = Binary.new_expr(Operator::Add, [col("a"), col("b")]); + let coerced = coerce_expression(expr, &scope)?; + + assert!(!coerced.child(0).is::()); + assert!(!coerced.child(1).is::()); + assert_eq!( + coerced.return_dtype(&scope)?, + DType::Decimal(DecimalDType::new(11, 2), NonNullable) + ); + Ok(()) + } + #[test] fn boolean_operators_no_coercion() -> VortexResult<()> { let scope = DType::Struct( diff --git a/vortex-array/src/scalar_fn/fns/binary/mod.rs b/vortex-array/src/scalar_fn/fns/binary/mod.rs index 1c860cb75b5..0f66ea61da7 100644 --- a/vortex-array/src/scalar_fn/fns/binary/mod.rs +++ b/vortex-array/src/scalar_fn/fns/binary/mod.rs @@ -10,6 +10,7 @@ pub use boolean::or_kleene; use prost::Message; use vortex_error::VortexResult; use vortex_error::vortex_bail; +use vortex_error::vortex_err; use vortex_proto::expr as pb; use vortex_session::VortexSession; use vortex_session::registry::CachedId; @@ -17,6 +18,9 @@ use vortex_session::registry::CachedId; use crate::ArrayRef; use crate::ExecutionCtx; use crate::dtype::DType; +use crate::dtype::DecimalDType; +use crate::dtype::NativeDecimalType; +use crate::dtype::i256; use crate::expr::StatsCatalog; use crate::expr::and; use crate::expr::and_collect; @@ -46,6 +50,64 @@ pub(crate) use numeric::*; use crate::scalar::NumericOperator; +/// Output decimal type of an arithmetic `operator` over two decimal operands. +/// +/// Mirrors the Hive-style rules `arrow-arith` applies at execution time +/// (see `arrow_arith::numeric::decimal_op`), including precision saturation +/// at the physical width's maximum: vortex lowers precisions +/// `<= i128::MAX_PRECISION` to Arrow `Decimal128` and wider decimals to +/// `Decimal256`. +fn decimal_arithmetic_dtype( + operator: Operator, + lhs: DecimalDType, + rhs: DecimalDType, +) -> VortexResult { + let p1 = i16::from(lhs.precision()); + let s1 = i16::from(lhs.scale()); + let p2 = i16::from(rhs.precision()); + let s2 = i16::from(rhs.scale()); + let (max_precision, max_scale) = + if lhs.precision() <= i128::MAX_PRECISION && rhs.precision() <= i128::MAX_PRECISION { + (i16::from(i128::MAX_PRECISION), i16::from(i128::MAX_SCALE)) + } else { + (i16::from(i256::MAX_PRECISION), i16::from(i256::MAX_SCALE)) + }; + let (precision, scale) = match operator { + // scale = max(s1, s2); precision = scale + max(p1 - s1, p2 - s2) + 1 + Operator::Add | Operator::Sub => { + let scale = s1.max(s2); + ( + (scale + (p1 - s1).max(p2 - s2) + 1).min(max_precision), + scale, + ) + } + // scale = s1 + s2; precision = p1 + p2 + 1 + Operator::Mul => { + let scale = s1 + s2; + if scale > max_scale { + vortex_bail!( + "output scale of {lhs} {operator} {rhs} exceeds the maximum scale \ + {max_scale}" + ); + } + ((p1 + p2 + 1).min(max_precision), scale) + } + // scale = min(s1 + 4, max); precision = p1 - s1 + s2 + scale + Operator::Div => { + let scale = (s1 + 4).min(max_scale); + let mul_pow = scale - s1 + s2; + ((p1 + mul_pow).clamp(1, max_precision), scale) + } + _ => vortex_bail!("operator {operator} is not arithmetic"), + }; + let precision = u8::try_from(precision) + .map_err(|_| vortex_err!("decimal arithmetic precision exceeds supported range"))?; + let scale = i8::try_from(scale) + .map_err(|_| vortex_err!("decimal arithmetic scale exceeds supported range"))?; + + DecimalDType::try_new(precision, scale) +} + #[derive(Clone)] pub struct Binary; @@ -103,6 +165,11 @@ impl ScalarFnVTable for Binary { fn coerce_args(&self, operator: &Self::Options, args: &[DType]) -> VortexResult> { let lhs = &args[0]; let rhs = &args[1]; + if operator.is_arithmetic() + && matches!((lhs, rhs), (DType::Decimal(..), DType::Decimal(..))) + { + return Ok(args.to_vec()); + } if operator.is_arithmetic() || operator.is_comparison() { let supertype = lhs.least_supertype(rhs).ok_or_else(|| { vortex_error::vortex_err!("No common supertype for {} and {}", lhs, rhs) @@ -122,6 +189,13 @@ impl ScalarFnVTable for Binary { if lhs.is_primitive() && lhs.eq_ignore_nullability(rhs) { return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability())); } + if let (DType::Decimal(l, _), DType::Decimal(r, _)) = (lhs, rhs) { + let result = decimal_arithmetic_dtype(*operator, *l, *r)?; + return Ok(DType::Decimal( + result, + lhs.nullability() | rhs.nullability(), + )); + } vortex_bail!( "incompatible types for arithmetic operation: {} {}", lhs, @@ -332,6 +406,58 @@ mod tests { use crate::expr::or_collect; use crate::expr::test_harness; use crate::scalar::Scalar; + + /// The decimal arithmetic dtypes derived at plan time must match what + /// arrow produces at execution time (see `decimal_arithmetic_dtype`). + #[test] + fn decimal_arithmetic_dtype_matches_execution() -> VortexResult<()> { + use vortex_buffer::buffer; + + use crate::Canonical; + use crate::IntoArray; + use crate::arrays::DecimalArray; + use crate::dtype::DecimalDType; + use crate::scalar::DecimalValue; + use crate::scalar_fn::ScalarFnVTableExt; + use crate::validity::Validity; + + let lhs_dec = DecimalDType::new(10, 2); + let rhs_dec = DecimalDType::new(5, 1); + let values = DecimalArray::new(buffer![100i128, 250, 1099], lhs_dec, Validity::NonNullable) + .into_array(); + let rhs = lit(Scalar::decimal( + DecimalValue::I128(50), + rhs_dec, + Nullability::NonNullable, + )); + for (op, expected) in [ + (Operator::Add, DecimalDType::new(11, 2)), + (Operator::Sub, DecimalDType::new(11, 2)), + (Operator::Mul, DecimalDType::new(16, 3)), + (Operator::Div, DecimalDType::new(15, 6)), + ] { + let expr = Binary.try_new_expr(op, [crate::expr::root(), rhs.clone()])?; + let derived = expr.return_dtype(values.dtype())?; + assert_eq!( + derived, + DType::Decimal(expected, Nullability::NonNullable), + "unexpected derived dtype for {op}" + ); + let mut ctx = LEGACY_SESSION.create_execution_ctx(); + let executed = values + .clone() + .apply(&expr)? + .execute::(&mut ctx)? + .into_array(); + assert_eq!( + executed.dtype(), + &derived, + "derived dtype diverges from execution for {op}" + ); + } + Ok(()) + } + #[test] fn and_collect_balanced() { let values = vec![lit(1), lit(2), lit(3), lit(4), lit(5)]; diff --git a/vortex-array/src/scalar_fn/fns/binary/numeric.rs b/vortex-array/src/scalar_fn/fns/binary/numeric.rs index 2ef58fb5209..a1c6d4cfd05 100644 --- a/vortex-array/src/scalar_fn/fns/binary/numeric.rs +++ b/vortex-array/src/scalar_fn/fns/binary/numeric.rs @@ -1,6 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright the Vortex contributors +use arrow_schema::DataType; use vortex_error::VortexResult; use crate::ArrayRef; @@ -9,6 +10,8 @@ use crate::arrays::Constant; use crate::arrays::ConstantArray; use crate::arrow::Datum; use crate::arrow::from_arrow_columnar; +use crate::dtype::DType; +use crate::dtype::NativeDecimalType; use crate::executor::ExecutionCtx; use crate::scalar::NumericOperator; @@ -38,8 +41,18 @@ pub(crate) fn arrow_numeric( let nullable = lhs.dtype().is_nullable() || rhs.dtype().is_nullable(); let len = lhs.len(); - let left = Datum::try_new(lhs, ctx)?; - let right = Datum::try_new_with_target_datatype(rhs, left.data_type(), ctx)?; + let (left, right) = if let Some((lhs_data_type, rhs_data_type)) = + decimal_arrow_data_types(lhs.dtype(), rhs.dtype()) + { + ( + Datum::try_new_with_target_datatype(lhs, &lhs_data_type, ctx)?, + Datum::try_new_with_target_datatype(rhs, &rhs_data_type, ctx)?, + ) + } else { + let left = Datum::try_new(lhs, ctx)?; + let right = Datum::try_new_with_target_datatype(rhs, left.data_type(), ctx)?; + (left, right) + }; let array = match operator { NumericOperator::Add => arrow_arith::numeric::add(&left, &right)?, @@ -51,6 +64,26 @@ pub(crate) fn arrow_numeric( from_arrow_columnar(array.as_ref(), len, nullable, ctx) } +fn decimal_arrow_data_types(lhs: &DType, rhs: &DType) -> Option<(DataType, DataType)> { + let (DType::Decimal(lhs_decimal, _), DType::Decimal(rhs_decimal, _)) = (lhs, rhs) else { + return None; + }; + + let use_decimal256 = lhs_decimal.precision() > i128::MAX_PRECISION + || rhs_decimal.precision() > i128::MAX_PRECISION; + if use_decimal256 { + Some(( + DataType::Decimal256(lhs_decimal.precision(), lhs_decimal.scale()), + DataType::Decimal256(rhs_decimal.precision(), rhs_decimal.scale()), + )) + } else { + Some(( + DataType::Decimal128(lhs_decimal.precision(), lhs_decimal.scale()), + DataType::Decimal128(rhs_decimal.precision(), rhs_decimal.scale()), + )) + } +} + fn constant_numeric( lhs: &ArrayRef, rhs: &ArrayRef,