refactor: Cleanup user get_id/get_user_id (#20492)

Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
John Bodley 2022-06-24 17:57:04 -07:00 committed by GitHub
parent c56e37cda2
commit 3483446c28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 182 additions and 137 deletions

View File

@ -805,7 +805,7 @@ class ChartRestApi(BaseSupersetModelRestApi):
charts = ChartDAO.find_by_ids(requested_ids)
if not charts:
return self.response_404()
favorited_chart_ids = ChartDAO.favorited_ids(charts, g.user.get_id())
favorited_chart_ids = ChartDAO.favorited_ids(charts)
res = [
{"id": request_id, "value": request_id in favorited_chart_ids}
for request_id in requested_ids

View File

@ -25,6 +25,7 @@ from superset.dao.base import BaseDAO
from superset.extensions import db
from superset.models.core import FavStar, FavStarClassName
from superset.models.slice import Slice
from superset.utils.core import get_user_id
if TYPE_CHECKING:
from superset.connectors.base.models import BaseDatasource
@ -70,7 +71,7 @@ class ChartDAO(BaseDAO):
db.session.commit()
@staticmethod
def favorited_ids(charts: List[Slice], current_user_id: int) -> List[FavStar]:
def favorited_ids(charts: List[Slice]) -> List[FavStar]:
ids = [chart.id for chart in charts]
return [
star.obj_id
@ -78,7 +79,7 @@ class ChartDAO(BaseDAO):
.filter(
FavStar.class_name == FavStarClassName.CHART,
FavStar.obj_id.in_(ids),
FavStar.user_id == current_user_id,
FavStar.user_id == get_user_id(),
)
.all()
]

View File

@ -21,7 +21,7 @@ import logging
from typing import Any, Dict, Optional, TYPE_CHECKING
import simplejson
from flask import current_app, g, make_response, request, Response
from flask import current_app, make_response, request, Response
from flask_appbuilder.api import expose, protect
from flask_babel import gettext as _
from marshmallow import ValidationError
@ -44,7 +44,7 @@ from superset.connectors.base.models import BaseDatasource
from superset.exceptions import QueryObjectValidationError
from superset.extensions import event_logger
from superset.utils.async_query_manager import AsyncQueryTokenException
from superset.utils.core import create_zip, json_int_dttm_ser
from superset.utils.core import create_zip, get_user_id, json_int_dttm_ser
from superset.views.base import CsvResponse, generate_download_headers
from superset.views.base_api import statsd_metrics
@ -324,7 +324,7 @@ class ChartDataRestApi(ChartRestApi):
except AsyncQueryTokenException:
return self.response_401()
result = async_command.run(form_data, g.user.get_id())
result = async_command.run(form_data, get_user_id())
return self.response(202, **result)
def _send_chart_response(

View File

@ -32,7 +32,7 @@ class CreateAsyncChartDataJobCommand:
jwt_data = async_query_manager.parse_jwt_from_request(request)
self._async_channel_id = jwt_data["channel"]
def run(self, form_data: Dict[str, Any], user_id: Optional[str]) -> Dict[str, Any]:
def run(self, form_data: Dict[str, Any], user_id: Optional[int]) -> Dict[str, Any]:
job_metadata = async_query_manager.init_job(self._async_channel_id, user_id)
load_chart_data_into_cache.delay(job_metadata, form_data)
return job_metadata

View File

@ -942,9 +942,7 @@ class DashboardRestApi(BaseSupersetModelRestApi):
dashboards = DashboardDAO.find_by_ids(requested_ids)
if not dashboards:
return self.response_404()
favorited_dashboard_ids = DashboardDAO.favorited_ids(
dashboards, g.user.get_id()
)
favorited_dashboard_ids = DashboardDAO.favorited_ids(dashboards)
res = [
{"id": request_id, "value": request_id in favorited_dashboard_ids}
for request_id in requested_ids

View File

@ -29,6 +29,7 @@ from superset.extensions import db
from superset.models.core import FavStar, FavStarClassName
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.utils.core import get_user_id
from superset.utils.dashboard_filter_scopes_converter import copy_filter_scopes
logger = logging.getLogger(__name__)
@ -274,9 +275,7 @@ class DashboardDAO(BaseDAO):
return dashboard
@staticmethod
def favorited_ids(
dashboards: List[Dashboard], current_user_id: int
) -> List[FavStar]:
def favorited_ids(dashboards: List[Dashboard]) -> List[FavStar]:
ids = [dash.id for dash in dashboards]
return [
star.obj_id
@ -284,7 +283,7 @@ class DashboardDAO(BaseDAO):
.filter(
FavStar.class_name == FavStarClassName.DASHBOARD,
FavStar.obj_id.in_(ids),
FavStar.user_id == current_user_id,
FavStar.user_id == get_user_id(),
)
.all()
]

View File

@ -29,6 +29,7 @@ from superset.models.dashboard import Dashboard
from superset.models.embedded_dashboard import EmbeddedDashboard
from superset.models.slice import Slice
from superset.security.guest_token import GuestTokenResourceType, GuestUser
from superset.utils.core import get_user_id
from superset.views.base import BaseFilter, is_user_admin
from superset.views.base_api import BaseFavoriteFilter
@ -57,9 +58,9 @@ class DashboardCreatedByMeFilter(BaseFilter): # pylint: disable=too-few-public-
return query.filter(
or_(
Dashboard.created_by_fk # pylint: disable=comparison-with-callable
== g.user.get_user_id(),
== get_user_id(),
Dashboard.changed_by_fk # pylint: disable=comparison-with-callable
== g.user.get_user_id(),
== get_user_id(),
)
)
@ -126,17 +127,14 @@ class DashboardAccessFilter(BaseFilter): # pylint: disable=too-few-public-metho
users_favorite_dash_query = db.session.query(FavStar.obj_id).filter(
and_(
FavStar.user_id == security_manager.user_model.get_user_id(),
FavStar.user_id == get_user_id(),
FavStar.class_name == "Dashboard",
)
)
owner_ids_query = (
db.session.query(Dashboard.id)
.join(Dashboard.owners)
.filter(
security_manager.user_model.id
== security_manager.user_model.get_user_id()
)
.filter(security_manager.user_model.id == get_user_id())
)
feature_flagged_filters = []

View File

@ -41,7 +41,11 @@ from typing_extensions import TypedDict
from superset.datasets.commands.exceptions import DatasetNotFoundError
from superset.exceptions import SupersetTemplateException
from superset.extensions import feature_flag_manager
from superset.utils.core import convert_legacy_filters_into_adhoc, merge_extra_filters
from superset.utils.core import (
convert_legacy_filters_into_adhoc,
get_user_id,
merge_extra_filters,
)
from superset.utils.memoized import memoized
if TYPE_CHECKING:
@ -115,9 +119,10 @@ class ExtraCache:
"""
if hasattr(g, "user") and g.user:
id_ = get_user_id()
if add_to_cache_keys:
self.cache_key_wrapper(g.user.get_id())
return g.user.get_id()
self.cache_key_wrapper(id_)
return id_
return None
def current_username(self, add_to_cache_keys: bool = True) -> Optional[str]:

View File

@ -67,4 +67,4 @@ def get_uuid_namespace(seed: str) -> UUID:
def get_owner(user: User) -> Optional[int]:
return user.get_user_id() if not user.is_anonymous else None
return user.id if not user.is_anonymous else None

View File

@ -34,6 +34,7 @@ from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy.orm import relationship
from superset import db
from superset.utils.core import get_user_id
Base = declarative_base()
@ -63,17 +64,10 @@ dashboard_user = Table(
class AuditMixin:
@classmethod
def get_user_id(cls):
try:
return g.user.id
except Exception:
return None
@declared_attr
def created_by_fk(cls):
return Column(
Integer, ForeignKey("ab_user.id"), default=cls.get_user_id, nullable=False
Integer, ForeignKey("ab_user.id"), default=get_user_id, nullable=False
)
@declared_attr

View File

@ -34,6 +34,7 @@ from sqlalchemy import Column, DateTime, Enum, ForeignKey, Integer, String
from sqlalchemy.ext.declarative import declarative_base, declared_attr
from superset.models.tags import ObjectTypes, TagTypes
from superset.utils.core import get_user_id
Base = declarative_base()
@ -54,7 +55,7 @@ class AuditMixinNullable(AuditMixin):
return Column(
Integer,
ForeignKey("ab_user.id"),
default=self.get_user_id,
default=get_user_id,
nullable=True,
)
@ -63,8 +64,8 @@ class AuditMixinNullable(AuditMixin):
return Column(
Integer,
ForeignKey("ab_user.id"),
default=self.get_user_id,
onupdate=self.get_user_id,
default=get_user_id,
onupdate=get_user_id,
nullable=True,
)

View File

@ -40,6 +40,7 @@ from sqlalchemy.orm.exc import MultipleResultsFound
from sqlalchemy_utils import UUIDType
from superset.common.db_query_status import QueryStatus
from superset.utils.core import get_user_id
logger = logging.getLogger(__name__)
@ -384,7 +385,7 @@ class AuditMixinNullable(AuditMixin):
return sa.Column(
sa.Integer,
sa.ForeignKey("ab_user.id"),
default=self.get_user_id,
default=get_user_id,
nullable=True,
)
@ -393,8 +394,8 @@ class AuditMixinNullable(AuditMixin):
return sa.Column(
sa.Integer,
sa.ForeignKey("ab_user.id"),
default=self.get_user_id,
onupdate=self.get_user_id,
default=get_user_id,
onupdate=get_user_id,
nullable=True,
)

View File

@ -16,11 +16,11 @@
# under the License.
from typing import Any
from flask import g
from flask_sqlalchemy import BaseQuery
from superset import security_manager
from superset.models.sql_lab import Query
from superset.utils.core import get_user_id
from superset.views.base import BaseFilter
@ -33,5 +33,5 @@ class QueryFilter(BaseFilter): # pylint: disable=too-few-public-methods
:returns: query
"""
if not security_manager.can_access_all_queries():
query = query.filter(Query.user_id == g.user.get_user_id())
query = query.filter(Query.user_id == get_user_id())
return query

View File

@ -75,7 +75,7 @@ from superset.security.guest_token import (
GuestTokenUser,
GuestUser,
)
from superset.utils.core import DatasourceName, RowLevelSecurityFilterType
from superset.utils.core import DatasourceName, get_user_id, RowLevelSecurityFilterType
from superset.utils.urls import get_url_host
if TYPE_CHECKING:
@ -529,7 +529,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
view_menu_names = (
base_query.join(assoc_user_role)
.join(self.user_model)
.filter(self.user_model.id == g.user.get_id())
.filter(self.user_model.id == get_user_id())
.filter(self.permission_model.name == permission_name)
).all()
return {s.name for s in view_menu_names}
@ -1252,10 +1252,9 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
@staticmethod
def raise_for_user_activity_access(user_id: int) -> None:
user = g.user if g.user and g.user.get_id() else None
if not user or (
if not get_user_id() or (
not current_app.config["ENABLE_BROAD_ACTIVITY_ACCESS"]
and user_id != user.id
and user_id != get_user_id()
):
raise SupersetSecurityException(
SupersetError(

View File

@ -28,7 +28,7 @@ from superset import is_feature_enabled
from superset.models.sql_lab import Query
from superset.sql_parse import CtasMethod
from superset.utils import core as utils
from superset.utils.core import apply_max_row_limit
from superset.utils.core import apply_max_row_limit, get_user_id
from superset.utils.dates import now_as_float
from superset.views.utils import get_cta_schema_name
@ -64,7 +64,7 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes
self.create_table_as_select = None
self.database = None
self._init_from_query_params(query_params)
self.user_id = self._get_user_id()
self.user_id = get_user_id()
self.client_id_or_short_id = cast(str, self.client_id or utils.shortid()[:10])
def set_query(self, query: Query) -> None:
@ -111,12 +111,6 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes
limit = 0
return limit
def _get_user_id(self) -> Optional[int]: # pylint: disable=no-self-use
try:
return g.user.get_id() if g.user else None
except RuntimeError:
return None
def is_run_asynchronous(self) -> bool:
return self.async_flag

View File

@ -21,7 +21,9 @@ from typing import Any, Dict, List, Optional, Tuple
import jwt
import redis
from flask import Flask, g, request, Request, Response, session
from flask import Flask, request, Request, Response, session
from superset.utils.core import get_user_id
logger = logging.getLogger(__name__)
@ -35,12 +37,12 @@ class AsyncQueryJobException(Exception):
def build_job_metadata(
channel_id: str, job_id: str, user_id: Optional[str], **kwargs: Any
channel_id: str, job_id: str, user_id: Optional[int], **kwargs: Any
) -> Dict[str, Any]:
return {
"channel_id": channel_id,
"job_id": job_id,
"user_id": int(user_id) if user_id else None,
"user_id": user_id,
"status": kwargs.get("status"),
"errors": kwargs.get("errors", []),
"result_url": kwargs.get("result_url"),
@ -113,13 +115,7 @@ class AsyncQueryManager:
@app.after_request
def validate_session(response: Response) -> Response:
user_id = None
try:
user_id = g.user.get_id()
user_id = int(user_id)
except Exception: # pylint: disable=broad-except
pass
user_id = get_user_id()
reset_token = (
not request.cookies.get(self._jwt_cookie_name)
@ -161,7 +157,7 @@ class AsyncQueryManager:
logger.warning("Parse jwt failed", exc_info=True)
raise AsyncQueryTokenException("Failed to parse token") from ex
def init_job(self, channel_id: str, user_id: Optional[str]) -> Dict[str, Any]:
def init_job(self, channel_id: str, user_id: Optional[int]) -> Dict[str, Any]:
job_id = str(uuid.uuid4())
return build_job_metadata(
channel_id, job_id, user_id, status=self.STATUS_PENDING

View File

@ -1422,13 +1422,36 @@ def split_adhoc_filters_into_base_filters( # pylint: disable=invalid-name
def get_username() -> Optional[str]:
"""Get username if within the flask context, otherwise return noffin'"""
"""
Get username (if defined) associated with the current user.
:returns: The username
"""
try:
return g.user.username
except Exception: # pylint: disable=broad-except
return None
def get_user_id() -> Optional[int]:
"""
Get the user identifier (if defined) associated with the current user.
Though the Flask-AppBuilder `User` and Flask-Login `AnonymousUserMixin` and
`UserMixin` models provide a convenience `get_id` method, for generality, the
identifier is encoded as a `str` whereas in Superset all identifiers are encoded as
an `int`.
returns: The user identifier
"""
try:
return g.user.id
except Exception: # pylint: disable=broad-except
return None
@contextmanager
def override_user(user: Optional[User]) -> Iterator[Any]:
"""

View File

@ -41,6 +41,8 @@ from flask_appbuilder.const import API_URI_RIS_KEY
from sqlalchemy.exc import SQLAlchemyError
from typing_extensions import Literal
from superset.utils.core import get_user_id
if TYPE_CHECKING:
from superset.stats_logger import BaseStatsLogger
@ -133,10 +135,7 @@ class AbstractEventLogger(ABC):
duration_ms = int(duration.total_seconds() * 1000) if duration else None
# Initial try and grab user_id via flask.g.user
try:
user_id = g.user.get_id()
except Exception: # pylint: disable=broad-except
user_id = None
user_id = get_user_id()
# Whenever a user is not bounded to a session we
# need to add them back before logging to capture user_id
@ -144,7 +143,7 @@ class AbstractEventLogger(ABC):
try:
session = current_app.appbuilder.get_session
session.add(g.user)
user_id = g.user.get_id()
user_id = get_user_id()
except Exception as ex: # pylint: disable=broad-except
logging.warning(ex)
user_id = None

View File

@ -76,6 +76,7 @@ from superset.models.reports import ReportRecipientType
from superset.superset_typing import FlaskResponse
from superset.translations.utils import get_language_pack
from superset.utils import core as utils
from superset.utils.core import get_user_id
from .utils import bootstrap_user_data
@ -623,10 +624,7 @@ class DatasourceFilter(BaseFilter): # pylint: disable=too-few-public-methods
owner_ids_query = (
db.session.query(models.SqlaTable.id)
.join(models.SqlaTable.owners)
.filter(
security_manager.user_model.id
== security_manager.user_model.get_user_id()
)
.filter(security_manager.user_model.id == get_user_id())
)
return query.filter(
or_(

View File

@ -18,7 +18,7 @@ import functools
import logging
from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type, Union
from flask import Blueprint, g, request, Response
from flask import Blueprint, request, Response
from flask_appbuilder import AppBuilder, Model, ModelRestApi
from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.models.filters import BaseFilter, Filters
@ -38,7 +38,7 @@ from superset.schemas import error_payload_content
from superset.sql_lab import Query as SqllabQuery
from superset.stats_logger import BaseStatsLogger
from superset.superset_typing import FlaskResponse
from superset.utils.core import time_function
from superset.utils.core import get_user_id, time_function
from superset.views.base import handle_api_exception
logger = logging.getLogger(__name__)
@ -145,7 +145,7 @@ class BaseFavoriteFilter(BaseFilter): # pylint: disable=too-few-public-methods
return query
users_favorite_query = db.session.query(FavStar.obj_id).filter(
and_(
FavStar.user_id == g.user.get_id(),
FavStar.user_id == get_user_id(),
FavStar.class_name == self.class_name,
)
)

View File

@ -21,6 +21,8 @@ from flask_appbuilder import Model
from marshmallow import post_load, pre_load, Schema, ValidationError
from sqlalchemy.orm.exc import NoResultFound
from superset.utils.core import get_user_id
def validate_owner(value: int) -> None:
try:
@ -113,8 +115,9 @@ class BaseOwnedSchema(BaseSupersetSchema):
@staticmethod
def set_owners(instance: Model, owners: List[int]) -> None:
owner_objs = []
if g.user.get_id() not in owners:
owners.append(g.user.get_id())
user_id = get_user_id()
if user_id and user_id not in owners:
owners.append(user_id)
for owner_id in owners:
user = current_app.appbuilder.get_session.query(
current_app.appbuilder.sm.user_model

View File

@ -132,6 +132,7 @@ from superset.utils.cache import etag_cache
from superset.utils.core import (
apply_max_row_limit,
DatasourceType,
get_user_id,
ReservedUrlParameters,
)
from superset.utils.dates import now_as_float
@ -673,7 +674,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
request
)["channel"]
job_metadata = async_query_manager.init_job(
async_channel_id, g.user.get_id()
async_channel_id, get_user_id()
)
load_explore_json_into_cache.delay(
job_metadata, form_data, response_type, force
@ -1885,13 +1886,13 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
self, class_name: str, obj_id: int, action: str
) -> FlaskResponse:
"""Toggle favorite stars on Slices and Dashboard"""
if not g.user.get_id():
if not get_user_id():
return json_error_response("ERROR: Favstar toggling denied", status=403)
session = db.session()
count = 0
favs = (
session.query(FavStar)
.filter_by(class_name=class_name, obj_id=obj_id, user_id=g.user.get_id())
.filter_by(class_name=class_name, obj_id=obj_id, user_id=get_user_id())
.all()
)
if action == "select":
@ -1900,7 +1901,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
FavStar(
class_name=class_name,
obj_id=obj_id,
user_id=g.user.get_id(),
user_id=get_user_id(),
dttm=datetime.now(),
)
)
@ -2582,7 +2583,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
@staticmethod
def queries_exec(last_updated_ms: Union[float, int]) -> FlaskResponse:
stats_logger.incr("queries")
if not g.user.get_id():
if not get_user_id():
return json_error_response(
"Please login to access the queries.", status=403
)
@ -2592,9 +2593,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
sql_queries = (
db.session.query(Query)
.filter(
Query.user_id == g.user.get_id(), Query.changed_on >= last_updated_dt
)
.filter(Query.user_id == get_user_id(), Query.changed_on >= last_updated_dt)
.all()
)
dict_queries = {q.client_id: q.to_dict() for q in sql_queries}
@ -2620,10 +2619,10 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
search_user_id = int(cast(int, request.args.get("user_id")))
except ValueError:
return Response(status=400, mimetype="application/json")
if search_user_id != g.user.get_user_id():
if search_user_id != get_user_id():
return Response(status=403, mimetype="application/json")
else:
search_user_id = g.user.get_user_id()
search_user_id = get_user_id()
database_id = request.args.get("database_id")
search_text = request.args.get("search_text")
status = request.args.get("status")
@ -2676,14 +2675,14 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
@expose("/welcome/")
def welcome(self) -> FlaskResponse:
"""Personalized welcome page"""
if not g.user or not g.user.get_id():
if not get_user_id():
if conf["PUBLIC_ROLE_LIKE"]:
return self.render_template("superset/public_welcome.html")
return redirect(appbuilder.get_url_for_login)
welcome_dashboard_id = (
db.session.query(UserAttribute.welcome_dashboard_id)
.filter_by(user_id=g.user.get_id())
.filter_by(user_id=get_user_id())
.scalar()
)
if welcome_dashboard_id:
@ -2728,7 +2727,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
)
@staticmethod
def _get_sqllab_tabs(user_id: int) -> Dict[str, Any]:
def _get_sqllab_tabs(user_id: Optional[int]) -> Dict[str, Any]:
# send list of tab state ids
tabs_state = (
db.session.query(TabState.id, TabState.label)
@ -2780,7 +2779,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
payload = {
"defaultDbId": config["SQLLAB_DEFAULT_DBID"],
"common": common_bootstrap_payload(),
**self._get_sqllab_tabs(g.user.get_id()),
**self._get_sqllab_tabs(get_user_id()),
}
form_data = request.form.get("form_data")

View File

@ -29,6 +29,7 @@ from superset.constants import MODEL_VIEW_RW_METHOD_PERMISSION_MAP, RouteMethod
from superset.models.sql_lab import Query, SavedQuery, TableSchema, TabState
from superset.superset_typing import FlaskResponse
from superset.utils import core as utils
from superset.utils.core import get_user_id
from .base import BaseSupersetView, DeleteMixin, json_success, SupersetModelView
@ -136,7 +137,7 @@ class TabStateView(BaseSupersetView):
def post(self) -> FlaskResponse: # pylint: disable=no-self-use
query_editor = json.loads(request.form["queryEditor"])
tab_state = TabState(
user_id=g.user.get_id(),
user_id=get_user_id(),
label=query_editor.get("title", "Untitled Query"),
active=True,
database_id=query_editor["dbId"],
@ -147,7 +148,7 @@ class TabStateView(BaseSupersetView):
)
(
db.session.query(TabState)
.filter_by(user_id=g.user.get_id())
.filter_by(user_id=get_user_id())
.update({"active": False})
)
db.session.add(tab_state)
@ -157,7 +158,7 @@ class TabStateView(BaseSupersetView):
@has_access_api
@expose("/<int:tab_state_id>", methods=["DELETE"])
def delete(self, tab_state_id: int) -> FlaskResponse: # pylint: disable=no-self-use
if _get_owner_id(tab_state_id) != int(g.user.get_id()):
if _get_owner_id(tab_state_id) != get_user_id():
return Response(status=403)
db.session.query(TabState).filter(TabState.id == tab_state_id).delete(
@ -172,7 +173,7 @@ class TabStateView(BaseSupersetView):
@has_access_api
@expose("/<int:tab_state_id>", methods=["GET"])
def get(self, tab_state_id: int) -> FlaskResponse: # pylint: disable=no-self-use
if _get_owner_id(tab_state_id) != int(g.user.get_id()):
if _get_owner_id(tab_state_id) != get_user_id():
return Response(status=403)
tab_state = db.session.query(TabState).filter_by(id=tab_state_id).first()
@ -190,12 +191,12 @@ class TabStateView(BaseSupersetView):
owner_id = _get_owner_id(tab_state_id)
if owner_id is None:
return Response(status=404)
if owner_id != int(g.user.get_id()):
if owner_id != get_user_id():
return Response(status=403)
(
db.session.query(TabState)
.filter_by(user_id=g.user.get_id())
.filter_by(user_id=get_user_id())
.update({"active": TabState.id == tab_state_id})
)
db.session.commit()
@ -204,7 +205,7 @@ class TabStateView(BaseSupersetView):
@has_access_api
@expose("<int:tab_state_id>", methods=["PUT"])
def put(self, tab_state_id: int) -> FlaskResponse: # pylint: disable=no-self-use
if _get_owner_id(tab_state_id) != int(g.user.get_id()):
if _get_owner_id(tab_state_id) != get_user_id():
return Response(status=403)
fields = {k: json.loads(v) for k, v in request.form.to_dict().items()}
@ -217,7 +218,7 @@ class TabStateView(BaseSupersetView):
def migrate_query( # pylint: disable=no-self-use
self, tab_state_id: int
) -> FlaskResponse:
if _get_owner_id(tab_state_id) != int(g.user.get_id()):
if _get_owner_id(tab_state_id) != get_user_id():
return Response(status=403)
client_id = json.loads(request.form["queryId"])
@ -244,7 +245,7 @@ class TabStateView(BaseSupersetView):
.filter(
and_(
Query.client_id != client_id,
Query.user_id == g.user.get_id(),
Query.user_id == get_user_id(),
Query.sql_editor_id == str(tab_state_id),
),
)
@ -257,7 +258,7 @@ class TabStateView(BaseSupersetView):
db.session.query(Query).filter_by(
client_id=client_id,
user_id=g.user.get_id(),
user_id=get_user_id(),
sql_editor_id=str(tab_state_id),
).delete(synchronize_session=False)
db.session.commit()
@ -327,4 +328,4 @@ class SqlLab(BaseSupersetView):
logger.warning(
"This endpoint is deprecated and will be removed in the next major release"
)
return redirect("/savedqueryview/list/?_flt_0_user={}".format(g.user.get_id()))
return redirect(f"/savedqueryview/list/?_flt_0_user={get_user_id()}")

View File

@ -18,11 +18,12 @@
"""Unit tests for Superset"""
import json
import unittest
from typing import Optional
from unittest import mock
import pytest
from flask import g
from flask.ctx import AppContext
from pytest_mock import MockFixture
from sqlalchemy import inspect
from tests.integration_tests.fixtures.birth_names_dashboard import (
@ -42,7 +43,7 @@ from superset import db, security_manager
from superset.connectors.sqla.models import SqlaTable
from superset.models import core as models
from superset.models.datasource_access_request import DatasourceAccessRequest
from superset.utils.core import get_username, override_user
from superset.utils.core import get_user_id, get_username, override_user
from superset.utils.database import get_example_database
from .base_tests import SupersetTestCase
@ -524,18 +525,40 @@ class TestRequestAccess(SupersetTestCase):
session.commit()
@pytest.mark.parametrize(
"username,user_id",
[
(None, None),
("alpha", 5),
("gamma", 2),
],
)
def test_get_user_id(
app_context: AppContext,
mocker: MockFixture,
username: Optional[str],
user_id: Optional[int],
) -> None:
mock_g = mocker.patch("superset.utils.core.g", spec={})
mock_g.user = security_manager.find_user(username)
assert get_user_id() == user_id
@pytest.mark.parametrize(
"username",
[
None,
"alpha",
"gamma",
],
)
def test_get_username(app_context: AppContext, username: str) -> None:
assert not hasattr(g, "user")
assert get_username() is None
g.user = security_manager.find_user(username)
def test_get_username(
app_context: AppContext,
mocker: MockFixture,
username: Optional[str],
) -> None:
mock_g = mocker.patch("superset.utils.core.g", spec={})
mock_g.user = security_manager.find_user(username)
assert get_username() == username
@ -543,26 +566,32 @@ def test_get_username(app_context: AppContext, username: str) -> None:
"username",
[
None,
"alpha",
"gamma",
],
)
def test_override_user(app_context: AppContext, username: str) -> None:
def test_override_user(
app_context: AppContext,
mocker: MockFixture,
username: str,
) -> None:
mock_g = mocker.patch("superset.utils.core.g", spec={})
admin = security_manager.find_user(username="admin")
user = security_manager.find_user(username)
assert not hasattr(g, "user")
assert not hasattr(mock_g, "user")
with override_user(user):
assert g.user == user
assert mock_g.user == user
assert not hasattr(g, "user")
assert not hasattr(mock_g, "user")
g.user = admin
mock_g.user = admin
with override_user(user):
assert g.user == user
assert mock_g.user == user
assert g.user == admin
assert mock_g.user == admin
if __name__ == "__main__":

View File

@ -112,7 +112,7 @@ class TestEventLogger(unittest.TestCase):
)
self.assertGreaterEqual(payload["duration_ms"], 100)
@patch("superset.utils.log.g", spec={})
@patch("superset.utils.core.g", spec={})
@freeze_time("Jan 14th, 2020", auto_tick_seconds=15)
def test_context_manager_log(self, mock_g):
class DummyEventLogger(AbstractEventLogger):
@ -144,12 +144,12 @@ class TestEventLogger(unittest.TestCase):
assert logger.records == [
{
"records": [{"path": "/", "engine": "bar"}],
"user_id": "2",
"user_id": 2,
"duration": 15000.0,
}
]
@patch("superset.utils.log.g", spec={})
@patch("superset.utils.core.g", spec={})
def test_context_manager_log_with_context(self, mock_g):
class DummyEventLogger(AbstractEventLogger):
def __init__(self):
@ -191,12 +191,12 @@ class TestEventLogger(unittest.TestCase):
"payload_override": {"engine": "sqllite"},
}
],
"user_id": "2",
"user_id": 2,
"duration": 5558756000,
}
]
@patch("superset.utils.log.g", spec={})
@patch("superset.utils.core.g", spec={})
def test_log_with_context_user_null(self, mock_g):
class DummyEventLogger(AbstractEventLogger):
def __init__(self):

View File

@ -412,8 +412,9 @@ class TestRolePermission(SupersetTestCase):
# TODO test slice permission
@patch("superset.security.manager.g")
def test_schemas_accessible_by_user_admin(self, mock_g):
mock_g.user = security_manager.find_user("admin")
@patch("superset.utils.core.g")
def test_schemas_accessible_by_user_admin(self, mock_sm_g, mock_g):
mock_g.user = mock_sm_g.user = security_manager.find_user("admin")
with self.client.application.test_request_context():
database = get_example_database()
schemas = security_manager.get_schemas_accessible_by_user(
@ -422,10 +423,11 @@ class TestRolePermission(SupersetTestCase):
self.assertEqual(schemas, ["1", "2", "3"]) # no changes
@patch("superset.security.manager.g")
def test_schemas_accessible_by_user_schema_access(self, mock_g):
@patch("superset.utils.core.g")
def test_schemas_accessible_by_user_schema_access(self, mock_sm_g, mock_g):
# User has schema access to the schema 1
create_schema_perm("[examples].[1]")
mock_g.user = security_manager.find_user("gamma")
mock_g.user = mock_sm_g.user = security_manager.find_user("gamma")
with self.client.application.test_request_context():
database = get_example_database()
schemas = security_manager.get_schemas_accessible_by_user(
@ -436,9 +438,10 @@ class TestRolePermission(SupersetTestCase):
delete_schema_perm("[examples].[1]")
@patch("superset.security.manager.g")
def test_schemas_accessible_by_user_datasource_access(self, mock_g):
@patch("superset.utils.core.g")
def test_schemas_accessible_by_user_datasource_access(self, mock_sm_g, mock_g):
# User has schema access to the datasource temp_schema.wb_health_population in examples DB.
mock_g.user = security_manager.find_user("gamma")
mock_g.user = mock_sm_g.user = security_manager.find_user("gamma")
with self.client.application.test_request_context():
database = get_example_database()
schemas = security_manager.get_schemas_accessible_by_user(
@ -447,10 +450,13 @@ class TestRolePermission(SupersetTestCase):
self.assertEqual(schemas, ["temp_schema"])
@patch("superset.security.manager.g")
def test_schemas_accessible_by_user_datasource_and_schema_access(self, mock_g):
@patch("superset.utils.core.g")
def test_schemas_accessible_by_user_datasource_and_schema_access(
self, mock_sm_g, mock_g
):
# User has schema access to the datasource temp_schema.wb_health_population in examples DB.
create_schema_perm("[examples].[2]")
mock_g.user = security_manager.find_user("gamma")
mock_g.user = mock_sm_g.user = security_manager.find_user("gamma")
with self.client.application.test_request_context():
database = get_example_database()
schemas = security_manager.get_schemas_accessible_by_user(

View File

@ -32,6 +32,7 @@ from superset.tasks.async_queries import (
load_chart_data_into_cache,
load_explore_json_into_cache,
)
from superset.utils.core import get_user_id
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
@ -218,7 +219,7 @@ class TestAsyncQueries(SupersetTestCase):
ensure_user_is_set(1)
self.assertTrue(hasattr(g, "user"))
self.assertFalse(g.user.is_anonymous)
self.assertEqual("1", g.user.get_id())
self.assertEqual(1, get_user_id())
del g.user
@ -226,22 +227,22 @@ class TestAsyncQueries(SupersetTestCase):
ensure_user_is_set(None)
self.assertTrue(hasattr(g, "user"))
self.assertTrue(g.user.is_anonymous)
self.assertEqual(None, g.user.get_id())
self.assertEqual(None, get_user_id())
del g.user
g.user = security_manager.get_user_by_id(2)
self.assertEqual("2", g.user.get_id())
self.assertEqual(2, get_user_id())
ensure_user_is_set(1)
self.assertTrue(hasattr(g, "user"))
self.assertFalse(g.user.is_anonymous)
self.assertEqual("2", g.user.get_id())
self.assertEqual(2, get_user_id())
ensure_user_is_set(None)
self.assertTrue(hasattr(g, "user"))
self.assertFalse(g.user.is_anonymous)
self.assertEqual("2", g.user.get_id())
self.assertEqual(2, get_user_id())
if g_user_is_set:
g.user = original_g_user