Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions pyiceberg/expressions/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ def visit_less_than_or_equal(self, term: BoundTerm, literal: LiteralValue) -> bo

def visit_starts_with(self, term: BoundTerm, literal: LiteralValue) -> bool:
eval_res = term.eval(self.struct)
return eval_res is not None and str(eval_res).startswith(str(literal.value))
return eval_res is not None and eval_res.startswith(literal.value)

def visit_not_starts_with(self, term: BoundTerm, literal: LiteralValue) -> bool:
return not self.visit_starts_with(term, literal)
Expand Down Expand Up @@ -712,7 +712,7 @@ def visit_less_than_or_equal(self, term: BoundTerm, literal: LiteralValue) -> bo
def visit_starts_with(self, term: BoundTerm, literal: LiteralValue) -> bool:
pos = term.ref().accessor.position
field = self.partition_fields[pos]
prefix = str(literal.value)
prefix = literal.value
len_prefix = len(prefix)

if field.lower_bound is None:
Expand All @@ -736,7 +736,7 @@ def visit_starts_with(self, term: BoundTerm, literal: LiteralValue) -> bool:
def visit_not_starts_with(self, term: BoundTerm, literal: LiteralValue) -> bool:
pos = term.ref().accessor.position
field = self.partition_fields[pos]
prefix = str(literal.value)
prefix = literal.value
len_prefix = len(prefix)

if field.contains_null or field.lower_bound is None or field.upper_bound is None:
Expand Down Expand Up @@ -1408,20 +1408,20 @@ def visit_starts_with(self, term: BoundTerm, literal: LiteralValue) -> bool:
if not isinstance(field.field_type, PrimitiveType):
raise ValueError(f"Expected PrimitiveType: {field.field_type}")

prefix = str(literal.value)
prefix = literal.value
len_prefix = len(prefix)

lower_bound_bytes = self.lower_bounds.get(field_id)
if lower_bound_bytes is not None:
lower_bound = str(from_bytes(field.field_type, lower_bound_bytes))
lower_bound = from_bytes(field.field_type, lower_bound_bytes)

# truncate lower bound so that its length is not greater than the length of prefix
if lower_bound and lower_bound[:len_prefix] > prefix:
return ROWS_CANNOT_MATCH

upper_bound_bytes = self.upper_bounds.get(field_id)
if upper_bound_bytes is not None:
upper_bound = str(from_bytes(field.field_type, upper_bound_bytes))
upper_bound = from_bytes(field.field_type, upper_bound_bytes)

# truncate upper bound so that its length is not greater than the length of prefix
if upper_bound is not None and upper_bound[:len_prefix] < prefix:
Expand All @@ -1439,16 +1439,16 @@ def visit_not_starts_with(self, term: BoundTerm, literal: LiteralValue) -> bool:
if not isinstance(field.field_type, PrimitiveType):
raise ValueError(f"Expected PrimitiveType: {field.field_type}")

prefix = str(literal.value)
prefix = literal.value
len_prefix = len(prefix)

# not_starts_with will match unless all values must start with the prefix. This happens when
# the lower and upper bounds both start with the prefix.
lower_bound_bytes = self.lower_bounds.get(field_id)
upper_bound_bytes = self.upper_bounds.get(field_id)
if lower_bound_bytes is not None and upper_bound_bytes is not None:
lower_bound = str(from_bytes(field.field_type, lower_bound_bytes))
upper_bound = str(from_bytes(field.field_type, upper_bound_bytes))
lower_bound = from_bytes(field.field_type, lower_bound_bytes)
upper_bound = from_bytes(field.field_type, upper_bound_bytes)

# if lower is shorter than the prefix then lower doesn't start with the prefix
if len(lower_bound) < len_prefix:
Expand Down Expand Up @@ -1899,7 +1899,7 @@ def visit_not_in(self, term: BoundTerm, literals: set[L]) -> BooleanExpression:

def visit_starts_with(self, term: BoundTerm, literal: LiteralValue) -> BooleanExpression:
eval_res = term.eval(self.struct)
if eval_res is not None and str(eval_res).startswith(str(literal.value)):
if eval_res is not None and eval_res.startswith(literal.value):
return AlwaysTrue()
else:
return AlwaysFalse()
Expand Down
19 changes: 19 additions & 0 deletions tests/expressions/test_visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
from pyiceberg.schema import Accessor, Schema
from pyiceberg.typedef import Record
from pyiceberg.types import (
BinaryType,
BooleanType,
DoubleType,
FloatType,
Expand Down Expand Up @@ -1629,6 +1630,24 @@ def test_expression_evaluator_null() -> None:
assert expression_evaluator(schema, NotStartsWith("a", 1), case_sensitive=True)(struct) is True


def test_expression_evaluator_binary_starts_with() -> None:
schema = Schema(NestedField(1, "x", BinaryType(), required=False), schema_id=1)
struct = Record(b"aa")
assert expression_evaluator(schema, StartsWith("x", b"a"), case_sensitive=True)(struct) is True
assert expression_evaluator(schema, StartsWith("x", b"aa"), case_sensitive=True)(struct) is True
assert expression_evaluator(schema, StartsWith("x", b"aaa"), case_sensitive=True)(struct) is False
assert expression_evaluator(schema, StartsWith("x", b"b"), case_sensitive=True)(struct) is False


def test_expression_evaluator_binary_not_starts_with() -> None:
schema = Schema(NestedField(1, "x", BinaryType(), required=False), schema_id=1)
struct = Record(b"aa")
assert expression_evaluator(schema, NotStartsWith("x", b"a"), case_sensitive=True)(struct) is False
assert expression_evaluator(schema, NotStartsWith("x", b"aa"), case_sensitive=True)(struct) is False
assert expression_evaluator(schema, NotStartsWith("x", b"aaa"), case_sensitive=True)(struct) is True
assert expression_evaluator(schema, NotStartsWith("x", b"b"), case_sensitive=True)(struct) is True


def test_translate_column_names_simple_case(table_schema_simple: Schema) -> None:
"""Test translate_column_names with matching column names."""
# Create a bound expression using the original schema
Expand Down