fix: revert fix(sqllab): Force trino client async execution (#24859) (#25541)

This commit is contained in:
Ville Brofeldt 2023-10-13 04:58:20 -07:00 committed by GitHub
parent ef1807cd7e
commit e56e0de458
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 18 additions and 114 deletions

View File

@ -1066,24 +1066,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
query object"""
# TODO: Fix circular import error caused by importing sql_lab.Query
@classmethod
def execute_with_cursor(
cls, cursor: Any, sql: str, query: Query, session: Session
) -> None:
"""
Trigger execution of a query and handle the resulting cursor.
For most implementations this just makes calls to `execute` and
`handle_cursor` consecutively, but in some engines (e.g. Trino) we may
need to handle client limitations such as lack of async support and
perform a more complicated operation to get information from the cursor
in a timely manner and facilitate operations such as query stop
"""
logger.debug("Query %d: Running query: %s", query.id, sql)
cls.execute(cursor, sql, async_=True)
logger.debug("Query %d: Handling cursor", query.id)
cls.handle_cursor(cursor, query, session)
@classmethod
def extract_error_message(cls, ex: Exception) -> str:
return f"{cls.engine} error: {cls._extract_error_message(ex)}"

View File

@ -18,8 +18,6 @@ from __future__ import annotations
import contextlib
import logging
import threading
import time
from typing import Any, TYPE_CHECKING
import simplejson as json
@ -153,22 +151,15 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
@classmethod
def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None:
"""
Handle a trino client cursor.
WARNING: if you execute a query, it will block until complete and you
will not be able to handle the cursor until complete. Use
`execute_with_cursor` instead, to handle this asynchronously.
"""
# Adds the executed query id to the extra payload so the query can be cancelled
cancel_query_id = cursor.query_id
logger.debug("Query %d: queryId %s found in cursor", query.id, cancel_query_id)
query.set_extra_json_key(key=QUERY_CANCEL_KEY, value=cancel_query_id)
if tracking_url := cls.get_tracking_url(cursor):
query.tracking_url = tracking_url
# Adds the executed query id to the extra payload so the query can be cancelled
query.set_extra_json_key(
key=QUERY_CANCEL_KEY,
value=(cancel_query_id := cursor.stats["queryId"]),
)
session.commit()
# if query cancelation was requested prior to the handle_cursor call, but
@ -182,51 +173,6 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
super().handle_cursor(cursor=cursor, query=query, session=session)
@classmethod
def execute_with_cursor(
cls, cursor: Any, sql: str, query: Query, session: Session
) -> None:
"""
Trigger execution of a query and handle the resulting cursor.
Trino's client blocks until the query is complete, so we need to run it
in another thread and invoke `handle_cursor` to poll for the query ID
to appear on the cursor in parallel.
"""
execute_result: dict[str, Any] = {}
def _execute(results: dict[str, Any]) -> None:
logger.debug("Query %d: Running query: %s", query.id, sql)
# Pass result / exception information back to the parent thread
try:
cls.execute(cursor, sql)
results["complete"] = True
except Exception as ex: # pylint: disable=broad-except
results["complete"] = True
results["error"] = ex
execute_thread = threading.Thread(target=_execute, args=(execute_result,))
execute_thread.start()
# Wait for a query ID to be available before handling the cursor, as
# it's required by that method; it may never become available on error.
while not cursor.query_id and not execute_result.get("complete"):
time.sleep(0.1)
logger.debug("Query %d: Handling cursor", query.id)
cls.handle_cursor(cursor, query, session)
# Block until the query completes; same behaviour as the client itself
logger.debug("Query %d: Waiting for query to complete", query.id)
while not execute_result.get("complete"):
time.sleep(0.5)
# Unfortunately we'll mangle the stack trace due to the thread, but
# throwing the original exception allows mapping database errors as normal
if err := execute_result.get("error"):
raise err
@classmethod
def prepare_cancel_query(cls, query: Query, session: Session) -> None:
if QUERY_CANCEL_KEY not in query.extra:

View File

@ -191,7 +191,7 @@ def get_sql_results( # pylint: disable=too-many-arguments
return handle_query_error(ex, query, session)
def execute_sql_statement( # pylint: disable=too-many-arguments
def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-statements
sql_statement: str,
query: Query,
session: Session,
@ -271,7 +271,10 @@ def execute_sql_statement( # pylint: disable=too-many-arguments
)
session.commit()
with stats_timing("sqllab.query.time_executing_query", stats_logger):
db_engine_spec.execute_with_cursor(cursor, sql, query, session)
logger.debug("Query %d: Running query: %s", query.id, sql)
db_engine_spec.execute(cursor, sql, async_=True)
logger.debug("Query %d: Handling cursor", query.id)
db_engine_spec.handle_cursor(cursor, query, session)
with stats_timing("sqllab.query.time_fetching_results", stats_logger):
logger.debug(

View File

@ -352,7 +352,7 @@ def test_handle_cursor_early_cancel(
query_id = "myQueryId"
cursor_mock = engine_mock.return_value.__enter__.return_value
cursor_mock.query_id = query_id
cursor_mock.stats = {"queryId": query_id}
session_mock = mocker.MagicMock()
query = Query()
@ -366,32 +366,3 @@ def test_handle_cursor_early_cancel(
assert cancel_query_mock.call_args[1]["cancel_query_id"] == query_id
else:
assert cancel_query_mock.call_args is None
def test_execute_with_cursor_in_parallel(mocker: MockerFixture):
"""Test that `execute_with_cursor` fetches query ID from the cursor"""
from superset.db_engine_specs.trino import TrinoEngineSpec
query_id = "myQueryId"
mock_cursor = mocker.MagicMock()
mock_cursor.query_id = None
mock_query = mocker.MagicMock()
mock_session = mocker.MagicMock()
def _mock_execute(*args, **kwargs):
mock_cursor.query_id = query_id
mock_cursor.execute.side_effect = _mock_execute
TrinoEngineSpec.execute_with_cursor(
cursor=mock_cursor,
sql="SELECT 1 FROM foo",
query=mock_query,
session=mock_session,
)
mock_query.set_extra_json_key.assert_called_once_with(
key=QUERY_CANCEL_KEY, value=query_id
)

View File

@ -55,8 +55,8 @@ def test_execute_sql_statement(mocker: MockerFixture, app: None) -> None:
)
database.apply_limit_to_sql.assert_called_with("SELECT 42 AS answer", 2, force=True)
db_engine_spec.execute_with_cursor.assert_called_with(
cursor, "SELECT 42 AS answer LIMIT 2", query, session
db_engine_spec.execute.assert_called_with(
cursor, "SELECT 42 AS answer LIMIT 2", async_=True
)
SupersetResultSet.assert_called_with([(42,)], cursor.description, db_engine_spec)
@ -106,8 +106,10 @@ def test_execute_sql_statement_with_rls(
101,
force=True,
)
db_engine_spec.execute_with_cursor.assert_called_with(
cursor, "SELECT * FROM sales WHERE organization_id=42 LIMIT 101", query, session
db_engine_spec.execute.assert_called_with(
cursor,
"SELECT * FROM sales WHERE organization_id=42 LIMIT 101",
async_=True,
)
SupersetResultSet.assert_called_with([(42,)], cursor.description, db_engine_spec)