chore: Cleanup database sessions (#10427)

Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
John Bodley 2020-07-30 23:07:56 -07:00 committed by GitHub
parent 7ff1757448
commit 7645fc85c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
39 changed files with 488 additions and 637 deletions

View File

@ -197,10 +197,9 @@ def set_database_uri(database_name: str, uri: str) -> None:
)
def refresh_druid(datasource: str, merge: bool) -> None:
"""Refresh druid datasources"""
session = db.session()
from superset.connectors.druid.models import DruidCluster
for cluster in session.query(DruidCluster).all():
for cluster in db.session.query(DruidCluster).all():
try:
cluster.refresh_datasources(datasource_name=datasource, merge_flag=merge)
except Exception as ex: # pylint: disable=broad-except
@ -208,7 +207,7 @@ def refresh_druid(datasource: str, merge: bool) -> None:
logger.exception(ex)
cluster.metadata_last_refreshed = datetime.now()
print("Refreshed metadata from cluster " "[" + cluster.cluster_name + "]")
session.commit()
db.session.commit()
@superset.command()
@ -250,7 +249,7 @@ def import_dashboards(path: str, recursive: bool, username: str) -> None:
logger.info("Importing dashboard from file %s", file_)
try:
with file_.open() as data_stream:
dashboard_import_export.import_dashboards(db.session, data_stream)
dashboard_import_export.import_dashboards(data_stream)
except Exception as ex: # pylint: disable=broad-except
logger.error("Error when importing dashboard from file %s", file_)
logger.error(ex)
@ -268,7 +267,7 @@ def export_dashboards(dashboard_file: str, print_stdout: bool) -> None:
"""Export dashboards to JSON"""
from superset.utils import dashboard_import_export
data = dashboard_import_export.export_dashboards(db.session)
data = dashboard_import_export.export_dashboards()
if print_stdout or not dashboard_file:
print(data)
if dashboard_file:
@ -321,7 +320,7 @@ def import_datasources(path: str, sync: str, recursive: bool) -> None:
try:
with file_.open() as data_stream:
dict_import_export.import_from_dict(
db.session, yaml.safe_load(data_stream), sync=sync_array
yaml.safe_load(data_stream), sync=sync_array
)
except Exception as ex: # pylint: disable=broad-except
logger.error("Error when importing datasources from file %s", file_)
@ -360,7 +359,6 @@ def export_datasources(
from superset.utils import dict_import_export
data = dict_import_export.export_to_dict(
session=db.session,
recursive=True,
back_references=back_references,
include_defaults=include_defaults,

View File

@ -25,7 +25,7 @@ from superset.commands.exceptions import (
)
from superset.connectors.base.models import BaseDatasource
from superset.connectors.connector_registry import ConnectorRegistry
from superset.extensions import db, security_manager
from superset.extensions import security_manager
def populate_owners(user: User, owners_ids: Optional[List[int]] = None) -> List[User]:
@ -50,8 +50,6 @@ def populate_owners(user: User, owners_ids: Optional[List[int]] = None) -> List[
def get_datasource_by_id(datasource_id: int, datasource_type: str) -> BaseDatasource:
try:
return ConnectorRegistry.get_datasource(
datasource_type, datasource_id, db.session
)
return ConnectorRegistry.get_datasource(datasource_type, datasource_id)
except (NoResultFound, KeyError):
raise DatasourceNotFoundValidationError()

View File

@ -23,7 +23,7 @@ from typing import Any, ClassVar, Dict, List, Optional, Union
import numpy as np
import pandas as pd
from superset import app, cache, db, security_manager
from superset import app, cache, security_manager
from superset.common.query_object import QueryObject
from superset.connectors.base.models import BaseDatasource
from superset.connectors.connector_registry import ConnectorRegistry
@ -64,7 +64,7 @@ class QueryContext:
result_format: Optional[utils.ChartDataResultFormat] = None,
) -> None:
self.datasource = ConnectorRegistry.get_datasource(
str(datasource["type"]), int(datasource["id"]), db.session
str(datasource["type"]), int(datasource["id"])
)
self.queries = [QueryObject(**query_obj) for query_obj in queries]
self.force = force

View File

@ -17,7 +17,9 @@
from typing import Dict, List, Optional, Set, Type, TYPE_CHECKING
from sqlalchemy import or_
from sqlalchemy.orm import Session, subqueryload
from sqlalchemy.orm import subqueryload
from superset.extensions import db
if TYPE_CHECKING:
# pylint: disable=unused-import
@ -43,20 +45,20 @@ class ConnectorRegistry:
@classmethod
def get_datasource(
cls, datasource_type: str, datasource_id: int, session: Session
cls, datasource_type: str, datasource_id: int
) -> "BaseDatasource":
return (
session.query(cls.sources[datasource_type])
db.session.query(cls.sources[datasource_type])
.filter_by(id=datasource_id)
.one()
)
@classmethod
def get_all_datasources(cls, session: Session) -> List["BaseDatasource"]:
def get_all_datasources(cls) -> List["BaseDatasource"]:
datasources: List["BaseDatasource"] = []
for source_type in ConnectorRegistry.sources:
source_class = ConnectorRegistry.sources[source_type]
qry = session.query(source_class)
qry = db.session.query(source_class)
qry = source_class.default_query(qry)
datasources.extend(qry.all())
return datasources
@ -64,7 +66,6 @@ class ConnectorRegistry:
@classmethod
def get_datasource_by_name( # pylint: disable=too-many-arguments
cls,
session: Session,
datasource_type: str,
datasource_name: str,
schema: str,
@ -72,21 +73,17 @@ class ConnectorRegistry:
) -> Optional["BaseDatasource"]:
datasource_class = ConnectorRegistry.sources[datasource_type]
return datasource_class.get_datasource_by_name(
session, datasource_name, schema, database_name
datasource_name, schema, database_name
)
@classmethod
def query_datasources_by_permissions( # pylint: disable=invalid-name
cls,
session: Session,
database: "Database",
permissions: Set[str],
schema_perms: Set[str],
cls, database: "Database", permissions: Set[str], schema_perms: Set[str],
) -> List["BaseDatasource"]:
# TODO(bogdan): add unit test
datasource_class = ConnectorRegistry.sources[database.type]
return (
session.query(datasource_class)
db.session.query(datasource_class)
.filter_by(database_id=database.id)
.filter(
or_(
@ -99,12 +96,12 @@ class ConnectorRegistry:
@classmethod
def get_eager_datasource(
cls, session: Session, datasource_type: str, datasource_id: int
cls, datasource_type: str, datasource_id: int
) -> "BaseDatasource":
"""Returns datasource with columns and metrics."""
datasource_class = ConnectorRegistry.sources[datasource_type]
return (
session.query(datasource_class)
db.session.query(datasource_class)
.options(
subqueryload(datasource_class.columns),
subqueryload(datasource_class.metrics),
@ -115,13 +112,9 @@ class ConnectorRegistry:
@classmethod
def query_datasources_by_name(
cls,
session: Session,
database: "Database",
datasource_name: str,
schema: Optional[str] = None,
cls, database: "Database", datasource_name: str, schema: Optional[str] = None,
) -> List["BaseDatasource"]:
datasource_class = ConnectorRegistry.sources[database.type]
return datasource_class.query_datasources_by_name(
session, database, datasource_name, schema=schema
database, datasource_name, schema=schema
)

View File

@ -45,7 +45,7 @@ from sqlalchemy import (
UniqueConstraint,
)
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import backref, relationship, Session
from sqlalchemy.orm import backref, relationship
from sqlalchemy.sql import expression
from sqlalchemy_utils import EncryptedType
@ -223,9 +223,8 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
Fetches metadata for the specified datasources and
merges to the Superset database
"""
session = db.session
ds_list = (
session.query(DruidDatasource)
db.session.query(DruidDatasource)
.filter(DruidDatasource.cluster_id == self.id)
.filter(DruidDatasource.datasource_name.in_(datasource_names))
)
@ -234,8 +233,8 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
datasource = ds_map.get(ds_name, None)
if not datasource:
datasource = DruidDatasource(datasource_name=ds_name)
with session.no_autoflush:
session.add(datasource)
with db.session.no_autoflush:
db.session.add(datasource)
flasher(_("Adding new datasource [{}]").format(ds_name), "success")
ds_map[ds_name] = datasource
elif refresh_all:
@ -245,7 +244,7 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
continue
datasource.cluster = self
datasource.merge_flag = merge_flag
session.flush()
db.session.flush()
# Prepare multithreaded executation
pool = ThreadPool()
@ -259,7 +258,7 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
cols = metadata[i]
if cols:
col_objs_list = (
session.query(DruidColumn)
db.session.query(DruidColumn)
.filter(DruidColumn.datasource_id == datasource.id)
.filter(DruidColumn.column_name.in_(cols.keys()))
)
@ -272,15 +271,15 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
col_obj = DruidColumn(
datasource_id=datasource.id, column_name=col
)
with session.no_autoflush:
session.add(col_obj)
with db.session.no_autoflush:
db.session.add(col_obj)
col_obj.type = cols[col]["type"]
col_obj.datasource = datasource
if col_obj.type == "STRING":
col_obj.groupby = True
col_obj.filterable = True
datasource.refresh_metrics()
session.commit()
db.session.commit()
@hybrid_property
def perm(self) -> str:
@ -390,7 +389,7 @@ class DruidColumn(Model, BaseColumn):
.first()
)
return import_datasource.import_simple_obj(db.session, i_column, lookup_obj)
return import_datasource.import_simple_obj(i_column, lookup_obj)
class DruidMetric(Model, BaseMetric):
@ -459,7 +458,7 @@ class DruidMetric(Model, BaseMetric):
.first()
)
return import_datasource.import_simple_obj(db.session, i_metric, lookup_obj)
return import_datasource.import_simple_obj(i_metric, lookup_obj)
druiddatasource_user = Table(
@ -635,7 +634,7 @@ class DruidDatasource(Model, BaseDatasource):
return db.session.query(DruidCluster).filter_by(id=d.cluster_id).first()
return import_datasource.import_datasource(
db.session, i_datasource, lookup_cluster, lookup_datasource, import_time
i_datasource, lookup_cluster, lookup_datasource, import_time
)
def latest_metadata(self) -> Optional[Dict[str, Any]]:
@ -705,9 +704,10 @@ class DruidDatasource(Model, BaseDatasource):
refresh: bool = True,
) -> None:
"""Merges the ds config from druid_config into one stored in the db."""
session = db.session
datasource = (
session.query(cls).filter_by(datasource_name=druid_config["name"]).first()
db.session.query(cls)
.filter_by(datasource_name=druid_config["name"])
.first()
)
# Create a new datasource.
if not datasource:
@ -718,13 +718,13 @@ class DruidDatasource(Model, BaseDatasource):
changed_by_fk=user.id,
created_by_fk=user.id,
)
session.add(datasource)
db.session.add(datasource)
elif not refresh:
return
dimensions = druid_config["dimensions"]
col_objs = (
session.query(DruidColumn)
db.session.query(DruidColumn)
.filter(DruidColumn.datasource_id == datasource.id)
.filter(DruidColumn.column_name.in_(dimensions))
)
@ -741,10 +741,10 @@ class DruidDatasource(Model, BaseDatasource):
type="STRING",
datasource=datasource,
)
session.add(col_obj)
db.session.add(col_obj)
# Import Druid metrics
metric_objs = (
session.query(DruidMetric)
db.session.query(DruidMetric)
.filter(DruidMetric.datasource_id == datasource.id)
.filter(
DruidMetric.metric_name.in_(
@ -777,8 +777,8 @@ class DruidDatasource(Model, BaseDatasource):
% druid_config["name"]
),
)
session.add(metric_obj)
session.commit()
db.session.add(metric_obj)
db.session.commit()
@staticmethod
def time_offset(granularity: Granularity) -> int:
@ -788,10 +788,10 @@ class DruidDatasource(Model, BaseDatasource):
@classmethod
def get_datasource_by_name(
cls, session: Session, datasource_name: str, schema: str, database_name: str
cls, datasource_name: str, schema: str, database_name: str
) -> Optional["DruidDatasource"]:
query = (
session.query(cls)
db.session.query(cls)
.join(DruidCluster)
.filter(cls.datasource_name == datasource_name)
.filter(DruidCluster.cluster_name == database_name)
@ -1724,11 +1724,7 @@ class DruidDatasource(Model, BaseDatasource):
@classmethod
def query_datasources_by_name(
cls,
session: Session,
database: Database,
datasource_name: str,
schema: Optional[str] = None,
cls, database: Database, datasource_name: str, schema: Optional[str] = None,
) -> List["DruidDatasource"]:
return []

View File

@ -365,11 +365,10 @@ class Druid(BaseSupersetView):
self, refresh_all: bool = True
) -> FlaskResponse:
"""endpoint that refreshes druid datasources metadata"""
session = db.session()
DruidCluster = ConnectorRegistry.sources[ # pylint: disable=invalid-name
"druid"
].cluster_class
for cluster in session.query(DruidCluster).all():
for cluster in db.session.query(DruidCluster).all():
cluster_name = cluster.cluster_name
valid_cluster = True
try:
@ -391,7 +390,7 @@ class Druid(BaseSupersetView):
),
"info",
)
session.commit()
db.session.commit()
return redirect("/druiddatasourcemodelview/list/")
@has_access

View File

@ -41,7 +41,7 @@ from sqlalchemy import (
Text,
)
from sqlalchemy.exc import CompileError
from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, Session
from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import column, ColumnElement, literal_column, table, text
@ -255,7 +255,7 @@ class TableColumn(Model, BaseColumn):
.first()
)
return import_datasource.import_simple_obj(db.session, i_column, lookup_obj)
return import_datasource.import_simple_obj(i_column, lookup_obj)
def dttm_sql_literal(
self,
@ -375,7 +375,7 @@ class SqlMetric(Model, BaseMetric):
.first()
)
return import_datasource.import_simple_obj(db.session, i_metric, lookup_obj)
return import_datasource.import_simple_obj(i_metric, lookup_obj)
sqlatable_user = Table(
@ -503,15 +503,11 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
@classmethod
def get_datasource_by_name(
cls,
session: Session,
datasource_name: str,
schema: Optional[str],
database_name: str,
cls, datasource_name: str, schema: Optional[str], database_name: str,
) -> Optional["SqlaTable"]:
schema = schema or None
query = (
session.query(cls)
db.session.query(cls)
.join(Database)
.filter(cls.table_name == datasource_name)
.filter(Database.database_name == database_name)
@ -1296,24 +1292,15 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
)
return import_datasource.import_datasource(
db.session,
i_datasource,
lookup_database,
lookup_sqlatable,
import_time,
database_id,
i_datasource, lookup_database, lookup_sqlatable, import_time, database_id,
)
@classmethod
def query_datasources_by_name(
cls,
session: Session,
database: Database,
datasource_name: str,
schema: Optional[str] = None,
cls, database: Database, datasource_name: str, schema: Optional[str] = None,
) -> List["SqlaTable"]:
query = (
session.query(cls)
db.session.query(cls)
.filter_by(database_id=database.id)
.filter_by(table_name=datasource_name)
)

View File

@ -99,9 +99,7 @@ class DashboardDAO(BaseDAO):
except KeyError:
pass
session = db.session()
current_slices = session.query(Slice).filter(Slice.id.in_(slice_ids)).all()
current_slices = db.session.query(Slice).filter(Slice.id.in_(slice_ids)).all()
dashboard.slices = current_slices
# update slice names. this assumes user has permissions to update the slice
@ -111,8 +109,8 @@ class DashboardDAO(BaseDAO):
new_name = slice_id_to_name[slc.id]
if slc.slice_name != new_name:
slc.slice_name = new_name
session.merge(slc)
session.flush()
db.session.merge(slc)
db.session.flush()
except KeyError:
pass

View File

@ -37,7 +37,7 @@ from sqlalchemy import (
UniqueConstraint,
)
from sqlalchemy.engine.base import Connection
from sqlalchemy.orm import relationship, sessionmaker, subqueryload
from sqlalchemy.orm import relationship, subqueryload
from sqlalchemy.orm.mapper import Mapper
from superset import app, ConnectorRegistry, db, is_feature_enabled, security_manager
@ -62,18 +62,17 @@ config = app.config
logger = logging.getLogger(__name__)
def copy_dashboard(mapper: Mapper, connection: Connection, target: "Dashboard") -> None:
# pylint: disable=unused-argument
def copy_dashboard( # pylint: disable=unused-argument
mapper: Mapper, connection: Connection, target: "Dashboard"
) -> None:
dashboard_id = config["DASHBOARD_TEMPLATE_ID"]
if dashboard_id is None:
return
session_class = sessionmaker(autoflush=False)
session = session_class(bind=connection)
new_user = session.query(User).filter_by(id=target.id).first()
new_user = db.session.query(User).filter_by(id=target.id).first()
# copy template dashboard to user
template = session.query(Dashboard).filter_by(id=int(dashboard_id)).first()
template = db.session.query(Dashboard).filter_by(id=int(dashboard_id)).first()
dashboard = Dashboard(
dashboard_title=template.dashboard_title,
position_json=template.position_json,
@ -83,15 +82,15 @@ def copy_dashboard(mapper: Mapper, connection: Connection, target: "Dashboard")
slices=template.slices,
owners=[new_user],
)
session.add(dashboard)
session.commit()
db.session.add(dashboard)
db.session.commit()
# set dashboard as the welcome dashboard
extra_attributes = UserAttribute(
user_id=target.id, welcome_dashboard_id=dashboard.id
)
session.add(extra_attributes)
session.commit()
db.session.add(extra_attributes)
db.session.commit()
sqla.event.listen(User, "after_insert", copy_dashboard)
@ -307,7 +306,6 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
logger.info(
"Started import of the dashboard: %s", dashboard_to_import.to_json()
)
session = db.session
logger.info("Dashboard has %d slices", len(dashboard_to_import.slices))
# copy slices object as Slice.import_slice will mutate the slice
# and will remove the existing dashboard - slice association
@ -324,7 +322,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
i_params_dict = dashboard_to_import.params_dict
remote_id_slice_map = {
slc.params_dict["remote_id"]: slc
for slc in session.query(Slice).all()
for slc in db.session.query(Slice).all()
if "remote_id" in slc.params_dict
}
for slc in slices:
@ -375,7 +373,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
# override the dashboard
existing_dashboard = None
for dash in session.query(Dashboard).all():
for dash in db.session.query(Dashboard).all():
if (
"remote_id" in dash.params_dict
and dash.params_dict["remote_id"] == dashboard_to_import.id
@ -402,7 +400,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
)
new_slices = (
session.query(Slice)
db.session.query(Slice)
.filter(Slice.id.in_(old_to_new_slc_id_dict.values()))
.all()
)
@ -410,12 +408,12 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
if existing_dashboard:
existing_dashboard.override(dashboard_to_import)
existing_dashboard.slices = new_slices
session.flush()
db.session.flush()
return existing_dashboard.id
dashboard_to_import.slices = new_slices
session.add(dashboard_to_import)
session.flush()
db.session.add(dashboard_to_import)
db.session.flush()
return dashboard_to_import.id # type: ignore
@classmethod
@ -457,7 +455,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
eager_datasources = []
for datasource_id, datasource_type in datasource_ids:
eager_datasource = ConnectorRegistry.get_eager_datasource(
db.session, datasource_type, datasource_id
datasource_type, datasource_id
)
copied_datasource = eager_datasource.copy()
copied_datasource.alter_params(

View File

@ -34,9 +34,9 @@ from flask_appbuilder.models.mixins import AuditMixin
from flask_appbuilder.security.sqla.models import User
from sqlalchemy import and_, or_, UniqueConstraint
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import Session
from sqlalchemy.orm.exc import MultipleResultsFound
from superset.extensions import db
from superset.utils.core import QueryStatus
logger = logging.getLogger(__name__)
@ -127,7 +127,6 @@ class ImportMixin:
@classmethod
def import_from_dict( # pylint: disable=too-many-arguments,too-many-branches,too-many-locals
cls,
session: Session,
dict_rep: Dict[Any, Any],
parent: Optional[Any] = None,
recursive: bool = True,
@ -178,7 +177,7 @@ class ImportMixin:
# Check if object already exists in DB, break if more than one is found
try:
obj_query = session.query(cls).filter(and_(*filters))
obj_query = db.session.query(cls).filter(and_(*filters))
obj = obj_query.one_or_none()
except MultipleResultsFound as ex:
logger.error(
@ -196,7 +195,7 @@ class ImportMixin:
logger.info("Importing new %s %s", obj.__tablename__, str(obj))
if cls.export_parent and parent:
setattr(obj, cls.export_parent, parent)
session.add(obj)
db.session.add(obj)
else:
is_new_obj = False
logger.info("Updating %s %s", obj.__tablename__, str(obj))
@ -214,7 +213,7 @@ class ImportMixin:
for c_obj in new_children.get(child, []):
added.append(
child_class.import_from_dict(
session=session, dict_rep=c_obj, parent=obj, sync=sync
dict_rep=c_obj, parent=obj, sync=sync
)
)
# If children should get synced, delete the ones that did not
@ -228,11 +227,11 @@ class ImportMixin:
for k in back_refs.keys()
]
to_delete = set(
session.query(child_class).filter(and_(*delete_filters))
db.session.query(child_class).filter(and_(*delete_filters))
).difference(set(added))
for o in to_delete:
logger.info("Deleting %s %s", child, str(obj))
session.delete(o)
db.session.delete(o)
return obj

View File

@ -300,7 +300,6 @@ class Slice(
:returns: The resulting id for the imported slice
:rtype: int
"""
session = db.session
make_transient(slc_to_import)
slc_to_import.dashboards = []
slc_to_import.alter_params(remote_id=slc_to_import.id, import_time=import_time)
@ -309,7 +308,6 @@ class Slice(
slc_to_import.reset_ownership()
params = slc_to_import.params_dict
datasource = ConnectorRegistry.get_datasource_by_name(
session,
slc_to_import.datasource_type,
params["datasource_name"],
params["schema"],
@ -318,11 +316,11 @@ class Slice(
slc_to_import.datasource_id = datasource.id # type: ignore
if slc_to_override:
slc_to_override.override(slc_to_import)
session.flush()
db.session.flush()
return slc_to_override.id
session.add(slc_to_import)
db.session.add(slc_to_import)
logger.info("Final slice: %s", str(slc_to_import.to_json()))
session.flush()
db.session.flush()
return slc_to_import.id
@property

View File

@ -22,10 +22,11 @@ from typing import List, Optional, TYPE_CHECKING, Union
from flask_appbuilder import Model
from sqlalchemy import Column, Enum, ForeignKey, Integer, String
from sqlalchemy.engine.base import Connection
from sqlalchemy.orm import relationship, Session, sessionmaker
from sqlalchemy.orm import relationship
from sqlalchemy.orm.exc import NoResultFound
from sqlalchemy.orm.mapper import Mapper
from superset.extensions import db
from superset.models.helpers import AuditMixinNullable
if TYPE_CHECKING:
@ -34,8 +35,6 @@ if TYPE_CHECKING:
from superset.models.slice import Slice # pylint: disable=unused-import
from superset.models.sql_lab import Query # pylint: disable=unused-import
Session = sessionmaker(autoflush=False)
class TagTypes(enum.Enum):
@ -88,13 +87,13 @@ class TaggedObject(Model, AuditMixinNullable):
tag = relationship("Tag", backref="objects")
def get_tag(name: str, session: Session, type_: TagTypes) -> Tag:
def get_tag(name: str, type_: TagTypes) -> Tag:
try:
tag = session.query(Tag).filter_by(name=name, type=type_).one()
tag = db.session.query(Tag).filter_by(name=name, type=type_).one()
except NoResultFound:
tag = Tag(name=name, type=type_)
session.add(tag)
session.commit()
db.session.add(tag)
db.session.commit()
return tag
@ -122,52 +121,43 @@ class ObjectUpdater:
raise NotImplementedError("Subclass should implement `get_owners_ids`")
@classmethod
def _add_owners(
cls, session: Session, target: Union["Dashboard", "FavStar", "Slice"]
) -> None:
def _add_owners(cls, target: Union["Dashboard", "FavStar", "Slice"]) -> None:
for owner_id in cls.get_owners_ids(target):
name = "owner:{0}".format(owner_id)
tag = get_tag(name, session, TagTypes.owner)
tag = get_tag(name, TagTypes.owner)
tagged_object = TaggedObject(
tag_id=tag.id, object_id=target.id, object_type=cls.object_type
)
session.add(tagged_object)
db.session.add(tagged_object)
@classmethod
def after_insert(
def after_insert( # pylint: disable=unused-argument
cls,
mapper: Mapper,
connection: Connection,
target: Union["Dashboard", "FavStar", "Slice"],
) -> None:
# pylint: disable=unused-argument
session = Session(bind=connection)
# add `owner:` tags
cls._add_owners(session, target)
cls._add_owners(target)
# add `type:` tags
tag = get_tag("type:{0}".format(cls.object_type), session, TagTypes.type)
tag = get_tag("type:{0}".format(cls.object_type), TagTypes.type)
tagged_object = TaggedObject(
tag_id=tag.id, object_id=target.id, object_type=cls.object_type
)
session.add(tagged_object)
session.commit()
db.session.add(tagged_object)
db.session.commit()
@classmethod
def after_update(
def after_update( # pylint: disable=unused-argument
cls,
mapper: Mapper,
connection: Connection,
target: Union["Dashboard", "FavStar", "Slice"],
) -> None:
# pylint: disable=unused-argument
session = Session(bind=connection)
# delete current `owner:` tags
query = (
session.query(TaggedObject.id)
db.session.query(TaggedObject.id)
.join(Tag)
.filter(
TaggedObject.object_type == cls.object_type,
@ -176,32 +166,28 @@ class ObjectUpdater:
)
)
ids = [row[0] for row in query]
session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
db.session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
synchronize_session=False
)
# add `owner:` tags
cls._add_owners(session, target)
session.commit()
cls._add_owners(target)
db.session.commit()
@classmethod
def after_delete(
def after_delete( # pylint: disable=unused-argument
cls,
mapper: Mapper,
connection: Connection,
target: Union["Dashboard", "FavStar", "Slice"],
) -> None:
# pylint: disable=unused-argument
session = Session(bind=connection)
# delete row from `tagged_objects`
session.query(TaggedObject).filter(
db.session.query(TaggedObject).filter(
TaggedObject.object_type == cls.object_type,
TaggedObject.object_id == target.id,
).delete()
session.commit()
db.session.commit()
class ChartUpdater(ObjectUpdater):
@ -233,31 +219,26 @@ class QueryUpdater(ObjectUpdater):
class FavStarUpdater:
@classmethod
def after_insert(
def after_insert( # pylint: disable=unused-argument
cls, mapper: Mapper, connection: Connection, target: "FavStar"
) -> None:
# pylint: disable=unused-argument
session = Session(bind=connection)
name = "favorited_by:{0}".format(target.user_id)
tag = get_tag(name, session, TagTypes.favorited_by)
tag = get_tag(name, TagTypes.favorited_by)
tagged_object = TaggedObject(
tag_id=tag.id,
object_id=target.obj_id,
object_type=get_object_type(target.class_name),
)
session.add(tagged_object)
session.commit()
db.session.add(tagged_object)
db.session.commit()
@classmethod
def after_delete(
cls, mapper: Mapper, connection: Connection, target: "FavStar"
def after_delete( # pylint: disable=unused-argument
cls, mapper: Mapper, connection: Connection, target: "FavStar",
) -> None:
# pylint: disable=unused-argument
session = Session(bind=connection)
name = "favorited_by:{0}".format(target.user_id)
query = (
session.query(TaggedObject.id)
db.session.query(TaggedObject.id)
.join(Tag)
.filter(
TaggedObject.object_id == target.obj_id,
@ -266,8 +247,8 @@ class FavStarUpdater:
)
)
ids = [row[0] for row in query]
session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
db.session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
synchronize_session=False
)
session.commit()
db.session.commit()

View File

@ -507,7 +507,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
user_perms = self.user_view_menu_names("datasource_access")
schema_perms = self.user_view_menu_names("schema_access")
user_datasources = ConnectorRegistry.query_datasources_by_permissions(
self.get_session, database, user_perms, schema_perms
database, user_perms, schema_perms
)
if schema:
names = {d.table_name for d in user_datasources if d.schema == schema}
@ -568,7 +568,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
self.add_permission_view_menu(view_menu, perm)
logger.info("Creating missing datasource permissions.")
datasources = ConnectorRegistry.get_all_datasources(self.get_session)
datasources = ConnectorRegistry.get_all_datasources()
for datasource in datasources:
merge_pv("datasource_access", datasource.get_perm())
merge_pv("schema_access", datasource.get_schema_perm())
@ -901,7 +901,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
if not (schema_perm and self.can_access("schema_access", schema_perm)):
datasources = SqlaTable.query_datasources_by_name(
self.get_session, database, table_.table, schema=table_.schema
database, table_.table, schema=table_.schema
)
# Access to any datasource is suffice.

View File

@ -132,8 +132,7 @@ def session_scope(nullpool: bool) -> Iterator[Session]:
)
if nullpool:
engine = sqlalchemy.create_engine(database_uri, poolclass=NullPool)
session_class = sessionmaker()
session_class.configure(bind=engine)
session_class = sessionmaker(bind=engine)
session = session_class()
else:
session = db.session()

View File

@ -134,8 +134,7 @@ class DummyStrategy(Strategy):
name = "dummy"
def get_urls(self) -> List[str]:
session = db.create_scoped_session()
charts = session.query(Slice).all()
charts = db.session.query(Slice).all()
return [get_url(chart) for chart in charts]
@ -167,10 +166,9 @@ class TopNDashboardsStrategy(Strategy):
def get_urls(self) -> List[str]:
urls = []
session = db.create_scoped_session()
records = (
session.query(Log.dashboard_id, func.count(Log.dashboard_id))
db.session.query(Log.dashboard_id, func.count(Log.dashboard_id))
.filter(and_(Log.dashboard_id.isnot(None), Log.dttm >= self.since))
.group_by(Log.dashboard_id)
.order_by(func.count(Log.dashboard_id).desc())
@ -178,7 +176,9 @@ class TopNDashboardsStrategy(Strategy):
.all()
)
dash_ids = [record.dashboard_id for record in records]
dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
dashboards = (
db.session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
)
for dashboard in dashboards:
for chart in dashboard.slices:
form_data_with_filters = get_form_data(chart.id, dashboard)
@ -211,14 +211,13 @@ class DashboardTagsStrategy(Strategy):
def get_urls(self) -> List[str]:
urls = []
session = db.create_scoped_session()
tags = session.query(Tag).filter(Tag.name.in_(self.tags)).all()
tags = db.session.query(Tag).filter(Tag.name.in_(self.tags)).all()
tag_ids = [tag.id for tag in tags]
# add dashboards that are tagged
tagged_objects = (
session.query(TaggedObject)
db.session.query(TaggedObject)
.filter(
and_(
TaggedObject.object_type == "dashboard",
@ -228,14 +227,16 @@ class DashboardTagsStrategy(Strategy):
.all()
)
dash_ids = [tagged_object.object_id for tagged_object in tagged_objects]
tagged_dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids))
tagged_dashboards = db.session.query(Dashboard).filter(
Dashboard.id.in_(dash_ids)
)
for dashboard in tagged_dashboards:
for chart in dashboard.slices:
urls.append(get_url(chart))
# add charts that are tagged
tagged_objects = (
session.query(TaggedObject)
db.session.query(TaggedObject)
.filter(
and_(
TaggedObject.object_type == "chart",
@ -245,7 +246,7 @@ class DashboardTagsStrategy(Strategy):
.all()
)
chart_ids = [tagged_object.object_id for tagged_object in tagged_objects]
tagged_charts = session.query(Slice).filter(Slice.id.in_(chart_ids))
tagged_charts = db.session.query(Slice).filter(Slice.id.in_(chart_ids))
for chart in tagged_charts:
urls.append(get_url(chart))

View File

@ -48,7 +48,6 @@ from flask_login import login_user
from retry.api import retry_call
from selenium.common.exceptions import WebDriverException
from selenium.webdriver import chrome, firefox
from sqlalchemy.orm import Session
from werkzeug.http import parse_cookie
from superset import app, db, security_manager, thumbnail_cache
@ -543,8 +542,7 @@ def schedule_alert_query( # pylint: disable=unused-argument
is_test_alert: Optional[bool] = False,
) -> None:
model_cls = get_scheduler_model(report_type)
dbsession = db.create_scoped_session()
schedule = dbsession.query(model_cls).get(schedule_id)
schedule = db.session.query(model_cls).get(schedule_id)
# The user may have disabled the schedule. If so, ignore this
if not schedule or not schedule.active:
@ -556,7 +554,7 @@ def schedule_alert_query( # pylint: disable=unused-argument
deliver_alert(schedule, recipients)
return
if run_alert_query(schedule, dbsession):
if run_alert_query(schedule):
# deliver_dashboard OR deliver_slice
return
else:
@ -614,7 +612,7 @@ def deliver_alert(alert: Alert, recipients: Optional[str] = None) -> None:
_deliver_email(recipients, deliver_as_group, subject, body, data, images)
def run_alert_query(alert: Alert, dbsession: Session) -> Optional[bool]:
def run_alert_query(alert: Alert) -> Optional[bool]:
"""
Execute alert.sql and return value if any rows are returned
"""
@ -666,7 +664,7 @@ def run_alert_query(alert: Alert, dbsession: Session) -> Optional[bool]:
state=state,
)
)
dbsession.commit()
db.session.commit()
return None
@ -706,8 +704,7 @@ def schedule_window(
if not model_cls:
return None
dbsession = db.create_scoped_session()
schedules = dbsession.query(model_cls).filter(model_cls.active.is_(True))
schedules = db.session.query(model_cls).filter(model_cls.active.is_(True))
for schedule in schedules:
logging.info("Processing schedule %s", schedule)

View File

@ -22,10 +22,10 @@ from io import BytesIO
from typing import Any, Dict, Optional
from flask_babel import lazy_gettext as _
from sqlalchemy.orm import Session
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.exceptions import DashboardImportException
from superset.extensions import db
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
@ -71,7 +71,6 @@ def decode_dashboards( # pylint: disable=too-many-return-statements
def import_dashboards(
session: Session,
data_stream: BytesIO,
database_id: Optional[int] = None,
import_time: Optional[int] = None,
@ -84,16 +83,16 @@ def import_dashboards(
raise DashboardImportException(_("No data in file"))
for table in data["datasources"]:
type(table).import_obj(table, database_id, import_time=import_time)
session.commit()
db.session.commit()
for dashboard in data["dashboards"]:
Dashboard.import_obj(dashboard, import_time=import_time)
session.commit()
db.session.commit()
def export_dashboards(session: Session) -> str:
def export_dashboards() -> str:
"""Returns all dashboards metadata as a json dump"""
logger.info("Starting export")
dashboards = session.query(Dashboard)
dashboards = db.session.query(Dashboard)
dashboard_ids = []
for dashboard in dashboards:
dashboard_ids.append(dashboard.id)

View File

@ -17,9 +17,8 @@
import logging
from typing import Any, Dict, List, Optional
from sqlalchemy.orm import Session
from superset.connectors.druid.models import DruidCluster
from superset.extensions import db
from superset.models.core import Database
DATABASES_KEY = "databases"
@ -44,11 +43,11 @@ def export_schema_to_dict(back_references: bool) -> Dict[str, Any]:
def export_to_dict(
session: Session, recursive: bool, back_references: bool, include_defaults: bool
recursive: bool, back_references: bool, include_defaults: bool
) -> Dict[str, Any]:
"""Exports databases and druid clusters to a dictionary"""
logger.info("Starting export")
dbs = session.query(Database)
dbs = db.session.query(Database)
databases = [
database.export_to_dict(
recursive=recursive,
@ -58,7 +57,7 @@ def export_to_dict(
for database in dbs
]
logger.info("Exported %d %s", len(databases), DATABASES_KEY)
cls = session.query(DruidCluster)
cls = db.session.query(DruidCluster)
clusters = [
cluster.export_to_dict(
recursive=recursive,
@ -76,22 +75,20 @@ def export_to_dict(
return data
def import_from_dict(
session: Session, data: Dict[str, Any], sync: Optional[List[str]] = None
) -> None:
def import_from_dict(data: Dict[str, Any], sync: Optional[List[str]] = None) -> None:
"""Imports databases and druid clusters from dictionary"""
if not sync:
sync = []
if isinstance(data, dict):
logger.info("Importing %d %s", len(data.get(DATABASES_KEY, [])), DATABASES_KEY)
for database in data.get(DATABASES_KEY, []):
Database.import_from_dict(session, database, sync=sync)
Database.import_from_dict(database, sync=sync)
logger.info(
"Importing %d %s", len(data.get(DRUID_CLUSTERS_KEY, [])), DRUID_CLUSTERS_KEY
)
for datasource in data.get(DRUID_CLUSTERS_KEY, []):
DruidCluster.import_from_dict(session, datasource, sync=sync)
session.commit()
DruidCluster.import_from_dict(datasource, sync=sync)
db.session.commit()
else:
logger.info("Supplied object is not a dictionary.")

View File

@ -18,14 +18,14 @@ import logging
from typing import Callable, Optional
from flask_appbuilder import Model
from sqlalchemy.orm import Session
from sqlalchemy.orm.session import make_transient
from superset.extensions import db
logger = logging.getLogger(__name__)
def import_datasource( # pylint: disable=too-many-arguments
session: Session,
i_datasource: Model,
lookup_database: Callable[[Model], Model],
lookup_datasource: Callable[[Model], Model],
@ -52,11 +52,11 @@ def import_datasource( # pylint: disable=too-many-arguments
if datasource:
datasource.override(i_datasource)
session.flush()
db.session.flush()
else:
datasource = i_datasource.copy()
session.add(datasource)
session.flush()
db.session.add(datasource)
db.session.flush()
for metric in i_datasource.metrics:
new_m = metric.copy()
@ -81,13 +81,11 @@ def import_datasource( # pylint: disable=too-many-arguments
imported_c = i_datasource.column_class.import_obj(new_c)
if imported_c.column_name not in [c.column_name for c in datasource.columns]:
datasource.columns.append(imported_c)
session.flush()
db.session.flush()
return datasource.id
def import_simple_obj(
session: Session, i_obj: Model, lookup_obj: Callable[[Model], Model]
) -> Model:
def import_simple_obj(i_obj: Model, lookup_obj: Callable[[Model], Model]) -> Model:
make_transient(i_obj)
i_obj.id = None
i_obj.table = None
@ -97,9 +95,9 @@ def import_simple_obj(
i_obj.table = None
if existing_column:
existing_column.override(i_obj)
session.flush()
db.session.flush()
return existing_column
session.add(i_obj)
session.flush()
db.session.add(i_obj)
db.session.flush()
return i_obj

View File

@ -487,8 +487,7 @@ def check_ownership(obj: Any, raise_if_false: bool = True) -> bool:
roles = [r.name for r in get_user_roles()]
if "Admin" in roles:
return True
scoped_session = db.create_scoped_session()
orig_obj = scoped_session.query(obj.__class__).filter_by(id=obj.id).first()
orig_obj = db.session.query(obj.__class__).filter_by(id=obj.id).first()
# Making a list of owners that works across ORM models
owners: List[User] = []

View File

@ -20,7 +20,7 @@ from flask_appbuilder import expose, has_access
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import lazy_gettext as _
from superset import app, db
from superset import app
from superset.connectors.connector_registry import ConnectorRegistry
from superset.constants import RouteMethod
from superset.models.slice import Slice
@ -56,7 +56,7 @@ class SliceModelView(
def add(self) -> FlaskResponse:
datasources = [
{"value": str(d.id) + "__" + d.type, "label": repr(d)}
for d in ConnectorRegistry.get_all_datasources(db.session)
for d in ConnectorRegistry.get_all_datasources()
]
return self.render_template(
"superset/add_slice.html",

View File

@ -40,7 +40,6 @@ from sqlalchemy.exc import (
OperationalError,
SQLAlchemyError,
)
from sqlalchemy.orm.session import Session
from werkzeug.urls import Href
import superset.models.core as models
@ -164,7 +163,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
sorted(
[
datasource.short_data
for datasource in ConnectorRegistry.get_all_datasources(db.session)
for datasource in ConnectorRegistry.get_all_datasources()
if datasource.short_data.get("name")
],
key=lambda datasource: datasource["name"],
@ -203,7 +202,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
)
db_ds_names.add(fullname)
existing_datasources = ConnectorRegistry.get_all_datasources(db.session)
existing_datasources = ConnectorRegistry.get_all_datasources()
datasources = [d for d in existing_datasources if d.full_name in db_ds_names]
role = security_manager.find_role(role_name)
# remove all permissions
@ -270,15 +269,15 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
@has_access
@expose("/approve")
def approve(self) -> FlaskResponse: # pylint: disable=too-many-locals,no-self-use
def clean_fulfilled_requests(session: Session) -> None:
for dar in session.query(DAR).all():
def clean_fulfilled_requests() -> None:
for dar in db.session.query(DAR).all():
datasource = ConnectorRegistry.get_datasource(
dar.datasource_type, dar.datasource_id, session
dar.datasource_type, dar.datasource_id
)
if not datasource or security_manager.can_access_datasource(datasource):
# datasource does not exist anymore
session.delete(dar)
session.commit()
db.session.delete(dar)
db.session.commit()
datasource_type = request.args["datasource_type"]
datasource_id = request.args["datasource_id"]
@ -286,10 +285,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
role_to_grant = request.args.get("role_to_grant")
role_to_extend = request.args.get("role_to_extend")
session = db.session
datasource = ConnectorRegistry.get_datasource(
datasource_type, datasource_id, session
)
datasource = ConnectorRegistry.get_datasource(datasource_type, datasource_id)
if not datasource:
flash(DATASOURCE_MISSING_ERR, "alert")
@ -301,7 +297,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
return json_error_response(USER_MISSING_ERR)
requests = (
session.query(DAR)
db.session.query(DAR)
.filter(
DAR.datasource_id == datasource_id,
DAR.datasource_type == datasource_type,
@ -361,13 +357,13 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
app.config,
)
flash(msg, "info")
clean_fulfilled_requests(session)
clean_fulfilled_requests()
else:
flash(__("You have no permission to approve this request"), "danger")
return redirect("/accessrequestsmodelview/list/")
for request_ in requests:
session.delete(request_)
session.commit()
db.session.delete(request_)
db.session.commit()
return redirect("/accessrequestsmodelview/list/")
@has_access
@ -548,7 +544,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
database_id = request.form.get("db_id")
try:
dashboard_import_export.import_dashboards(
db.session, import_file.stream, database_id
import_file.stream, database_id
)
success = True
except DatabaseNotFound as ex:
@ -630,7 +626,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
return redirect(error_redirect)
datasource = ConnectorRegistry.get_datasource(
cast(str, datasource_type), datasource_id, db.session
cast(str, datasource_type), datasource_id
)
if not datasource:
flash(DATASOURCE_MISSING_ERR, "danger")
@ -749,9 +745,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
:raises SupersetSecurityException: If the user cannot access the resource
"""
# TODO: Cache endpoint by user, datasource and column
datasource = ConnectorRegistry.get_datasource(
datasource_type, datasource_id, db.session
)
datasource = ConnectorRegistry.get_datasource(datasource_type, datasource_id)
if not datasource:
return json_error_response(DATASOURCE_MISSING_ERR)
@ -1015,10 +1009,9 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
self, dashboard_id: int
) -> FlaskResponse:
"""Copy dashboard"""
session = db.session()
data = json.loads(request.form["data"])
dash = models.Dashboard()
original_dash = session.query(Dashboard).get(dashboard_id)
original_dash = db.session.query(Dashboard).get(dashboard_id)
dash.owners = [g.user] if g.user else []
dash.dashboard_title = data["dashboard_title"]
@ -1029,8 +1022,8 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
for slc in original_dash.slices:
new_slice = slc.clone()
new_slice.owners = [g.user] if g.user else []
session.add(new_slice)
session.flush()
db.session.add(new_slice)
db.session.flush()
new_slice.dashboards.append(dash)
old_to_new_slice_ids[slc.id] = new_slice.id
@ -1046,10 +1039,9 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
dash.params = original_dash.params
DashboardDAO.set_dash_metadata(dash, data, old_to_new_slice_ids)
session.add(dash)
session.commit()
db.session.add(dash)
db.session.commit()
dash_json = json.dumps(dash.data)
session.close()
return json_success(dash_json)
@api
@ -1059,14 +1051,12 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
self, dashboard_id: int
) -> FlaskResponse:
"""Save a dashboard's metadata"""
session = db.session()
dash = session.query(Dashboard).get(dashboard_id)
dash = db.session.query(Dashboard).get(dashboard_id)
check_ownership(dash, raise_if_false=True)
data = json.loads(request.form["data"])
DashboardDAO.set_dash_metadata(dash, data)
session.merge(dash)
session.commit()
session.close()
db.session.merge(dash)
db.session.commit()
return json_success(json.dumps({"status": "SUCCESS"}))
@api
@ -1077,14 +1067,12 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
) -> FlaskResponse:
"""Add and save slices to a dashboard"""
data = json.loads(request.form["data"])
session = db.session()
dash = session.query(Dashboard).get(dashboard_id)
dash = db.session.query(Dashboard).get(dashboard_id)
check_ownership(dash, raise_if_false=True)
new_slices = session.query(Slice).filter(Slice.id.in_(data["slice_ids"]))
new_slices = db.session.query(Slice).filter(Slice.id.in_(data["slice_ids"]))
dash.slices += new_slices
session.merge(dash)
session.commit()
session.close()
db.session.merge(dash)
db.session.commit()
return "SLICES ADDED"
@api
@ -1431,7 +1419,6 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
Note for slices a force refresh occurs.
"""
session = db.session()
slice_id = request.args.get("slice_id")
dashboard_id = request.args.get("dashboard_id")
table_name = request.args.get("table_name")
@ -1446,14 +1433,14 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
status=400,
)
if slice_id:
slices = session.query(Slice).filter_by(id=slice_id).all()
slices = db.session.query(Slice).filter_by(id=slice_id).all()
if not slices:
return json_error_response(
__("Chart %(id)s not found", id=slice_id), status=404
)
elif table_name and db_name:
table = (
session.query(SqlaTable)
db.session.query(SqlaTable)
.join(models.Database)
.filter(
models.Database.database_name == db_name
@ -1470,7 +1457,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
status=404,
)
slices = (
session.query(Slice)
db.session.query(Slice)
.filter_by(datasource_id=table.id, datasource_type=table.type)
.all()
)
@ -1513,17 +1500,16 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
self, class_name: str, obj_id: int, action: str
) -> FlaskResponse:
"""Toggle favorite stars on Slices and Dashboard"""
session = db.session()
FavStar = models.FavStar
count = 0
favs = (
session.query(FavStar)
db.session.query(FavStar)
.filter_by(class_name=class_name, obj_id=obj_id, user_id=g.user.get_id())
.all()
)
if action == "select":
if not favs:
session.add(
db.session.add(
FavStar(
class_name=class_name,
obj_id=obj_id,
@ -1534,10 +1520,10 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
count = 1
elif action == "unselect":
for fav in favs:
session.delete(fav)
db.session.delete(fav)
else:
count = len(favs)
session.commit()
db.session.commit()
return json_success(json.dumps({"count": count}))
@api
@ -1550,12 +1536,13 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
logger.warning(
"This API endpoint is deprecated and will be removed in version 1.0.0"
)
session = db.session()
Role = ab_models.Role
dash = (
session.query(Dashboard).filter(Dashboard.id == dashboard_id).one_or_none()
db.session.query(Dashboard)
.filter(Dashboard.id == dashboard_id)
.one_or_none()
)
admin_role = session.query(Role).filter(Role.name == "Admin").one_or_none()
admin_role = db.session.query(Role).filter(Role.name == "Admin").one_or_none()
if request.method == "GET":
if dash:
@ -1574,7 +1561,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
)
dash.published = str(request.form["published"]).lower() == "true"
session.commit()
db.session.commit()
return json_success(json.dumps({"published": dash.published}))
@has_access
@ -1583,8 +1570,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
self, dashboard_id_or_slug: str
) -> FlaskResponse:
"""Server side rendering for a dashboard"""
session = db.session()
qry = session.query(Dashboard)
qry = db.session.query(Dashboard)
if dashboard_id_or_slug.isdigit():
qry = qry.filter_by(id=int(dashboard_id_or_slug))
else:
@ -2042,8 +2028,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
"SQL validation does not support template parameters", status=400
)
session = db.session()
mydb = session.query(models.Database).filter_by(id=database_id).one_or_none()
mydb = db.session.query(models.Database).filter_by(id=database_id).one_or_none()
if not mydb:
return json_error_response(
"Database with id {} is missing.".format(database_id), status=400
@ -2092,7 +2077,6 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
@staticmethod
def _sql_json_async( # pylint: disable=too-many-arguments
session: Session,
rendered_query: str,
query: Query,
expand_data: bool,
@ -2101,7 +2085,6 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
"""
Send SQL JSON query to celery workers.
:param session: SQLAlchemy session object
:param rendered_query: the rendered query to perform by workers
:param query: The query (SQLAlchemy) object
:return: A Flask Response
@ -2132,7 +2115,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
)
query.status = QueryStatus.FAILED
query.error_message = msg
session.commit()
db.session.commit()
return json_error_response("{}".format(msg))
resp = json_success(
json.dumps(
@ -2142,12 +2125,11 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
),
status=202,
)
session.commit()
db.session.commit()
return resp
@staticmethod
def _sql_json_sync(
_session: Session,
rendered_query: str,
query: Query,
expand_data: bool,
@ -2241,8 +2223,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
tab_name: str = cast(str, query_params.get("tab"))
status: str = QueryStatus.PENDING if async_flag else QueryStatus.RUNNING
session = db.session()
mydb = session.query(models.Database).get(database_id)
mydb = db.session.query(models.Database).get(database_id)
if not mydb:
return json_error_response("Database with id %i is missing.", database_id)
@ -2273,13 +2254,13 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
client_id=client_id,
)
try:
session.add(query)
session.flush()
db.session.add(query)
db.session.flush()
query_id = query.id
session.commit() # shouldn't be necessary
db.session.commit() # shouldn't be necessary
except SQLAlchemyError as ex:
logger.error("Errors saving query details %s", str(ex))
session.rollback()
db.session.rollback()
raise Exception(_("Query record was not created as expected."))
if not query_id:
raise Exception(_("Query record was not created as expected."))
@ -2290,7 +2271,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
query.raise_for_access()
except SupersetSecurityException as ex:
query.status = QueryStatus.FAILED
session.commit()
db.session.commit()
return json_errors_response([ex.error], status=403)
try:
@ -2323,13 +2304,9 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
# Async request.
if async_flag:
return self._sql_json_async(
session, rendered_query, query, expand_data, log_params
)
return self._sql_json_async(rendered_query, query, expand_data, log_params)
# Sync request.
return self._sql_json_sync(
session, rendered_query, query, expand_data, log_params
)
return self._sql_json_sync(rendered_query, query, expand_data, log_params)
@has_access
@expose("/csv/<client_id>")
@ -2398,9 +2375,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
"""
datasource_id, datasource_type = request.args["datasourceKey"].split("__")
datasource = ConnectorRegistry.get_datasource(
datasource_type, datasource_id, db.session
)
datasource = ConnectorRegistry.get_datasource(datasource_type, datasource_id)
# Check if datasource exists
if not datasource:
return json_error_response(DATASOURCE_MISSING_ERR)

View File

@ -47,7 +47,7 @@ class Datasource(BaseSupersetView):
datasource_type = datasource_dict.get("type")
database_id = datasource_dict["database"].get("id")
orm_datasource = ConnectorRegistry.get_datasource(
datasource_type, datasource_id, db.session
datasource_type, datasource_id
)
orm_datasource.database_id = database_id
@ -82,7 +82,7 @@ class Datasource(BaseSupersetView):
def get(self, datasource_type: str, datasource_id: int) -> FlaskResponse:
try:
orm_datasource = ConnectorRegistry.get_datasource(
datasource_type, datasource_id, db.session
datasource_type, datasource_id
)
if not orm_datasource.data:
return json_error_response(
@ -102,7 +102,7 @@ class Datasource(BaseSupersetView):
"""Gets column info from the source system"""
if datasource_type == "druid":
datasource = ConnectorRegistry.get_datasource(
datasource_type, datasource_id, db.session
datasource_type, datasource_id
)
elif datasource_type == "table":
database = (

View File

@ -105,9 +105,7 @@ def get_viz(
form_data: FormData, datasource_type: str, datasource_id: int, force: bool = False
) -> BaseViz:
viz_type = form_data.get("viz_type", "table")
datasource = ConnectorRegistry.get_datasource(
datasource_type, datasource_id, db.session
)
datasource = ConnectorRegistry.get_datasource(datasource_type, datasource_id)
viz_obj = viz.viz_types[viz_type](datasource, form_data=form_data, force=force)
return viz_obj
@ -293,8 +291,7 @@ CONTAINER_TYPES = ["COLUMN", "GRID", "TABS", "TAB", "ROW"]
def get_dashboard_extra_filters(
slice_id: int, dashboard_id: int
) -> List[Dict[str, Any]]:
session = db.session()
dashboard = session.query(Dashboard).filter_by(id=dashboard_id).one_or_none()
dashboard = db.session.query(Dashboard).filter_by(id=dashboard_id).one_or_none()
# is chart in this dashboard?
if (

View File

@ -71,13 +71,17 @@ DB_ACCESS_ROLE = "db_access_role"
SCHEMA_ACCESS_ROLE = "schema_access_role"
def create_access_request(session, ds_type, ds_name, role_name, user_name):
def create_access_request(ds_type, ds_name, role_name, user_name):
ds_class = ConnectorRegistry.sources[ds_type]
# TODO: generalize datasource names
if ds_type == "table":
ds = session.query(ds_class).filter(ds_class.table_name == ds_name).first()
ds = db.session.query(ds_class).filter(ds_class.table_name == ds_name).first()
else:
ds = session.query(ds_class).filter(ds_class.datasource_name == ds_name).first()
ds = (
db.session.query(ds_class)
.filter(ds_class.datasource_name == ds_name)
.first()
)
ds_perm_view = security_manager.find_permission_view_menu(
"datasource_access", ds.perm
)
@ -89,8 +93,8 @@ def create_access_request(session, ds_type, ds_name, role_name, user_name):
datasource_type=ds_type,
created_by_fk=security_manager.find_user(username=user_name).id,
)
session.add(access_request)
session.commit()
db.session.add(access_request)
db.session.commit()
return access_request
@ -126,7 +130,6 @@ class TestRequestAccess(SupersetTestCase):
override_me = security_manager.find_role("override_me")
override_me.permissions = []
db.session.commit()
db.session.close()
def test_override_role_permissions_is_admin_only(self):
self.logout()
@ -211,7 +214,6 @@ class TestRequestAccess(SupersetTestCase):
)
def test_clean_requests_after_role_extend(self):
session = db.session
# Case 1. Gamma and gamma2 requested test_role1 on energy_usage access
# Gamma already has role test_role1
@ -221,12 +223,10 @@ class TestRequestAccess(SupersetTestCase):
# gamma2 and gamma request table_role on energy usage
if app.config["ENABLE_ACCESS_REQUEST"]:
access_request1 = create_access_request(
session, "table", "random_time_series", TEST_ROLE_1, "gamma2"
"table", "random_time_series", TEST_ROLE_1, "gamma2"
)
ds_1_id = access_request1.datasource_id
create_access_request(
session, "table", "random_time_series", TEST_ROLE_1, "gamma"
)
create_access_request("table", "random_time_series", TEST_ROLE_1, "gamma")
access_requests = self.get_access_requests("gamma", "table", ds_1_id)
self.assertTrue(access_requests)
# gamma gets test_role1
@ -244,22 +244,20 @@ class TestRequestAccess(SupersetTestCase):
gamma_user.roles.remove(security_manager.find_role("test_role1"))
def test_clean_requests_after_alpha_grant(self):
session = db.session
# Case 2. Two access requests from gamma and gamma2
# Gamma becomes alpha, gamma2 gets granted
# Check if request by gamma has been deleted
access_request1 = create_access_request(
session, "table", "birth_names", TEST_ROLE_1, "gamma"
"table", "birth_names", TEST_ROLE_1, "gamma"
)
create_access_request(session, "table", "birth_names", TEST_ROLE_2, "gamma2")
create_access_request("table", "birth_names", TEST_ROLE_2, "gamma2")
ds_1_id = access_request1.datasource_id
# gamma becomes alpha
alpha_role = security_manager.find_role("Alpha")
gamma_user = security_manager.find_user(username="gamma")
gamma_user.roles.append(alpha_role)
session.commit()
db.session.commit()
access_requests = self.get_access_requests("gamma", "table", ds_1_id)
self.assertTrue(access_requests)
self.client.get(
@ -270,23 +268,21 @@ class TestRequestAccess(SupersetTestCase):
gamma_user = security_manager.find_user(username="gamma")
gamma_user.roles.remove(security_manager.find_role("Alpha"))
session.commit()
db.session.commit()
def test_clean_requests_after_db_grant(self):
session = db.session
# Case 3. Two access requests from gamma and gamma2
# Gamma gets database access, gamma2 access request granted
# Check if request by gamma has been deleted
gamma_user = security_manager.find_user(username="gamma")
access_request1 = create_access_request(
session, "table", "energy_usage", TEST_ROLE_1, "gamma"
"table", "energy_usage", TEST_ROLE_1, "gamma"
)
create_access_request(session, "table", "energy_usage", TEST_ROLE_2, "gamma2")
create_access_request("table", "energy_usage", TEST_ROLE_2, "gamma2")
ds_1_id = access_request1.datasource_id
# gamma gets granted database access
database = session.query(models.Database).first()
database = db.session.query(models.Database).first()
security_manager.add_permission_view_menu("database_access", database.perm)
ds_perm_view = security_manager.find_permission_view_menu(
@ -296,7 +292,7 @@ class TestRequestAccess(SupersetTestCase):
security_manager.find_role(DB_ACCESS_ROLE), ds_perm_view
)
gamma_user.roles.append(security_manager.find_role(DB_ACCESS_ROLE))
session.commit()
db.session.commit()
access_requests = self.get_access_requests("gamma", "table", ds_1_id)
self.assertTrue(access_requests)
# gamma2 request gets fulfilled
@ -308,25 +304,21 @@ class TestRequestAccess(SupersetTestCase):
self.assertFalse(access_requests)
gamma_user = security_manager.find_user(username="gamma")
gamma_user.roles.remove(security_manager.find_role(DB_ACCESS_ROLE))
session.commit()
db.session.commit()
def test_clean_requests_after_schema_grant(self):
session = db.session
# Case 4. Two access requests from gamma and gamma2
# Gamma gets schema access, gamma2 access request granted
# Check if request by gamma has been deleted
gamma_user = security_manager.find_user(username="gamma")
access_request1 = create_access_request(
session, "table", "wb_health_population", TEST_ROLE_1, "gamma"
)
create_access_request(
session, "table", "wb_health_population", TEST_ROLE_2, "gamma2"
"table", "wb_health_population", TEST_ROLE_1, "gamma"
)
create_access_request("table", "wb_health_population", TEST_ROLE_2, "gamma2")
ds_1_id = access_request1.datasource_id
ds = (
session.query(SqlaTable)
db.session.query(SqlaTable)
.filter_by(table_name="wb_health_population")
.first()
)
@ -340,7 +332,7 @@ class TestRequestAccess(SupersetTestCase):
security_manager.find_role(SCHEMA_ACCESS_ROLE), schema_perm_view
)
gamma_user.roles.append(security_manager.find_role(SCHEMA_ACCESS_ROLE))
session.commit()
db.session.commit()
# gamma2 request gets fulfilled
self.client.get(
EXTEND_ROLE_REQUEST.format("table", ds_1_id, "gamma2", TEST_ROLE_2)
@ -351,25 +343,24 @@ class TestRequestAccess(SupersetTestCase):
gamma_user.roles.remove(security_manager.find_role(SCHEMA_ACCESS_ROLE))
ds = (
session.query(SqlaTable)
db.session.query(SqlaTable)
.filter_by(table_name="wb_health_population")
.first()
)
ds.schema = None
session.commit()
db.session.commit()
@mock.patch("superset.utils.core.send_mime_email")
def test_approve(self, mock_send_mime):
if app.config["ENABLE_ACCESS_REQUEST"]:
session = db.session
TEST_ROLE_NAME = "table_role"
security_manager.add_role(TEST_ROLE_NAME)
# Case 1. Grant new role to the user.
access_request1 = create_access_request(
session, "table", "unicode_test", TEST_ROLE_NAME, "gamma"
"table", "unicode_test", TEST_ROLE_NAME, "gamma"
)
ds_1_id = access_request1.datasource_id
self.get_resp(
@ -404,7 +395,7 @@ class TestRequestAccess(SupersetTestCase):
# Case 2. Extend the role to have access to the table
access_request2 = create_access_request(
session, "table", "energy_usage", TEST_ROLE_NAME, "gamma"
"table", "energy_usage", TEST_ROLE_NAME, "gamma"
)
ds_2_id = access_request2.datasource_id
energy_usage_perm = access_request2.datasource.perm
@ -448,7 +439,7 @@ class TestRequestAccess(SupersetTestCase):
security_manager.add_role("druid_role")
access_request3 = create_access_request(
session, "druid", "druid_ds_1", "druid_role", "gamma"
"druid", "druid_ds_1", "druid_role", "gamma"
)
self.get_resp(
GRANT_ROLE_REQUEST.format(
@ -463,7 +454,7 @@ class TestRequestAccess(SupersetTestCase):
# Case 4. Extend the role to have access to the druid datasource
access_request4 = create_access_request(
session, "druid", "druid_ds_2", "druid_role", "gamma"
"druid", "druid_ds_2", "druid_role", "gamma"
)
druid_ds_2_perm = access_request4.datasource.perm
@ -483,19 +474,18 @@ class TestRequestAccess(SupersetTestCase):
gamma_user = security_manager.find_user(username="gamma")
gamma_user.roles.remove(security_manager.find_role("druid_role"))
gamma_user.roles.remove(security_manager.find_role(TEST_ROLE_NAME))
session.delete(security_manager.find_role("druid_role"))
session.delete(security_manager.find_role(TEST_ROLE_NAME))
session.commit()
db.session.delete(security_manager.find_role("druid_role"))
db.session.delete(security_manager.find_role(TEST_ROLE_NAME))
db.session.commit()
def test_request_access(self):
if app.config["ENABLE_ACCESS_REQUEST"]:
session = db.session
self.logout()
self.login(username="gamma")
gamma_user = security_manager.find_user(username="gamma")
security_manager.add_role("dummy_role")
gamma_user.roles.append(security_manager.find_role("dummy_role"))
session.commit()
db.session.commit()
ACCESS_REQUEST = (
"/superset/request_access?"
@ -511,7 +501,7 @@ class TestRequestAccess(SupersetTestCase):
# Request table access, there are no roles have this table.
table1 = (
session.query(SqlaTable)
db.session.query(SqlaTable)
.filter_by(table_name="random_time_series")
.first()
)
@ -526,7 +516,7 @@ class TestRequestAccess(SupersetTestCase):
# Request access, roles exist that contains the table.
# add table to the existing roles
table3 = (
session.query(SqlaTable).filter_by(table_name="energy_usage").first()
db.session.query(SqlaTable).filter_by(table_name="energy_usage").first()
)
table_3_id = table3.id
table3_perm = table3.perm
@ -545,7 +535,7 @@ class TestRequestAccess(SupersetTestCase):
"datasource_access", table3_perm
),
)
session.commit()
db.session.commit()
self.get_resp(ACCESS_REQUEST.format("table", table_3_id, "go"))
access_request3 = self.get_access_requests("gamma", "table", table_3_id)
@ -559,7 +549,7 @@ class TestRequestAccess(SupersetTestCase):
# Request druid access, there are no roles have this table.
druid_ds_4 = (
session.query(DruidDatasource)
db.session.query(DruidDatasource)
.filter_by(datasource_name="druid_ds_1")
.first()
)
@ -574,7 +564,7 @@ class TestRequestAccess(SupersetTestCase):
# Case 5. Roles exist that contains the druid datasource.
# add druid ds to the existing roles
druid_ds_5 = (
session.query(DruidDatasource)
db.session.query(DruidDatasource)
.filter_by(datasource_name="druid_ds_2")
.first()
)
@ -595,7 +585,7 @@ class TestRequestAccess(SupersetTestCase):
"datasource_access", druid_ds_5_perm
),
)
session.commit()
db.session.commit()
self.get_resp(ACCESS_REQUEST.format("druid", druid_ds_5_id, "go"))
access_request5 = self.get_access_requests("gamma", "druid", druid_ds_5_id)
@ -610,7 +600,7 @@ class TestRequestAccess(SupersetTestCase):
# cleanup
gamma_user = security_manager.find_user(username="gamma")
gamma_user.roles.remove(security_manager.find_role("dummy_role"))
session.commit()
db.session.commit()
if __name__ == "__main__":

View File

@ -32,112 +32,118 @@ logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@pytest.yield_fixture(scope="module")
def setup_database():
def setup_module():
with app.app_context():
slice_id = db.session.query(Slice).all()[0].id
database_id = utils.get_example_database().id
alert1 = Alert(
id=1,
label="alert_1",
active=True,
crontab="*/1 * * * *",
sql="SELECT 0",
alert_type="email",
slice_id=slice_id,
database_id=database_id,
)
alert2 = Alert(
id=2,
label="alert_2",
active=True,
crontab="*/1 * * * *",
sql="SELECT 55",
alert_type="email",
slice_id=slice_id,
database_id=database_id,
)
alert3 = Alert(
id=3,
label="alert_3",
active=False,
crontab="*/1 * * * *",
sql="UPDATE 55",
alert_type="email",
slice_id=slice_id,
database_id=database_id,
)
alert4 = Alert(id=4, active=False, label="alert_4", database_id=-1)
alert5 = Alert(id=5, active=False, label="alert_5", database_id=database_id)
alerts = [
Alert(
id=1,
label="alert_1",
active=True,
crontab="*/1 * * * *",
sql="SELECT 0",
alert_type="email",
slice_id=slice_id,
database_id=database_id,
),
Alert(
id=2,
label="alert_2",
active=True,
crontab="*/1 * * * *",
sql="SELECT 55",
alert_type="email",
slice_id=slice_id,
database_id=database_id,
),
Alert(
id=3,
label="alert_3",
active=False,
crontab="*/1 * * * *",
sql="UPDATE 55",
alert_type="email",
slice_id=slice_id,
database_id=database_id,
),
Alert(id=4, active=False, label="alert_4", database_id=-1),
Alert(id=5, active=False, label="alert_5", database_id=database_id),
]
for num in range(1, 6):
eval(f"db.session.add(alert{num})")
db.session.bulk_save_objects(alerts)
db.session.commit()
yield db.session
def teardown_module():
with app.app_context():
db.session.query(AlertLog).delete()
db.session.query(Alert).delete()
@patch("superset.tasks.schedules.deliver_alert")
@patch("superset.tasks.schedules.logging.Logger.error")
def test_run_alert_query(mock_error, mock_deliver, setup_database):
database = setup_database
run_alert_query(database.query(Alert).filter_by(id=1).one(), database)
alert1 = database.query(Alert).filter_by(id=1).one()
assert mock_deliver.call_count == 0
assert len(alert1.logs) == 1
assert alert1.logs[0].alert_id == 1
assert alert1.logs[0].state == "pass"
def test_run_alert_query(mock_error, mock_deliver_alert):
with app.app_context():
run_alert_query(db.session.query(Alert).filter_by(id=1).one())
alert1 = db.session.query(Alert).filter_by(id=1).one()
assert mock_deliver_alert.call_count == 0
assert len(alert1.logs) == 1
assert alert1.logs[0].alert_id == 1
assert alert1.logs[0].state == "pass"
run_alert_query(database.query(Alert).filter_by(id=2).one(), database)
alert2 = database.query(Alert).filter_by(id=2).one()
assert mock_deliver.call_count == 1
assert len(alert2.logs) == 1
assert alert2.logs[0].alert_id == 2
assert alert2.logs[0].state == "trigger"
run_alert_query(db.session.query(Alert).filter_by(id=2).one())
alert2 = db.session.query(Alert).filter_by(id=2).one()
assert mock_deliver_alert.call_count == 1
assert len(alert2.logs) == 1
assert alert2.logs[0].alert_id == 2
assert alert2.logs[0].state == "trigger"
run_alert_query(database.query(Alert).filter_by(id=3).one(), database)
alert3 = database.query(Alert).filter_by(id=3).one()
assert mock_deliver.call_count == 1
assert mock_error.call_count == 2
assert len(alert3.logs) == 1
assert alert3.logs[0].alert_id == 3
assert alert3.logs[0].state == "error"
run_alert_query(db.session.query(Alert).filter_by(id=3).one())
alert3 = db.session.query(Alert).filter_by(id=3).one()
assert mock_deliver_alert.call_count == 1
assert mock_error.call_count == 2
assert len(alert3.logs) == 1
assert alert3.logs[0].alert_id == 3
assert alert3.logs[0].state == "error"
run_alert_query(database.query(Alert).filter_by(id=4).one(), database)
assert mock_deliver.call_count == 1
assert mock_error.call_count == 3
run_alert_query(db.session.query(Alert).filter_by(id=4).one())
assert mock_deliver_alert.call_count == 1
assert mock_error.call_count == 3
run_alert_query(database.query(Alert).filter_by(id=5).one(), database)
assert mock_deliver.call_count == 1
assert mock_error.call_count == 4
run_alert_query(db.session.query(Alert).filter_by(id=5).one())
assert mock_deliver_alert.call_count == 1
assert mock_error.call_count == 4
@patch("superset.tasks.schedules.deliver_alert")
@patch("superset.tasks.schedules.run_alert_query")
def test_schedule_alert_query(mock_run_alert, mock_deliver_alert, setup_database):
database = setup_database
active_alert = database.query(Alert).filter_by(id=1).one()
inactive_alert = database.query(Alert).filter_by(id=3).one()
def test_schedule_alert_query(mock_run_alert, mock_deliver_alert):
with app.app_context():
active_alert = db.session.query(Alert).filter_by(id=1).one()
inactive_alert = db.session.query(Alert).filter_by(id=3).one()
# Test that inactive alerts are no processed
schedule_alert_query(report_type=ScheduleType.alert, schedule_id=inactive_alert.id)
assert mock_run_alert.call_count == 0
assert mock_deliver_alert.call_count == 0
# Test that inactive alerts are no processed
schedule_alert_query(
report_type=ScheduleType.alert, schedule_id=inactive_alert.id
)
assert mock_run_alert.call_count == 0
assert mock_deliver_alert.call_count == 0
# Test that active alerts with no recipients passed in are processed regularly
schedule_alert_query(report_type=ScheduleType.alert, schedule_id=active_alert.id)
assert mock_run_alert.call_count == 1
assert mock_deliver_alert.call_count == 0
# Test that active alerts with no recipients passed in are processed regularly
schedule_alert_query(
report_type=ScheduleType.alert, schedule_id=active_alert.id
)
assert mock_run_alert.call_count == 1
assert mock_deliver_alert.call_count == 0
# Test that active alerts sent as a test are delivered immediately
schedule_alert_query(
report_type=ScheduleType.alert,
schedule_id=active_alert.id,
recipients="testing@email.com",
is_test_alert=True,
)
assert mock_run_alert.call_count == 1
assert mock_deliver_alert.call_count == 1
# Test that active alerts sent as a test are delivered immediately
schedule_alert_query(
report_type=ScheduleType.alert,
schedule_id=active_alert.id,
recipients="testing@email.com",
is_test_alert=True,
)
assert mock_run_alert.call_count == 1
assert mock_deliver_alert.call_count == 1

View File

@ -25,7 +25,6 @@ import pandas as pd
from flask import Response
from flask_appbuilder.security.sqla import models as ab_models
from flask_testing import TestCase
from sqlalchemy.orm import Session
from tests.test_app import app
from superset.sql_parse import CtasMethod
@ -103,24 +102,25 @@ class SupersetTestCase(TestCase):
# create druid cluster and druid datasources
with app.app_context():
session = db.session
cluster = (
session.query(DruidCluster).filter_by(cluster_name="druid_test").first()
db.session.query(DruidCluster)
.filter_by(cluster_name="druid_test")
.first()
)
if not cluster:
cluster = DruidCluster(cluster_name="druid_test")
session.add(cluster)
session.commit()
db.session.add(cluster)
db.session.commit()
druid_datasource1 = DruidDatasource(
datasource_name="druid_ds_1", cluster=cluster
)
session.add(druid_datasource1)
db.session.add(druid_datasource1)
druid_datasource2 = DruidDatasource(
datasource_name="druid_ds_2", cluster=cluster
)
session.add(druid_datasource2)
session.commit()
db.session.add(druid_datasource2)
db.session.commit()
@staticmethod
def get_table_by_id(table_id: int) -> SqlaTable:
@ -134,25 +134,23 @@ class SupersetTestCase(TestCase):
except ImportError:
return False
def get_or_create(self, cls, criteria, session, **kwargs):
obj = session.query(cls).filter_by(**criteria).first()
def get_or_create(self, cls, criteria, **kwargs):
obj = db.session.query(cls).filter_by(**criteria).first()
if not obj:
obj = cls(**criteria)
obj.__dict__.update(**kwargs)
session.add(obj)
session.commit()
db.session.add(obj)
db.session.commit()
return obj
def login(self, username="admin", password="general"):
resp = self.get_resp("/login/", data=dict(username=username, password=password))
self.assertNotIn("User confirmation needed", resp)
def get_slice(
self, slice_name: str, session: Session, expunge_from_session: bool = True
) -> Slice:
slc = session.query(Slice).filter_by(slice_name=slice_name).one()
def get_slice(self, slice_name: str, expunge_from_session: bool = True) -> Slice:
slc = db.session.query(Slice).filter_by(slice_name=slice_name).one()
if expunge_from_session:
session.expunge_all()
db.session.expunge_all()
return slc
@staticmethod
@ -301,7 +299,6 @@ class SupersetTestCase(TestCase):
return self.get_or_create(
cls=models.Database,
criteria={"database_name": database_name},
session=db.session,
sqlalchemy_uri="sqlite:///:memory:",
id=db_id,
extra=extra,
@ -323,7 +320,6 @@ class SupersetTestCase(TestCase):
return self.get_or_create(
cls=models.Database,
criteria={"database_name": database_name},
session=db.session,
sqlalchemy_uri="presto://user@host:8080/hive",
id=db_id,
)

View File

@ -97,15 +97,13 @@ CTAS_SCHEMA_NAME = "sqllab_test_db"
class TestCelery(SupersetTestCase):
def get_query_by_name(self, sql):
session = db.session
query = session.query(Query).filter_by(sql=sql).first()
session.close()
query = db.session.query(Query).filter_by(sql=sql).first()
db.session.close()
return query
def get_query_by_id(self, id):
session = db.session
query = session.query(Query).filter_by(id=id).first()
session.close()
query = db.session.query(Query).filter_by(id=id).first()
db.session.close()
return query
@classmethod

View File

@ -57,9 +57,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
for owner in owners:
user = db.session.query(security_manager.user_model).get(owner)
obj_owners.append(user)
datasource = ConnectorRegistry.get_datasource(
datasource_type, datasource_id, db.session
)
datasource = ConnectorRegistry.get_datasource(datasource_type, datasource_id)
slice = Slice(
slice_name=slice_name,
datasource_id=datasource.id,

View File

@ -99,7 +99,7 @@ class TestCore(SupersetTestCase):
def test_slice_endpoint(self):
self.login(username="admin")
slc = self.get_slice("Girls", db.session)
slc = self.get_slice("Girls")
resp = self.get_resp("/superset/slice/{}/".format(slc.id))
assert "Time Column" in resp
assert "List Roles" in resp
@ -113,7 +113,7 @@ class TestCore(SupersetTestCase):
def test_viz_cache_key(self):
self.login(username="admin")
slc = self.get_slice("Girls", db.session)
slc = self.get_slice("Girls")
viz = slc.viz
qobj = viz.query_obj()
@ -229,7 +229,7 @@ class TestCore(SupersetTestCase):
def test_save_slice(self):
self.login(username="admin")
slice_name = f"Energy Sankey"
slice_id = self.get_slice(slice_name, db.session).id
slice_id = self.get_slice(slice_name).id
copy_name_prefix = "Test Sankey"
copy_name = f"{copy_name_prefix}[save]{random.random()}"
tbl_id = self.table_ids.get("energy_usage")
@ -295,7 +295,7 @@ class TestCore(SupersetTestCase):
def test_filter_endpoint(self):
self.login(username="admin")
slice_name = "Energy Sankey"
slice_id = self.get_slice(slice_name, db.session).id
slice_id = self.get_slice(slice_name).id
db.session.commit()
tbl_id = self.table_ids.get("energy_usage")
table = db.session.query(SqlaTable).filter(SqlaTable.id == tbl_id)
@ -315,9 +315,7 @@ class TestCore(SupersetTestCase):
def test_slice_data(self):
# slice data should have some required attributes
self.login(username="admin")
slc = self.get_slice(
slice_name="Girls", session=db.session, expunge_from_session=False
)
slc = self.get_slice(slice_name="Girls", expunge_from_session=False)
slc_data_attributes = slc.data.keys()
assert "changed_on" in slc_data_attributes
assert "modified" in slc_data_attributes
@ -368,9 +366,7 @@ class TestCore(SupersetTestCase):
self.assertEqual(data, [])
# make user owner of slice and verify that endpoint returns said slice
slc = self.get_slice(
slice_name=slice_name, session=db.session, expunge_from_session=False
)
slc = self.get_slice(slice_name=slice_name, expunge_from_session=False)
slc.owners = [user]
db.session.merge(slc)
db.session.commit()
@ -381,9 +377,7 @@ class TestCore(SupersetTestCase):
self.assertEqual(data[0]["title"], slice_name)
# remove ownership and ensure user no longer gets slice
slc = self.get_slice(
slice_name=slice_name, session=db.session, expunge_from_session=False
)
slc = self.get_slice(slice_name=slice_name, expunge_from_session=False)
slc.owners = []
db.session.merge(slc)
db.session.commit()
@ -561,7 +555,7 @@ class TestCore(SupersetTestCase):
db.session.commit()
def test_warm_up_cache(self):
slc = self.get_slice("Girls", db.session)
slc = self.get_slice("Girls")
data = self.get_json_resp("/superset/warm_up_cache?slice_id={}".format(slc.id))
self.assertEqual(
data, [{"slice_id": slc.id, "viz_error": None, "viz_status": "success"}]
@ -769,7 +763,7 @@ class TestCore(SupersetTestCase):
def test_user_profile(self, username="admin"):
self.login(username=username)
slc = self.get_slice("Girls", db.session)
slc = self.get_slice("Girls")
# Setting some faves
url = f"/superset/favstar/Slice/{slc.id}/select/"

View File

@ -178,12 +178,11 @@ class TestDatabaseApi(SupersetTestCase):
"""
Database API: Test get select star with datasource access
"""
session = db.session
table = SqlaTable(
schema="main", table_name="ab_permission", database=get_example_database()
)
session.add(table)
session.commit()
db.session.add(table)
db.session.commit()
tmp_table_perm = security_manager.find_permission_view_menu(
"datasource_access", table.get_perm()

View File

@ -156,7 +156,7 @@ class TestDatasetApi(SupersetTestCase):
"template_params": None,
}
for key, value in expected_result.items():
self.assertEqual(response["result"][key], expected_result[key])
self.assertEqual(response["result"][key], value)
self.assertEqual(len(response["result"]["columns"]), 8)
self.assertEqual(len(response["result"]["metrics"]), 2)
@ -717,10 +717,7 @@ class TestDatasetApi(SupersetTestCase):
)
cli_export = export_to_dict(
session=db.session,
recursive=True,
back_references=False,
include_defaults=False,
recursive=True, back_references=False, include_defaults=False,
)
cli_export_tables = cli_export["databases"][0]["tables"]
expected_response = []

View File

@ -47,14 +47,13 @@ class TestDictImportExport(SupersetTestCase):
def delete_imports(cls):
with app.app_context():
# Imported data clean up
session = db.session
for table in session.query(SqlaTable):
for table in db.session.query(SqlaTable):
if DBREF in table.params_dict:
session.delete(table)
for datasource in session.query(DruidDatasource):
db.session.delete(table)
for datasource in db.session.query(DruidDatasource):
if DBREF in datasource.params_dict:
session.delete(datasource)
session.commit()
db.session.delete(datasource)
db.session.commit()
@classmethod
def setUpClass(cls):
@ -90,9 +89,7 @@ class TestDictImportExport(SupersetTestCase):
def create_druid_datasource(self, name, id=0, cols_names=[], metric_names=[]):
cluster_name = "druid_test"
cluster = self.get_or_create(
DruidCluster, {"cluster_name": cluster_name}, db.session
)
cluster = self.get_or_create(DruidCluster, {"cluster_name": cluster_name})
name = "{0}{1}".format(NAME_PREFIX, name)
params = {DBREF: id, "database_name": cluster_name}
@ -159,7 +156,7 @@ class TestDictImportExport(SupersetTestCase):
def test_import_table_no_metadata(self):
table, dict_table = self.create_table("pure_table", id=ID_PREFIX + 1)
new_table = SqlaTable.import_from_dict(db.session, dict_table)
new_table = SqlaTable.import_from_dict(dict_table)
db.session.commit()
imported_id = new_table.id
imported = self.get_table_by_id(imported_id)
@ -173,7 +170,7 @@ class TestDictImportExport(SupersetTestCase):
cols_names=["col1"],
metric_names=["metric1"],
)
imported_table = SqlaTable.import_from_dict(db.session, dict_table)
imported_table = SqlaTable.import_from_dict(dict_table)
db.session.commit()
imported = self.get_table_by_id(imported_table.id)
self.assert_table_equals(table, imported)
@ -189,7 +186,7 @@ class TestDictImportExport(SupersetTestCase):
cols_names=["c1", "c2"],
metric_names=["m1", "m2"],
)
imported_table = SqlaTable.import_from_dict(db.session, dict_table)
imported_table = SqlaTable.import_from_dict(dict_table)
db.session.commit()
imported = self.get_table_by_id(imported_table.id)
self.assert_table_equals(table, imported)
@ -199,7 +196,7 @@ class TestDictImportExport(SupersetTestCase):
table, dict_table = self.create_table(
"table_override", id=ID_PREFIX + 3, cols_names=["col1"], metric_names=["m1"]
)
imported_table = SqlaTable.import_from_dict(db.session, dict_table)
imported_table = SqlaTable.import_from_dict(dict_table)
db.session.commit()
table_over, dict_table_over = self.create_table(
"table_override",
@ -207,7 +204,7 @@ class TestDictImportExport(SupersetTestCase):
cols_names=["new_col1", "col2", "col3"],
metric_names=["new_metric1"],
)
imported_over_table = SqlaTable.import_from_dict(db.session, dict_table_over)
imported_over_table = SqlaTable.import_from_dict(dict_table_over)
db.session.commit()
imported_over = self.get_table_by_id(imported_over_table.id)
@ -227,7 +224,7 @@ class TestDictImportExport(SupersetTestCase):
table, dict_table = self.create_table(
"table_override", id=ID_PREFIX + 3, cols_names=["col1"], metric_names=["m1"]
)
imported_table = SqlaTable.import_from_dict(db.session, dict_table)
imported_table = SqlaTable.import_from_dict(dict_table)
db.session.commit()
table_over, dict_table_over = self.create_table(
"table_override",
@ -236,7 +233,7 @@ class TestDictImportExport(SupersetTestCase):
metric_names=["new_metric1"],
)
imported_over_table = SqlaTable.import_from_dict(
session=db.session, dict_rep=dict_table_over, sync=["metrics", "columns"]
dict_rep=dict_table_over, sync=["metrics", "columns"]
)
db.session.commit()
@ -260,7 +257,7 @@ class TestDictImportExport(SupersetTestCase):
cols_names=["new_col1", "col2", "col3"],
metric_names=["new_metric1"],
)
imported_table = SqlaTable.import_from_dict(db.session, dict_table)
imported_table = SqlaTable.import_from_dict(dict_table)
db.session.commit()
copy_table, dict_copy_table = self.create_table(
"copy_cat",
@ -268,7 +265,7 @@ class TestDictImportExport(SupersetTestCase):
cols_names=["new_col1", "col2", "col3"],
metric_names=["new_metric1"],
)
imported_copy_table = SqlaTable.import_from_dict(db.session, dict_copy_table)
imported_copy_table = SqlaTable.import_from_dict(dict_copy_table)
db.session.commit()
self.assertEqual(imported_table.id, imported_copy_table.id)
self.assert_table_equals(copy_table, self.get_table_by_id(imported_table.id))
@ -281,10 +278,7 @@ class TestDictImportExport(SupersetTestCase):
self.delete_fake_db()
cli_export = export_to_dict(
session=db.session,
recursive=True,
back_references=False,
include_defaults=False,
recursive=True, back_references=False, include_defaults=False,
)
self.get_resp("/login/", data=dict(username="admin", password="general"))
resp = self.get_resp(
@ -303,7 +297,7 @@ class TestDictImportExport(SupersetTestCase):
datasource, dict_datasource = self.create_druid_datasource(
"pure_druid", id=ID_PREFIX + 1
)
imported_cluster = DruidDatasource.import_from_dict(db.session, dict_datasource)
imported_cluster = DruidDatasource.import_from_dict(dict_datasource)
db.session.commit()
imported = self.get_datasource(imported_cluster.id)
self.assert_datasource_equals(datasource, imported)
@ -315,7 +309,7 @@ class TestDictImportExport(SupersetTestCase):
cols_names=["col1"],
metric_names=["metric1"],
)
imported_cluster = DruidDatasource.import_from_dict(db.session, dict_datasource)
imported_cluster = DruidDatasource.import_from_dict(dict_datasource)
db.session.commit()
imported = self.get_datasource(imported_cluster.id)
self.assert_datasource_equals(datasource, imported)
@ -331,7 +325,7 @@ class TestDictImportExport(SupersetTestCase):
cols_names=["c1", "c2"],
metric_names=["m1", "m2"],
)
imported_cluster = DruidDatasource.import_from_dict(db.session, dict_datasource)
imported_cluster = DruidDatasource.import_from_dict(dict_datasource)
db.session.commit()
imported = self.get_datasource(imported_cluster.id)
self.assert_datasource_equals(datasource, imported)
@ -340,7 +334,7 @@ class TestDictImportExport(SupersetTestCase):
datasource, dict_datasource = self.create_druid_datasource(
"druid_override", id=ID_PREFIX + 3, cols_names=["col1"], metric_names=["m1"]
)
imported_cluster = DruidDatasource.import_from_dict(db.session, dict_datasource)
imported_cluster = DruidDatasource.import_from_dict(dict_datasource)
db.session.commit()
table_over, table_over_dict = self.create_druid_datasource(
"druid_override",
@ -348,9 +342,7 @@ class TestDictImportExport(SupersetTestCase):
cols_names=["new_col1", "col2", "col3"],
metric_names=["new_metric1"],
)
imported_over_cluster = DruidDatasource.import_from_dict(
db.session, table_over_dict
)
imported_over_cluster = DruidDatasource.import_from_dict(table_over_dict)
db.session.commit()
imported_over = self.get_datasource(imported_over_cluster.id)
self.assertEqual(imported_cluster.id, imported_over.id)
@ -366,7 +358,7 @@ class TestDictImportExport(SupersetTestCase):
datasource, dict_datasource = self.create_druid_datasource(
"druid_override", id=ID_PREFIX + 3, cols_names=["col1"], metric_names=["m1"]
)
imported_cluster = DruidDatasource.import_from_dict(db.session, dict_datasource)
imported_cluster = DruidDatasource.import_from_dict(dict_datasource)
db.session.commit()
table_over, table_over_dict = self.create_druid_datasource(
"druid_override",
@ -375,7 +367,7 @@ class TestDictImportExport(SupersetTestCase):
metric_names=["new_metric1"],
)
imported_over_cluster = DruidDatasource.import_from_dict(
session=db.session, dict_rep=table_over_dict, sync=["metrics", "columns"]
dict_rep=table_over_dict, sync=["metrics", "columns"]
) # syncing metrics and columns
db.session.commit()
imported_over = self.get_datasource(imported_over_cluster.id)
@ -395,9 +387,7 @@ class TestDictImportExport(SupersetTestCase):
cols_names=["new_col1", "col2", "col3"],
metric_names=["new_metric1"],
)
imported = DruidDatasource.import_from_dict(
session=db.session, dict_rep=dict_datasource
)
imported = DruidDatasource.import_from_dict(dict_rep=dict_datasource)
db.session.commit()
copy_datasource, dict_cp_datasource = self.create_druid_datasource(
"copy_cat",
@ -405,7 +395,7 @@ class TestDictImportExport(SupersetTestCase):
cols_names=["new_col1", "col2", "col3"],
metric_names=["new_metric1"],
)
imported_copy = DruidDatasource.import_from_dict(db.session, dict_cp_datasource)
imported_copy = DruidDatasource.import_from_dict(dict_cp_datasource)
db.session.commit()
self.assertEqual(imported.id, imported_copy.id)

View File

@ -212,9 +212,7 @@ class TestDruid(SupersetTestCase):
def test_druid_sync_from_config(self):
CLUSTER_NAME = "new_druid"
self.login()
cluster = self.get_or_create(
DruidCluster, {"cluster_name": CLUSTER_NAME}, db.session
)
cluster = self.get_or_create(DruidCluster, {"cluster_name": CLUSTER_NAME})
db.session.merge(cluster)
db.session.commit()
@ -302,15 +300,12 @@ class TestDruid(SupersetTestCase):
@unittest.skipUnless(app.config["DRUID_IS_ACTIVE"], "DRUID_IS_ACTIVE is false")
def test_filter_druid_datasource(self):
CLUSTER_NAME = "new_druid"
cluster = self.get_or_create(
DruidCluster, {"cluster_name": CLUSTER_NAME}, db.session
)
cluster = self.get_or_create(DruidCluster, {"cluster_name": CLUSTER_NAME})
db.session.merge(cluster)
gamma_ds = self.get_or_create(
DruidDatasource,
{"datasource_name": "datasource_for_gamma", "cluster": cluster},
db.session,
)
gamma_ds.cluster = cluster
db.session.merge(gamma_ds)
@ -318,7 +313,6 @@ class TestDruid(SupersetTestCase):
no_gamma_ds = self.get_or_create(
DruidDatasource,
{"datasource_name": "datasource_not_for_gamma", "cluster": cluster},
db.session,
)
no_gamma_ds.cluster = cluster
db.session.merge(no_gamma_ds)

View File

@ -46,20 +46,19 @@ class TestImportExport(SupersetTestCase):
def delete_imports(cls):
with app.app_context():
# Imported data clean up
session = db.session
for slc in session.query(Slice):
for slc in db.session.query(Slice):
if "remote_id" in slc.params_dict:
session.delete(slc)
for dash in session.query(Dashboard):
db.session.delete(slc)
for dash in db.session.query(Dashboard):
if "remote_id" in dash.params_dict:
session.delete(dash)
for table in session.query(SqlaTable):
db.session.delete(dash)
for table in db.session.query(SqlaTable):
if "remote_id" in table.params_dict:
session.delete(table)
for datasource in session.query(DruidDatasource):
db.session.delete(table)
for datasource in db.session.query(DruidDatasource):
if "remote_id" in datasource.params_dict:
session.delete(datasource)
session.commit()
db.session.delete(datasource)
db.session.commit()
@classmethod
def setUpClass(cls):
@ -126,9 +125,7 @@ class TestImportExport(SupersetTestCase):
def create_druid_datasource(self, name, id=0, cols_names=[], metric_names=[]):
cluster_name = "druid_test"
cluster = self.get_or_create(
DruidCluster, {"cluster_name": cluster_name}, db.session
)
cluster = self.get_or_create(DruidCluster, {"cluster_name": cluster_name})
params = {"remote_id": id, "database_name": cluster_name}
datasource = DruidDatasource(

View File

@ -83,7 +83,6 @@ class TestQueryContext(SupersetTestCase):
datasource = ConnectorRegistry.get_datasource(
datasource_type=payload["datasource"]["type"],
datasource_id=payload["datasource"]["id"],
session=db.session,
)
description_original = datasource.description
datasource.description = "temporary description"

View File

@ -69,9 +69,8 @@ class TestRolePermission(SupersetTestCase):
"""Testing export role permissions."""
def setUp(self):
session = db.session
security_manager.add_role(SCHEMA_ACCESS_ROLE)
session.commit()
db.session.commit()
ds = (
db.session.query(SqlaTable)
@ -82,7 +81,7 @@ class TestRolePermission(SupersetTestCase):
ds.schema_perm = ds.get_schema_perm()
ds_slices = (
session.query(Slice)
db.session.query(Slice)
.filter_by(datasource_type="table")
.filter_by(datasource_id=ds.id)
.all()
@ -92,12 +91,11 @@ class TestRolePermission(SupersetTestCase):
create_schema_perm("[examples].[temp_schema]")
gamma_user = security_manager.find_user(username="gamma")
gamma_user.roles.append(security_manager.find_role(SCHEMA_ACCESS_ROLE))
session.commit()
db.session.commit()
def tearDown(self):
session = db.session
ds = (
session.query(SqlaTable)
db.session.query(SqlaTable)
.filter_by(table_name="wb_health_population")
.first()
)
@ -105,7 +103,7 @@ class TestRolePermission(SupersetTestCase):
ds.schema = None
ds.schema_perm = None
ds_slices = (
session.query(Slice)
db.session.query(Slice)
.filter_by(datasource_type="table")
.filter_by(datasource_id=ds.id)
.all()
@ -114,21 +112,20 @@ class TestRolePermission(SupersetTestCase):
s.schema_perm = None
delete_schema_perm(schema_perm)
session.delete(security_manager.find_role(SCHEMA_ACCESS_ROLE))
session.commit()
db.session.delete(security_manager.find_role(SCHEMA_ACCESS_ROLE))
db.session.commit()
def test_set_perm_sqla_table(self):
session = db.session
table = SqlaTable(
schema="tmp_schema",
table_name="tmp_perm_table",
database=get_example_database(),
)
session.add(table)
session.commit()
db.session.add(table)
db.session.commit()
stored_table = (
session.query(SqlaTable).filter_by(table_name="tmp_perm_table").one()
db.session.query(SqlaTable).filter_by(table_name="tmp_perm_table").one()
)
self.assertEqual(
stored_table.perm, f"[examples].[tmp_perm_table](id:{stored_table.id})"
@ -147,9 +144,9 @@ class TestRolePermission(SupersetTestCase):
# table name change
stored_table.table_name = "tmp_perm_table_v2"
session.commit()
db.session.commit()
stored_table = (
session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one()
db.session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one()
)
self.assertEqual(
stored_table.perm, f"[examples].[tmp_perm_table_v2](id:{stored_table.id})"
@ -169,9 +166,9 @@ class TestRolePermission(SupersetTestCase):
# schema name change
stored_table.schema = "tmp_schema_v2"
session.commit()
db.session.commit()
stored_table = (
session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one()
db.session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one()
)
self.assertEqual(
stored_table.perm, f"[examples].[tmp_perm_table_v2](id:{stored_table.id})"
@ -191,13 +188,13 @@ class TestRolePermission(SupersetTestCase):
# database change
new_db = Database(sqlalchemy_uri="some_uri", database_name="tmp_db")
session.add(new_db)
db.session.add(new_db)
stored_table.database = (
session.query(Database).filter_by(database_name="tmp_db").one()
db.session.query(Database).filter_by(database_name="tmp_db").one()
)
session.commit()
db.session.commit()
stored_table = (
session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one()
db.session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one()
)
self.assertEqual(
stored_table.perm, f"[tmp_db].[tmp_perm_table_v2](id:{stored_table.id})"
@ -217,9 +214,9 @@ class TestRolePermission(SupersetTestCase):
# no schema
stored_table.schema = None
session.commit()
db.session.commit()
stored_table = (
session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one()
db.session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one()
)
self.assertEqual(
stored_table.perm, f"[tmp_db].[tmp_perm_table_v2](id:{stored_table.id})"
@ -231,26 +228,25 @@ class TestRolePermission(SupersetTestCase):
)
self.assertIsNone(stored_table.schema_perm)
session.delete(new_db)
session.delete(stored_table)
session.commit()
db.session.delete(new_db)
db.session.delete(stored_table)
db.session.commit()
def test_set_perm_druid_datasource(self):
session = db.session
druid_cluster = (
session.query(DruidCluster).filter_by(cluster_name="druid_test").one()
db.session.query(DruidCluster).filter_by(cluster_name="druid_test").one()
)
datasource = DruidDatasource(
datasource_name="tmp_datasource",
cluster=druid_cluster,
cluster_id=druid_cluster.id,
)
session.add(datasource)
session.commit()
db.session.add(datasource)
db.session.commit()
# store without a schema
stored_datasource = (
session.query(DruidDatasource)
db.session.query(DruidDatasource)
.filter_by(datasource_name="tmp_datasource")
.one()
)
@ -267,7 +263,7 @@ class TestRolePermission(SupersetTestCase):
# store with a schema
stored_datasource.datasource_name = "tmp_schema.tmp_datasource"
session.commit()
db.session.commit()
self.assertEqual(
stored_datasource.perm,
f"[druid_test].[tmp_schema.tmp_datasource](id:{stored_datasource.id})",
@ -284,16 +280,15 @@ class TestRolePermission(SupersetTestCase):
)
)
session.delete(stored_datasource)
session.commit()
db.session.delete(stored_datasource)
db.session.commit()
def test_set_perm_druid_cluster(self):
session = db.session
cluster = DruidCluster(cluster_name="tmp_druid_cluster")
session.add(cluster)
db.session.add(cluster)
stored_cluster = (
session.query(DruidCluster)
db.session.query(DruidCluster)
.filter_by(cluster_name="tmp_druid_cluster")
.one()
)
@ -307,7 +302,7 @@ class TestRolePermission(SupersetTestCase):
)
stored_cluster.cluster_name = "tmp_druid_cluster2"
session.commit()
db.session.commit()
self.assertEqual(
stored_cluster.perm, f"[tmp_druid_cluster2].(id:{stored_cluster.id})"
)
@ -317,18 +312,17 @@ class TestRolePermission(SupersetTestCase):
)
)
session.delete(stored_cluster)
session.commit()
db.session.delete(stored_cluster)
db.session.commit()
def test_set_perm_database(self):
session = db.session
database = Database(
database_name="tmp_database", sqlalchemy_uri="sqlite://test"
)
session.add(database)
db.session.add(database)
stored_db = (
session.query(Database).filter_by(database_name="tmp_database").one()
db.session.query(Database).filter_by(database_name="tmp_database").one()
)
self.assertEqual(stored_db.perm, f"[tmp_database].(id:{stored_db.id})")
self.assertIsNotNone(
@ -338,9 +332,9 @@ class TestRolePermission(SupersetTestCase):
)
stored_db.database_name = "tmp_database2"
session.commit()
db.session.commit()
stored_db = (
session.query(Database).filter_by(database_name="tmp_database2").one()
db.session.query(Database).filter_by(database_name="tmp_database2").one()
)
self.assertEqual(stored_db.perm, f"[tmp_database2].(id:{stored_db.id})")
self.assertIsNotNone(
@ -349,8 +343,8 @@ class TestRolePermission(SupersetTestCase):
)
)
session.delete(stored_db)
session.commit()
db.session.delete(stored_db)
db.session.commit()
def test_hybrid_perm_druid_cluster(self):
cluster = DruidCluster(cluster_name="tmp_druid_cluster3")
@ -400,14 +394,13 @@ class TestRolePermission(SupersetTestCase):
db.session.commit()
def test_set_perm_slice(self):
session = db.session
database = Database(
database_name="tmp_database", sqlalchemy_uri="sqlite://test"
)
table = SqlaTable(table_name="tmp_perm_table", database=database)
session.add(database)
session.add(table)
session.commit()
db.session.add(database)
db.session.add(table)
db.session.commit()
# no schema permission
slice = Slice(
@ -416,10 +409,10 @@ class TestRolePermission(SupersetTestCase):
datasource_name="tmp_perm_table",
slice_name="slice_name",
)
session.add(slice)
session.commit()
db.session.add(slice)
db.session.commit()
slice = session.query(Slice).filter_by(slice_name="slice_name").one()
slice = db.session.query(Slice).filter_by(slice_name="slice_name").one()
self.assertEqual(slice.perm, table.perm)
self.assertEqual(slice.perm, f"[tmp_database].[tmp_perm_table](id:{table.id})")
self.assertEqual(slice.schema_perm, table.schema_perm)
@ -427,7 +420,7 @@ class TestRolePermission(SupersetTestCase):
table.schema = "tmp_perm_schema"
table.table_name = "tmp_perm_table_v2"
session.commit()
db.session.commit()
# TODO(bogdan): modify slice permissions on the table update.
self.assertNotEquals(slice.perm, table.perm)
self.assertEqual(slice.perm, f"[tmp_database].[tmp_perm_table](id:{table.id})")
@ -440,7 +433,7 @@ class TestRolePermission(SupersetTestCase):
# updating slice refreshes the permissions
slice.slice_name = "slice_name_v2"
session.commit()
db.session.commit()
self.assertEqual(slice.perm, table.perm)
self.assertEqual(
slice.perm, f"[tmp_database].[tmp_perm_table_v2](id:{table.id})"
@ -448,11 +441,10 @@ class TestRolePermission(SupersetTestCase):
self.assertEqual(slice.schema_perm, table.schema_perm)
self.assertEqual(slice.schema_perm, "[tmp_database].[tmp_perm_schema]")
session.delete(slice)
session.delete(table)
session.delete(database)
session.commit()
db.session.delete(slice)
db.session.delete(table)
db.session.delete(database)
db.session.commit()
# TODO test slice permission
@ -532,11 +524,11 @@ class TestRolePermission(SupersetTestCase):
self.assertNotIn("Girl Name Cloud", data) # birth_names slice, no access
def test_sqllab_gamma_user_schema_access_to_sqllab(self):
session = db.session
example_db = session.query(Database).filter_by(database_name="examples").one()
example_db = (
db.session.query(Database).filter_by(database_name="examples").one()
)
example_db.expose_in_sqllab = True
session.commit()
db.session.commit()
arguments = {
"keys": ["none"],
@ -959,12 +951,10 @@ class TestRowLevelSecurity(SupersetTestCase):
rls_entry = None
def setUp(self):
session = db.session
# Create the RowLevelSecurityFilter
self.rls_entry = RowLevelSecurityFilter()
self.rls_entry.tables.extend(
session.query(SqlaTable)
db.session.query(SqlaTable)
.filter(SqlaTable.table_name.in_(["energy_usage", "unicode_test"]))
.all()
)
@ -974,13 +964,11 @@ class TestRowLevelSecurity(SupersetTestCase):
) # db.session.query(Role).filter_by(name="Gamma").first())
self.rls_entry.roles.append(security_manager.find_role("Alpha"))
db.session.add(self.rls_entry)
db.session.commit()
def tearDown(self):
session = db.session
session.delete(self.rls_entry)
session.commit()
db.session.delete(self.rls_entry)
db.session.commit()
# Do another test to make sure it doesn't alter another query
def test_rls_filter_alters_query(self):

View File

@ -55,7 +55,6 @@ class TestSqlLab(SupersetTestCase):
self.logout()
db.session.query(Query).delete()
db.session.commit()
db.session.close()
def test_sql_json(self):
self.login("admin")
@ -433,7 +432,6 @@ class TestSqlLab(SupersetTestCase):
Test query api with can_access_all_queries perm added to
gamma and make sure all queries show up.
"""
session = db.session
# Add all_query_access perm to Gamma user
all_queries_view = security_manager.find_permission_view_menu(
@ -443,7 +441,7 @@ class TestSqlLab(SupersetTestCase):
security_manager.add_permission_role(
security_manager.find_role("gamma_sqllab"), all_queries_view
)
session.commit()
db.session.commit()
# Test search_queries for Admin user
self.run_some_queries()
@ -460,7 +458,7 @@ class TestSqlLab(SupersetTestCase):
security_manager.find_role("gamma_sqllab"), all_queries_view
)
session.commit()
db.session.commit()
def test_query_admin_can_access_all_queries(self) -> None:
"""

View File

@ -194,7 +194,7 @@ class TestCacheWarmUp(SupersetTestCase):
db.session.commit()
def test_dashboard_tags(self):
tag1 = get_tag("tag1", db.session, TagTypes.custom)
tag1 = get_tag("tag1", TagTypes.custom)
# delete first to make test idempotent
self.reset_tag(tag1)
@ -204,7 +204,7 @@ class TestCacheWarmUp(SupersetTestCase):
self.assertEqual(result, expected)
# tag dashboard 'births' with `tag1`
tag1 = get_tag("tag1", db.session, TagTypes.custom)
tag1 = get_tag("tag1", TagTypes.custom)
dash = self.get_dash_by_slug("births")
tag1_urls = sorted([f"{URL_PREFIX}{slc.url}" for slc in dash.slices])
tagged_object = TaggedObject(
@ -216,7 +216,7 @@ class TestCacheWarmUp(SupersetTestCase):
self.assertEqual(sorted(strategy.get_urls()), tag1_urls)
strategy = DashboardTagsStrategy(["tag2"])
tag2 = get_tag("tag2", db.session, TagTypes.custom)
tag2 = get_tag("tag2", TagTypes.custom)
self.reset_tag(tag2)
result = sorted(strategy.get_urls())