mirror of https://github.com/apache/superset.git
refactor: Cleanup user get_id/get_user_id (#20492)
Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
parent
c56e37cda2
commit
3483446c28
|
@ -805,7 +805,7 @@ class ChartRestApi(BaseSupersetModelRestApi):
|
|||
charts = ChartDAO.find_by_ids(requested_ids)
|
||||
if not charts:
|
||||
return self.response_404()
|
||||
favorited_chart_ids = ChartDAO.favorited_ids(charts, g.user.get_id())
|
||||
favorited_chart_ids = ChartDAO.favorited_ids(charts)
|
||||
res = [
|
||||
{"id": request_id, "value": request_id in favorited_chart_ids}
|
||||
for request_id in requested_ids
|
||||
|
|
|
@ -25,6 +25,7 @@ from superset.dao.base import BaseDAO
|
|||
from superset.extensions import db
|
||||
from superset.models.core import FavStar, FavStarClassName
|
||||
from superset.models.slice import Slice
|
||||
from superset.utils.core import get_user_id
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.connectors.base.models import BaseDatasource
|
||||
|
@ -70,7 +71,7 @@ class ChartDAO(BaseDAO):
|
|||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def favorited_ids(charts: List[Slice], current_user_id: int) -> List[FavStar]:
|
||||
def favorited_ids(charts: List[Slice]) -> List[FavStar]:
|
||||
ids = [chart.id for chart in charts]
|
||||
return [
|
||||
star.obj_id
|
||||
|
@ -78,7 +79,7 @@ class ChartDAO(BaseDAO):
|
|||
.filter(
|
||||
FavStar.class_name == FavStarClassName.CHART,
|
||||
FavStar.obj_id.in_(ids),
|
||||
FavStar.user_id == current_user_id,
|
||||
FavStar.user_id == get_user_id(),
|
||||
)
|
||||
.all()
|
||||
]
|
||||
|
|
|
@ -21,7 +21,7 @@ import logging
|
|||
from typing import Any, Dict, Optional, TYPE_CHECKING
|
||||
|
||||
import simplejson
|
||||
from flask import current_app, g, make_response, request, Response
|
||||
from flask import current_app, make_response, request, Response
|
||||
from flask_appbuilder.api import expose, protect
|
||||
from flask_babel import gettext as _
|
||||
from marshmallow import ValidationError
|
||||
|
@ -44,7 +44,7 @@ from superset.connectors.base.models import BaseDatasource
|
|||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.extensions import event_logger
|
||||
from superset.utils.async_query_manager import AsyncQueryTokenException
|
||||
from superset.utils.core import create_zip, json_int_dttm_ser
|
||||
from superset.utils.core import create_zip, get_user_id, json_int_dttm_ser
|
||||
from superset.views.base import CsvResponse, generate_download_headers
|
||||
from superset.views.base_api import statsd_metrics
|
||||
|
||||
|
@ -324,7 +324,7 @@ class ChartDataRestApi(ChartRestApi):
|
|||
except AsyncQueryTokenException:
|
||||
return self.response_401()
|
||||
|
||||
result = async_command.run(form_data, g.user.get_id())
|
||||
result = async_command.run(form_data, get_user_id())
|
||||
return self.response(202, **result)
|
||||
|
||||
def _send_chart_response(
|
||||
|
|
|
@ -32,7 +32,7 @@ class CreateAsyncChartDataJobCommand:
|
|||
jwt_data = async_query_manager.parse_jwt_from_request(request)
|
||||
self._async_channel_id = jwt_data["channel"]
|
||||
|
||||
def run(self, form_data: Dict[str, Any], user_id: Optional[str]) -> Dict[str, Any]:
|
||||
def run(self, form_data: Dict[str, Any], user_id: Optional[int]) -> Dict[str, Any]:
|
||||
job_metadata = async_query_manager.init_job(self._async_channel_id, user_id)
|
||||
load_chart_data_into_cache.delay(job_metadata, form_data)
|
||||
return job_metadata
|
||||
|
|
|
@ -942,9 +942,7 @@ class DashboardRestApi(BaseSupersetModelRestApi):
|
|||
dashboards = DashboardDAO.find_by_ids(requested_ids)
|
||||
if not dashboards:
|
||||
return self.response_404()
|
||||
favorited_dashboard_ids = DashboardDAO.favorited_ids(
|
||||
dashboards, g.user.get_id()
|
||||
)
|
||||
favorited_dashboard_ids = DashboardDAO.favorited_ids(dashboards)
|
||||
res = [
|
||||
{"id": request_id, "value": request_id in favorited_dashboard_ids}
|
||||
for request_id in requested_ids
|
||||
|
|
|
@ -29,6 +29,7 @@ from superset.extensions import db
|
|||
from superset.models.core import FavStar, FavStarClassName
|
||||
from superset.models.dashboard import Dashboard
|
||||
from superset.models.slice import Slice
|
||||
from superset.utils.core import get_user_id
|
||||
from superset.utils.dashboard_filter_scopes_converter import copy_filter_scopes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -274,9 +275,7 @@ class DashboardDAO(BaseDAO):
|
|||
return dashboard
|
||||
|
||||
@staticmethod
|
||||
def favorited_ids(
|
||||
dashboards: List[Dashboard], current_user_id: int
|
||||
) -> List[FavStar]:
|
||||
def favorited_ids(dashboards: List[Dashboard]) -> List[FavStar]:
|
||||
ids = [dash.id for dash in dashboards]
|
||||
return [
|
||||
star.obj_id
|
||||
|
@ -284,7 +283,7 @@ class DashboardDAO(BaseDAO):
|
|||
.filter(
|
||||
FavStar.class_name == FavStarClassName.DASHBOARD,
|
||||
FavStar.obj_id.in_(ids),
|
||||
FavStar.user_id == current_user_id,
|
||||
FavStar.user_id == get_user_id(),
|
||||
)
|
||||
.all()
|
||||
]
|
||||
|
|
|
@ -29,6 +29,7 @@ from superset.models.dashboard import Dashboard
|
|||
from superset.models.embedded_dashboard import EmbeddedDashboard
|
||||
from superset.models.slice import Slice
|
||||
from superset.security.guest_token import GuestTokenResourceType, GuestUser
|
||||
from superset.utils.core import get_user_id
|
||||
from superset.views.base import BaseFilter, is_user_admin
|
||||
from superset.views.base_api import BaseFavoriteFilter
|
||||
|
||||
|
@ -57,9 +58,9 @@ class DashboardCreatedByMeFilter(BaseFilter): # pylint: disable=too-few-public-
|
|||
return query.filter(
|
||||
or_(
|
||||
Dashboard.created_by_fk # pylint: disable=comparison-with-callable
|
||||
== g.user.get_user_id(),
|
||||
== get_user_id(),
|
||||
Dashboard.changed_by_fk # pylint: disable=comparison-with-callable
|
||||
== g.user.get_user_id(),
|
||||
== get_user_id(),
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -126,17 +127,14 @@ class DashboardAccessFilter(BaseFilter): # pylint: disable=too-few-public-metho
|
|||
|
||||
users_favorite_dash_query = db.session.query(FavStar.obj_id).filter(
|
||||
and_(
|
||||
FavStar.user_id == security_manager.user_model.get_user_id(),
|
||||
FavStar.user_id == get_user_id(),
|
||||
FavStar.class_name == "Dashboard",
|
||||
)
|
||||
)
|
||||
owner_ids_query = (
|
||||
db.session.query(Dashboard.id)
|
||||
.join(Dashboard.owners)
|
||||
.filter(
|
||||
security_manager.user_model.id
|
||||
== security_manager.user_model.get_user_id()
|
||||
)
|
||||
.filter(security_manager.user_model.id == get_user_id())
|
||||
)
|
||||
|
||||
feature_flagged_filters = []
|
||||
|
|
|
@ -41,7 +41,11 @@ from typing_extensions import TypedDict
|
|||
from superset.datasets.commands.exceptions import DatasetNotFoundError
|
||||
from superset.exceptions import SupersetTemplateException
|
||||
from superset.extensions import feature_flag_manager
|
||||
from superset.utils.core import convert_legacy_filters_into_adhoc, merge_extra_filters
|
||||
from superset.utils.core import (
|
||||
convert_legacy_filters_into_adhoc,
|
||||
get_user_id,
|
||||
merge_extra_filters,
|
||||
)
|
||||
from superset.utils.memoized import memoized
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -115,9 +119,10 @@ class ExtraCache:
|
|||
"""
|
||||
|
||||
if hasattr(g, "user") and g.user:
|
||||
id_ = get_user_id()
|
||||
if add_to_cache_keys:
|
||||
self.cache_key_wrapper(g.user.get_id())
|
||||
return g.user.get_id()
|
||||
self.cache_key_wrapper(id_)
|
||||
return id_
|
||||
return None
|
||||
|
||||
def current_username(self, add_to_cache_keys: bool = True) -> Optional[str]:
|
||||
|
|
|
@ -67,4 +67,4 @@ def get_uuid_namespace(seed: str) -> UUID:
|
|||
|
||||
|
||||
def get_owner(user: User) -> Optional[int]:
|
||||
return user.get_user_id() if not user.is_anonymous else None
|
||||
return user.id if not user.is_anonymous else None
|
||||
|
|
|
@ -34,6 +34,7 @@ from sqlalchemy.ext.declarative import declarative_base, declared_attr
|
|||
from sqlalchemy.orm import relationship
|
||||
|
||||
from superset import db
|
||||
from superset.utils.core import get_user_id
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
@ -63,17 +64,10 @@ dashboard_user = Table(
|
|||
|
||||
|
||||
class AuditMixin:
|
||||
@classmethod
|
||||
def get_user_id(cls):
|
||||
try:
|
||||
return g.user.id
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@declared_attr
|
||||
def created_by_fk(cls):
|
||||
return Column(
|
||||
Integer, ForeignKey("ab_user.id"), default=cls.get_user_id, nullable=False
|
||||
Integer, ForeignKey("ab_user.id"), default=get_user_id, nullable=False
|
||||
)
|
||||
|
||||
@declared_attr
|
||||
|
|
|
@ -34,6 +34,7 @@ from sqlalchemy import Column, DateTime, Enum, ForeignKey, Integer, String
|
|||
from sqlalchemy.ext.declarative import declarative_base, declared_attr
|
||||
|
||||
from superset.models.tags import ObjectTypes, TagTypes
|
||||
from superset.utils.core import get_user_id
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
@ -54,7 +55,7 @@ class AuditMixinNullable(AuditMixin):
|
|||
return Column(
|
||||
Integer,
|
||||
ForeignKey("ab_user.id"),
|
||||
default=self.get_user_id,
|
||||
default=get_user_id,
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
|
@ -63,8 +64,8 @@ class AuditMixinNullable(AuditMixin):
|
|||
return Column(
|
||||
Integer,
|
||||
ForeignKey("ab_user.id"),
|
||||
default=self.get_user_id,
|
||||
onupdate=self.get_user_id,
|
||||
default=get_user_id,
|
||||
onupdate=get_user_id,
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
|
|
|
@ -40,6 +40,7 @@ from sqlalchemy.orm.exc import MultipleResultsFound
|
|||
from sqlalchemy_utils import UUIDType
|
||||
|
||||
from superset.common.db_query_status import QueryStatus
|
||||
from superset.utils.core import get_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -384,7 +385,7 @@ class AuditMixinNullable(AuditMixin):
|
|||
return sa.Column(
|
||||
sa.Integer,
|
||||
sa.ForeignKey("ab_user.id"),
|
||||
default=self.get_user_id,
|
||||
default=get_user_id,
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
|
@ -393,8 +394,8 @@ class AuditMixinNullable(AuditMixin):
|
|||
return sa.Column(
|
||||
sa.Integer,
|
||||
sa.ForeignKey("ab_user.id"),
|
||||
default=self.get_user_id,
|
||||
onupdate=self.get_user_id,
|
||||
default=get_user_id,
|
||||
onupdate=get_user_id,
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
|
|
|
@ -16,11 +16,11 @@
|
|||
# under the License.
|
||||
from typing import Any
|
||||
|
||||
from flask import g
|
||||
from flask_sqlalchemy import BaseQuery
|
||||
|
||||
from superset import security_manager
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.utils.core import get_user_id
|
||||
from superset.views.base import BaseFilter
|
||||
|
||||
|
||||
|
@ -33,5 +33,5 @@ class QueryFilter(BaseFilter): # pylint: disable=too-few-public-methods
|
|||
:returns: query
|
||||
"""
|
||||
if not security_manager.can_access_all_queries():
|
||||
query = query.filter(Query.user_id == g.user.get_user_id())
|
||||
query = query.filter(Query.user_id == get_user_id())
|
||||
return query
|
||||
|
|
|
@ -75,7 +75,7 @@ from superset.security.guest_token import (
|
|||
GuestTokenUser,
|
||||
GuestUser,
|
||||
)
|
||||
from superset.utils.core import DatasourceName, RowLevelSecurityFilterType
|
||||
from superset.utils.core import DatasourceName, get_user_id, RowLevelSecurityFilterType
|
||||
from superset.utils.urls import get_url_host
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -529,7 +529,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
view_menu_names = (
|
||||
base_query.join(assoc_user_role)
|
||||
.join(self.user_model)
|
||||
.filter(self.user_model.id == g.user.get_id())
|
||||
.filter(self.user_model.id == get_user_id())
|
||||
.filter(self.permission_model.name == permission_name)
|
||||
).all()
|
||||
return {s.name for s in view_menu_names}
|
||||
|
@ -1252,10 +1252,9 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
|
||||
@staticmethod
|
||||
def raise_for_user_activity_access(user_id: int) -> None:
|
||||
user = g.user if g.user and g.user.get_id() else None
|
||||
if not user or (
|
||||
if not get_user_id() or (
|
||||
not current_app.config["ENABLE_BROAD_ACTIVITY_ACCESS"]
|
||||
and user_id != user.id
|
||||
and user_id != get_user_id()
|
||||
):
|
||||
raise SupersetSecurityException(
|
||||
SupersetError(
|
||||
|
|
|
@ -28,7 +28,7 @@ from superset import is_feature_enabled
|
|||
from superset.models.sql_lab import Query
|
||||
from superset.sql_parse import CtasMethod
|
||||
from superset.utils import core as utils
|
||||
from superset.utils.core import apply_max_row_limit
|
||||
from superset.utils.core import apply_max_row_limit, get_user_id
|
||||
from superset.utils.dates import now_as_float
|
||||
from superset.views.utils import get_cta_schema_name
|
||||
|
||||
|
@ -64,7 +64,7 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes
|
|||
self.create_table_as_select = None
|
||||
self.database = None
|
||||
self._init_from_query_params(query_params)
|
||||
self.user_id = self._get_user_id()
|
||||
self.user_id = get_user_id()
|
||||
self.client_id_or_short_id = cast(str, self.client_id or utils.shortid()[:10])
|
||||
|
||||
def set_query(self, query: Query) -> None:
|
||||
|
@ -111,12 +111,6 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes
|
|||
limit = 0
|
||||
return limit
|
||||
|
||||
def _get_user_id(self) -> Optional[int]: # pylint: disable=no-self-use
|
||||
try:
|
||||
return g.user.get_id() if g.user else None
|
||||
except RuntimeError:
|
||||
return None
|
||||
|
||||
def is_run_asynchronous(self) -> bool:
|
||||
return self.async_flag
|
||||
|
||||
|
|
|
@ -21,7 +21,9 @@ from typing import Any, Dict, List, Optional, Tuple
|
|||
|
||||
import jwt
|
||||
import redis
|
||||
from flask import Flask, g, request, Request, Response, session
|
||||
from flask import Flask, request, Request, Response, session
|
||||
|
||||
from superset.utils.core import get_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -35,12 +37,12 @@ class AsyncQueryJobException(Exception):
|
|||
|
||||
|
||||
def build_job_metadata(
|
||||
channel_id: str, job_id: str, user_id: Optional[str], **kwargs: Any
|
||||
channel_id: str, job_id: str, user_id: Optional[int], **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"channel_id": channel_id,
|
||||
"job_id": job_id,
|
||||
"user_id": int(user_id) if user_id else None,
|
||||
"user_id": user_id,
|
||||
"status": kwargs.get("status"),
|
||||
"errors": kwargs.get("errors", []),
|
||||
"result_url": kwargs.get("result_url"),
|
||||
|
@ -113,13 +115,7 @@ class AsyncQueryManager:
|
|||
|
||||
@app.after_request
|
||||
def validate_session(response: Response) -> Response:
|
||||
user_id = None
|
||||
|
||||
try:
|
||||
user_id = g.user.get_id()
|
||||
user_id = int(user_id)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
pass
|
||||
user_id = get_user_id()
|
||||
|
||||
reset_token = (
|
||||
not request.cookies.get(self._jwt_cookie_name)
|
||||
|
@ -161,7 +157,7 @@ class AsyncQueryManager:
|
|||
logger.warning("Parse jwt failed", exc_info=True)
|
||||
raise AsyncQueryTokenException("Failed to parse token") from ex
|
||||
|
||||
def init_job(self, channel_id: str, user_id: Optional[str]) -> Dict[str, Any]:
|
||||
def init_job(self, channel_id: str, user_id: Optional[int]) -> Dict[str, Any]:
|
||||
job_id = str(uuid.uuid4())
|
||||
return build_job_metadata(
|
||||
channel_id, job_id, user_id, status=self.STATUS_PENDING
|
||||
|
|
|
@ -1422,13 +1422,36 @@ def split_adhoc_filters_into_base_filters( # pylint: disable=invalid-name
|
|||
|
||||
|
||||
def get_username() -> Optional[str]:
|
||||
"""Get username if within the flask context, otherwise return noffin'"""
|
||||
"""
|
||||
Get username (if defined) associated with the current user.
|
||||
|
||||
:returns: The username
|
||||
"""
|
||||
|
||||
try:
|
||||
return g.user.username
|
||||
except Exception: # pylint: disable=broad-except
|
||||
return None
|
||||
|
||||
|
||||
def get_user_id() -> Optional[int]:
|
||||
"""
|
||||
Get the user identifier (if defined) associated with the current user.
|
||||
|
||||
Though the Flask-AppBuilder `User` and Flask-Login `AnonymousUserMixin` and
|
||||
`UserMixin` models provide a convenience `get_id` method, for generality, the
|
||||
identifier is encoded as a `str` whereas in Superset all identifiers are encoded as
|
||||
an `int`.
|
||||
|
||||
returns: The user identifier
|
||||
"""
|
||||
|
||||
try:
|
||||
return g.user.id
|
||||
except Exception: # pylint: disable=broad-except
|
||||
return None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def override_user(user: Optional[User]) -> Iterator[Any]:
|
||||
"""
|
||||
|
|
|
@ -41,6 +41,8 @@ from flask_appbuilder.const import API_URI_RIS_KEY
|
|||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from typing_extensions import Literal
|
||||
|
||||
from superset.utils.core import get_user_id
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.stats_logger import BaseStatsLogger
|
||||
|
||||
|
@ -133,10 +135,7 @@ class AbstractEventLogger(ABC):
|
|||
duration_ms = int(duration.total_seconds() * 1000) if duration else None
|
||||
|
||||
# Initial try and grab user_id via flask.g.user
|
||||
try:
|
||||
user_id = g.user.get_id()
|
||||
except Exception: # pylint: disable=broad-except
|
||||
user_id = None
|
||||
user_id = get_user_id()
|
||||
|
||||
# Whenever a user is not bounded to a session we
|
||||
# need to add them back before logging to capture user_id
|
||||
|
@ -144,7 +143,7 @@ class AbstractEventLogger(ABC):
|
|||
try:
|
||||
session = current_app.appbuilder.get_session
|
||||
session.add(g.user)
|
||||
user_id = g.user.get_id()
|
||||
user_id = get_user_id()
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
logging.warning(ex)
|
||||
user_id = None
|
||||
|
|
|
@ -76,6 +76,7 @@ from superset.models.reports import ReportRecipientType
|
|||
from superset.superset_typing import FlaskResponse
|
||||
from superset.translations.utils import get_language_pack
|
||||
from superset.utils import core as utils
|
||||
from superset.utils.core import get_user_id
|
||||
|
||||
from .utils import bootstrap_user_data
|
||||
|
||||
|
@ -623,10 +624,7 @@ class DatasourceFilter(BaseFilter): # pylint: disable=too-few-public-methods
|
|||
owner_ids_query = (
|
||||
db.session.query(models.SqlaTable.id)
|
||||
.join(models.SqlaTable.owners)
|
||||
.filter(
|
||||
security_manager.user_model.id
|
||||
== security_manager.user_model.get_user_id()
|
||||
)
|
||||
.filter(security_manager.user_model.id == get_user_id())
|
||||
)
|
||||
return query.filter(
|
||||
or_(
|
||||
|
|
|
@ -18,7 +18,7 @@ import functools
|
|||
import logging
|
||||
from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
|
||||
from flask import Blueprint, g, request, Response
|
||||
from flask import Blueprint, request, Response
|
||||
from flask_appbuilder import AppBuilder, Model, ModelRestApi
|
||||
from flask_appbuilder.api import expose, protect, rison, safe
|
||||
from flask_appbuilder.models.filters import BaseFilter, Filters
|
||||
|
@ -38,7 +38,7 @@ from superset.schemas import error_payload_content
|
|||
from superset.sql_lab import Query as SqllabQuery
|
||||
from superset.stats_logger import BaseStatsLogger
|
||||
from superset.superset_typing import FlaskResponse
|
||||
from superset.utils.core import time_function
|
||||
from superset.utils.core import get_user_id, time_function
|
||||
from superset.views.base import handle_api_exception
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -145,7 +145,7 @@ class BaseFavoriteFilter(BaseFilter): # pylint: disable=too-few-public-methods
|
|||
return query
|
||||
users_favorite_query = db.session.query(FavStar.obj_id).filter(
|
||||
and_(
|
||||
FavStar.user_id == g.user.get_id(),
|
||||
FavStar.user_id == get_user_id(),
|
||||
FavStar.class_name == self.class_name,
|
||||
)
|
||||
)
|
||||
|
|
|
@ -21,6 +21,8 @@ from flask_appbuilder import Model
|
|||
from marshmallow import post_load, pre_load, Schema, ValidationError
|
||||
from sqlalchemy.orm.exc import NoResultFound
|
||||
|
||||
from superset.utils.core import get_user_id
|
||||
|
||||
|
||||
def validate_owner(value: int) -> None:
|
||||
try:
|
||||
|
@ -113,8 +115,9 @@ class BaseOwnedSchema(BaseSupersetSchema):
|
|||
@staticmethod
|
||||
def set_owners(instance: Model, owners: List[int]) -> None:
|
||||
owner_objs = []
|
||||
if g.user.get_id() not in owners:
|
||||
owners.append(g.user.get_id())
|
||||
user_id = get_user_id()
|
||||
if user_id and user_id not in owners:
|
||||
owners.append(user_id)
|
||||
for owner_id in owners:
|
||||
user = current_app.appbuilder.get_session.query(
|
||||
current_app.appbuilder.sm.user_model
|
||||
|
|
|
@ -132,6 +132,7 @@ from superset.utils.cache import etag_cache
|
|||
from superset.utils.core import (
|
||||
apply_max_row_limit,
|
||||
DatasourceType,
|
||||
get_user_id,
|
||||
ReservedUrlParameters,
|
||||
)
|
||||
from superset.utils.dates import now_as_float
|
||||
|
@ -673,7 +674,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
request
|
||||
)["channel"]
|
||||
job_metadata = async_query_manager.init_job(
|
||||
async_channel_id, g.user.get_id()
|
||||
async_channel_id, get_user_id()
|
||||
)
|
||||
load_explore_json_into_cache.delay(
|
||||
job_metadata, form_data, response_type, force
|
||||
|
@ -1885,13 +1886,13 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
self, class_name: str, obj_id: int, action: str
|
||||
) -> FlaskResponse:
|
||||
"""Toggle favorite stars on Slices and Dashboard"""
|
||||
if not g.user.get_id():
|
||||
if not get_user_id():
|
||||
return json_error_response("ERROR: Favstar toggling denied", status=403)
|
||||
session = db.session()
|
||||
count = 0
|
||||
favs = (
|
||||
session.query(FavStar)
|
||||
.filter_by(class_name=class_name, obj_id=obj_id, user_id=g.user.get_id())
|
||||
.filter_by(class_name=class_name, obj_id=obj_id, user_id=get_user_id())
|
||||
.all()
|
||||
)
|
||||
if action == "select":
|
||||
|
@ -1900,7 +1901,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
FavStar(
|
||||
class_name=class_name,
|
||||
obj_id=obj_id,
|
||||
user_id=g.user.get_id(),
|
||||
user_id=get_user_id(),
|
||||
dttm=datetime.now(),
|
||||
)
|
||||
)
|
||||
|
@ -2582,7 +2583,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
@staticmethod
|
||||
def queries_exec(last_updated_ms: Union[float, int]) -> FlaskResponse:
|
||||
stats_logger.incr("queries")
|
||||
if not g.user.get_id():
|
||||
if not get_user_id():
|
||||
return json_error_response(
|
||||
"Please login to access the queries.", status=403
|
||||
)
|
||||
|
@ -2592,9 +2593,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
|
||||
sql_queries = (
|
||||
db.session.query(Query)
|
||||
.filter(
|
||||
Query.user_id == g.user.get_id(), Query.changed_on >= last_updated_dt
|
||||
)
|
||||
.filter(Query.user_id == get_user_id(), Query.changed_on >= last_updated_dt)
|
||||
.all()
|
||||
)
|
||||
dict_queries = {q.client_id: q.to_dict() for q in sql_queries}
|
||||
|
@ -2620,10 +2619,10 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
search_user_id = int(cast(int, request.args.get("user_id")))
|
||||
except ValueError:
|
||||
return Response(status=400, mimetype="application/json")
|
||||
if search_user_id != g.user.get_user_id():
|
||||
if search_user_id != get_user_id():
|
||||
return Response(status=403, mimetype="application/json")
|
||||
else:
|
||||
search_user_id = g.user.get_user_id()
|
||||
search_user_id = get_user_id()
|
||||
database_id = request.args.get("database_id")
|
||||
search_text = request.args.get("search_text")
|
||||
status = request.args.get("status")
|
||||
|
@ -2676,14 +2675,14 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
@expose("/welcome/")
|
||||
def welcome(self) -> FlaskResponse:
|
||||
"""Personalized welcome page"""
|
||||
if not g.user or not g.user.get_id():
|
||||
if not get_user_id():
|
||||
if conf["PUBLIC_ROLE_LIKE"]:
|
||||
return self.render_template("superset/public_welcome.html")
|
||||
return redirect(appbuilder.get_url_for_login)
|
||||
|
||||
welcome_dashboard_id = (
|
||||
db.session.query(UserAttribute.welcome_dashboard_id)
|
||||
.filter_by(user_id=g.user.get_id())
|
||||
.filter_by(user_id=get_user_id())
|
||||
.scalar()
|
||||
)
|
||||
if welcome_dashboard_id:
|
||||
|
@ -2728,7 +2727,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_sqllab_tabs(user_id: int) -> Dict[str, Any]:
|
||||
def _get_sqllab_tabs(user_id: Optional[int]) -> Dict[str, Any]:
|
||||
# send list of tab state ids
|
||||
tabs_state = (
|
||||
db.session.query(TabState.id, TabState.label)
|
||||
|
@ -2780,7 +2779,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
payload = {
|
||||
"defaultDbId": config["SQLLAB_DEFAULT_DBID"],
|
||||
"common": common_bootstrap_payload(),
|
||||
**self._get_sqllab_tabs(g.user.get_id()),
|
||||
**self._get_sqllab_tabs(get_user_id()),
|
||||
}
|
||||
|
||||
form_data = request.form.get("form_data")
|
||||
|
|
|
@ -29,6 +29,7 @@ from superset.constants import MODEL_VIEW_RW_METHOD_PERMISSION_MAP, RouteMethod
|
|||
from superset.models.sql_lab import Query, SavedQuery, TableSchema, TabState
|
||||
from superset.superset_typing import FlaskResponse
|
||||
from superset.utils import core as utils
|
||||
from superset.utils.core import get_user_id
|
||||
|
||||
from .base import BaseSupersetView, DeleteMixin, json_success, SupersetModelView
|
||||
|
||||
|
@ -136,7 +137,7 @@ class TabStateView(BaseSupersetView):
|
|||
def post(self) -> FlaskResponse: # pylint: disable=no-self-use
|
||||
query_editor = json.loads(request.form["queryEditor"])
|
||||
tab_state = TabState(
|
||||
user_id=g.user.get_id(),
|
||||
user_id=get_user_id(),
|
||||
label=query_editor.get("title", "Untitled Query"),
|
||||
active=True,
|
||||
database_id=query_editor["dbId"],
|
||||
|
@ -147,7 +148,7 @@ class TabStateView(BaseSupersetView):
|
|||
)
|
||||
(
|
||||
db.session.query(TabState)
|
||||
.filter_by(user_id=g.user.get_id())
|
||||
.filter_by(user_id=get_user_id())
|
||||
.update({"active": False})
|
||||
)
|
||||
db.session.add(tab_state)
|
||||
|
@ -157,7 +158,7 @@ class TabStateView(BaseSupersetView):
|
|||
@has_access_api
|
||||
@expose("/<int:tab_state_id>", methods=["DELETE"])
|
||||
def delete(self, tab_state_id: int) -> FlaskResponse: # pylint: disable=no-self-use
|
||||
if _get_owner_id(tab_state_id) != int(g.user.get_id()):
|
||||
if _get_owner_id(tab_state_id) != get_user_id():
|
||||
return Response(status=403)
|
||||
|
||||
db.session.query(TabState).filter(TabState.id == tab_state_id).delete(
|
||||
|
@ -172,7 +173,7 @@ class TabStateView(BaseSupersetView):
|
|||
@has_access_api
|
||||
@expose("/<int:tab_state_id>", methods=["GET"])
|
||||
def get(self, tab_state_id: int) -> FlaskResponse: # pylint: disable=no-self-use
|
||||
if _get_owner_id(tab_state_id) != int(g.user.get_id()):
|
||||
if _get_owner_id(tab_state_id) != get_user_id():
|
||||
return Response(status=403)
|
||||
|
||||
tab_state = db.session.query(TabState).filter_by(id=tab_state_id).first()
|
||||
|
@ -190,12 +191,12 @@ class TabStateView(BaseSupersetView):
|
|||
owner_id = _get_owner_id(tab_state_id)
|
||||
if owner_id is None:
|
||||
return Response(status=404)
|
||||
if owner_id != int(g.user.get_id()):
|
||||
if owner_id != get_user_id():
|
||||
return Response(status=403)
|
||||
|
||||
(
|
||||
db.session.query(TabState)
|
||||
.filter_by(user_id=g.user.get_id())
|
||||
.filter_by(user_id=get_user_id())
|
||||
.update({"active": TabState.id == tab_state_id})
|
||||
)
|
||||
db.session.commit()
|
||||
|
@ -204,7 +205,7 @@ class TabStateView(BaseSupersetView):
|
|||
@has_access_api
|
||||
@expose("<int:tab_state_id>", methods=["PUT"])
|
||||
def put(self, tab_state_id: int) -> FlaskResponse: # pylint: disable=no-self-use
|
||||
if _get_owner_id(tab_state_id) != int(g.user.get_id()):
|
||||
if _get_owner_id(tab_state_id) != get_user_id():
|
||||
return Response(status=403)
|
||||
|
||||
fields = {k: json.loads(v) for k, v in request.form.to_dict().items()}
|
||||
|
@ -217,7 +218,7 @@ class TabStateView(BaseSupersetView):
|
|||
def migrate_query( # pylint: disable=no-self-use
|
||||
self, tab_state_id: int
|
||||
) -> FlaskResponse:
|
||||
if _get_owner_id(tab_state_id) != int(g.user.get_id()):
|
||||
if _get_owner_id(tab_state_id) != get_user_id():
|
||||
return Response(status=403)
|
||||
|
||||
client_id = json.loads(request.form["queryId"])
|
||||
|
@ -244,7 +245,7 @@ class TabStateView(BaseSupersetView):
|
|||
.filter(
|
||||
and_(
|
||||
Query.client_id != client_id,
|
||||
Query.user_id == g.user.get_id(),
|
||||
Query.user_id == get_user_id(),
|
||||
Query.sql_editor_id == str(tab_state_id),
|
||||
),
|
||||
)
|
||||
|
@ -257,7 +258,7 @@ class TabStateView(BaseSupersetView):
|
|||
|
||||
db.session.query(Query).filter_by(
|
||||
client_id=client_id,
|
||||
user_id=g.user.get_id(),
|
||||
user_id=get_user_id(),
|
||||
sql_editor_id=str(tab_state_id),
|
||||
).delete(synchronize_session=False)
|
||||
db.session.commit()
|
||||
|
@ -327,4 +328,4 @@ class SqlLab(BaseSupersetView):
|
|||
logger.warning(
|
||||
"This endpoint is deprecated and will be removed in the next major release"
|
||||
)
|
||||
return redirect("/savedqueryview/list/?_flt_0_user={}".format(g.user.get_id()))
|
||||
return redirect(f"/savedqueryview/list/?_flt_0_user={get_user_id()}")
|
||||
|
|
|
@ -18,11 +18,12 @@
|
|||
"""Unit tests for Superset"""
|
||||
import json
|
||||
import unittest
|
||||
from typing import Optional
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from flask import g
|
||||
from flask.ctx import AppContext
|
||||
from pytest_mock import MockFixture
|
||||
from sqlalchemy import inspect
|
||||
|
||||
from tests.integration_tests.fixtures.birth_names_dashboard import (
|
||||
|
@ -42,7 +43,7 @@ from superset import db, security_manager
|
|||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.models import core as models
|
||||
from superset.models.datasource_access_request import DatasourceAccessRequest
|
||||
from superset.utils.core import get_username, override_user
|
||||
from superset.utils.core import get_user_id, get_username, override_user
|
||||
from superset.utils.database import get_example_database
|
||||
|
||||
from .base_tests import SupersetTestCase
|
||||
|
@ -524,18 +525,40 @@ class TestRequestAccess(SupersetTestCase):
|
|||
session.commit()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"username,user_id",
|
||||
[
|
||||
(None, None),
|
||||
("alpha", 5),
|
||||
("gamma", 2),
|
||||
],
|
||||
)
|
||||
def test_get_user_id(
|
||||
app_context: AppContext,
|
||||
mocker: MockFixture,
|
||||
username: Optional[str],
|
||||
user_id: Optional[int],
|
||||
) -> None:
|
||||
mock_g = mocker.patch("superset.utils.core.g", spec={})
|
||||
mock_g.user = security_manager.find_user(username)
|
||||
assert get_user_id() == user_id
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"username",
|
||||
[
|
||||
None,
|
||||
"alpha",
|
||||
"gamma",
|
||||
],
|
||||
)
|
||||
def test_get_username(app_context: AppContext, username: str) -> None:
|
||||
assert not hasattr(g, "user")
|
||||
assert get_username() is None
|
||||
|
||||
g.user = security_manager.find_user(username)
|
||||
def test_get_username(
|
||||
app_context: AppContext,
|
||||
mocker: MockFixture,
|
||||
username: Optional[str],
|
||||
) -> None:
|
||||
mock_g = mocker.patch("superset.utils.core.g", spec={})
|
||||
mock_g.user = security_manager.find_user(username)
|
||||
assert get_username() == username
|
||||
|
||||
|
||||
|
@ -543,26 +566,32 @@ def test_get_username(app_context: AppContext, username: str) -> None:
|
|||
"username",
|
||||
[
|
||||
None,
|
||||
"alpha",
|
||||
"gamma",
|
||||
],
|
||||
)
|
||||
def test_override_user(app_context: AppContext, username: str) -> None:
|
||||
def test_override_user(
|
||||
app_context: AppContext,
|
||||
mocker: MockFixture,
|
||||
username: str,
|
||||
) -> None:
|
||||
mock_g = mocker.patch("superset.utils.core.g", spec={})
|
||||
admin = security_manager.find_user(username="admin")
|
||||
user = security_manager.find_user(username)
|
||||
|
||||
assert not hasattr(g, "user")
|
||||
assert not hasattr(mock_g, "user")
|
||||
|
||||
with override_user(user):
|
||||
assert g.user == user
|
||||
assert mock_g.user == user
|
||||
|
||||
assert not hasattr(g, "user")
|
||||
assert not hasattr(mock_g, "user")
|
||||
|
||||
g.user = admin
|
||||
mock_g.user = admin
|
||||
|
||||
with override_user(user):
|
||||
assert g.user == user
|
||||
assert mock_g.user == user
|
||||
|
||||
assert g.user == admin
|
||||
assert mock_g.user == admin
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -112,7 +112,7 @@ class TestEventLogger(unittest.TestCase):
|
|||
)
|
||||
self.assertGreaterEqual(payload["duration_ms"], 100)
|
||||
|
||||
@patch("superset.utils.log.g", spec={})
|
||||
@patch("superset.utils.core.g", spec={})
|
||||
@freeze_time("Jan 14th, 2020", auto_tick_seconds=15)
|
||||
def test_context_manager_log(self, mock_g):
|
||||
class DummyEventLogger(AbstractEventLogger):
|
||||
|
@ -144,12 +144,12 @@ class TestEventLogger(unittest.TestCase):
|
|||
assert logger.records == [
|
||||
{
|
||||
"records": [{"path": "/", "engine": "bar"}],
|
||||
"user_id": "2",
|
||||
"user_id": 2,
|
||||
"duration": 15000.0,
|
||||
}
|
||||
]
|
||||
|
||||
@patch("superset.utils.log.g", spec={})
|
||||
@patch("superset.utils.core.g", spec={})
|
||||
def test_context_manager_log_with_context(self, mock_g):
|
||||
class DummyEventLogger(AbstractEventLogger):
|
||||
def __init__(self):
|
||||
|
@ -191,12 +191,12 @@ class TestEventLogger(unittest.TestCase):
|
|||
"payload_override": {"engine": "sqllite"},
|
||||
}
|
||||
],
|
||||
"user_id": "2",
|
||||
"user_id": 2,
|
||||
"duration": 5558756000,
|
||||
}
|
||||
]
|
||||
|
||||
@patch("superset.utils.log.g", spec={})
|
||||
@patch("superset.utils.core.g", spec={})
|
||||
def test_log_with_context_user_null(self, mock_g):
|
||||
class DummyEventLogger(AbstractEventLogger):
|
||||
def __init__(self):
|
||||
|
|
|
@ -412,8 +412,9 @@ class TestRolePermission(SupersetTestCase):
|
|||
# TODO test slice permission
|
||||
|
||||
@patch("superset.security.manager.g")
|
||||
def test_schemas_accessible_by_user_admin(self, mock_g):
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
@patch("superset.utils.core.g")
|
||||
def test_schemas_accessible_by_user_admin(self, mock_sm_g, mock_g):
|
||||
mock_g.user = mock_sm_g.user = security_manager.find_user("admin")
|
||||
with self.client.application.test_request_context():
|
||||
database = get_example_database()
|
||||
schemas = security_manager.get_schemas_accessible_by_user(
|
||||
|
@ -422,10 +423,11 @@ class TestRolePermission(SupersetTestCase):
|
|||
self.assertEqual(schemas, ["1", "2", "3"]) # no changes
|
||||
|
||||
@patch("superset.security.manager.g")
|
||||
def test_schemas_accessible_by_user_schema_access(self, mock_g):
|
||||
@patch("superset.utils.core.g")
|
||||
def test_schemas_accessible_by_user_schema_access(self, mock_sm_g, mock_g):
|
||||
# User has schema access to the schema 1
|
||||
create_schema_perm("[examples].[1]")
|
||||
mock_g.user = security_manager.find_user("gamma")
|
||||
mock_g.user = mock_sm_g.user = security_manager.find_user("gamma")
|
||||
with self.client.application.test_request_context():
|
||||
database = get_example_database()
|
||||
schemas = security_manager.get_schemas_accessible_by_user(
|
||||
|
@ -436,9 +438,10 @@ class TestRolePermission(SupersetTestCase):
|
|||
delete_schema_perm("[examples].[1]")
|
||||
|
||||
@patch("superset.security.manager.g")
|
||||
def test_schemas_accessible_by_user_datasource_access(self, mock_g):
|
||||
@patch("superset.utils.core.g")
|
||||
def test_schemas_accessible_by_user_datasource_access(self, mock_sm_g, mock_g):
|
||||
# User has schema access to the datasource temp_schema.wb_health_population in examples DB.
|
||||
mock_g.user = security_manager.find_user("gamma")
|
||||
mock_g.user = mock_sm_g.user = security_manager.find_user("gamma")
|
||||
with self.client.application.test_request_context():
|
||||
database = get_example_database()
|
||||
schemas = security_manager.get_schemas_accessible_by_user(
|
||||
|
@ -447,10 +450,13 @@ class TestRolePermission(SupersetTestCase):
|
|||
self.assertEqual(schemas, ["temp_schema"])
|
||||
|
||||
@patch("superset.security.manager.g")
|
||||
def test_schemas_accessible_by_user_datasource_and_schema_access(self, mock_g):
|
||||
@patch("superset.utils.core.g")
|
||||
def test_schemas_accessible_by_user_datasource_and_schema_access(
|
||||
self, mock_sm_g, mock_g
|
||||
):
|
||||
# User has schema access to the datasource temp_schema.wb_health_population in examples DB.
|
||||
create_schema_perm("[examples].[2]")
|
||||
mock_g.user = security_manager.find_user("gamma")
|
||||
mock_g.user = mock_sm_g.user = security_manager.find_user("gamma")
|
||||
with self.client.application.test_request_context():
|
||||
database = get_example_database()
|
||||
schemas = security_manager.get_schemas_accessible_by_user(
|
||||
|
|
|
@ -32,6 +32,7 @@ from superset.tasks.async_queries import (
|
|||
load_chart_data_into_cache,
|
||||
load_explore_json_into_cache,
|
||||
)
|
||||
from superset.utils.core import get_user_id
|
||||
from tests.integration_tests.base_tests import SupersetTestCase
|
||||
from tests.integration_tests.fixtures.birth_names_dashboard import (
|
||||
load_birth_names_dashboard_with_slices,
|
||||
|
@ -218,7 +219,7 @@ class TestAsyncQueries(SupersetTestCase):
|
|||
ensure_user_is_set(1)
|
||||
self.assertTrue(hasattr(g, "user"))
|
||||
self.assertFalse(g.user.is_anonymous)
|
||||
self.assertEqual("1", g.user.get_id())
|
||||
self.assertEqual(1, get_user_id())
|
||||
|
||||
del g.user
|
||||
|
||||
|
@ -226,22 +227,22 @@ class TestAsyncQueries(SupersetTestCase):
|
|||
ensure_user_is_set(None)
|
||||
self.assertTrue(hasattr(g, "user"))
|
||||
self.assertTrue(g.user.is_anonymous)
|
||||
self.assertEqual(None, g.user.get_id())
|
||||
self.assertEqual(None, get_user_id())
|
||||
|
||||
del g.user
|
||||
|
||||
g.user = security_manager.get_user_by_id(2)
|
||||
self.assertEqual("2", g.user.get_id())
|
||||
self.assertEqual(2, get_user_id())
|
||||
|
||||
ensure_user_is_set(1)
|
||||
self.assertTrue(hasattr(g, "user"))
|
||||
self.assertFalse(g.user.is_anonymous)
|
||||
self.assertEqual("2", g.user.get_id())
|
||||
self.assertEqual(2, get_user_id())
|
||||
|
||||
ensure_user_is_set(None)
|
||||
self.assertTrue(hasattr(g, "user"))
|
||||
self.assertFalse(g.user.is_anonymous)
|
||||
self.assertEqual("2", g.user.get_id())
|
||||
self.assertEqual(2, get_user_id())
|
||||
|
||||
if g_user_is_set:
|
||||
g.user = original_g_user
|
||||
|
|
Loading…
Reference in New Issue