mirror of https://github.com/apache/superset.git
feat(trino): add query cancellation (#21035)
This commit is contained in:
parent
2d1ba46844
commit
5113b01031
|
@ -14,6 +14,8 @@
|
||||||
# KIND, either express or implied. See the License for the
|
# KIND, either express or implied. See the License for the
|
||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
@ -90,7 +92,7 @@ class TrinoEngineSpec(PrestoEngineSpec):
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_table_names(
|
def get_table_names(
|
||||||
cls,
|
cls,
|
||||||
database: "Database",
|
database: Database,
|
||||||
inspector: Inspector,
|
inspector: Inspector,
|
||||||
schema: Optional[str],
|
schema: Optional[str],
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
|
@ -103,7 +105,7 @@ class TrinoEngineSpec(PrestoEngineSpec):
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_view_names(
|
def get_view_names(
|
||||||
cls,
|
cls,
|
||||||
database: "Database",
|
database: Database,
|
||||||
inspector: Inspector,
|
inspector: Inspector,
|
||||||
schema: Optional[str],
|
schema: Optional[str],
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
|
@ -114,7 +116,7 @@ class TrinoEngineSpec(PrestoEngineSpec):
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_tracking_url(cls, cursor: "Cursor") -> Optional[str]:
|
def get_tracking_url(cls, cursor: Cursor) -> Optional[str]:
|
||||||
try:
|
try:
|
||||||
return cursor.info_uri
|
return cursor.info_uri
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
|
@ -127,14 +129,42 @@ class TrinoEngineSpec(PrestoEngineSpec):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def handle_cursor(cls, cursor: "Cursor", query: Query, session: Session) -> None:
|
def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None:
|
||||||
"""Updates progress information"""
|
|
||||||
tracking_url = cls.get_tracking_url(cursor)
|
tracking_url = cls.get_tracking_url(cursor)
|
||||||
if tracking_url:
|
if tracking_url:
|
||||||
query.tracking_url = tracking_url
|
query.tracking_url = tracking_url
|
||||||
session.commit()
|
|
||||||
|
# Adds the executed query id to the extra payload so the query can be cancelled
|
||||||
|
query.set_extra_json_key("cancel_query", cursor.stats["queryId"])
|
||||||
|
|
||||||
|
session.commit()
|
||||||
BaseEngineSpec.handle_cursor(cursor=cursor, query=query, session=session)
|
BaseEngineSpec.handle_cursor(cursor=cursor, query=query, session=session)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def has_implicit_cancel(cls) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def cancel_query(cls, cursor: Any, query: Query, cancel_query_id: str) -> bool:
|
||||||
|
"""
|
||||||
|
Cancel query in the underlying database.
|
||||||
|
|
||||||
|
:param cursor: New cursor instance to the db of the query
|
||||||
|
:param query: Query instance
|
||||||
|
:param cancel_query_id: Trino `queryId`
|
||||||
|
:return: True if query cancelled successfully, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
cursor.execute(
|
||||||
|
f"CALL system.runtime.kill_query(query_id => '{cancel_query_id}',"
|
||||||
|
"message => 'Query cancelled by Superset')"
|
||||||
|
)
|
||||||
|
cursor.fetchall() # needed to trigger the call
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_extra_params(database: "Database") -> Dict[str, Any]:
|
def get_extra_params(database: "Database") -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch("sqlalchemy.engine.Engine.connect")
|
||||||
|
def test_cancel_query_success(engine_mock: mock.Mock) -> None:
|
||||||
|
from superset.db_engine_specs.trino import TrinoEngineSpec
|
||||||
|
from superset.models.sql_lab import Query
|
||||||
|
|
||||||
|
query = Query()
|
||||||
|
cursor_mock = engine_mock.return_value.__enter__.return_value
|
||||||
|
assert TrinoEngineSpec.cancel_query(cursor_mock, query, "123") is True
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch("sqlalchemy.engine.Engine.connect")
|
||||||
|
def test_cancel_query_failed(engine_mock: mock.Mock) -> None:
|
||||||
|
from superset.db_engine_specs.trino import TrinoEngineSpec
|
||||||
|
from superset.models.sql_lab import Query
|
||||||
|
|
||||||
|
query = Query()
|
||||||
|
cursor_mock = engine_mock.raiseError.side_effect = Exception()
|
||||||
|
assert TrinoEngineSpec.cancel_query(cursor_mock, query, "123") is False
|
Loading…
Reference in New Issue