Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 110 additions & 9 deletions vortex-array/src/arrays/decimal/compute/cast.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use num_traits::CheckedMul;
use vortex_buffer::Buffer;
use vortex_error::VortexExpect;
use vortex_error::VortexResult;
Expand Down Expand Up @@ -77,17 +78,21 @@ impl CastKernel for Decimal {
);
};

// Scale changes are not yet supported
if from_decimal_dtype.scale() != to_decimal_dtype.scale() {
// Narrowing the scale (dropping fractional digits) is not supported.
if from_decimal_dtype.scale() > to_decimal_dtype.scale() {
vortex_bail!(
"Casting decimal with scale {} to scale {} not yet implemented",
from_decimal_dtype.scale(),
to_decimal_dtype.scale()
);
}

// Downcasting precision is not yet supported
if to_decimal_dtype.precision() < from_decimal_dtype.precision() {
// The target must retain at least the source's integer digits.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its just a TODO right? Worth opening a ticket or something

let from_integer_digits =
i16::from(from_decimal_dtype.precision()) - i16::from(from_decimal_dtype.scale());
let to_integer_digits =
i16::from(to_decimal_dtype.precision()) - i16::from(to_decimal_dtype.scale());
if to_integer_digits < from_integer_digits {
vortex_bail!(
"Downcasting decimal from precision {} to {} not yet implemented",
from_decimal_dtype.precision(),
Expand All @@ -105,6 +110,12 @@ impl CastKernel for Decimal {
.validity()?
.cast_nullability(*to_nullability, array.len(), ctx)?;

// Widening the scale multiplies unscaled values by a power of ten.
if from_decimal_dtype.scale() < to_decimal_dtype.scale() {
let rescaled = rescale_decimal_values(array, *to_decimal_dtype, new_validity)?;
return Ok(Some(rescaled.into_array()));
}

// If the target needs a wider physical type, upcast the values
let target_values_type = DecimalType::smallest_decimal_value_type(to_decimal_dtype);
let array = if target_values_type > array.values_type() {
Expand All @@ -128,6 +139,59 @@ impl CastKernel for Decimal {
}
}

/// Rescale a DecimalArray to a wider scale (e.g. `(16,2)` → `(31,4)`),
/// multiplying unscaled values by the corresponding power of ten. The
/// result is stored at the width the target precision requires.
fn rescale_decimal_values(
array: ArrayView<'_, Decimal>,
to: crate::dtype::DecimalDType,
validity: crate::validity::Validity,
) -> VortexResult<DecimalArray> {
let from = array.decimal_dtype();
let scale_up = u32::try_from(to.scale() - from.scale())
.map_err(|_| vortex_error::vortex_err!("rescale requires a widening scale"))?;
let factor = 10i128
.checked_pow(scale_up)
.ok_or_else(|| vortex_error::vortex_err!("rescale factor overflows i128"))?;

let from_values_type = array.values_type();
if from_values_type == DecimalType::I256 {
vortex_bail!("rescaling i256 decimals is not supported");
}

let to_values_type = DecimalType::smallest_decimal_value_type(&to);
if to_values_type == DecimalType::I256 {
vortex_bail!("rescaling into i256 decimals is not supported");
}
Comment on lines +156 to +165

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? given appropriate precision/scale it should be possible.


match_each_decimal_value_type!(from_values_type, |F| {
let from_buffer = array.buffer::<F>();
match_each_decimal_value_type!(to_values_type, |T| {
let to_buffer = rescale_decimal_buffer::<F, T>(from_buffer, factor)?;
Ok(DecimalArray::new(to_buffer, to, validity))
})
})
}

fn rescale_decimal_buffer<F, T>(from: Buffer<F>, factor: i128) -> VortexResult<Buffer<T>>
where
F: NativeDecimalType,
T: NativeDecimalType + CheckedMul,
{
let factor = <T as crate::dtype::BigCast>::from(factor)
.ok_or_else(|| vortex_error::vortex_err!("decimal rescale factor exceeds target width"))?;

from.iter()
.map(|&v| {
let v = <T as crate::dtype::BigCast>::from(v).ok_or_else(|| {
vortex_error::vortex_err!("decimal rescale input exceeds target width")
})?;
CheckedMul::checked_mul(&v, &factor)
.ok_or_else(|| vortex_error::vortex_err!("decimal rescale overflows target width"))
})
.collect()
}

/// Upcast a DecimalArray to a wider physical representation (e.g., i32 -> i64) while keeping
/// the same precision and scale.
///
Expand Down Expand Up @@ -262,27 +326,64 @@ mod tests {
}

#[test]
fn cast_different_scale_fails() {
fn cast_widening_scale_rescales() {
let array = DecimalArray::new(
buffer![100i32, -250],
DecimalDType::new(10, 2),
Validity::NonNullable,
);

// 1.00 and -2.50 at scale 2 become 1.000 and -2.500 at scale 3.
let wider = DType::Decimal(DecimalDType::new(15, 3), Nullability::NonNullable);
#[expect(deprecated)]
let casted = array.into_array().cast(wider.clone()).unwrap().to_decimal();
assert_eq!(casted.dtype(), &wider);
assert_eq!(casted.buffer::<i64>().as_ref(), &[1000i64, -2500]);
}

#[test]
fn cast_widening_scale_uses_target_width() {
let array = DecimalArray::new(
buffer![9i8, -8],
DecimalDType::new(1, 0),
Validity::NonNullable,
);

let wider_scale = DType::Decimal(DecimalDType::new(2, 1), Nullability::NonNullable);
#[expect(deprecated)]
let casted = array
.into_array()
.cast(wider_scale.clone())
.unwrap()
.to_decimal();

assert_eq!(casted.dtype(), &wider_scale);
assert_eq!(casted.values_type(), DecimalType::I8);
assert_eq!(casted.buffer::<i8>().as_ref(), &[90i8, -80]);
}

#[test]
fn cast_narrowing_scale_fails() {
let array = DecimalArray::new(
buffer![100i32],
DecimalDType::new(10, 2),
Validity::NonNullable,
);

// Try to cast to different scale - not supported
let different_dtype = DType::Decimal(DecimalDType::new(15, 3), Nullability::NonNullable);
// Dropping fractional digits is not supported.
let narrower = DType::Decimal(DecimalDType::new(15, 1), Nullability::NonNullable);
#[expect(deprecated)]
let result = array
.into_array()
.cast(different_dtype)
.cast(narrower)
.and_then(|a| a.to_canonical().map(|c| c.into_array()));

assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Casting decimal with scale 2 to scale 3 not yet implemented")
.contains("Casting decimal with scale 2 to scale 1 not yet implemented")
);
}

Expand Down
27 changes: 27 additions & 0 deletions vortex-array/src/expr/transform/coerce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ mod tests {
use vortex_error::VortexResult;

use crate::dtype::DType;
use crate::dtype::DecimalDType;
use crate::dtype::Nullability::NonNullable;
use crate::dtype::PType;
use crate::dtype::StructFields;
Expand Down Expand Up @@ -153,6 +154,32 @@ mod tests {
Ok(())
}

#[test]
fn mixed_decimal_arithmetic_preserves_input_types() -> VortexResult<()> {
let lhs = DecimalDType::new(10, 2);
let rhs = DecimalDType::new(5, 1);
let scope = DType::Struct(
StructFields::new(
["a", "b"].into(),
vec![
DType::Decimal(lhs, NonNullable),
DType::Decimal(rhs, NonNullable),
],
),
NonNullable,
);
let expr = Binary.new_expr(Operator::Add, [col("a"), col("b")]);
let coerced = coerce_expression(expr, &scope)?;

assert!(!coerced.child(0).is::<Cast>());
assert!(!coerced.child(1).is::<Cast>());
assert_eq!(
coerced.return_dtype(&scope)?,
DType::Decimal(DecimalDType::new(11, 2), NonNullable)
);
Ok(())
}

#[test]
fn boolean_operators_no_coercion() -> VortexResult<()> {
let scope = DType::Struct(
Expand Down
126 changes: 126 additions & 0 deletions vortex-array/src/scalar_fn/fns/binary/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,17 @@ pub use boolean::or_kleene;
use prost::Message;
use vortex_error::VortexResult;
use vortex_error::vortex_bail;
use vortex_error::vortex_err;
use vortex_proto::expr as pb;
use vortex_session::VortexSession;
use vortex_session::registry::CachedId;

use crate::ArrayRef;
use crate::ExecutionCtx;
use crate::dtype::DType;
use crate::dtype::DecimalDType;
use crate::dtype::NativeDecimalType;
use crate::dtype::i256;
use crate::expr::StatsCatalog;
use crate::expr::and;
use crate::expr::and_collect;
Expand Down Expand Up @@ -46,6 +50,64 @@ pub(crate) use numeric::*;

use crate::scalar::NumericOperator;

/// Output decimal type of an arithmetic `operator` over two decimal operands.
///
/// Mirrors the Hive-style rules `arrow-arith` applies at execution time

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have widening system, I think its worth spelling it out in the docs or including a permalink or something to a reference

/// (see `arrow_arith::numeric::decimal_op`), including precision saturation
/// at the physical width's maximum: vortex lowers precisions
/// `<= i128::MAX_PRECISION` to Arrow `Decimal128` and wider decimals to
/// `Decimal256`.
fn decimal_arithmetic_dtype(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use the narrower types?

operator: Operator,
lhs: DecimalDType,
rhs: DecimalDType,
) -> VortexResult<DecimalDType> {
let p1 = i16::from(lhs.precision());
let s1 = i16::from(lhs.scale());
let p2 = i16::from(rhs.precision());
let s2 = i16::from(rhs.scale());
let (max_precision, max_scale) =
if lhs.precision() <= i128::MAX_PRECISION && rhs.precision() <= i128::MAX_PRECISION {
(i16::from(i128::MAX_PRECISION), i16::from(i128::MAX_SCALE))
} else {
(i16::from(i256::MAX_PRECISION), i16::from(i256::MAX_SCALE))
};
let (precision, scale) = match operator {
// scale = max(s1, s2); precision = scale + max(p1 - s1, p2 - s2) + 1
Operator::Add | Operator::Sub => {
let scale = s1.max(s2);
(
(scale + (p1 - s1).max(p2 - s2) + 1).min(max_precision),
scale,
)
}
// scale = s1 + s2; precision = p1 + p2 + 1
Operator::Mul => {
let scale = s1 + s2;
if scale > max_scale {
vortex_bail!(
"output scale of {lhs} {operator} {rhs} exceeds the maximum scale \
{max_scale}"
);
}
((p1 + p2 + 1).min(max_precision), scale)
}
// scale = min(s1 + 4, max); precision = p1 - s1 + s2 + scale
Operator::Div => {
let scale = (s1 + 4).min(max_scale);
let mul_pow = scale - s1 + s2;
((p1 + mul_pow).clamp(1, max_precision), scale)
}
_ => vortex_bail!("operator {operator} is not arithmetic"),
};
let precision = u8::try_from(precision)
.map_err(|_| vortex_err!("decimal arithmetic precision exceeds supported range"))?;
let scale = i8::try_from(scale)
.map_err(|_| vortex_err!("decimal arithmetic scale exceeds supported range"))?;

DecimalDType::try_new(precision, scale)
}

#[derive(Clone)]
pub struct Binary;

Expand Down Expand Up @@ -103,6 +165,11 @@ impl ScalarFnVTable for Binary {
fn coerce_args(&self, operator: &Self::Options, args: &[DType]) -> VortexResult<Vec<DType>> {
let lhs = &args[0];
let rhs = &args[1];
if operator.is_arithmetic()
&& matches!((lhs, rhs), (DType::Decimal(..), DType::Decimal(..)))
{
return Ok(args.to_vec());
}
if operator.is_arithmetic() || operator.is_comparison() {
let supertype = lhs.least_supertype(rhs).ok_or_else(|| {
vortex_error::vortex_err!("No common supertype for {} and {}", lhs, rhs)
Expand All @@ -122,6 +189,13 @@ impl ScalarFnVTable for Binary {
if lhs.is_primitive() && lhs.eq_ignore_nullability(rhs) {
return Ok(lhs.with_nullability(lhs.nullability() | rhs.nullability()));
}
if let (DType::Decimal(l, _), DType::Decimal(r, _)) = (lhs, rhs) {
let result = decimal_arithmetic_dtype(*operator, *l, *r)?;
return Ok(DType::Decimal(
result,
lhs.nullability() | rhs.nullability(),
));
}
vortex_bail!(
"incompatible types for arithmetic operation: {} {}",
lhs,
Expand Down Expand Up @@ -332,6 +406,58 @@ mod tests {
use crate::expr::or_collect;
use crate::expr::test_harness;
use crate::scalar::Scalar;

/// The decimal arithmetic dtypes derived at plan time must match what
/// arrow produces at execution time (see `decimal_arithmetic_dtype`).
#[test]
fn decimal_arithmetic_dtype_matches_execution() -> VortexResult<()> {
use vortex_buffer::buffer;

use crate::Canonical;
use crate::IntoArray;
use crate::arrays::DecimalArray;
use crate::dtype::DecimalDType;
use crate::scalar::DecimalValue;
use crate::scalar_fn::ScalarFnVTableExt;
use crate::validity::Validity;

let lhs_dec = DecimalDType::new(10, 2);
let rhs_dec = DecimalDType::new(5, 1);
let values = DecimalArray::new(buffer![100i128, 250, 1099], lhs_dec, Validity::NonNullable)
.into_array();
let rhs = lit(Scalar::decimal(
DecimalValue::I128(50),
rhs_dec,
Nullability::NonNullable,
));
for (op, expected) in [
(Operator::Add, DecimalDType::new(11, 2)),
(Operator::Sub, DecimalDType::new(11, 2)),
(Operator::Mul, DecimalDType::new(16, 3)),
(Operator::Div, DecimalDType::new(15, 6)),
] {
let expr = Binary.try_new_expr(op, [crate::expr::root(), rhs.clone()])?;
let derived = expr.return_dtype(values.dtype())?;
assert_eq!(
derived,
DType::Decimal(expected, Nullability::NonNullable),
"unexpected derived dtype for {op}"
);
let mut ctx = LEGACY_SESSION.create_execution_ctx();
let executed = values
.clone()
.apply(&expr)?
.execute::<Canonical>(&mut ctx)?
.into_array();
assert_eq!(
executed.dtype(),
&derived,
"derived dtype diverges from execution for {op}"
);
}
Ok(())
}

#[test]
fn and_collect_balanced() {
let values = vec![lit(1), lit(2), lit(3), lit(4), lit(5)];
Expand Down
Loading
Loading