diff --git a/pyiceberg/expressions/visitors.py b/pyiceberg/expressions/visitors.py index f3cfabf1a5..5a584d116c 100644 --- a/pyiceberg/expressions/visitors.py +++ b/pyiceberg/expressions/visitors.py @@ -509,7 +509,7 @@ def visit_less_than_or_equal(self, term: BoundTerm, literal: LiteralValue) -> bo def visit_starts_with(self, term: BoundTerm, literal: LiteralValue) -> bool: eval_res = term.eval(self.struct) - return eval_res is not None and str(eval_res).startswith(str(literal.value)) + return eval_res is not None and eval_res.startswith(literal.value) def visit_not_starts_with(self, term: BoundTerm, literal: LiteralValue) -> bool: return not self.visit_starts_with(term, literal) @@ -712,7 +712,7 @@ def visit_less_than_or_equal(self, term: BoundTerm, literal: LiteralValue) -> bo def visit_starts_with(self, term: BoundTerm, literal: LiteralValue) -> bool: pos = term.ref().accessor.position field = self.partition_fields[pos] - prefix = str(literal.value) + prefix = literal.value len_prefix = len(prefix) if field.lower_bound is None: @@ -736,7 +736,7 @@ def visit_starts_with(self, term: BoundTerm, literal: LiteralValue) -> bool: def visit_not_starts_with(self, term: BoundTerm, literal: LiteralValue) -> bool: pos = term.ref().accessor.position field = self.partition_fields[pos] - prefix = str(literal.value) + prefix = literal.value len_prefix = len(prefix) if field.contains_null or field.lower_bound is None or field.upper_bound is None: @@ -1408,12 +1408,12 @@ def visit_starts_with(self, term: BoundTerm, literal: LiteralValue) -> bool: if not isinstance(field.field_type, PrimitiveType): raise ValueError(f"Expected PrimitiveType: {field.field_type}") - prefix = str(literal.value) + prefix = literal.value len_prefix = len(prefix) lower_bound_bytes = self.lower_bounds.get(field_id) if lower_bound_bytes is not None: - lower_bound = str(from_bytes(field.field_type, lower_bound_bytes)) + lower_bound = from_bytes(field.field_type, lower_bound_bytes) # truncate lower bound so that its length is not greater than the length of prefix if lower_bound and lower_bound[:len_prefix] > prefix: @@ -1421,7 +1421,7 @@ def visit_starts_with(self, term: BoundTerm, literal: LiteralValue) -> bool: upper_bound_bytes = self.upper_bounds.get(field_id) if upper_bound_bytes is not None: - upper_bound = str(from_bytes(field.field_type, upper_bound_bytes)) + upper_bound = from_bytes(field.field_type, upper_bound_bytes) # truncate upper bound so that its length is not greater than the length of prefix if upper_bound is not None and upper_bound[:len_prefix] < prefix: @@ -1439,7 +1439,7 @@ def visit_not_starts_with(self, term: BoundTerm, literal: LiteralValue) -> bool: if not isinstance(field.field_type, PrimitiveType): raise ValueError(f"Expected PrimitiveType: {field.field_type}") - prefix = str(literal.value) + prefix = literal.value len_prefix = len(prefix) # not_starts_with will match unless all values must start with the prefix. This happens when @@ -1447,8 +1447,8 @@ def visit_not_starts_with(self, term: BoundTerm, literal: LiteralValue) -> bool: lower_bound_bytes = self.lower_bounds.get(field_id) upper_bound_bytes = self.upper_bounds.get(field_id) if lower_bound_bytes is not None and upper_bound_bytes is not None: - lower_bound = str(from_bytes(field.field_type, lower_bound_bytes)) - upper_bound = str(from_bytes(field.field_type, upper_bound_bytes)) + lower_bound = from_bytes(field.field_type, lower_bound_bytes) + upper_bound = from_bytes(field.field_type, upper_bound_bytes) # if lower is shorter than the prefix then lower doesn't start with the prefix if len(lower_bound) < len_prefix: @@ -1899,7 +1899,7 @@ def visit_not_in(self, term: BoundTerm, literals: set[L]) -> BooleanExpression: def visit_starts_with(self, term: BoundTerm, literal: LiteralValue) -> BooleanExpression: eval_res = term.eval(self.struct) - if eval_res is not None and str(eval_res).startswith(str(literal.value)): + if eval_res is not None and eval_res.startswith(literal.value): return AlwaysTrue() else: return AlwaysFalse() diff --git a/tests/expressions/test_visitors.py b/tests/expressions/test_visitors.py index 7687d2e5a0..a0634d683c 100644 --- a/tests/expressions/test_visitors.py +++ b/tests/expressions/test_visitors.py @@ -80,6 +80,7 @@ from pyiceberg.schema import Accessor, Schema from pyiceberg.typedef import Record from pyiceberg.types import ( + BinaryType, BooleanType, DoubleType, FloatType, @@ -1629,6 +1630,24 @@ def test_expression_evaluator_null() -> None: assert expression_evaluator(schema, NotStartsWith("a", 1), case_sensitive=True)(struct) is True +def test_expression_evaluator_binary_starts_with() -> None: + schema = Schema(NestedField(1, "x", BinaryType(), required=False), schema_id=1) + struct = Record(b"aa") + assert expression_evaluator(schema, StartsWith("x", b"a"), case_sensitive=True)(struct) is True + assert expression_evaluator(schema, StartsWith("x", b"aa"), case_sensitive=True)(struct) is True + assert expression_evaluator(schema, StartsWith("x", b"aaa"), case_sensitive=True)(struct) is False + assert expression_evaluator(schema, StartsWith("x", b"b"), case_sensitive=True)(struct) is False + + +def test_expression_evaluator_binary_not_starts_with() -> None: + schema = Schema(NestedField(1, "x", BinaryType(), required=False), schema_id=1) + struct = Record(b"aa") + assert expression_evaluator(schema, NotStartsWith("x", b"a"), case_sensitive=True)(struct) is False + assert expression_evaluator(schema, NotStartsWith("x", b"aa"), case_sensitive=True)(struct) is False + assert expression_evaluator(schema, NotStartsWith("x", b"aaa"), case_sensitive=True)(struct) is True + assert expression_evaluator(schema, NotStartsWith("x", b"b"), case_sensitive=True)(struct) is True + + def test_translate_column_names_simple_case(table_schema_simple: Schema) -> None: """Test translate_column_names with matching column names.""" # Create a bound expression using the original schema