refactor: Cleanup user get_id/get_user_id (#20492)

Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
John Bodley 2022-06-24 17:57:04 -07:00 committed by GitHub
parent c56e37cda2
commit 3483446c28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 182 additions and 137 deletions

View File

@ -805,7 +805,7 @@ class ChartRestApi(BaseSupersetModelRestApi):
charts = ChartDAO.find_by_ids(requested_ids) charts = ChartDAO.find_by_ids(requested_ids)
if not charts: if not charts:
return self.response_404() return self.response_404()
favorited_chart_ids = ChartDAO.favorited_ids(charts, g.user.get_id()) favorited_chart_ids = ChartDAO.favorited_ids(charts)
res = [ res = [
{"id": request_id, "value": request_id in favorited_chart_ids} {"id": request_id, "value": request_id in favorited_chart_ids}
for request_id in requested_ids for request_id in requested_ids

View File

@ -25,6 +25,7 @@ from superset.dao.base import BaseDAO
from superset.extensions import db from superset.extensions import db
from superset.models.core import FavStar, FavStarClassName from superset.models.core import FavStar, FavStarClassName
from superset.models.slice import Slice from superset.models.slice import Slice
from superset.utils.core import get_user_id
if TYPE_CHECKING: if TYPE_CHECKING:
from superset.connectors.base.models import BaseDatasource from superset.connectors.base.models import BaseDatasource
@ -70,7 +71,7 @@ class ChartDAO(BaseDAO):
db.session.commit() db.session.commit()
@staticmethod @staticmethod
def favorited_ids(charts: List[Slice], current_user_id: int) -> List[FavStar]: def favorited_ids(charts: List[Slice]) -> List[FavStar]:
ids = [chart.id for chart in charts] ids = [chart.id for chart in charts]
return [ return [
star.obj_id star.obj_id
@ -78,7 +79,7 @@ class ChartDAO(BaseDAO):
.filter( .filter(
FavStar.class_name == FavStarClassName.CHART, FavStar.class_name == FavStarClassName.CHART,
FavStar.obj_id.in_(ids), FavStar.obj_id.in_(ids),
FavStar.user_id == current_user_id, FavStar.user_id == get_user_id(),
) )
.all() .all()
] ]

View File

@ -21,7 +21,7 @@ import logging
from typing import Any, Dict, Optional, TYPE_CHECKING from typing import Any, Dict, Optional, TYPE_CHECKING
import simplejson import simplejson
from flask import current_app, g, make_response, request, Response from flask import current_app, make_response, request, Response
from flask_appbuilder.api import expose, protect from flask_appbuilder.api import expose, protect
from flask_babel import gettext as _ from flask_babel import gettext as _
from marshmallow import ValidationError from marshmallow import ValidationError
@ -44,7 +44,7 @@ from superset.connectors.base.models import BaseDatasource
from superset.exceptions import QueryObjectValidationError from superset.exceptions import QueryObjectValidationError
from superset.extensions import event_logger from superset.extensions import event_logger
from superset.utils.async_query_manager import AsyncQueryTokenException from superset.utils.async_query_manager import AsyncQueryTokenException
from superset.utils.core import create_zip, json_int_dttm_ser from superset.utils.core import create_zip, get_user_id, json_int_dttm_ser
from superset.views.base import CsvResponse, generate_download_headers from superset.views.base import CsvResponse, generate_download_headers
from superset.views.base_api import statsd_metrics from superset.views.base_api import statsd_metrics
@ -324,7 +324,7 @@ class ChartDataRestApi(ChartRestApi):
except AsyncQueryTokenException: except AsyncQueryTokenException:
return self.response_401() return self.response_401()
result = async_command.run(form_data, g.user.get_id()) result = async_command.run(form_data, get_user_id())
return self.response(202, **result) return self.response(202, **result)
def _send_chart_response( def _send_chart_response(

View File

@ -32,7 +32,7 @@ class CreateAsyncChartDataJobCommand:
jwt_data = async_query_manager.parse_jwt_from_request(request) jwt_data = async_query_manager.parse_jwt_from_request(request)
self._async_channel_id = jwt_data["channel"] self._async_channel_id = jwt_data["channel"]
def run(self, form_data: Dict[str, Any], user_id: Optional[str]) -> Dict[str, Any]: def run(self, form_data: Dict[str, Any], user_id: Optional[int]) -> Dict[str, Any]:
job_metadata = async_query_manager.init_job(self._async_channel_id, user_id) job_metadata = async_query_manager.init_job(self._async_channel_id, user_id)
load_chart_data_into_cache.delay(job_metadata, form_data) load_chart_data_into_cache.delay(job_metadata, form_data)
return job_metadata return job_metadata

View File

@ -942,9 +942,7 @@ class DashboardRestApi(BaseSupersetModelRestApi):
dashboards = DashboardDAO.find_by_ids(requested_ids) dashboards = DashboardDAO.find_by_ids(requested_ids)
if not dashboards: if not dashboards:
return self.response_404() return self.response_404()
favorited_dashboard_ids = DashboardDAO.favorited_ids( favorited_dashboard_ids = DashboardDAO.favorited_ids(dashboards)
dashboards, g.user.get_id()
)
res = [ res = [
{"id": request_id, "value": request_id in favorited_dashboard_ids} {"id": request_id, "value": request_id in favorited_dashboard_ids}
for request_id in requested_ids for request_id in requested_ids

View File

@ -29,6 +29,7 @@ from superset.extensions import db
from superset.models.core import FavStar, FavStarClassName from superset.models.core import FavStar, FavStarClassName
from superset.models.dashboard import Dashboard from superset.models.dashboard import Dashboard
from superset.models.slice import Slice from superset.models.slice import Slice
from superset.utils.core import get_user_id
from superset.utils.dashboard_filter_scopes_converter import copy_filter_scopes from superset.utils.dashboard_filter_scopes_converter import copy_filter_scopes
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -274,9 +275,7 @@ class DashboardDAO(BaseDAO):
return dashboard return dashboard
@staticmethod @staticmethod
def favorited_ids( def favorited_ids(dashboards: List[Dashboard]) -> List[FavStar]:
dashboards: List[Dashboard], current_user_id: int
) -> List[FavStar]:
ids = [dash.id for dash in dashboards] ids = [dash.id for dash in dashboards]
return [ return [
star.obj_id star.obj_id
@ -284,7 +283,7 @@ class DashboardDAO(BaseDAO):
.filter( .filter(
FavStar.class_name == FavStarClassName.DASHBOARD, FavStar.class_name == FavStarClassName.DASHBOARD,
FavStar.obj_id.in_(ids), FavStar.obj_id.in_(ids),
FavStar.user_id == current_user_id, FavStar.user_id == get_user_id(),
) )
.all() .all()
] ]

View File

@ -29,6 +29,7 @@ from superset.models.dashboard import Dashboard
from superset.models.embedded_dashboard import EmbeddedDashboard from superset.models.embedded_dashboard import EmbeddedDashboard
from superset.models.slice import Slice from superset.models.slice import Slice
from superset.security.guest_token import GuestTokenResourceType, GuestUser from superset.security.guest_token import GuestTokenResourceType, GuestUser
from superset.utils.core import get_user_id
from superset.views.base import BaseFilter, is_user_admin from superset.views.base import BaseFilter, is_user_admin
from superset.views.base_api import BaseFavoriteFilter from superset.views.base_api import BaseFavoriteFilter
@ -57,9 +58,9 @@ class DashboardCreatedByMeFilter(BaseFilter): # pylint: disable=too-few-public-
return query.filter( return query.filter(
or_( or_(
Dashboard.created_by_fk # pylint: disable=comparison-with-callable Dashboard.created_by_fk # pylint: disable=comparison-with-callable
== g.user.get_user_id(), == get_user_id(),
Dashboard.changed_by_fk # pylint: disable=comparison-with-callable Dashboard.changed_by_fk # pylint: disable=comparison-with-callable
== g.user.get_user_id(), == get_user_id(),
) )
) )
@ -126,17 +127,14 @@ class DashboardAccessFilter(BaseFilter): # pylint: disable=too-few-public-metho
users_favorite_dash_query = db.session.query(FavStar.obj_id).filter( users_favorite_dash_query = db.session.query(FavStar.obj_id).filter(
and_( and_(
FavStar.user_id == security_manager.user_model.get_user_id(), FavStar.user_id == get_user_id(),
FavStar.class_name == "Dashboard", FavStar.class_name == "Dashboard",
) )
) )
owner_ids_query = ( owner_ids_query = (
db.session.query(Dashboard.id) db.session.query(Dashboard.id)
.join(Dashboard.owners) .join(Dashboard.owners)
.filter( .filter(security_manager.user_model.id == get_user_id())
security_manager.user_model.id
== security_manager.user_model.get_user_id()
)
) )
feature_flagged_filters = [] feature_flagged_filters = []

View File

@ -41,7 +41,11 @@ from typing_extensions import TypedDict
from superset.datasets.commands.exceptions import DatasetNotFoundError from superset.datasets.commands.exceptions import DatasetNotFoundError
from superset.exceptions import SupersetTemplateException from superset.exceptions import SupersetTemplateException
from superset.extensions import feature_flag_manager from superset.extensions import feature_flag_manager
from superset.utils.core import convert_legacy_filters_into_adhoc, merge_extra_filters from superset.utils.core import (
convert_legacy_filters_into_adhoc,
get_user_id,
merge_extra_filters,
)
from superset.utils.memoized import memoized from superset.utils.memoized import memoized
if TYPE_CHECKING: if TYPE_CHECKING:
@ -115,9 +119,10 @@ class ExtraCache:
""" """
if hasattr(g, "user") and g.user: if hasattr(g, "user") and g.user:
id_ = get_user_id()
if add_to_cache_keys: if add_to_cache_keys:
self.cache_key_wrapper(g.user.get_id()) self.cache_key_wrapper(id_)
return g.user.get_id() return id_
return None return None
def current_username(self, add_to_cache_keys: bool = True) -> Optional[str]: def current_username(self, add_to_cache_keys: bool = True) -> Optional[str]:

View File

@ -67,4 +67,4 @@ def get_uuid_namespace(seed: str) -> UUID:
def get_owner(user: User) -> Optional[int]: def get_owner(user: User) -> Optional[int]:
return user.get_user_id() if not user.is_anonymous else None return user.id if not user.is_anonymous else None

View File

@ -34,6 +34,7 @@ from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy.orm import relationship from sqlalchemy.orm import relationship
from superset import db from superset import db
from superset.utils.core import get_user_id
Base = declarative_base() Base = declarative_base()
@ -63,17 +64,10 @@ dashboard_user = Table(
class AuditMixin: class AuditMixin:
@classmethod
def get_user_id(cls):
try:
return g.user.id
except Exception:
return None
@declared_attr @declared_attr
def created_by_fk(cls): def created_by_fk(cls):
return Column( return Column(
Integer, ForeignKey("ab_user.id"), default=cls.get_user_id, nullable=False Integer, ForeignKey("ab_user.id"), default=get_user_id, nullable=False
) )
@declared_attr @declared_attr

View File

@ -34,6 +34,7 @@ from sqlalchemy import Column, DateTime, Enum, ForeignKey, Integer, String
from sqlalchemy.ext.declarative import declarative_base, declared_attr from sqlalchemy.ext.declarative import declarative_base, declared_attr
from superset.models.tags import ObjectTypes, TagTypes from superset.models.tags import ObjectTypes, TagTypes
from superset.utils.core import get_user_id
Base = declarative_base() Base = declarative_base()
@ -54,7 +55,7 @@ class AuditMixinNullable(AuditMixin):
return Column( return Column(
Integer, Integer,
ForeignKey("ab_user.id"), ForeignKey("ab_user.id"),
default=self.get_user_id, default=get_user_id,
nullable=True, nullable=True,
) )
@ -63,8 +64,8 @@ class AuditMixinNullable(AuditMixin):
return Column( return Column(
Integer, Integer,
ForeignKey("ab_user.id"), ForeignKey("ab_user.id"),
default=self.get_user_id, default=get_user_id,
onupdate=self.get_user_id, onupdate=get_user_id,
nullable=True, nullable=True,
) )

View File

@ -40,6 +40,7 @@ from sqlalchemy.orm.exc import MultipleResultsFound
from sqlalchemy_utils import UUIDType from sqlalchemy_utils import UUIDType
from superset.common.db_query_status import QueryStatus from superset.common.db_query_status import QueryStatus
from superset.utils.core import get_user_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -384,7 +385,7 @@ class AuditMixinNullable(AuditMixin):
return sa.Column( return sa.Column(
sa.Integer, sa.Integer,
sa.ForeignKey("ab_user.id"), sa.ForeignKey("ab_user.id"),
default=self.get_user_id, default=get_user_id,
nullable=True, nullable=True,
) )
@ -393,8 +394,8 @@ class AuditMixinNullable(AuditMixin):
return sa.Column( return sa.Column(
sa.Integer, sa.Integer,
sa.ForeignKey("ab_user.id"), sa.ForeignKey("ab_user.id"),
default=self.get_user_id, default=get_user_id,
onupdate=self.get_user_id, onupdate=get_user_id,
nullable=True, nullable=True,
) )

View File

@ -16,11 +16,11 @@
# under the License. # under the License.
from typing import Any from typing import Any
from flask import g
from flask_sqlalchemy import BaseQuery from flask_sqlalchemy import BaseQuery
from superset import security_manager from superset import security_manager
from superset.models.sql_lab import Query from superset.models.sql_lab import Query
from superset.utils.core import get_user_id
from superset.views.base import BaseFilter from superset.views.base import BaseFilter
@ -33,5 +33,5 @@ class QueryFilter(BaseFilter): # pylint: disable=too-few-public-methods
:returns: query :returns: query
""" """
if not security_manager.can_access_all_queries(): if not security_manager.can_access_all_queries():
query = query.filter(Query.user_id == g.user.get_user_id()) query = query.filter(Query.user_id == get_user_id())
return query return query

View File

@ -75,7 +75,7 @@ from superset.security.guest_token import (
GuestTokenUser, GuestTokenUser,
GuestUser, GuestUser,
) )
from superset.utils.core import DatasourceName, RowLevelSecurityFilterType from superset.utils.core import DatasourceName, get_user_id, RowLevelSecurityFilterType
from superset.utils.urls import get_url_host from superset.utils.urls import get_url_host
if TYPE_CHECKING: if TYPE_CHECKING:
@ -529,7 +529,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
view_menu_names = ( view_menu_names = (
base_query.join(assoc_user_role) base_query.join(assoc_user_role)
.join(self.user_model) .join(self.user_model)
.filter(self.user_model.id == g.user.get_id()) .filter(self.user_model.id == get_user_id())
.filter(self.permission_model.name == permission_name) .filter(self.permission_model.name == permission_name)
).all() ).all()
return {s.name for s in view_menu_names} return {s.name for s in view_menu_names}
@ -1252,10 +1252,9 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
@staticmethod @staticmethod
def raise_for_user_activity_access(user_id: int) -> None: def raise_for_user_activity_access(user_id: int) -> None:
user = g.user if g.user and g.user.get_id() else None if not get_user_id() or (
if not user or (
not current_app.config["ENABLE_BROAD_ACTIVITY_ACCESS"] not current_app.config["ENABLE_BROAD_ACTIVITY_ACCESS"]
and user_id != user.id and user_id != get_user_id()
): ):
raise SupersetSecurityException( raise SupersetSecurityException(
SupersetError( SupersetError(

View File

@ -28,7 +28,7 @@ from superset import is_feature_enabled
from superset.models.sql_lab import Query from superset.models.sql_lab import Query
from superset.sql_parse import CtasMethod from superset.sql_parse import CtasMethod
from superset.utils import core as utils from superset.utils import core as utils
from superset.utils.core import apply_max_row_limit from superset.utils.core import apply_max_row_limit, get_user_id
from superset.utils.dates import now_as_float from superset.utils.dates import now_as_float
from superset.views.utils import get_cta_schema_name from superset.views.utils import get_cta_schema_name
@ -64,7 +64,7 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes
self.create_table_as_select = None self.create_table_as_select = None
self.database = None self.database = None
self._init_from_query_params(query_params) self._init_from_query_params(query_params)
self.user_id = self._get_user_id() self.user_id = get_user_id()
self.client_id_or_short_id = cast(str, self.client_id or utils.shortid()[:10]) self.client_id_or_short_id = cast(str, self.client_id or utils.shortid()[:10])
def set_query(self, query: Query) -> None: def set_query(self, query: Query) -> None:
@ -111,12 +111,6 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes
limit = 0 limit = 0
return limit return limit
def _get_user_id(self) -> Optional[int]: # pylint: disable=no-self-use
try:
return g.user.get_id() if g.user else None
except RuntimeError:
return None
def is_run_asynchronous(self) -> bool: def is_run_asynchronous(self) -> bool:
return self.async_flag return self.async_flag

View File

@ -21,7 +21,9 @@ from typing import Any, Dict, List, Optional, Tuple
import jwt import jwt
import redis import redis
from flask import Flask, g, request, Request, Response, session from flask import Flask, request, Request, Response, session
from superset.utils.core import get_user_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -35,12 +37,12 @@ class AsyncQueryJobException(Exception):
def build_job_metadata( def build_job_metadata(
channel_id: str, job_id: str, user_id: Optional[str], **kwargs: Any channel_id: str, job_id: str, user_id: Optional[int], **kwargs: Any
) -> Dict[str, Any]: ) -> Dict[str, Any]:
return { return {
"channel_id": channel_id, "channel_id": channel_id,
"job_id": job_id, "job_id": job_id,
"user_id": int(user_id) if user_id else None, "user_id": user_id,
"status": kwargs.get("status"), "status": kwargs.get("status"),
"errors": kwargs.get("errors", []), "errors": kwargs.get("errors", []),
"result_url": kwargs.get("result_url"), "result_url": kwargs.get("result_url"),
@ -113,13 +115,7 @@ class AsyncQueryManager:
@app.after_request @app.after_request
def validate_session(response: Response) -> Response: def validate_session(response: Response) -> Response:
user_id = None user_id = get_user_id()
try:
user_id = g.user.get_id()
user_id = int(user_id)
except Exception: # pylint: disable=broad-except
pass
reset_token = ( reset_token = (
not request.cookies.get(self._jwt_cookie_name) not request.cookies.get(self._jwt_cookie_name)
@ -161,7 +157,7 @@ class AsyncQueryManager:
logger.warning("Parse jwt failed", exc_info=True) logger.warning("Parse jwt failed", exc_info=True)
raise AsyncQueryTokenException("Failed to parse token") from ex raise AsyncQueryTokenException("Failed to parse token") from ex
def init_job(self, channel_id: str, user_id: Optional[str]) -> Dict[str, Any]: def init_job(self, channel_id: str, user_id: Optional[int]) -> Dict[str, Any]:
job_id = str(uuid.uuid4()) job_id = str(uuid.uuid4())
return build_job_metadata( return build_job_metadata(
channel_id, job_id, user_id, status=self.STATUS_PENDING channel_id, job_id, user_id, status=self.STATUS_PENDING

View File

@ -1422,13 +1422,36 @@ def split_adhoc_filters_into_base_filters( # pylint: disable=invalid-name
def get_username() -> Optional[str]: def get_username() -> Optional[str]:
"""Get username if within the flask context, otherwise return noffin'""" """
Get username (if defined) associated with the current user.
:returns: The username
"""
try: try:
return g.user.username return g.user.username
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
return None return None
def get_user_id() -> Optional[int]:
"""
Get the user identifier (if defined) associated with the current user.
Though the Flask-AppBuilder `User` and Flask-Login `AnonymousUserMixin` and
`UserMixin` models provide a convenience `get_id` method, for generality, the
identifier is encoded as a `str` whereas in Superset all identifiers are encoded as
an `int`.
returns: The user identifier
"""
try:
return g.user.id
except Exception: # pylint: disable=broad-except
return None
@contextmanager @contextmanager
def override_user(user: Optional[User]) -> Iterator[Any]: def override_user(user: Optional[User]) -> Iterator[Any]:
""" """

View File

@ -41,6 +41,8 @@ from flask_appbuilder.const import API_URI_RIS_KEY
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from typing_extensions import Literal from typing_extensions import Literal
from superset.utils.core import get_user_id
if TYPE_CHECKING: if TYPE_CHECKING:
from superset.stats_logger import BaseStatsLogger from superset.stats_logger import BaseStatsLogger
@ -133,10 +135,7 @@ class AbstractEventLogger(ABC):
duration_ms = int(duration.total_seconds() * 1000) if duration else None duration_ms = int(duration.total_seconds() * 1000) if duration else None
# Initial try and grab user_id via flask.g.user # Initial try and grab user_id via flask.g.user
try: user_id = get_user_id()
user_id = g.user.get_id()
except Exception: # pylint: disable=broad-except
user_id = None
# Whenever a user is not bounded to a session we # Whenever a user is not bounded to a session we
# need to add them back before logging to capture user_id # need to add them back before logging to capture user_id
@ -144,7 +143,7 @@ class AbstractEventLogger(ABC):
try: try:
session = current_app.appbuilder.get_session session = current_app.appbuilder.get_session
session.add(g.user) session.add(g.user)
user_id = g.user.get_id() user_id = get_user_id()
except Exception as ex: # pylint: disable=broad-except except Exception as ex: # pylint: disable=broad-except
logging.warning(ex) logging.warning(ex)
user_id = None user_id = None

View File

@ -76,6 +76,7 @@ from superset.models.reports import ReportRecipientType
from superset.superset_typing import FlaskResponse from superset.superset_typing import FlaskResponse
from superset.translations.utils import get_language_pack from superset.translations.utils import get_language_pack
from superset.utils import core as utils from superset.utils import core as utils
from superset.utils.core import get_user_id
from .utils import bootstrap_user_data from .utils import bootstrap_user_data
@ -623,10 +624,7 @@ class DatasourceFilter(BaseFilter): # pylint: disable=too-few-public-methods
owner_ids_query = ( owner_ids_query = (
db.session.query(models.SqlaTable.id) db.session.query(models.SqlaTable.id)
.join(models.SqlaTable.owners) .join(models.SqlaTable.owners)
.filter( .filter(security_manager.user_model.id == get_user_id())
security_manager.user_model.id
== security_manager.user_model.get_user_id()
)
) )
return query.filter( return query.filter(
or_( or_(

View File

@ -18,7 +18,7 @@ import functools
import logging import logging
from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type, Union from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type, Union
from flask import Blueprint, g, request, Response from flask import Blueprint, request, Response
from flask_appbuilder import AppBuilder, Model, ModelRestApi from flask_appbuilder import AppBuilder, Model, ModelRestApi
from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.models.filters import BaseFilter, Filters from flask_appbuilder.models.filters import BaseFilter, Filters
@ -38,7 +38,7 @@ from superset.schemas import error_payload_content
from superset.sql_lab import Query as SqllabQuery from superset.sql_lab import Query as SqllabQuery
from superset.stats_logger import BaseStatsLogger from superset.stats_logger import BaseStatsLogger
from superset.superset_typing import FlaskResponse from superset.superset_typing import FlaskResponse
from superset.utils.core import time_function from superset.utils.core import get_user_id, time_function
from superset.views.base import handle_api_exception from superset.views.base import handle_api_exception
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -145,7 +145,7 @@ class BaseFavoriteFilter(BaseFilter): # pylint: disable=too-few-public-methods
return query return query
users_favorite_query = db.session.query(FavStar.obj_id).filter( users_favorite_query = db.session.query(FavStar.obj_id).filter(
and_( and_(
FavStar.user_id == g.user.get_id(), FavStar.user_id == get_user_id(),
FavStar.class_name == self.class_name, FavStar.class_name == self.class_name,
) )
) )

View File

@ -21,6 +21,8 @@ from flask_appbuilder import Model
from marshmallow import post_load, pre_load, Schema, ValidationError from marshmallow import post_load, pre_load, Schema, ValidationError
from sqlalchemy.orm.exc import NoResultFound from sqlalchemy.orm.exc import NoResultFound
from superset.utils.core import get_user_id
def validate_owner(value: int) -> None: def validate_owner(value: int) -> None:
try: try:
@ -113,8 +115,9 @@ class BaseOwnedSchema(BaseSupersetSchema):
@staticmethod @staticmethod
def set_owners(instance: Model, owners: List[int]) -> None: def set_owners(instance: Model, owners: List[int]) -> None:
owner_objs = [] owner_objs = []
if g.user.get_id() not in owners: user_id = get_user_id()
owners.append(g.user.get_id()) if user_id and user_id not in owners:
owners.append(user_id)
for owner_id in owners: for owner_id in owners:
user = current_app.appbuilder.get_session.query( user = current_app.appbuilder.get_session.query(
current_app.appbuilder.sm.user_model current_app.appbuilder.sm.user_model

View File

@ -132,6 +132,7 @@ from superset.utils.cache import etag_cache
from superset.utils.core import ( from superset.utils.core import (
apply_max_row_limit, apply_max_row_limit,
DatasourceType, DatasourceType,
get_user_id,
ReservedUrlParameters, ReservedUrlParameters,
) )
from superset.utils.dates import now_as_float from superset.utils.dates import now_as_float
@ -673,7 +674,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
request request
)["channel"] )["channel"]
job_metadata = async_query_manager.init_job( job_metadata = async_query_manager.init_job(
async_channel_id, g.user.get_id() async_channel_id, get_user_id()
) )
load_explore_json_into_cache.delay( load_explore_json_into_cache.delay(
job_metadata, form_data, response_type, force job_metadata, form_data, response_type, force
@ -1885,13 +1886,13 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
self, class_name: str, obj_id: int, action: str self, class_name: str, obj_id: int, action: str
) -> FlaskResponse: ) -> FlaskResponse:
"""Toggle favorite stars on Slices and Dashboard""" """Toggle favorite stars on Slices and Dashboard"""
if not g.user.get_id(): if not get_user_id():
return json_error_response("ERROR: Favstar toggling denied", status=403) return json_error_response("ERROR: Favstar toggling denied", status=403)
session = db.session() session = db.session()
count = 0 count = 0
favs = ( favs = (
session.query(FavStar) session.query(FavStar)
.filter_by(class_name=class_name, obj_id=obj_id, user_id=g.user.get_id()) .filter_by(class_name=class_name, obj_id=obj_id, user_id=get_user_id())
.all() .all()
) )
if action == "select": if action == "select":
@ -1900,7 +1901,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
FavStar( FavStar(
class_name=class_name, class_name=class_name,
obj_id=obj_id, obj_id=obj_id,
user_id=g.user.get_id(), user_id=get_user_id(),
dttm=datetime.now(), dttm=datetime.now(),
) )
) )
@ -2582,7 +2583,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
@staticmethod @staticmethod
def queries_exec(last_updated_ms: Union[float, int]) -> FlaskResponse: def queries_exec(last_updated_ms: Union[float, int]) -> FlaskResponse:
stats_logger.incr("queries") stats_logger.incr("queries")
if not g.user.get_id(): if not get_user_id():
return json_error_response( return json_error_response(
"Please login to access the queries.", status=403 "Please login to access the queries.", status=403
) )
@ -2592,9 +2593,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
sql_queries = ( sql_queries = (
db.session.query(Query) db.session.query(Query)
.filter( .filter(Query.user_id == get_user_id(), Query.changed_on >= last_updated_dt)
Query.user_id == g.user.get_id(), Query.changed_on >= last_updated_dt
)
.all() .all()
) )
dict_queries = {q.client_id: q.to_dict() for q in sql_queries} dict_queries = {q.client_id: q.to_dict() for q in sql_queries}
@ -2620,10 +2619,10 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
search_user_id = int(cast(int, request.args.get("user_id"))) search_user_id = int(cast(int, request.args.get("user_id")))
except ValueError: except ValueError:
return Response(status=400, mimetype="application/json") return Response(status=400, mimetype="application/json")
if search_user_id != g.user.get_user_id(): if search_user_id != get_user_id():
return Response(status=403, mimetype="application/json") return Response(status=403, mimetype="application/json")
else: else:
search_user_id = g.user.get_user_id() search_user_id = get_user_id()
database_id = request.args.get("database_id") database_id = request.args.get("database_id")
search_text = request.args.get("search_text") search_text = request.args.get("search_text")
status = request.args.get("status") status = request.args.get("status")
@ -2676,14 +2675,14 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
@expose("/welcome/") @expose("/welcome/")
def welcome(self) -> FlaskResponse: def welcome(self) -> FlaskResponse:
"""Personalized welcome page""" """Personalized welcome page"""
if not g.user or not g.user.get_id(): if not get_user_id():
if conf["PUBLIC_ROLE_LIKE"]: if conf["PUBLIC_ROLE_LIKE"]:
return self.render_template("superset/public_welcome.html") return self.render_template("superset/public_welcome.html")
return redirect(appbuilder.get_url_for_login) return redirect(appbuilder.get_url_for_login)
welcome_dashboard_id = ( welcome_dashboard_id = (
db.session.query(UserAttribute.welcome_dashboard_id) db.session.query(UserAttribute.welcome_dashboard_id)
.filter_by(user_id=g.user.get_id()) .filter_by(user_id=get_user_id())
.scalar() .scalar()
) )
if welcome_dashboard_id: if welcome_dashboard_id:
@ -2728,7 +2727,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
) )
@staticmethod @staticmethod
def _get_sqllab_tabs(user_id: int) -> Dict[str, Any]: def _get_sqllab_tabs(user_id: Optional[int]) -> Dict[str, Any]:
# send list of tab state ids # send list of tab state ids
tabs_state = ( tabs_state = (
db.session.query(TabState.id, TabState.label) db.session.query(TabState.id, TabState.label)
@ -2780,7 +2779,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
payload = { payload = {
"defaultDbId": config["SQLLAB_DEFAULT_DBID"], "defaultDbId": config["SQLLAB_DEFAULT_DBID"],
"common": common_bootstrap_payload(), "common": common_bootstrap_payload(),
**self._get_sqllab_tabs(g.user.get_id()), **self._get_sqllab_tabs(get_user_id()),
} }
form_data = request.form.get("form_data") form_data = request.form.get("form_data")

View File

@ -29,6 +29,7 @@ from superset.constants import MODEL_VIEW_RW_METHOD_PERMISSION_MAP, RouteMethod
from superset.models.sql_lab import Query, SavedQuery, TableSchema, TabState from superset.models.sql_lab import Query, SavedQuery, TableSchema, TabState
from superset.superset_typing import FlaskResponse from superset.superset_typing import FlaskResponse
from superset.utils import core as utils from superset.utils import core as utils
from superset.utils.core import get_user_id
from .base import BaseSupersetView, DeleteMixin, json_success, SupersetModelView from .base import BaseSupersetView, DeleteMixin, json_success, SupersetModelView
@ -136,7 +137,7 @@ class TabStateView(BaseSupersetView):
def post(self) -> FlaskResponse: # pylint: disable=no-self-use def post(self) -> FlaskResponse: # pylint: disable=no-self-use
query_editor = json.loads(request.form["queryEditor"]) query_editor = json.loads(request.form["queryEditor"])
tab_state = TabState( tab_state = TabState(
user_id=g.user.get_id(), user_id=get_user_id(),
label=query_editor.get("title", "Untitled Query"), label=query_editor.get("title", "Untitled Query"),
active=True, active=True,
database_id=query_editor["dbId"], database_id=query_editor["dbId"],
@ -147,7 +148,7 @@ class TabStateView(BaseSupersetView):
) )
( (
db.session.query(TabState) db.session.query(TabState)
.filter_by(user_id=g.user.get_id()) .filter_by(user_id=get_user_id())
.update({"active": False}) .update({"active": False})
) )
db.session.add(tab_state) db.session.add(tab_state)
@ -157,7 +158,7 @@ class TabStateView(BaseSupersetView):
@has_access_api @has_access_api
@expose("/<int:tab_state_id>", methods=["DELETE"]) @expose("/<int:tab_state_id>", methods=["DELETE"])
def delete(self, tab_state_id: int) -> FlaskResponse: # pylint: disable=no-self-use def delete(self, tab_state_id: int) -> FlaskResponse: # pylint: disable=no-self-use
if _get_owner_id(tab_state_id) != int(g.user.get_id()): if _get_owner_id(tab_state_id) != get_user_id():
return Response(status=403) return Response(status=403)
db.session.query(TabState).filter(TabState.id == tab_state_id).delete( db.session.query(TabState).filter(TabState.id == tab_state_id).delete(
@ -172,7 +173,7 @@ class TabStateView(BaseSupersetView):
@has_access_api @has_access_api
@expose("/<int:tab_state_id>", methods=["GET"]) @expose("/<int:tab_state_id>", methods=["GET"])
def get(self, tab_state_id: int) -> FlaskResponse: # pylint: disable=no-self-use def get(self, tab_state_id: int) -> FlaskResponse: # pylint: disable=no-self-use
if _get_owner_id(tab_state_id) != int(g.user.get_id()): if _get_owner_id(tab_state_id) != get_user_id():
return Response(status=403) return Response(status=403)
tab_state = db.session.query(TabState).filter_by(id=tab_state_id).first() tab_state = db.session.query(TabState).filter_by(id=tab_state_id).first()
@ -190,12 +191,12 @@ class TabStateView(BaseSupersetView):
owner_id = _get_owner_id(tab_state_id) owner_id = _get_owner_id(tab_state_id)
if owner_id is None: if owner_id is None:
return Response(status=404) return Response(status=404)
if owner_id != int(g.user.get_id()): if owner_id != get_user_id():
return Response(status=403) return Response(status=403)
( (
db.session.query(TabState) db.session.query(TabState)
.filter_by(user_id=g.user.get_id()) .filter_by(user_id=get_user_id())
.update({"active": TabState.id == tab_state_id}) .update({"active": TabState.id == tab_state_id})
) )
db.session.commit() db.session.commit()
@ -204,7 +205,7 @@ class TabStateView(BaseSupersetView):
@has_access_api @has_access_api
@expose("<int:tab_state_id>", methods=["PUT"]) @expose("<int:tab_state_id>", methods=["PUT"])
def put(self, tab_state_id: int) -> FlaskResponse: # pylint: disable=no-self-use def put(self, tab_state_id: int) -> FlaskResponse: # pylint: disable=no-self-use
if _get_owner_id(tab_state_id) != int(g.user.get_id()): if _get_owner_id(tab_state_id) != get_user_id():
return Response(status=403) return Response(status=403)
fields = {k: json.loads(v) for k, v in request.form.to_dict().items()} fields = {k: json.loads(v) for k, v in request.form.to_dict().items()}
@ -217,7 +218,7 @@ class TabStateView(BaseSupersetView):
def migrate_query( # pylint: disable=no-self-use def migrate_query( # pylint: disable=no-self-use
self, tab_state_id: int self, tab_state_id: int
) -> FlaskResponse: ) -> FlaskResponse:
if _get_owner_id(tab_state_id) != int(g.user.get_id()): if _get_owner_id(tab_state_id) != get_user_id():
return Response(status=403) return Response(status=403)
client_id = json.loads(request.form["queryId"]) client_id = json.loads(request.form["queryId"])
@ -244,7 +245,7 @@ class TabStateView(BaseSupersetView):
.filter( .filter(
and_( and_(
Query.client_id != client_id, Query.client_id != client_id,
Query.user_id == g.user.get_id(), Query.user_id == get_user_id(),
Query.sql_editor_id == str(tab_state_id), Query.sql_editor_id == str(tab_state_id),
), ),
) )
@ -257,7 +258,7 @@ class TabStateView(BaseSupersetView):
db.session.query(Query).filter_by( db.session.query(Query).filter_by(
client_id=client_id, client_id=client_id,
user_id=g.user.get_id(), user_id=get_user_id(),
sql_editor_id=str(tab_state_id), sql_editor_id=str(tab_state_id),
).delete(synchronize_session=False) ).delete(synchronize_session=False)
db.session.commit() db.session.commit()
@ -327,4 +328,4 @@ class SqlLab(BaseSupersetView):
logger.warning( logger.warning(
"This endpoint is deprecated and will be removed in the next major release" "This endpoint is deprecated and will be removed in the next major release"
) )
return redirect("/savedqueryview/list/?_flt_0_user={}".format(g.user.get_id())) return redirect(f"/savedqueryview/list/?_flt_0_user={get_user_id()}")

View File

@ -18,11 +18,12 @@
"""Unit tests for Superset""" """Unit tests for Superset"""
import json import json
import unittest import unittest
from typing import Optional
from unittest import mock from unittest import mock
import pytest import pytest
from flask import g
from flask.ctx import AppContext from flask.ctx import AppContext
from pytest_mock import MockFixture
from sqlalchemy import inspect from sqlalchemy import inspect
from tests.integration_tests.fixtures.birth_names_dashboard import ( from tests.integration_tests.fixtures.birth_names_dashboard import (
@ -42,7 +43,7 @@ from superset import db, security_manager
from superset.connectors.sqla.models import SqlaTable from superset.connectors.sqla.models import SqlaTable
from superset.models import core as models from superset.models import core as models
from superset.models.datasource_access_request import DatasourceAccessRequest from superset.models.datasource_access_request import DatasourceAccessRequest
from superset.utils.core import get_username, override_user from superset.utils.core import get_user_id, get_username, override_user
from superset.utils.database import get_example_database from superset.utils.database import get_example_database
from .base_tests import SupersetTestCase from .base_tests import SupersetTestCase
@ -524,18 +525,40 @@ class TestRequestAccess(SupersetTestCase):
session.commit() session.commit()
@pytest.mark.parametrize(
"username,user_id",
[
(None, None),
("alpha", 5),
("gamma", 2),
],
)
def test_get_user_id(
app_context: AppContext,
mocker: MockFixture,
username: Optional[str],
user_id: Optional[int],
) -> None:
mock_g = mocker.patch("superset.utils.core.g", spec={})
mock_g.user = security_manager.find_user(username)
assert get_user_id() == user_id
@pytest.mark.parametrize( @pytest.mark.parametrize(
"username", "username",
[ [
None, None,
"alpha",
"gamma", "gamma",
], ],
) )
def test_get_username(app_context: AppContext, username: str) -> None: def test_get_username(
assert not hasattr(g, "user") app_context: AppContext,
assert get_username() is None mocker: MockFixture,
username: Optional[str],
g.user = security_manager.find_user(username) ) -> None:
mock_g = mocker.patch("superset.utils.core.g", spec={})
mock_g.user = security_manager.find_user(username)
assert get_username() == username assert get_username() == username
@ -543,26 +566,32 @@ def test_get_username(app_context: AppContext, username: str) -> None:
"username", "username",
[ [
None, None,
"alpha",
"gamma", "gamma",
], ],
) )
def test_override_user(app_context: AppContext, username: str) -> None: def test_override_user(
app_context: AppContext,
mocker: MockFixture,
username: str,
) -> None:
mock_g = mocker.patch("superset.utils.core.g", spec={})
admin = security_manager.find_user(username="admin") admin = security_manager.find_user(username="admin")
user = security_manager.find_user(username) user = security_manager.find_user(username)
assert not hasattr(g, "user") assert not hasattr(mock_g, "user")
with override_user(user): with override_user(user):
assert g.user == user assert mock_g.user == user
assert not hasattr(g, "user") assert not hasattr(mock_g, "user")
g.user = admin mock_g.user = admin
with override_user(user): with override_user(user):
assert g.user == user assert mock_g.user == user
assert g.user == admin assert mock_g.user == admin
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -112,7 +112,7 @@ class TestEventLogger(unittest.TestCase):
) )
self.assertGreaterEqual(payload["duration_ms"], 100) self.assertGreaterEqual(payload["duration_ms"], 100)
@patch("superset.utils.log.g", spec={}) @patch("superset.utils.core.g", spec={})
@freeze_time("Jan 14th, 2020", auto_tick_seconds=15) @freeze_time("Jan 14th, 2020", auto_tick_seconds=15)
def test_context_manager_log(self, mock_g): def test_context_manager_log(self, mock_g):
class DummyEventLogger(AbstractEventLogger): class DummyEventLogger(AbstractEventLogger):
@ -144,12 +144,12 @@ class TestEventLogger(unittest.TestCase):
assert logger.records == [ assert logger.records == [
{ {
"records": [{"path": "/", "engine": "bar"}], "records": [{"path": "/", "engine": "bar"}],
"user_id": "2", "user_id": 2,
"duration": 15000.0, "duration": 15000.0,
} }
] ]
@patch("superset.utils.log.g", spec={}) @patch("superset.utils.core.g", spec={})
def test_context_manager_log_with_context(self, mock_g): def test_context_manager_log_with_context(self, mock_g):
class DummyEventLogger(AbstractEventLogger): class DummyEventLogger(AbstractEventLogger):
def __init__(self): def __init__(self):
@ -191,12 +191,12 @@ class TestEventLogger(unittest.TestCase):
"payload_override": {"engine": "sqllite"}, "payload_override": {"engine": "sqllite"},
} }
], ],
"user_id": "2", "user_id": 2,
"duration": 5558756000, "duration": 5558756000,
} }
] ]
@patch("superset.utils.log.g", spec={}) @patch("superset.utils.core.g", spec={})
def test_log_with_context_user_null(self, mock_g): def test_log_with_context_user_null(self, mock_g):
class DummyEventLogger(AbstractEventLogger): class DummyEventLogger(AbstractEventLogger):
def __init__(self): def __init__(self):

View File

@ -412,8 +412,9 @@ class TestRolePermission(SupersetTestCase):
# TODO test slice permission # TODO test slice permission
@patch("superset.security.manager.g") @patch("superset.security.manager.g")
def test_schemas_accessible_by_user_admin(self, mock_g): @patch("superset.utils.core.g")
mock_g.user = security_manager.find_user("admin") def test_schemas_accessible_by_user_admin(self, mock_sm_g, mock_g):
mock_g.user = mock_sm_g.user = security_manager.find_user("admin")
with self.client.application.test_request_context(): with self.client.application.test_request_context():
database = get_example_database() database = get_example_database()
schemas = security_manager.get_schemas_accessible_by_user( schemas = security_manager.get_schemas_accessible_by_user(
@ -422,10 +423,11 @@ class TestRolePermission(SupersetTestCase):
self.assertEqual(schemas, ["1", "2", "3"]) # no changes self.assertEqual(schemas, ["1", "2", "3"]) # no changes
@patch("superset.security.manager.g") @patch("superset.security.manager.g")
def test_schemas_accessible_by_user_schema_access(self, mock_g): @patch("superset.utils.core.g")
def test_schemas_accessible_by_user_schema_access(self, mock_sm_g, mock_g):
# User has schema access to the schema 1 # User has schema access to the schema 1
create_schema_perm("[examples].[1]") create_schema_perm("[examples].[1]")
mock_g.user = security_manager.find_user("gamma") mock_g.user = mock_sm_g.user = security_manager.find_user("gamma")
with self.client.application.test_request_context(): with self.client.application.test_request_context():
database = get_example_database() database = get_example_database()
schemas = security_manager.get_schemas_accessible_by_user( schemas = security_manager.get_schemas_accessible_by_user(
@ -436,9 +438,10 @@ class TestRolePermission(SupersetTestCase):
delete_schema_perm("[examples].[1]") delete_schema_perm("[examples].[1]")
@patch("superset.security.manager.g") @patch("superset.security.manager.g")
def test_schemas_accessible_by_user_datasource_access(self, mock_g): @patch("superset.utils.core.g")
def test_schemas_accessible_by_user_datasource_access(self, mock_sm_g, mock_g):
# User has schema access to the datasource temp_schema.wb_health_population in examples DB. # User has schema access to the datasource temp_schema.wb_health_population in examples DB.
mock_g.user = security_manager.find_user("gamma") mock_g.user = mock_sm_g.user = security_manager.find_user("gamma")
with self.client.application.test_request_context(): with self.client.application.test_request_context():
database = get_example_database() database = get_example_database()
schemas = security_manager.get_schemas_accessible_by_user( schemas = security_manager.get_schemas_accessible_by_user(
@ -447,10 +450,13 @@ class TestRolePermission(SupersetTestCase):
self.assertEqual(schemas, ["temp_schema"]) self.assertEqual(schemas, ["temp_schema"])
@patch("superset.security.manager.g") @patch("superset.security.manager.g")
def test_schemas_accessible_by_user_datasource_and_schema_access(self, mock_g): @patch("superset.utils.core.g")
def test_schemas_accessible_by_user_datasource_and_schema_access(
self, mock_sm_g, mock_g
):
# User has schema access to the datasource temp_schema.wb_health_population in examples DB. # User has schema access to the datasource temp_schema.wb_health_population in examples DB.
create_schema_perm("[examples].[2]") create_schema_perm("[examples].[2]")
mock_g.user = security_manager.find_user("gamma") mock_g.user = mock_sm_g.user = security_manager.find_user("gamma")
with self.client.application.test_request_context(): with self.client.application.test_request_context():
database = get_example_database() database = get_example_database()
schemas = security_manager.get_schemas_accessible_by_user( schemas = security_manager.get_schemas_accessible_by_user(

View File

@ -32,6 +32,7 @@ from superset.tasks.async_queries import (
load_chart_data_into_cache, load_chart_data_into_cache,
load_explore_json_into_cache, load_explore_json_into_cache,
) )
from superset.utils.core import get_user_id
from tests.integration_tests.base_tests import SupersetTestCase from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.fixtures.birth_names_dashboard import ( from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices, load_birth_names_dashboard_with_slices,
@ -218,7 +219,7 @@ class TestAsyncQueries(SupersetTestCase):
ensure_user_is_set(1) ensure_user_is_set(1)
self.assertTrue(hasattr(g, "user")) self.assertTrue(hasattr(g, "user"))
self.assertFalse(g.user.is_anonymous) self.assertFalse(g.user.is_anonymous)
self.assertEqual("1", g.user.get_id()) self.assertEqual(1, get_user_id())
del g.user del g.user
@ -226,22 +227,22 @@ class TestAsyncQueries(SupersetTestCase):
ensure_user_is_set(None) ensure_user_is_set(None)
self.assertTrue(hasattr(g, "user")) self.assertTrue(hasattr(g, "user"))
self.assertTrue(g.user.is_anonymous) self.assertTrue(g.user.is_anonymous)
self.assertEqual(None, g.user.get_id()) self.assertEqual(None, get_user_id())
del g.user del g.user
g.user = security_manager.get_user_by_id(2) g.user = security_manager.get_user_by_id(2)
self.assertEqual("2", g.user.get_id()) self.assertEqual(2, get_user_id())
ensure_user_is_set(1) ensure_user_is_set(1)
self.assertTrue(hasattr(g, "user")) self.assertTrue(hasattr(g, "user"))
self.assertFalse(g.user.is_anonymous) self.assertFalse(g.user.is_anonymous)
self.assertEqual("2", g.user.get_id()) self.assertEqual(2, get_user_id())
ensure_user_is_set(None) ensure_user_is_set(None)
self.assertTrue(hasattr(g, "user")) self.assertTrue(hasattr(g, "user"))
self.assertFalse(g.user.is_anonymous) self.assertFalse(g.user.is_anonymous)
self.assertEqual("2", g.user.get_id()) self.assertEqual(2, get_user_id())
if g_user_is_set: if g_user_is_set:
g.user = original_g_user g.user = original_g_user