refactor: sql lab: handling command exceptions (#16852)

* chore: support error_type in SupersetException and method to convert the exception to dictionary

* chore: support error_type in SupersetException and method to convert the exception to dictionary

* refactor handling command exceptions   fix update query status when query was not created
This commit is contained in:
ofekisr 2021-09-29 16:20:42 +03:00 committed by GitHub
parent 3d8cc15cba
commit 3f784cc1c7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 194 additions and 70 deletions

View File

@ -218,3 +218,9 @@ class SupersetError:
]
}
)
def to_dict(self) -> Dict[str, Any]:
rv = {"message": self.message, "error_type": self.error_type}
if self.extra:
rv["extra"] = self.extra # type: ignore
return rv

View File

@ -28,17 +28,35 @@ class SupersetException(Exception):
message = ""
def __init__(
self, message: str = "", exception: Optional[Exception] = None,
self,
message: str = "",
exception: Optional[Exception] = None,
error_type: Optional[SupersetErrorType] = None,
) -> None:
if message:
self.message = message
self._exception = exception
self._error_type = error_type
super().__init__(self.message)
@property
def exception(self) -> Optional[Exception]:
return self._exception
@property
def error_type(self) -> Optional[SupersetErrorType]:
return self._error_type
def to_dict(self) -> Dict[str, Any]:
rv = {}
if hasattr(self, "message"):
rv["message"] = self.message
if self.error_type:
rv["error_type"] = self.error_type
if self.exception is not None and hasattr(self.exception, "to_dict"):
rv = {**rv, **self.exception.to_dict()} # type: ignore
return rv
class SupersetErrorException(SupersetException):
"""Exceptions with a single SupersetErrorType associated with them"""
@ -49,6 +67,9 @@ class SupersetErrorException(SupersetException):
if status is not None:
self.status = status
def to_dict(self) -> Dict[str, Any]:
return self.error.to_dict()
class SupersetGenericErrorException(SupersetErrorException):
"""Exceptions that are too generic to have their own type"""

View File

