fix: Use g.user for getting the user_id for async queries (#14702)

* fix: Use g.user for getting the user_id

* Use id form for one user.id call

* Fix references to g.user

* Correct types

* Use if over try/catch

* Switch back to try/except
This commit is contained in:
Ben Reinhart 2021-05-21 14:30:13 -07:00 committed by GitHub
parent d5c008dd99
commit b38596fd96
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 23 additions and 11 deletions

View File

@ -595,7 +595,7 @@ class ChartRestApi(BaseSupersetModelRestApi):
except AsyncQueryTokenException:
return self.response_401()
result = command.run_async()
result = command.run_async(g.user.get_id())
return self.response(202, **result)
return self.get_data_response(command)

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict
from typing import Any, Dict, Optional
from flask import Request
from marshmallow import ValidationError
@ -67,8 +67,8 @@ class ChartDataCommand(BaseCommand):
return return_value
def run_async(self) -> Dict[str, Any]:
job_metadata = async_query_manager.init_job(self._async_channel_id)
def run_async(self, user_id: Optional[str]) -> Dict[str, Any]:
job_metadata = async_query_manager.init_job(self._async_channel_id, user_id)
load_chart_data_into_cache.delay(job_metadata, self._form_data)
return job_metadata

View File

@ -21,7 +21,7 @@ from typing import Any, Dict, List, Optional, Tuple
import jwt
import redis
from flask import Flask, request, Request, Response, session
from flask import Flask, g, request, Request, Response, session
logger = logging.getLogger(__name__)
@ -34,11 +34,13 @@ class AsyncQueryJobException(Exception):
pass
def build_job_metadata(channel_id: str, job_id: str, **kwargs: Any) -> Dict[str, Any]:
def build_job_metadata(
channel_id: str, job_id: str, user_id: Optional[str], **kwargs: Any
) -> Dict[str, Any]:
return {
"channel_id": channel_id,
"job_id": job_id,
"user_id": session.get("user_id"),
"user_id": int(user_id) if user_id else None,
"status": kwargs.get("status"),
"errors": kwargs.get("errors", []),
"result_url": kwargs.get("result_url"),
@ -115,7 +117,13 @@ class AsyncQueryManager:
def validate_session( # pylint: disable=unused-variable
response: Response,
) -> Response:
user_id = session["user_id"] if "user_id" in session else None
user_id = None
try:
user_id = g.user.get_id()
user_id = int(user_id)
except Exception: # pylint: disable=broad-except
pass
reset_token = (
not request.cookies.get(self._jwt_cookie_name)
@ -161,9 +169,11 @@ class AsyncQueryManager:
logger.warning(exc)
raise AsyncQueryTokenException("Failed to parse token")
def init_job(self, channel_id: str) -> Dict[str, Any]:
def init_job(self, channel_id: str, user_id: Optional[str]) -> Dict[str, Any]:
job_id = str(uuid.uuid4())
return build_job_metadata(channel_id, job_id, status=self.STATUS_PENDING)
return build_job_metadata(
channel_id, job_id, user_id, status=self.STATUS_PENDING
)
def read_events(
self, channel: str, last_id: Optional[str]

View File

@ -611,7 +611,9 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
async_channel_id = async_query_manager.parse_jwt_from_request(
request
)["channel"]
job_metadata = async_query_manager.init_job(async_channel_id)
job_metadata = async_query_manager.init_job(
async_channel_id, g.user.get_id()
)
load_explore_json_into_cache.delay(
job_metadata, form_data, response_type, force
)