slightly decouple sql_json, queries, and results http endpoints from … (#8626)

* slightly decouple sql_json, queries, and results http endpoints from biz logic

* fix syntax errors

* add some type annotations, fix a bug

* remove unnecessary var decl and assign

* add a lot more type annotations to fix tests

* fix mypy issues
This commit is contained in:
Dave Smith 2019-11-22 10:12:48 -08:00 committed by Kim Truong
parent a72a39502f
commit aafbfd3b4e
2 changed files with 47 additions and 30 deletions

View File

@ -645,13 +645,13 @@ def pessimistic_connection_handling(some_engine):
class QueryStatus:
"""Enum-type class for query statuses"""
STOPPED = "stopped"
FAILED = "failed"
PENDING = "pending"
RUNNING = "running"
SCHEDULED = "scheduled"
SUCCESS = "success"
TIMED_OUT = "timed_out"
STOPPED: str = "stopped"
FAILED: str = "failed"
PENDING: str = "pending"
RUNNING: str = "running"
SCHEDULED: str = "scheduled"
SUCCESS: str = "success"
TIMED_OUT: str = "timed_out"
def notify_user_about_perm_udate(granter, user, role, datasource, tpl_name, config):

View File

@ -20,7 +20,7 @@ import re
from contextlib import closing
from datetime import datetime, timedelta
from enum import Enum
from typing import List, Optional, Union
from typing import cast, List, Optional, Union
from urllib import parse
import backoff
@ -2491,6 +2491,9 @@ class Superset(BaseSupersetView):
@expose("/results/<key>/")
@event_logger.log_this
def results(self, key):
return self.results_exec(key)
def results_exec(self, key: str):
"""Serves a key off of the results backend
It is possible to pass the `rows` query argument to limit the number
@ -2527,7 +2530,9 @@ class Superset(BaseSupersetView):
)
payload = utils.zlib_decompress(blob, decode=not results_backend_use_msgpack)
obj = _deserialize_results_payload(payload, query, results_backend_use_msgpack)
obj: dict = _deserialize_results_payload(
payload, query, cast(bool, results_backend_use_msgpack)
)
if "rows" in request.args:
try:
@ -2722,34 +2727,39 @@ class Superset(BaseSupersetView):
@expose("/sql_json/", methods=["POST"])
@event_logger.log_this
def sql_json(self):
return self.sql_json_exec(request.json)
def sql_json_exec(self, query_params: dict):
"""Runs arbitrary sql and returns data as json"""
# Collect Values
database_id: int = request.json.get("database_id")
schema: str = request.json.get("schema")
sql: str = request.json.get("sql")
database_id: int = cast(int, query_params.get("database_id"))
schema: str = cast(str, query_params.get("schema"))
sql: str = cast(str, query_params.get("sql"))
try:
template_params: dict = json.loads(
request.json.get("templateParams") or "{}"
query_params.get("templateParams") or "{}"
)
except json.decoder.JSONDecodeError:
except json.JSONDecodeError:
logging.warning(
f"Invalid template parameter {request.json.get('templateParams')}"
f"Invalid template parameter {query_params.get('templateParams')}"
" specified. Defaulting to empty dict"
)
template_params = {}
limit = request.json.get("queryLimit") or app.config["SQL_MAX_ROW"]
async_flag: bool = request.json.get("runAsync")
limit: int = query_params.get("queryLimit") or app.config["SQL_MAX_ROW"]
async_flag: bool = cast(bool, query_params.get("runAsync"))
if limit < 0:
logging.warning(
f"Invalid limit of {limit} specified. Defaulting to max limit."
)
limit = 0
select_as_cta: bool = request.json.get("select_as_cta")
tmp_table_name: str = request.json.get("tmp_table_name")
client_id: str = request.json.get("client_id") or utils.shortid()[:10]
sql_editor_id: str = request.json.get("sql_editor_id")
tab_name: str = request.json.get("tab")
status: bool = QueryStatus.PENDING if async_flag else QueryStatus.RUNNING
select_as_cta: bool = cast(bool, query_params.get("select_as_cta"))
tmp_table_name: str = cast(str, query_params.get("tmp_table_name"))
client_id: str = cast(
str, query_params.get("client_id") or utils.shortid()[:10]
)
sql_editor_id: str = cast(str, query_params.get("sql_editor_id"))
tab_name: str = cast(str, query_params.get("tab"))
status: str = QueryStatus.PENDING if async_flag else QueryStatus.RUNNING
session = db.session()
mydb = session.query(models.Database).filter_by(id=database_id).one_or_none()
@ -2817,9 +2827,11 @@ class Superset(BaseSupersetView):
# Flag for whether or not to expand data
# (feature that will expand Presto row objects and arrays)
expand_data: bool = is_feature_enabled(
"PRESTO_EXPAND_DATA"
) and request.json.get("expand_data")
expand_data: bool = cast(
bool,
is_feature_enabled("PRESTO_EXPAND_DATA")
and query_params.get("expand_data"),
)
# Async request.
if async_flag:
@ -2904,16 +2916,21 @@ class Superset(BaseSupersetView):
@has_access_api
@expose("/queries/<last_updated_ms>")
def queries(self, last_updated_ms):
"""Get the updated queries."""
"""
Get the updated queries.
:param last_updated_ms: unix time, milliseconds
"""
last_updated_ms_int = int(float(last_updated_ms)) if last_updated_ms else 0
return self.queries_exec(last_updated_ms_int)
def queries_exec(self, last_updated_ms_int: int):
stats_logger.incr("queries")
if not g.user.get_id():
return json_error_response(
"Please login to access the queries.", status=403
)
# Unix time, milliseconds.
last_updated_ms_int = int(float(last_updated_ms)) if last_updated_ms else 0
# UTC date time, same that is stored in the DB.
last_updated_dt = utils.EPOCH + timedelta(seconds=last_updated_ms_int / 1000)