-
Notifications
You must be signed in to change notification settings - Fork 169
Decimal arithmetic dtypes and widening rescale cast #8343
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,7 @@ | ||
| // SPDX-License-Identifier: Apache-2.0 | ||
| // SPDX-FileCopyrightText: Copyright the Vortex contributors | ||
|
|
||
| use num_traits::CheckedMul; | ||
| use vortex_buffer::Buffer; | ||
| use vortex_error::VortexExpect; | ||
| use vortex_error::VortexResult; | ||
|
|
@@ -77,17 +78,21 @@ impl CastKernel for Decimal { | |
| ); | ||
| }; | ||
|
|
||
| // Scale changes are not yet supported | ||
| if from_decimal_dtype.scale() != to_decimal_dtype.scale() { | ||
| // Narrowing the scale (dropping fractional digits) is not supported. | ||
| if from_decimal_dtype.scale() > to_decimal_dtype.scale() { | ||
| vortex_bail!( | ||
| "Casting decimal with scale {} to scale {} not yet implemented", | ||
| from_decimal_dtype.scale(), | ||
| to_decimal_dtype.scale() | ||
| ); | ||
| } | ||
|
|
||
| // Downcasting precision is not yet supported | ||
| if to_decimal_dtype.precision() < from_decimal_dtype.precision() { | ||
| // The target must retain at least the source's integer digits. | ||
| let from_integer_digits = | ||
| i16::from(from_decimal_dtype.precision()) - i16::from(from_decimal_dtype.scale()); | ||
| let to_integer_digits = | ||
| i16::from(to_decimal_dtype.precision()) - i16::from(to_decimal_dtype.scale()); | ||
| if to_integer_digits < from_integer_digits { | ||
| vortex_bail!( | ||
| "Downcasting decimal from precision {} to {} not yet implemented", | ||
| from_decimal_dtype.precision(), | ||
|
|
@@ -105,6 +110,12 @@ impl CastKernel for Decimal { | |
| .validity()? | ||
| .cast_nullability(*to_nullability, array.len(), ctx)?; | ||
|
|
||
| // Widening the scale multiplies unscaled values by a power of ten. | ||
| if from_decimal_dtype.scale() < to_decimal_dtype.scale() { | ||
| let rescaled = rescale_decimal_values(array, *to_decimal_dtype, new_validity)?; | ||
| return Ok(Some(rescaled.into_array())); | ||
| } | ||
|
|
||
| // If the target needs a wider physical type, upcast the values | ||
| let target_values_type = DecimalType::smallest_decimal_value_type(to_decimal_dtype); | ||
| let array = if target_values_type > array.values_type() { | ||
|
|
@@ -128,6 +139,59 @@ impl CastKernel for Decimal { | |
| } | ||
| } | ||
|
|
||
| /// Rescale a DecimalArray to a wider scale (e.g. `(16,2)` → `(31,4)`), | ||
| /// multiplying unscaled values by the corresponding power of ten. The | ||
| /// result is stored at the width the target precision requires. | ||
| fn rescale_decimal_values( | ||
| array: ArrayView<'_, Decimal>, | ||
| to: crate::dtype::DecimalDType, | ||
| validity: crate::validity::Validity, | ||
| ) -> VortexResult<DecimalArray> { | ||
| let from = array.decimal_dtype(); | ||
| let scale_up = u32::try_from(to.scale() - from.scale()) | ||
| .map_err(|_| vortex_error::vortex_err!("rescale requires a widening scale"))?; | ||
| let factor = 10i128 | ||
| .checked_pow(scale_up) | ||
| .ok_or_else(|| vortex_error::vortex_err!("rescale factor overflows i128"))?; | ||
|
|
||
| let from_values_type = array.values_type(); | ||
| if from_values_type == DecimalType::I256 { | ||
| vortex_bail!("rescaling i256 decimals is not supported"); | ||
| } | ||
|
|
||
| let to_values_type = DecimalType::smallest_decimal_value_type(&to); | ||
| if to_values_type == DecimalType::I256 { | ||
| vortex_bail!("rescaling into i256 decimals is not supported"); | ||
| } | ||
|
Comment on lines
+156
to
+165
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why? given appropriate precision/scale it should be possible. |
||
|
|
||
| match_each_decimal_value_type!(from_values_type, |F| { | ||
| let from_buffer = array.buffer::<F>(); | ||
| match_each_decimal_value_type!(to_values_type, |T| { | ||
| let to_buffer = rescale_decimal_buffer::<F, T>(from_buffer, factor)?; | ||
| Ok(DecimalArray::new(to_buffer, to, validity)) | ||
| }) | ||
| }) | ||
| } | ||
|
|
||
| fn rescale_decimal_buffer<F, T>(from: Buffer<F>, factor: i128) -> VortexResult<Buffer<T>> | ||
| where | ||
| F: NativeDecimalType, | ||
| T: NativeDecimalType + CheckedMul, | ||
| { | ||
| let factor = <T as crate::dtype::BigCast>::from(factor) | ||
| .ok_or_else(|| vortex_error::vortex_err!("decimal rescale factor exceeds target width"))?; | ||
|
|
||
| from.iter() | ||
| .map(|&v| { | ||
| let v = <T as crate::dtype::BigCast>::from(v).ok_or_else(|| { | ||
| vortex_error::vortex_err!("decimal rescale input exceeds target width") | ||
| })?; | ||
| CheckedMul::checked_mul(&v, &factor) | ||
| .ok_or_else(|| vortex_error::vortex_err!("decimal rescale overflows target width")) | ||
| }) | ||
| .collect() | ||
| } | ||
|
|
||
| /// Upcast a DecimalArray to a wider physical representation (e.g., i32 -> i64) while keeping | ||
| /// the same precision and scale. | ||
| /// | ||
|
|
@@ -262,27 +326,64 @@ mod tests { | |
| } | ||
|
|
||
| #[test] | ||
| fn cast_different_scale_fails() { | ||
| fn cast_widening_scale_rescales() { | ||
| let array = DecimalArray::new( | ||
| buffer![100i32, -250], | ||
| DecimalDType::new(10, 2), | ||
| Validity::NonNullable, | ||
| ); | ||
|
|
||
| // 1.00 and -2.50 at scale 2 become 1.000 and -2.500 at scale 3. | ||
| let wider = DType::Decimal(DecimalDType::new(15, 3), Nullability::NonNullable); | ||
| #[expect(deprecated)] | ||
| let casted = array.into_array().cast(wider.clone()).unwrap().to_decimal(); | ||
| assert_eq!(casted.dtype(), &wider); | ||
| assert_eq!(casted.buffer::<i64>().as_ref(), &[1000i64, -2500]); | ||
| } | ||
|
|
||
| #[test] | ||
| fn cast_widening_scale_uses_target_width() { | ||
| let array = DecimalArray::new( | ||
| buffer![9i8, -8], | ||
| DecimalDType::new(1, 0), | ||
| Validity::NonNullable, | ||
| ); | ||
|
|
||
| let wider_scale = DType::Decimal(DecimalDType::new(2, 1), Nullability::NonNullable); | ||
| #[expect(deprecated)] | ||
| let casted = array | ||
| .into_array() | ||
| .cast(wider_scale.clone()) | ||
| .unwrap() | ||
| .to_decimal(); | ||
|
|
||
| assert_eq!(casted.dtype(), &wider_scale); | ||
| assert_eq!(casted.values_type(), DecimalType::I8); | ||
| assert_eq!(casted.buffer::<i8>().as_ref(), &[90i8, -80]); | ||
| } | ||
|
|
||
| #[test] | ||
| fn cast_narrowing_scale_fails() { | ||
| let array = DecimalArray::new( | ||
| buffer![100i32], | ||
| DecimalDType::new(10, 2), | ||
| Validity::NonNullable, | ||
| ); | ||
|
|
||
| // Try to cast to different scale - not supported | ||
| let different_dtype = DType::Decimal(DecimalDType::new(15, 3), Nullability::NonNullable); | ||
| // Dropping fractional digits is not supported. | ||
| let narrower = DType::Decimal(DecimalDType::new(15, 1), Nullability::NonNullable); | ||
| #[expect(deprecated)] | ||
| let result = array | ||
| .into_array() | ||
| .cast(different_dtype) | ||
| .cast(narrower) | ||
| .and_then(|a| a.to_canonical().map(|c| c.into_array())); | ||
|
|
||
| assert!(result.is_err()); | ||
| assert!( | ||
| result | ||
| .unwrap_err() | ||
| .to_string() | ||
| .contains("Casting decimal with scale 2 to scale 3 not yet implemented") | ||
| .contains("Casting decimal with scale 2 to scale 1 not yet implemented") | ||
| ); | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,13 +10,17 @@ pub use boolean::or_kleene; | |
| use prost::Message; | ||
| use vortex_error::VortexResult; | ||
| use vortex_error::vortex_bail; | ||
| use vortex_error::vortex_err; | ||
| use vortex_proto::expr as pb; | ||
| use vortex_session::VortexSession; | ||
| use vortex_session::registry::CachedId; | ||
|
|
||
| use crate::ArrayRef; | ||
| use crate::ExecutionCtx; | ||
| use crate::dtype::DType; | ||
| use crate::dtype::DecimalDType; | ||
| use crate::dtype::NativeDecimalType; | ||
| use crate::dtype::i256; | ||
| use crate::expr::StatsCatalog; | ||
| use crate::expr::and; | ||
| use crate::expr::and_collect; | ||
|
|
@@ -46,6 +50,64 @@ pub(crate) use numeric::*; | |
|
|
||
| use crate::scalar::NumericOperator; | ||
|
|
||
| /// Output decimal type of an arithmetic `operator` over two decimal operands. | ||
| /// | ||
| /// Mirrors the Hive-style rules `arrow-arith` applies at execution time | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we have widening system, I think its worth spelling it out in the docs or including a permalink or something to a reference |
||
| /// (see `arrow_arith::numeric::decimal_op`), including precision saturation | ||
| /// at the physical width's maximum: vortex lowers precisions | ||
| /// `<= i128::MAX_PRECISION` to Arrow `Decimal128` and wider decimals to | ||
| /// `Decimal256`. | ||
| fn decimal_arithmetic_dtype( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not use the narrower types? |
||
| operator: Operator, | ||
| lhs: DecimalDType, | ||
| rhs: DecimalDType, | ||
| ) -> VortexResult<DecimalDType> { | ||
| let p1 = i16::from(lhs.precision()); | ||
| let s1 = i16::from(lhs.scale()); | ||
| let p2 = i16::from(rhs.precision()); | ||
| let s2 = i16::from(rhs.scale()); | ||
| let (max_precision, max_scale) = | ||
| if lhs.precision() <= i128::MAX_PRECISION && rhs.precision() <= i128::MAX_PRECISION { | ||
| (i16::from(i128::MAX_PRECISION), i16::from(i128::MAX_SCALE)) | ||
| } else { | ||
| (i16::from(i256::MAX_PRECISION), i16::from(i256::MAX_SCALE)) | ||
| }; | ||
| let (precision, scale) = match operator { | ||
| // scale = max(s1, s2); precision = scale + max(p1 - s1, p2 - s2) + 1 | ||
| Operator::Add | Operator::Sub => { | ||
| let scale = s1.max(s2); | ||
| ( | ||
| (scale + (p1 - s1).max(p2 - s2) + 1).min(max_precision), | ||
| scale, | ||
| ) | ||
| } | ||
| // scale = s1 + s2; precision = p1 + p2 + 1 | ||
| Operator::Mul => { | ||
| let scale = s1 + s2; | ||
| if scale > max_scale { | ||
| vortex_bail!( | ||
| "output scale of {lhs} {operator} {rhs} exceeds the maximum scale \ | ||
| {max_scale}" | ||
| ); | ||
| } | ||
| ((p1 + p2 + 1).min(max_precision), scale) | ||
| } | ||
| // scale = min(s1 + 4, max); precision = p1 - s1 + s2 + scale | ||
| Operator::Div => { | ||
| let scale = (s1 + 4).min(max_scale); | ||
| let mul_pow = scale - s1 + s2; | ||
| ((p1 + mul_pow).clamp(1, max_precision), scale) | ||
| } | ||
| _ => vortex_bail!("operator {operator} is not arithmetic"), | ||
| }; | ||
| let precision = u8::try_from(precision) | ||
| .map_err(|_| vortex_err!("decimal arithmetic precision exceeds supported range"))?; | ||
| let scale = i8::try_from(scale) | ||
| .map_err(|_| vortex_err!("decimal arithmetic scale exceeds supported range"))?; | ||
|
|
||
| DecimalDType::try_new(precision, scale) | ||
| } | ||
|
|
||
| #[derive(Clone)] | ||
| pub struct Binary; | ||
|
|
||
|
|
@@ -103,6 +165,11 @@ impl ScalarFnVTable for Binary { | |
| fn coerce_args(&self, operator: &Self::Options, args: &[DType]) -> VortexResult<Vec<DType>> { | ||
| let lhs = &args[0]; | ||
| let rhs = &args[1]; | ||
| if operator.is_arithmetic() | ||
| && matches!((lhs, rhs), (DType::Decimal(..), DType::Decimal(..))) | ||
| { | ||
| return Ok(args.to_vec()); | ||
| } | ||
| if operator.is_arithmetic() || operator.is_comparison() { | ||
| let supertype = lhs.least_supertype(rhs).ok_or_else(|| { | ||
| vortex_error::vortex_err!("No common supertype for {} and {}", lhs, rhs) | ||
|
|
@@ -122,6 +189,13 @@ impl ScalarFnVTable for Binary { | |
| if lhs.is_primitive() && lhs.eq_ignore_nullability(rhs) { | ||
| return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability())); | ||
| } | ||
| if let (DType::Decimal(l, _), DType::Decimal(r, _)) = (lhs, rhs) { | ||
| let result = decimal_arithmetic_dtype(*operator, *l, *r)?; | ||
| return Ok(DType::Decimal( | ||
| result, | ||
| lhs.nullability() | rhs.nullability(), | ||
| )); | ||
| } | ||
| vortex_bail!( | ||
| "incompatible types for arithmetic operation: {} {}", | ||
| lhs, | ||
|
|
@@ -332,6 +406,58 @@ mod tests { | |
| use crate::expr::or_collect; | ||
| use crate::expr::test_harness; | ||
| use crate::scalar::Scalar; | ||
|
|
||
| /// The decimal arithmetic dtypes derived at plan time must match what | ||
| /// arrow produces at execution time (see `decimal_arithmetic_dtype`). | ||
| #[test] | ||
| fn decimal_arithmetic_dtype_matches_execution() -> VortexResult<()> { | ||
| use vortex_buffer::buffer; | ||
|
|
||
| use crate::Canonical; | ||
| use crate::IntoArray; | ||
| use crate::arrays::DecimalArray; | ||
| use crate::dtype::DecimalDType; | ||
| use crate::scalar::DecimalValue; | ||
| use crate::scalar_fn::ScalarFnVTableExt; | ||
| use crate::validity::Validity; | ||
|
|
||
| let lhs_dec = DecimalDType::new(10, 2); | ||
| let rhs_dec = DecimalDType::new(5, 1); | ||
| let values = DecimalArray::new(buffer![100i128, 250, 1099], lhs_dec, Validity::NonNullable) | ||
| .into_array(); | ||
| let rhs = lit(Scalar::decimal( | ||
| DecimalValue::I128(50), | ||
| rhs_dec, | ||
| Nullability::NonNullable, | ||
| )); | ||
| for (op, expected) in [ | ||
| (Operator::Add, DecimalDType::new(11, 2)), | ||
| (Operator::Sub, DecimalDType::new(11, 2)), | ||
| (Operator::Mul, DecimalDType::new(16, 3)), | ||
| (Operator::Div, DecimalDType::new(15, 6)), | ||
| ] { | ||
| let expr = Binary.try_new_expr(op, [crate::expr::root(), rhs.clone()])?; | ||
| let derived = expr.return_dtype(values.dtype())?; | ||
| assert_eq!( | ||
| derived, | ||
| DType::Decimal(expected, Nullability::NonNullable), | ||
| "unexpected derived dtype for {op}" | ||
| ); | ||
| let mut ctx = LEGACY_SESSION.create_execution_ctx(); | ||
| let executed = values | ||
| .clone() | ||
| .apply(&expr)? | ||
| .execute::<Canonical>(&mut ctx)? | ||
| .into_array(); | ||
| assert_eq!( | ||
| executed.dtype(), | ||
| &derived, | ||
| "derived dtype diverges from execution for {op}" | ||
| ); | ||
| } | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn and_collect_balanced() { | ||
| let values = vec![lit(1), lit(2), lit(3), lit(4), lit(5)]; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its just a TODO right? Worth opening a ticket or something