fix: Use g.user for getting the user_id for async queries (#14702)

* fix: Use g.user for getting the user_id

* Use id form for one user.id call

* Fix references to g.user

* Correct types

* Use if over try/catch

* Switch back to try/except
This commit is contained in:
Ben Reinhart 2021-05-21 14:30:13 -07:00 committed by GitHub
parent d5c008dd99
commit b38596fd96
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 23 additions and 11 deletions

View File

@ -595,7 +595,7 @@ class ChartRestApi(BaseSupersetModelRestApi):
except AsyncQueryTokenException: except AsyncQueryTokenException:
return self.response_401() return self.response_401()
result = command.run_async() result = command.run_async(g.user.get_id())
return self.response(202, **result) return self.response(202, **result)
return self.get_data_response(command) return self.get_data_response(command)

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from typing import Any, Dict from typing import Any, Dict, Optional
from flask import Request from flask import Request
from marshmallow import ValidationError from marshmallow import ValidationError
@ -67,8 +67,8 @@ class ChartDataCommand(BaseCommand):
return return_value return return_value
def run_async(self) -> Dict[str, Any]: def run_async(self, user_id: Optional[str]) -> Dict[str, Any]:
job_metadata = async_query_manager.init_job(self._async_channel_id) job_metadata = async_query_manager.init_job(self._async_channel_id, user_id)
load_chart_data_into_cache.delay(job_metadata, self._form_data) load_chart_data_into_cache.delay(job_metadata, self._form_data)
return job_metadata return job_metadata

View File

@ -21,7 +21,7 @@ from typing import Any, Dict, List, Optional, Tuple
import jwt import jwt
import redis import redis
from flask import Flask, request, Request, Response, session from flask import Flask, g, request, Request, Response, session
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -34,11 +34,13 @@ class AsyncQueryJobException(Exception):
pass pass
def build_job_metadata(channel_id: str, job_id: str, **kwargs: Any) -> Dict[str, Any]: def build_job_metadata(
channel_id: str, job_id: str, user_id: Optional[str], **kwargs: Any
) -> Dict[str, Any]:
return { return {
"channel_id": channel_id, "channel_id": channel_id,
"job_id": job_id, "job_id": job_id,
"user_id": session.get("user_id"), "user_id": int(user_id) if user_id else None,
"status": kwargs.get("status"), "status": kwargs.get("status"),
"errors": kwargs.get("errors", []), "errors": kwargs.get("errors", []),
"result_url": kwargs.get("result_url"), "result_url": kwargs.get("result_url"),
@ -115,7 +117,13 @@ class AsyncQueryManager:
def validate_session( # pylint: disable=unused-variable def validate_session( # pylint: disable=unused-variable
response: Response, response: Response,
) -> Response: ) -> Response:
user_id = session["user_id"] if "user_id" in session else None user_id = None
try:
user_id = g.user.get_id()
user_id = int(user_id)
except Exception: # pylint: disable=broad-except
pass
reset_token = ( reset_token = (
not request.cookies.get(self._jwt_cookie_name) not request.cookies.get(self._jwt_cookie_name)
@ -161,9 +169,11 @@ class AsyncQueryManager:
logger.warning(exc) logger.warning(exc)
raise AsyncQueryTokenException("Failed to parse token") raise AsyncQueryTokenException("Failed to parse token")
def init_job(self, channel_id: str) -> Dict[str, Any]: def init_job(self, channel_id: str, user_id: Optional[str]) -> Dict[str, Any]:
job_id = str(uuid.uuid4()) job_id = str(uuid.uuid4())
return build_job_metadata(channel_id, job_id, status=self.STATUS_PENDING) return build_job_metadata(
channel_id, job_id, user_id, status=self.STATUS_PENDING
)
def read_events( def read_events(
self, channel: str, last_id: Optional[str] self, channel: str, last_id: Optional[str]

View File

@ -611,7 +611,9 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
async_channel_id = async_query_manager.parse_jwt_from_request( async_channel_id = async_query_manager.parse_jwt_from_request(
request request
)["channel"] )["channel"]
job_metadata = async_query_manager.init_job(async_channel_id) job_metadata = async_query_manager.init_job(
async_channel_id, g.user.get_id()
)
load_explore_json_into_cache.delay( load_explore_json_into_cache.delay(
job_metadata, form_data, response_type, force job_metadata, form_data, response_type, force
) )