diff --git a/KERNEL_REV b/KERNEL_REV index 696572aef..e61c2567b 100644 --- a/KERNEL_REV +++ b/KERNEL_REV @@ -1 +1 @@ -f4ee6fec78aabce8c0ea9c1ff47fc11b8191d013 +3991d8b4677f9fa8d3bdf607f3db875cd21d3304 diff --git a/src/databricks/sql/backend/kernel/client.py b/src/databricks/sql/backend/kernel/client.py index cb3b0b7ba..3a342e854 100644 --- a/src/databricks/sql/backend/kernel/client.py +++ b/src/databricks/sql/backend/kernel/client.py @@ -16,13 +16,6 @@ - ``query_tags`` on execute is not supported (kernel exposes ``statement_conf`` but PyO3 doesn't surface it). -- ``get_tables`` with a non-empty ``table_types`` filter applies - the filter client-side; today the kernel returns the full - ``SHOW TABLES`` shape unchanged. The connector's existing - ``ResultSetFilter.filter_tables_by_type`` is keyed on - ``SeaResultSet`` not ``KernelResultSet``, so we punt and let - the caller see all rows — documented as a known gap in the - design doc. - Volume PUT/GET (staging operations): kernel has no Volume API yet. Users on Thrift-only paths. """ @@ -32,7 +25,8 @@ import logging import threading import uuid -from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Union +from collections import OrderedDict +from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union from databricks.sql.backend.databricks_client import DatabricksClient from databricks.sql.backend.kernel._errors import ( @@ -52,7 +46,6 @@ from databricks.sql.exc import ( InterfaceError, NotSupportedError, - OperationalError, ProgrammingError, ) from databricks.sql.thrift_api.TCLIService import ttypes @@ -76,6 +69,16 @@ # on staging ops it can't service — see ``execute_command``. _STAGING_VERBS = ("PUT", "GET", "REMOVE") +# Upper bound on the per-session ``_closed_commands`` registry. The set +# only needs to remember *recently* closed async command ids long enough +# for a client still holding the id to poll ``get_query_state`` and see +# ``CLOSED`` (rather than the SUCCEEDED fall-through). Bounding it (FIFO +# eviction) prevents unbounded growth on a long-lived session that opens +# and closes many async commands. An evicted (very old) id degrades from +# CLOSED -> SUCCEEDED in ``get_query_state`` — consistent with the +# never-tracked path, not a correctness break. +_CLOSED_COMMANDS_MAX = 10_000 + def _strip_leading_sql_comments(sql: str) -> str: """Strip leading whitespace and SQL comments (``-- …`` line and @@ -107,6 +110,39 @@ def _strip_leading_sql_comments(sql: str) -> str: return sql[i:] +def _none_if_blank(value: Optional[str]) -> Optional[str]: + """Map an empty/whitespace-only metadata filter to ``None`` + ("match all"), matching the Thrift backend's effective behaviour. + + The kernel's ``Identifier`` / ``LikePattern`` reject ``""`` with + ``InvalidArgument`` (-> ``ProgrammingError``); ``None`` is the + kernel's canonical "match all". Applied to schema / table / column + *pattern* args (which otherwise keep ``%`` / ``_`` as real LIKE + wildcards).""" + if value is None: + return None + return value if value.strip() else None + + +def _catalog_or_none(value: Optional[str]) -> Optional[str]: + """Normalise a catalog filter: ``None`` / blank / ``'%'`` / ``'*'`` + all mean "all catalogs" -> ``None``. + + This makes ``columns(catalog='%')`` behave like + ``tables(catalog='%')`` / ``schemas(catalog='%')`` — the kernel + already treats blank/``%``/``*`` as "all catalogs" for SHOW SCHEMAS + / SHOW TABLES (``is_null_or_wildcard``) but treats the catalog as an + exact identifier for SHOW COLUMNS, so the three diverged. Normalising + connector-side makes them symmetric. This intentionally diverges from + raw-Thrift literalness (Thrift treats ``%`` as a literal catalog + name) in favour of JDBC "catalog is exact-or-all, not a pattern" + + internal consistency. Catalog is the only arg normalised this way; + schema/table/column patterns keep ``%`` / ``*`` as LIKE wildcards.""" + if value is None or not value.strip() or value in ("%", "*"): + return None + return value + + def _is_staging_statement(operation: str) -> bool: """True iff ``operation`` is a volume/staging statement (PUT / GET / REMOVE). @@ -219,8 +255,11 @@ def __init__( # closed (via ``close_command`` or ``close_session``). Lets # ``get_query_state`` report ``CLOSED`` for them rather than # the SUCCEEDED fall-through used for the never-tracked sync - # path. Same lock as ``_async_handles``. - self._closed_commands: Set[str] = set() + # path. Same lock as ``_async_handles``. Bounded FIFO (see + # ``_record_closed`` / ``_CLOSED_COMMANDS_MAX``) so it can't grow + # without limit on a long-lived session. Used as an ordered set + # (values are ignored). + self._closed_commands: "OrderedDict[str, None]" = OrderedDict() self._async_handles_lock = threading.RLock() # Sync-execute cancellers keyed by ``id(cursor)``. A blocking # ``execute()`` sets ``cursor.active_command_id`` only AFTER it @@ -355,7 +394,7 @@ def close_session(self, session_id: SessionId) -> None: self._async_handles.clear() self._async_statements.clear() for guid, _ in tracked: - self._closed_commands.add(guid) + self._record_closed(guid) for _, handle in tracked: # Per-handle close errors are non-fatal — PEP 249 # discourages raising from session close — so log and @@ -487,6 +526,27 @@ def execute_command( # produced to reap it. close_stmt = False except Exception as exc: + # Failed sync execute: publish the server-issued + # statement id (observed mid-execute via the canceller's + # inflight slot, still registered here — the finally pops + # it) so the cursor's query_id reflects the FAILED query, + # matching the Thrift backend which sets active_command_id + # on every execute regardless of outcome. statement_id() + # is None for a pre-id failure (transport error on the + # initial POST) — then leave active_command_id untouched. + # Best-effort; never mask the original failure. + try: + with self._sync_cancellers_lock: + canceller = self._sync_cancellers.get(id(cursor)) + stmt_id = ( + canceller.statement_id() if canceller is not None else None + ) + if stmt_id: + cursor.active_command_id = CommandId.from_sea_statement_id( + stmt_id + ) + except Exception: + pass raise _wrap_kernel_exception("execute_command", exc) from exc finally: with self._sync_cancellers_lock: @@ -502,7 +562,21 @@ def execute_command( pass command_id = CommandId.from_sea_statement_id(executed.statement_id) - cursor.active_command_id = command_id + # Surface the affected-row count for DML (INSERT/UPDATE/DELETE/ + # MERGE) as ``cursor.rowcount`` instead of the hardcoded ``-1``. + # ``num_modified_rows`` is ``None`` for SELECT (and warehouses + # that don't report it) → leave ``rowcount`` at its ``-1`` + # default. ``getattr`` guards against an older kernel wheel that + # predates the pyo3 getter. NB the Thrift backend also hardcodes + # ``-1`` here, so this makes the kernel path *exceed* Thrift. + try: + modified = getattr(executed, "num_modified_rows", None) + if callable(modified): + modified = modified() + except Exception: + modified = None + if modified is not None: + cursor.rowcount = modified # ``KernelResultSet.__init__`` calls ``arrow_schema()`` which # can itself raise ``KernelError`` (or, in principle, a PyO3 # native exception) — wrap the construction so callers see a @@ -574,7 +648,7 @@ def close_command(self, command_id: CommandId) -> None: if handle is not None: # Record the close so ``get_query_state`` can report # ``CLOSED`` (not ``SUCCEEDED``) for this command. - self._closed_commands.add(command_id.guid) + self._record_closed(command_id.guid) if handle is None: logger.debug("close_command: no tracked handle for %s", command_id) # Still drop the parent Statement if somehow tracked without @@ -650,36 +724,17 @@ def get_execution_result( stream = async_exec.await_result() except Exception as exc: raise _wrap_kernel_exception("get_execution_result", exc) from exc - # The async-exec handle's role ends once it has produced the - # ``ResultStream`` — keeping it around (and tracked in - # ``_async_handles``) would leak the server-side - # ``ExecutedAsyncStatement`` until ``close_session`` swept it - # up, since ``KernelResultSet.close`` only closes the stream - # it wraps. Drop tracking and fire-and-forget the close. - with self._async_handles_lock: - self._async_handles.pop(command_id.guid, None) - stmt = self._async_statements.pop(command_id.guid, None) - self._closed_commands.add(command_id.guid) - try: - async_exec.close() - except Exception as exc: - logger.warning( - "Error closing async_exec after await_result for %s: %s", - command_id, - exc, - ) - # The parent Statement is no longer needed once the async handle - # has produced its ResultStream. Close to release server-side - # tracking; matches the sync path's eager Statement close. - if stmt is not None: - try: - stmt.close() - except Exception as exc: - logger.warning( - "Error closing async statement after await_result for %s: %s", - command_id, - exc, - ) + # Do NOT close/drop the async handle here. The kernel's + # ``await_result()`` is idempotent and re-callable (it re-polls + + # re-materialises a fresh ``ResultStream`` each time), so keeping + # the handle tracked lets ``get_async_execution_result()`` be + # called more than once — matching the Thrift backend, where the + # operation handle stays valid (re-fetchable) until an explicit + # ``close_command`` / ``close_session``. The prior eager close + # made a second call raise ``ProgrammingError(unknown + # command_id)``. The handle + parent Statement are still reaped + # by ``close_command`` / ``close_session``, so this does not leak. + # # ``KernelResultSet.__init__`` calls ``arrow_schema()`` which # can raise — map that to PEP 249 too. try: @@ -697,7 +752,17 @@ def _make_result_set( ) -> "ResultSet": """Build a ``KernelResultSet`` from any kernel handle. Used by sync execute, ``get_execution_result``, and all metadata - paths to keep construction in one place.""" + paths to keep construction in one place. + + Sets ``cursor.active_command_id`` here so every result-producing + path — sync execute, async fetch, AND metadata — leaves the + cursor pointing at the command that produced the current result + set. This matches the Thrift backend, which sets it + unconditionally in ``_handle_execute_response``. Without it, + ``cursor.query_id`` / ``get_query_state`` would stay pinned to a + prior query after a metadata call (the metadata methods mint a + synthetic command id but previously never published it).""" + cursor.active_command_id = command_id return KernelResultSet( connection=cursor.connection, backend=self, @@ -707,6 +772,17 @@ def _make_result_set( buffer_size_bytes=cursor.buffer_size_bytes, ) + def _record_closed(self, guid: str) -> None: + """Record an async command guid as closed, bounded FIFO. + + Caller must hold ``_async_handles_lock``. Evicts the oldest + entries past ``_CLOSED_COMMANDS_MAX`` so the registry can't grow + unbounded on a long-lived session.""" + self._closed_commands[guid] = None + self._closed_commands.move_to_end(guid) + while len(self._closed_commands) > _CLOSED_COMMANDS_MAX: + self._closed_commands.popitem(last=False) + def _synthetic_command_id(self) -> CommandId: """Metadata calls don't produce a server statement id; mint a synthetic UUID so the ``ResultSet`` still has a stable @@ -746,8 +822,8 @@ def get_schemas( raise InterfaceError("get_schemas requires an open session.") try: stream = self._kernel_session.metadata().list_schemas( - catalog=catalog_name, - schema_pattern=schema_name, + catalog=_catalog_or_none(catalog_name), + schema_pattern=_none_if_blank(schema_name), ) return self._make_result_set(stream, cursor, self._synthetic_command_id()) except Exception as exc: @@ -767,45 +843,18 @@ def get_tables( if self._kernel_session is None: raise InterfaceError("get_tables requires an open session.") try: + # ``table_types`` is filtered kernel-side (the kernel applies + # it to the reshaped result, case-insensitively as of the + # batch-3 kernel change), so we forward it and let the kernel + # do the work — no connector-side drain + refilter. Passing it + # through preserves streaming for large schemas. stream = self._kernel_session.metadata().list_tables( - catalog=catalog_name, - schema_pattern=schema_name, - table_pattern=table_name, - table_types=table_types, - ) - if not table_types: - return self._make_result_set( - stream, cursor, self._synthetic_command_id() - ) - # The kernel today returns the unfiltered ``SHOW TABLES`` - # shape regardless of ``table_types``. Drain to a single - # Arrow table and apply the same client-side filter the - # native SEA backend uses. The filter is **case-sensitive** - # — matches the SEA backend's documented behaviour, and - # mirrors how the warehouse reports the values - # (``TABLE`` / ``VIEW`` / ``SYSTEM_TABLE`` — uppercase). - # Look the column up by name rather than positional index - # so a future kernel reshape of ``SHOW TABLES`` doesn't - # silently filter the wrong column. - from databricks.sql.backend.sea.utils.filters import ResultSetFilter - - full_table = _drain_kernel_handle(stream) - if "TABLE_TYPE" not in full_table.schema.names: - raise OperationalError( - "kernel get_tables result is missing a TABLE_TYPE " - f"column; got {full_table.schema.names!r}" - ) - filtered_table = ResultSetFilter._filter_arrow_table( - full_table, - column_name="TABLE_TYPE", - allowed_values=table_types, - case_sensitive=True, - ) - return self._make_result_set( - _StaticArrowHandle(filtered_table), - cursor, - self._synthetic_command_id(), + catalog=_catalog_or_none(catalog_name), + schema_pattern=_none_if_blank(schema_name), + table_pattern=_none_if_blank(table_name), + table_types=table_types if table_types else None, ) + return self._make_result_set(stream, cursor, self._synthetic_command_id()) except Exception as exc: raise _wrap_kernel_exception("get_tables", exc) from exc @@ -830,10 +879,10 @@ def get_columns( # Thrift backend's `getColumns(null, …)` behaviour from # the user's perspective. stream = self._kernel_session.metadata().list_columns( - catalog=catalog_name, - schema_pattern=schema_name, - table_pattern=table_name, - column_pattern=column_name, + catalog=_catalog_or_none(catalog_name), + schema_pattern=_none_if_blank(schema_name), + table_pattern=_none_if_blank(table_name), + column_pattern=_none_if_blank(column_name), ) return self._make_result_set(stream, cursor, self._synthetic_command_id()) except Exception as exc: @@ -1006,55 +1055,3 @@ def _read_pem_bytes(path: str, label: str) -> bytes: "kernel TLS config." ) return data - - -def _drain_kernel_handle(handle: Any) -> Any: - """Drain a kernel ResultStream / ExecutedStatement into a single - ``pyarrow.Table``. Used by ``get_tables`` to apply a client-side - ``table_types`` filter on a metadata result; cheap because - metadata streams are small.""" - import pyarrow - - schema = handle.arrow_schema() - batches = [] - while True: - batch = handle.fetch_next_batch() - if batch is None: - break - if batch.num_rows > 0: - batches.append(batch) - try: - handle.close() - except Exception: - # Non-fatal — the surrounding ``get_tables`` call has already - # captured the result data, and the handle's server-side - # state will be reaped by the kernel's Drop impl. - pass - return pyarrow.Table.from_batches(batches, schema=schema) - - -class _StaticArrowHandle: - """Duck-typed kernel handle that replays a pre-built - ``pyarrow.Table`` through ``arrow_schema()`` / - ``fetch_next_batch()`` / ``close()``. Used to wrap a - post-processed table (e.g., the ``table_types``-filtered output - of ``get_tables``) so it flows back through the normal - ``KernelResultSet`` path.""" - - def __init__(self, table: Any) -> None: - self._schema = table.schema - self._batches = list(table.to_batches()) - self._idx = 0 - - def arrow_schema(self) -> Any: - return self._schema - - def fetch_next_batch(self) -> Optional[Any]: - if self._idx >= len(self._batches): - return None - batch = self._batches[self._idx] - self._idx += 1 - return batch - - def close(self) -> None: - self._batches = [] diff --git a/src/databricks/sql/client.py b/src/databricks/sql/client.py index e66dd897c..7b94fc98d 100755 --- a/src/databricks/sql/client.py +++ b/src/databricks/sql/client.py @@ -1590,10 +1590,9 @@ def columns( Names can contain % wildcards. - Note: on ``use_kernel=True``, ``catalog_name`` is required — - the kernel's underlying ``SHOW COLUMNS`` cannot span catalogs. - Passing ``catalog_name=None`` raises ``ProgrammingError``. The - Thrift and native SEA backends accept ``catalog_name=None``. + ``catalog_name=None`` is accepted on all backends and matches + columns across every catalog (the kernel issues ``SHOW COLUMNS`` + over all catalogs). :returns self """ diff --git a/tests/e2e/test_kernel_backend.py b/tests/e2e/test_kernel_backend.py index f25c60630..38195b4bd 100644 --- a/tests/e2e/test_kernel_backend.py +++ b/tests/e2e/test_kernel_backend.py @@ -21,6 +21,8 @@ from __future__ import annotations +from uuid import uuid4 + import pytest import databricks.sql as sql @@ -252,6 +254,112 @@ def test_metadata_columns(conn): assert len(rows) > 0 +# ── Metadata filter normalization (batch 3) ─────────────────────── + + +def test_schemas_with_empty_string_filter_matches_all(conn): + """An empty-string schema pattern normalizes to match-all rather + than raising ``ProgrammingError`` (kernel rejects ``""``) — locks + ``_none_if_blank`` on the pattern args.""" + with conn.cursor() as cur: + cur.schemas(catalog_name="main", schema_name="") + rows = cur.fetchall() + assert len(rows) > 0 + + +def test_tables_table_types_filter_is_case_insensitive(conn): + """Lowercase ``table_types=['view']`` / uppercase ``['TABLE']`` + each match the right object regardless of case — locks the + kernel-side case-insensitive ``table_types`` match (batch-3 kernel + B2) end-to-end plus the connector drain removal (the filter now + runs kernel-side, not client-side). + + Self-contained: creates a table + a view over it in the session's + default (writable) schema, scopes the lookup to a unique name + prefix, and drops both afterward — no dependency on which + workspace schemas happen to contain views.""" + sfx = str(uuid4()).replace("-", "_") + tbl = f"dbsql_kernel_tt_t_{sfx}" + vw = f"dbsql_kernel_tt_v_{sfx}" + name_pat = f"dbsql_kernel_tt_%_{sfx}" + with conn.cursor() as cur: + cur.execute("SELECT current_catalog(), current_schema()") + cat, sch = cur.fetchall()[0] + try: + cur.execute(f"CREATE TABLE {tbl} (n INT)") + cur.execute(f"CREATE VIEW {vw} AS SELECT * FROM {tbl}") + + def _names_and_types(): + rows = cur.fetchall() + cols = [d[0] for d in cur.description] + ni, ti = cols.index("TABLE_NAME"), cols.index("TABLE_TYPE") + return {(r[ni], r[ti]) for r in rows} + + # Lowercase 'view' must match the VIEW (and only it). + cur.tables( + catalog_name=cat, + schema_name=sch, + table_name=name_pat, + table_types=["view"], + ) + assert _names_and_types() == {(vw, "VIEW")} + + # Uppercase 'TABLE' must match the TABLE (and only it). + cur.tables( + catalog_name=cat, + schema_name=sch, + table_name=name_pat, + table_types=["TABLE"], + ) + assert _names_and_types() == {(tbl, "TABLE")} + finally: + cur.execute(f"DROP VIEW IF EXISTS {vw}") + cur.execute(f"DROP TABLE IF EXISTS {tbl}") + + +# ── Cursor-state tracking (batch 3) ─────────────────────────────── + + +def test_metadata_call_publishes_active_command_id(conn): + """A metadata call leaves the cursor pointing at the command that + produced the current result set (Thrift parity) — ``query_id`` is + populated rather than stale/None after ``catalogs()``.""" + with conn.cursor() as cur: + cur.catalogs() + cur.fetchall() + assert cur.active_command_id is not None + assert cur.query_id is not None + + +def test_dml_rowcount_wiring_does_not_break_dml(conn): + """The ``num_modified_rows`` → ``cursor.rowcount`` wiring must not + break DML execution, and ``rowcount`` is a well-formed int. + + The affected-row count itself is only surfaced when the warehouse + reports ``num_modified_rows`` (absent on some warehouses, including + parts of dogfood — then ``rowcount`` stays at its ``-1`` default, + matching the Thrift backend). The positive-count mapping is locked + by the unit test; here we assert the path runs end-to-end and the + rows really landed. Self-contained in the writable default schema.""" + sfx = str(uuid4()).replace("-", "_") + tbl = f"dbsql_kernel_rc_{sfx}" + with conn.cursor() as cur: + try: + cur.execute(f"CREATE TABLE {tbl} (n INT)") + cur.execute(f"INSERT INTO {tbl} VALUES (1), (2), (3)") + # rowcount is a real int (>= -1); never a MagicMock / None / + # crash from the getattr wiring. + assert isinstance(cur.rowcount, int) + assert ( + cur.rowcount == 3 or cur.rowcount == -1 + ), f"unexpected rowcount {cur.rowcount!r}" + # The INSERT genuinely modified the table. + cur.execute(f"SELECT COUNT(*) FROM {tbl}") + assert cur.fetchall()[0][0] == 3 + finally: + cur.execute(f"DROP TABLE IF EXISTS {tbl}") + + # ── Session configuration ───────────────────────────────────────── diff --git a/tests/unit/test_kernel_client.py b/tests/unit/test_kernel_client.py index 42ac36197..ab4c7dc48 100644 --- a/tests/unit/test_kernel_client.py +++ b/tests/unit/test_kernel_client.py @@ -930,20 +930,21 @@ def test_kernel_error_during_result_set_construction_is_mapped(): # --------------------------------------------------------------------------- -def test_get_execution_result_closes_async_exec_and_drops_tracking(): - """The ``ExecutedAsyncStatement`` handle's role ends once it - produces a ``ResultStream`` via ``await_result()``. The client - must close the async_exec and drop the tracking entry there — - otherwise ``KernelResultSet.close()`` (which only closes the - stream) leaves the executed handle leaked server-side until - ``close_session`` sweeps.""" +def test_get_execution_result_retains_handle_for_recall(): + """``get_execution_result`` must NOT eagerly close/drop the + ``ExecutedAsyncStatement`` after the first ``await_result()``. + The kernel's ``await_result()`` is idempotent (``&self``, no + consumed flag), so the connector keeps the handle tracked — + matching Thrift, where ``get_async_execution_result()`` can be + re-called until the command is explicitly closed. The handle is + reaped later by ``close_command``/``close_session``.""" c = _make_client() c._kernel_session = MagicMock() async_exec = MagicMock() fake_stream = MagicMock() fake_stream.arrow_schema.return_value = pa.schema([("n", pa.int64())]) async_exec.await_result.return_value = fake_stream - cid = CommandId.from_sea_statement_id("async-leak-test") + cid = CommandId.from_sea_statement_id("async-recall-test") c._async_handles[cid.guid] = async_exec cursor = MagicMock() cursor.arraysize = 100 @@ -951,47 +952,48 @@ def test_get_execution_result_closes_async_exec_and_drops_tracking(): c.get_execution_result(cid, cursor=cursor) - # async_exec must be closed and dropped from tracking; the - # closed-commands set records it. - assert async_exec.close.called - assert cid.guid not in c._async_handles - assert cid.guid in c._closed_commands + # Handle retained + not closed; not recorded as closed yet. + assert not async_exec.close.called + assert c._async_handles.get(cid.guid) is async_exec + assert cid.guid not in c._closed_commands -def test_get_execution_result_does_not_raise_on_async_exec_close_failure(): - """A failure to close the async_exec is non-fatal — the result - stream has already been returned by ``await_result()`` and the - kernel's Drop will reap server-side state.""" +def test_get_execution_result_is_re_callable(): + """A second ``get_execution_result`` for the same async command + succeeds (Thrift-parity re-fetch). ``await_result()`` is called + once per invocation and neither call raises.""" c = _make_client() c._kernel_session = MagicMock() async_exec = MagicMock() fake_stream = MagicMock() fake_stream.arrow_schema.return_value = pa.schema([("n", pa.int64())]) async_exec.await_result.return_value = fake_stream - async_exec.close.side_effect = _FakeKernelError(code="Unavailable") - cid = CommandId.from_sea_statement_id("async-close-fail") + cid = CommandId.from_sea_statement_id("async-recall-twice") c._async_handles[cid.guid] = async_exec cursor = MagicMock() cursor.arraysize = 100 cursor.buffer_size_bytes = 1024 - # Must not raise. - rs = c.get_execution_result(cid, cursor=cursor) - assert rs is not None - assert cid.guid not in c._async_handles + rs1 = c.get_execution_result(cid, cursor=cursor) + rs2 = c.get_execution_result(cid, cursor=cursor) + + assert rs1 is not None and rs2 is not None + assert async_exec.await_result.call_count == 2 + # Still tracked after the second fetch — only explicit close reaps it. + assert c._async_handles.get(cid.guid) is async_exec # --------------------------------------------------------------------------- -# get_tables table_types client-side filter (m2) +# get_tables — table_types is filtered kernel-side (no connector drain) # --------------------------------------------------------------------------- def _make_tables_stream() -> MagicMock: """Build a fake stream that mimics the kernel's ``list_tables`` - output shape (5 cols ending in TABLE_TYPE at index 5 — the - connector matches what SEA produces, which has 5 metadata cols - before TABLE_TYPE). Returns a fixed table with mixed table types - so the filter has something to discriminate.""" + output shape. The kernel applies the ``table_types`` filter + itself, so the connector now forwards ``table_types`` and returns + this stream unchanged — these tests mock the kernel filter away + and only assert the forwarded args + pass-through behaviour.""" schema = pa.schema( [ ("TABLE_CAT", pa.string()), @@ -1021,12 +1023,15 @@ def _make_tables_stream() -> MagicMock: return stream -def test_get_tables_with_table_types_filters_rows(): +def test_get_tables_forwards_table_types_to_kernel(): + """``table_types`` is forwarded verbatim to the kernel's + ``list_tables`` (which filters case-insensitively) — the + connector no longer drains + refilters client-side. The stream + flows back through the normal ``KernelResultSet`` path unchanged.""" c = _make_client() c._kernel_session = MagicMock() - c._kernel_session.metadata.return_value.list_tables.return_value = ( - _make_tables_stream() - ) + list_tables = c._kernel_session.metadata.return_value.list_tables + list_tables.return_value = _make_tables_stream() cursor = MagicMock() cursor.arraysize = 100 cursor.buffer_size_bytes = 1024 @@ -1036,21 +1041,28 @@ def test_get_tables_with_table_types_filters_rows(): max_rows=10, max_bytes=1, cursor=cursor, - table_types=["TABLE"], + table_types=["view"], + ) + + list_tables.assert_called_once_with( + catalog=None, + schema_pattern=None, + table_pattern=None, + table_types=["view"], ) + # Stream is returned as-is — no connector-side row filtering. The + # mock kernel doesn't filter, so all three rows pass through. table = rs.fetchall_arrow() - assert table.num_rows == 2 - assert set(table.column("TABLE_TYPE").to_pylist()) == {"TABLE"} + assert table.num_rows == 3 -def test_get_tables_without_table_types_returns_full_stream(): - """No filter → kernel result flows through unchanged via the - normal ``KernelResultSet`` path (no drain-and-rewrap).""" +def test_get_tables_without_table_types_passes_none(): + """No filter → ``table_types=None`` forwarded; stream flows + through unchanged via the normal ``KernelResultSet`` path.""" c = _make_client() c._kernel_session = MagicMock() - c._kernel_session.metadata.return_value.list_tables.return_value = ( - _make_tables_stream() - ) + list_tables = c._kernel_session.metadata.return_value.list_tables + list_tables.return_value = _make_tables_stream() cursor = MagicMock() cursor.arraysize = 100 cursor.buffer_size_bytes = 1024 @@ -1062,10 +1074,334 @@ def test_get_tables_without_table_types_returns_full_stream(): cursor=cursor, table_types=None, ) + + list_tables.assert_called_once_with( + catalog=None, + schema_pattern=None, + table_pattern=None, + table_types=None, + ) table = rs.fetchall_arrow() assert table.num_rows == 3 +# --------------------------------------------------------------------------- +# Cursor-state tracking (T7) — active_command_id consistency +# --------------------------------------------------------------------------- + + +def _stream_with_schema() -> MagicMock: + """A minimal fake kernel handle whose ``arrow_schema()`` returns a + real schema so ``KernelResultSet.__init__`` succeeds.""" + stream = MagicMock() + stream.arrow_schema.return_value = pa.schema([("x", pa.int64())]) + return stream + + +def test_metadata_call_sets_active_command_id(): + """Metadata methods mint a synthetic command id and must publish + it on the cursor (Thrift sets ``active_command_id`` unconditionally + in ``_handle_execute_response``). Without this, ``cursor.query_id`` + would stay pinned to a prior query after a metadata browse.""" + c = _make_client() + c._kernel_session = MagicMock() + c._kernel_session.metadata.return_value.list_catalogs.return_value = ( + _stream_with_schema() + ) + cursor = MagicMock() + cursor.arraysize = 100 + cursor.buffer_size_bytes = 1024 + cursor.active_command_id = None + + c.get_catalogs(session_id=MagicMock(), max_rows=1, max_bytes=1, cursor=cursor) + + assert cursor.active_command_id is not None + # The synthetic id is a UUID-shaped guid the cursor can attribute + # logs to (rather than a stale prior-query id). + assert cursor.active_command_id.guid + + +def test_sync_execute_sets_active_command_id(): + """A successful sync execute publishes the server statement id on + the cursor.""" + c = _make_client() + c._kernel_session = MagicMock() + cursor = MagicMock() + cursor.arraysize = 100 + cursor.buffer_size_bytes = 1024 + cursor.active_command_id = None + + stmt = MagicMock() + stmt.execute.return_value = MagicMock( + statement_id="stmt-abc", + num_modified_rows=None, + arrow_schema=MagicMock(return_value=pa.schema([("x", pa.int64())])), + ) + c._kernel_session.statement.return_value = stmt + + c.execute_command( + operation="SELECT 1", + session_id=MagicMock(), + max_rows=1, + max_bytes=1, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + assert cursor.active_command_id is not None + assert cursor.active_command_id.guid == "stmt-abc" + + +def test_failed_sync_execute_sets_active_command_id_from_canceller(): + """When execute() fails *after* the server assigned a statement id, + the canceller's inflight slot holds that id. The connector reads it + and publishes ``active_command_id`` before re-raising, so the cursor + can correlate the failed query (telemetry / log lookup).""" + c = _make_client() + c._kernel_session = MagicMock() + cursor = MagicMock() + cursor.arraysize = 100 + cursor.buffer_size_bytes = 1024 + cursor.active_command_id = None + + canceller = MagicMock() + canceller.statement_id.return_value = "failed-stmt-id" + stmt = MagicMock() + stmt.canceller.return_value = canceller + stmt.execute.side_effect = RuntimeError("boom after id assigned") + c._kernel_session.statement.return_value = stmt + + with pytest.raises(Exception): + c.execute_command( + operation="SELECT 1", + session_id=MagicMock(), + max_rows=1, + max_bytes=1, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + assert cursor.active_command_id is not None + assert cursor.active_command_id.guid == "failed-stmt-id" + + +def test_failed_sync_execute_leaves_active_command_id_untouched_when_no_id(): + """A pre-id transport failure (canceller has no statement id yet) + must leave ``active_command_id`` untouched — there's nothing to + correlate, and clobbering it would lie about cursor state.""" + c = _make_client() + c._kernel_session = MagicMock() + cursor = MagicMock() + cursor.arraysize = 100 + cursor.buffer_size_bytes = 1024 + sentinel = object() + cursor.active_command_id = sentinel + + canceller = MagicMock() + canceller.statement_id.return_value = None # no id observed yet + stmt = MagicMock() + stmt.canceller.return_value = canceller + stmt.execute.side_effect = RuntimeError("connect refused") + c._kernel_session.statement.return_value = stmt + + with pytest.raises(Exception): + c.execute_command( + operation="SELECT 1", + session_id=MagicMock(), + max_rows=1, + max_bytes=1, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + assert cursor.active_command_id is sentinel + + +def test_sync_execute_forwards_num_modified_rows_to_rowcount(): + """DML reports a real ``cursor.rowcount`` from the kernel's + ``num_modified_rows`` instead of the hardcoded ``-1``.""" + c = _make_client() + c._kernel_session = MagicMock() + cursor = MagicMock() + cursor.arraysize = 100 + cursor.buffer_size_bytes = 1024 + cursor.rowcount = -1 + + stmt = MagicMock() + stmt.execute.return_value = MagicMock( + statement_id="dml-1", + num_modified_rows=42, + arrow_schema=MagicMock(return_value=pa.schema([("x", pa.int64())])), + ) + c._kernel_session.statement.return_value = stmt + + c.execute_command( + operation="INSERT INTO t VALUES (1)", + session_id=MagicMock(), + max_rows=1, + max_bytes=1, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + assert cursor.rowcount == 42 + + +def test_sync_execute_leaves_rowcount_default_when_num_modified_rows_none(): + """SELECT (and warehouses that don't report it) → ``None`` leaves + ``rowcount`` at its ``-1`` default.""" + c = _make_client() + c._kernel_session = MagicMock() + cursor = MagicMock() + cursor.arraysize = 100 + cursor.buffer_size_bytes = 1024 + cursor.rowcount = -1 + + stmt = MagicMock() + stmt.execute.return_value = MagicMock( + statement_id="select-1", + num_modified_rows=None, + arrow_schema=MagicMock(return_value=pa.schema([("x", pa.int64())])), + ) + c._kernel_session.statement.return_value = stmt + + c.execute_command( + operation="SELECT 1", + session_id=MagicMock(), + max_rows=1, + max_bytes=1, + lz4_compression=False, + cursor=cursor, + use_cloud_fetch=False, + parameters=[], + async_op=False, + enforce_embedded_schema_correctness=False, + ) + + assert cursor.rowcount == -1 + + +def test_closed_commands_is_bounded(): + """``_closed_commands`` is a bounded FIFO — inserting past the cap + evicts the oldest. A recently-closed command still reports CLOSED; + an evicted (very old) one falls through to the SUCCEEDED + sync-default, which is consistent with never having been tracked.""" + c = _make_client() + cap = kernel_client._CLOSED_COMMANDS_MAX + first = CommandId.from_sea_statement_id("closed-0") + with c._async_handles_lock: + c._record_closed(first.guid) + for i in range(1, cap + 5): + c._record_closed(CommandId.from_sea_statement_id(f"closed-{i}").guid) + + assert len(c._closed_commands) <= cap + # Oldest evicted → no longer reports CLOSED. + assert c.get_query_state(first) == CommandState.SUCCEEDED + # Most-recent retained → reports CLOSED. + recent = CommandId.from_sea_statement_id(f"closed-{cap + 4}") + assert c.get_query_state(recent) == CommandState.CLOSED + + +# --------------------------------------------------------------------------- +# Metadata filter normalization — wildcard catalog + empty-string patterns +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize("wildcard", ["%", "*", "", " "]) +def test_get_columns_normalizes_wildcard_catalog_to_none(wildcard): + """``catalog_name`` of ``%``/``*``/blank → ``None`` (all-catalogs), + matching JDBC exact-or-all semantics and keeping the three metadata + methods symmetric.""" + c = _make_client() + c._kernel_session = MagicMock() + list_columns = c._kernel_session.metadata.return_value.list_columns + list_columns.return_value = _stream_with_schema() + cursor = MagicMock() + cursor.arraysize = 100 + cursor.buffer_size_bytes = 1024 + + c.get_columns( + session_id=MagicMock(), + max_rows=1, + max_bytes=1, + cursor=cursor, + catalog_name=wildcard, + schema_name="s", + table_name="t", + column_name="c", + ) + + list_columns.assert_called_once_with( + catalog=None, + schema_pattern="s", + table_pattern="t", + column_pattern="c", + ) + + +def test_get_schemas_normalizes_blank_pattern_to_none(): + """An empty-string schema pattern → ``None`` (match-all), mapping + the kernel's ``InvalidArgument``-on-``""`` to Thrift's effective + match-all. ``%``/``*`` stay as real LIKE wildcards on patterns.""" + c = _make_client() + c._kernel_session = MagicMock() + list_schemas = c._kernel_session.metadata.return_value.list_schemas + list_schemas.return_value = _stream_with_schema() + cursor = MagicMock() + cursor.arraysize = 100 + cursor.buffer_size_bytes = 1024 + + c.get_schemas( + session_id=MagicMock(), + max_rows=1, + max_bytes=1, + cursor=cursor, + catalog_name="main", + schema_name="", + ) + + list_schemas.assert_called_once_with(catalog="main", schema_pattern=None) + + +def test_get_schemas_keeps_wildcard_pattern(): + """A ``%`` schema pattern is a real LIKE wildcard — passed through, + NOT normalized to None.""" + c = _make_client() + c._kernel_session = MagicMock() + list_schemas = c._kernel_session.metadata.return_value.list_schemas + list_schemas.return_value = _stream_with_schema() + cursor = MagicMock() + cursor.arraysize = 100 + cursor.buffer_size_bytes = 1024 + + c.get_schemas( + session_id=MagicMock(), + max_rows=1, + max_bytes=1, + cursor=cursor, + catalog_name="main", + schema_name="prod_%", + ) + + list_schemas.assert_called_once_with(catalog="main", schema_pattern="prod_%") + + # --------------------------------------------------------------------------- # TLS translation: SSLOptions -> kernel Session tls_* kwargs. # ---------------------------------------------------------------------------