@ -47,6 +47,7 @@ from superset.models.core import Database
from superset.models.sql_lab import Query
from superset.queries.dao import QueryDAO
from superset.sqllab.command_status import SqlJsonExecutionStatus
from superset.sqllab.exceptions import SqlLabException
from superset.sqllab.limiting_factor import LimitingFactor
from superset.sqllab.utils import apply_display_max_row_configuration_if_require
from superset.utils import core as utils
@ -68,18 +69,18 @@ CommandResult = Dict[str, Any]
class ExecuteSqlCommand(BaseCommand):
execution_context: SqlJsonExecutionContext
log_params: Optional[Dict[str, Any]] = None
session: Session
_execution_context: SqlJsonExecutionContext
_log_params: Optional[Dict[str, Any]] = None
_session: Session
def __init__(
self,
execution_context: SqlJsonExecutionContext,
log_params: Optional[Dict[str, Any]] = None,
) -> None:
self.execution_context = execution_context
self.log_params = log_params
self.session = db.session()
self._execution_context = execution_context
self._log_params = log_params
self._session = db.session()
def validate(self) -> None:
pass
@ -88,30 +89,29 @@ class ExecuteSqlCommand(BaseCommand):
self,
) -> CommandResult:
"""Runs arbitrary sql and returns data as json"""
try:
query = self._get_existing_query()
if self.is_query_handled(query):
self._execution_context.set_query(query) # type: ignore
status = SqlJsonExecutionStatus.QUERY_ALREADY_CREATED
else:
status = self._run_sql_json_exec_from_scratch()
return {
"status": status,
"payload": self._create_payload_from_execution_context(status),
}
except (SqlLabException, SupersetErrorsException) as ex:
raise ex
except Exception as ex:
raise SqlLabException(self._execution_context, exception=ex) from ex
query = self._get_existing_query(self.execution_context, self.session)
if self.is_query_handled(query):
self.execution_context.set_query(query) # type: ignore
status = SqlJsonExecutionStatus.QUERY_ALREADY_CREATED
else:
status = self._run_sql_json_exec_from_scratch()
return {
"status": status,
"payload": self._create_payload_from_execution_context(status),
}
@classmethod
def _get_existing_query(
cls, execution_context: SqlJsonExecutionContext, session: Session
) -> Optional[Query]:
def _get_existing_query(self) -> Optional[Query]:
query = (
session.query(Query)
self._session.query(Query)
.filter_by(
client_id=execution_context.client_id,
user_id=execution_context.user_id,
sql_editor_id=execution_context.sql_editor_id,
client_id=self._execution_context.client_id,
user_id=self._execution_context.user_id,
sql_editor_id=self._execution_context.sql_editor_id,
)
.one_or_none()
)
@ -126,25 +126,24 @@ class ExecuteSqlCommand(BaseCommand):
]
def _run_sql_json_exec_from_scratch(self) -> SqlJsonExecutionStatus:
self.execution_context.set_database(self._get_the_query_db())
query = self.execution_context.create_query()
self._execution_context.set_database(self._get_the_query_db())
query = self._execution_context.create_query()
self._save_new_query(query)
try:
self._save_new_query(query)
logger.info("Triggering query_id: %i", query.id)
self._validate_access(query)
self.execution_context.set_query(query)
self._execution_context.set_query(query)
rendered_query = self._render_query()
self._set_query_limit_if_required(rendered_query)
return self._execute_query(rendered_query)
except Exception as ex:
query.status = QueryStatus.FAILED
self.session.commit()
self._session.commit()
raise ex
def _get_the_query_db(self) -> Database:
mydb = self.session.query(Database).get(self.execution_context.database_id)
mydb = self._session.query(Database).get(self._execution_context.database_id)
self._validate_query_db(mydb)
return mydb
@ -160,12 +159,12 @@ class ExecuteSqlCommand(BaseCommand):
def _save_new_query(self, query: Query) -> None:
try:
self.session.add(query)
self.session.flush()
self.session.commit() # shouldn't be necessary
self._session.add(query)
self._session.flush()
self._session.commit() # shouldn't be necessary
except SQLAlchemyError as ex:
logger.error("Errors saving query details %s", str(ex), exc_info=True)
self.session.rollback()
self._session.rollback()
if not query.id:
raise SupersetGenericErrorException(
__(
@ -181,7 +180,7 @@ class ExecuteSqlCommand(BaseCommand):
query.set_extra_json_key("errors", [dataclasses.asdict(ex.error)])
query.status = QueryStatus.FAILED
query.error_message = ex.error.message
self.session.commit()
self._session.commit()
raise SupersetErrorException(ex.error, status=403) from ex
def _render_query(self) -> str:
@ -205,18 +204,18 @@ class ExecuteSqlCommand(BaseCommand):
error=SupersetErrorType.MISSING_TEMPLATE_PARAMS_ERROR,
extra={
"undefined_parameters": list(undefined_parameters),
"template_parameters": self.execution_context.template_params,
"template_parameters": self._execution_context.template_params,
},
)
query = self.execution_context.query
query = self._execution_context.query
try:
template_processor = get_template_processor(
database=query.database, query=query
)
rendered_query = template_processor.process_template(
query.sql, **self.execution_context.template_params
query.sql, **self._execution_context.template_params
)
validate(rendered_query, template_processor)
except TemplateError as ex:
@ -235,24 +234,24 @@ class ExecuteSqlCommand(BaseCommand):
def _is_required_to_set_limit(self) -> bool:
return not (
config.get("SQLLAB_CTAS_NO_LIMIT") and self.execution_context.select_as_cta
config.get("SQLLAB_CTAS_NO_LIMIT") and self._execution_context.select_as_cta
)
def _set_query_limit(self, rendered_query: str) -> None:
db_engine_spec = self.execution_context.database.db_engine_spec # type: ignore
db_engine_spec = self._execution_context.database.db_engine_spec # type: ignore
limits = [
db_engine_spec.get_limit_from_sql(rendered_query),
self.execution_context.limit,
self._execution_context.limit,
]
if limits[0] is None or limits[0] > limits[1]: # type: ignore
self.execution_context.query.limiting_factor = LimitingFactor.DROPDOWN
self._execution_context.query.limiting_factor = LimitingFactor.DROPDOWN
elif limits[1] > limits[0]: # type: ignore
self.execution_context.query.limiting_factor = LimitingFactor.QUERY
self._execution_context.query.limiting_factor = LimitingFactor.QUERY
else: # limits[0] == limits[1]
self.execution_context.query.limiting_factor = (
self._execution_context.query.limiting_factor = (
LimitingFactor.QUERY_AND_DROPDOWN
)
self.execution_context.query.limit = min(
self._execution_context.query.limit = min(
lim for lim in limits if lim is not None
)
@ -260,7 +259,7 @@ class ExecuteSqlCommand(BaseCommand):
# Flag for whether or not to expand data
# (feature that will expand Presto row objects and arrays)
# Async request.
if self.execution_context.is_run_asynchronous():
if self._execution_context.is_run_asynchronous():
return self._sql_json_async(rendered_query)
return self._sql_json_sync(rendered_query)
@ -271,7 +270,7 @@ class ExecuteSqlCommand(BaseCommand):
:param rendered_query: the rendered query to perform by workers
:return: A Flask Response
"""
query = self.execution_context.query
query = self._execution_context.query
logger.info("Query %i: Running query on a Celery worker", query.id)
# Ignore the celery future object and the request may time out.
query_id = query.id
@ -285,8 +284,8 @@ class ExecuteSqlCommand(BaseCommand):
if g.user and hasattr(g.user, "username")
else None,
start_time=now_as_float(),
expand_data=self.execution_context.expand_data,
log_params=self.log_params,
expand_data=self._execution_context.expand_data,
log_params=self._log_params,
)
# Explicitly forget the task to ensure the task metadata is removed from the
@ -312,14 +311,14 @@ class ExecuteSqlCommand(BaseCommand):
query.set_extra_json_key("errors", [error_payload])
query.status = QueryStatus.FAILED
query.error_message = message
self.session.commit()
self._session.commit()
raise SupersetErrorException(error) from ex
# Update saved query with execution info from the query execution
QueryDAO.update_saved_query_exec_info(query_id)
self.session.commit()
self._session.commit()
return SqlJsonExecutionStatus.QUERY_IS_RUNNING
def _sql_json_sync(self, rendered_query: str) -> SqlJsonExecutionStatus:
@ -329,7 +328,7 @@ class ExecuteSqlCommand(BaseCommand):
:param rendered_query: The rendered query (included templates)
:raises: SupersetTimeoutException
"""
query = self.execution_context.query
query = self._execution_context.query
try:
timeout = config["SQLLAB_TIMEOUT"]
timeout_msg = f"The query exceeded the {timeout} seconds timeout."
@ -339,7 +338,7 @@ class ExecuteSqlCommand(BaseCommand):
)
# Update saved query if needed
QueryDAO.update_saved_query_exec_info(query_id)
self.execution_context.set_execution_result(data)
self._execution_context.set_execution_result(data)
except SupersetTimeoutException as ex:
# re-raise exception for api exception handler
raise ex
@ -362,7 +361,7 @@ class ExecuteSqlCommand(BaseCommand):
def _get_sql_results_with_timeout(
self, timeout: int, rendered_query: str, timeout_msg: str,
) -> Optional[SqlResults]:
query = self.execution_context.query
query = self._execution_context.query
with utils.timeout(seconds=timeout, error_message=timeout_msg):
# pylint: disable=no-value-for-parameter
return sql_lab.get_sql_results(
@ -373,8 +372,8 @@ class ExecuteSqlCommand(BaseCommand):
user_name=g.user.username
if g.user and hasattr(g.user, "username")
else None,
expand_data=self.execution_context.expand_data,
log_params=self.log_params,
expand_data=self._execution_context.expand_data,
log_params=self._log_params,
)
@classmethod
@ -389,9 +388,9 @@ class ExecuteSqlCommand(BaseCommand):
if status == SqlJsonExecutionStatus.HAS_RESULTS:
return self._to_payload_results_based(
self.execution_context.get_execution_result() or {}
self._execution_context.get_execution_result() or {}
)
return self._to_payload_query_based(self.execution_context.query)
return self._to_payload_query_based(self._execution_context.query)
def _to_payload_results_based( # pylint: disable=no-self-use
self, execution_result: SqlResults

View File

@ -0,0 +1,83 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import os
from typing import Optional, TYPE_CHECKING
from superset.errors import SupersetError, SupersetErrorType
from superset.exceptions import SupersetException
MSG_FORMAT = "Failed to execute {}"
if TYPE_CHECKING:
from superset.utils.sqllab_execution_context import SqlJsonExecutionContext
class SqlLabException(SupersetException):
sql_json_execution_context: SqlJsonExecutionContext
failed_reason_msg: str
suggestion_help_msg: Optional[str]
def __init__( # pylint: disable=too-many-arguments
self,
sql_json_execution_context: SqlJsonExecutionContext,
error_type: Optional[SupersetErrorType] = None,
reason_message: Optional[str] = None,
exception: Optional[Exception] = None,
suggestion_help_msg: Optional[str] = None,
) -> None:
self.sql_json_execution_context = sql_json_execution_context
self.failed_reason_msg = self._get_reason(reason_message, exception)
self.suggestion_help_msg = suggestion_help_msg
if error_type is None:
if exception is not None:
if (
hasattr(exception, "error_type")
and exception.error_type is not None # type: ignore
):
error_type = exception.error_type # type: ignore
elif hasattr(exception, "error") and isinstance(
exception.error, SupersetError # type: ignore
):
error_type = exception.error.error_type # type: ignore
else:
error_type = SupersetErrorType.GENERIC_BACKEND_ERROR
super().__init__(self._generate_message(), exception, error_type)
def _generate_message(self) -> str:
msg = MSG_FORMAT.format(self.sql_json_execution_context.get_query_details())
if self.failed_reason_msg:
msg = msg + self.failed_reason_msg
if self.suggestion_help_msg is not None:
msg = "{} {} {}".format(msg, os.linesep, self.suggestion_help_msg)
return msg
@classmethod
def _get_reason(
cls, reason_message: Optional[str] = None, exception: Optional[Exception] = None
) -> str:
if reason_message is not None:
return ": {}".format(reason_message)
if exception is not None:
if hasattr(exception, "get_message"):
return ": {}".format(exception.get_message()) # type: ignore
if hasattr(exception, "message"):
return ": {}".format(exception.message) # type: ignore
return ": {}".format(str(exception))
return ""

View File

@ -22,6 +22,7 @@ from dataclasses import dataclass
from typing import Any, cast, Dict, Optional, TYPE_CHECKING
from flask import g
from sqlalchemy.orm.exc import DetachedInstanceError
from superset import is_feature_enabled
from superset.models.sql_lab import Query
@ -177,6 +178,15 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes
client_id=self.client_id_or_short_id,
)
def get_query_details(self) -> str:
try:
if self.query:
if self.query.id:
return "query '{}' - '{}'".format(self.query.id, self.query.sql)
except DetachedInstanceError:
pass
return "query '{}'".format(self.sql)
class CreateTableAsSelect: # pylint: disable=too-few-public-methods
ctas_method: CtasMethod

View File

@ -100,6 +100,7 @@ from superset.sql_parse import ParsedQuery, Table
from superset.sql_validators import get_validator_by_name
from superset.sqllab.command import CommandResult, ExecuteSqlCommand
from superset.sqllab.command_status import SqlJsonExecutionStatus
from superset.sqllab.exceptions import SqlLabException
from superset.sqllab.limiting_factor import LimitingFactor
from superset.sqllab.utils import apply_display_max_row_configuration_if_require
from superset.tasks.async_queries import load_explore_json_into_cache
@ -2434,13 +2435,17 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
@event_logger.log_this
@expose("/sql_json/", methods=["POST"])
def sql_json(self) -> FlaskResponse:
log_params = {
"user_agent": cast(Optional[str], request.headers.get("USER_AGENT"))
}
execution_context = SqlJsonExecutionContext(request.json)
command = ExecuteSqlCommand(execution_context, log_params)
command_result: CommandResult = command.run()
return self._create_response_from_execution_context(command_result)
try:
log_params = {
"user_agent": cast(Optional[str], request.headers.get("USER_AGENT"))
}
execution_context = SqlJsonExecutionContext(request.json)
command = ExecuteSqlCommand(execution_context, log_params)
command_result: CommandResult = command.run()
return self._create_response_from_execution_context(command_result)
except SqlLabException as ex:
payload = {"errors": [ex.to_dict()]}
return json_error_response(status=ex.status, payload=payload)
def _create_response_from_execution_context( # pylint: disable=invalid-name, no-self-use
self, command_result: CommandResult,