mirror of https://github.com/apache/superset.git
chore: Cleanup database sessions (#10427)
Co-authored-by: John Bodley <john.bodley@airbnb.com>
This commit is contained in:
parent
7ff1757448
commit
7645fc85c3
|
@ -197,10 +197,9 @@ def set_database_uri(database_name: str, uri: str) -> None:
|
|||
)
|
||||
def refresh_druid(datasource: str, merge: bool) -> None:
|
||||
"""Refresh druid datasources"""
|
||||
session = db.session()
|
||||
from superset.connectors.druid.models import DruidCluster
|
||||
|
||||
for cluster in session.query(DruidCluster).all():
|
||||
for cluster in db.session.query(DruidCluster).all():
|
||||
try:
|
||||
cluster.refresh_datasources(datasource_name=datasource, merge_flag=merge)
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
|
@ -208,7 +207,7 @@ def refresh_druid(datasource: str, merge: bool) -> None:
|
|||
logger.exception(ex)
|
||||
cluster.metadata_last_refreshed = datetime.now()
|
||||
print("Refreshed metadata from cluster " "[" + cluster.cluster_name + "]")
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@superset.command()
|
||||
|
@ -250,7 +249,7 @@ def import_dashboards(path: str, recursive: bool, username: str) -> None:
|
|||
logger.info("Importing dashboard from file %s", file_)
|
||||
try:
|
||||
with file_.open() as data_stream:
|
||||
dashboard_import_export.import_dashboards(db.session, data_stream)
|
||||
dashboard_import_export.import_dashboards(data_stream)
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
logger.error("Error when importing dashboard from file %s", file_)
|
||||
logger.error(ex)
|
||||
|
@ -268,7 +267,7 @@ def export_dashboards(dashboard_file: str, print_stdout: bool) -> None:
|
|||
"""Export dashboards to JSON"""
|
||||
from superset.utils import dashboard_import_export
|
||||
|
||||
data = dashboard_import_export.export_dashboards(db.session)
|
||||
data = dashboard_import_export.export_dashboards()
|
||||
if print_stdout or not dashboard_file:
|
||||
print(data)
|
||||
if dashboard_file:
|
||||
|
@ -321,7 +320,7 @@ def import_datasources(path: str, sync: str, recursive: bool) -> None:
|
|||
try:
|
||||
with file_.open() as data_stream:
|
||||
dict_import_export.import_from_dict(
|
||||
db.session, yaml.safe_load(data_stream), sync=sync_array
|
||||
yaml.safe_load(data_stream), sync=sync_array
|
||||
)
|
||||
except Exception as ex: # pylint: disable=broad-except
|
||||
logger.error("Error when importing datasources from file %s", file_)
|
||||
|
@ -360,7 +359,6 @@ def export_datasources(
|
|||
from superset.utils import dict_import_export
|
||||
|
||||
data = dict_import_export.export_to_dict(
|
||||
session=db.session,
|
||||
recursive=True,
|
||||
back_references=back_references,
|
||||
include_defaults=include_defaults,
|
||||
|
|
|
@ -25,7 +25,7 @@ from superset.commands.exceptions import (
|
|||
)
|
||||
from superset.connectors.base.models import BaseDatasource
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.extensions import db, security_manager
|
||||
from superset.extensions import security_manager
|
||||
|
||||
|
||||
def populate_owners(user: User, owners_ids: Optional[List[int]] = None) -> List[User]:
|
||||
|
@ -50,8 +50,6 @@ def populate_owners(user: User, owners_ids: Optional[List[int]] = None) -> List[
|
|||
|
||||
def get_datasource_by_id(datasource_id: int, datasource_type: str) -> BaseDatasource:
|
||||
try:
|
||||
return ConnectorRegistry.get_datasource(
|
||||
datasource_type, datasource_id, db.session
|
||||
)
|
||||
return ConnectorRegistry.get_datasource(datasource_type, datasource_id)
|
||||
except (NoResultFound, KeyError):
|
||||
raise DatasourceNotFoundValidationError()
|
||||
|
|
|
@ -23,7 +23,7 @@ from typing import Any, ClassVar, Dict, List, Optional, Union
|
|||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from superset import app, cache, db, security_manager
|
||||
from superset import app, cache, security_manager
|
||||
from superset.common.query_object import QueryObject
|
||||
from superset.connectors.base.models import BaseDatasource
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
|
@ -64,7 +64,7 @@ class QueryContext:
|
|||
result_format: Optional[utils.ChartDataResultFormat] = None,
|
||||
) -> None:
|
||||
self.datasource = ConnectorRegistry.get_datasource(
|
||||
str(datasource["type"]), int(datasource["id"]), db.session
|
||||
str(datasource["type"]), int(datasource["id"])
|
||||
)
|
||||
self.queries = [QueryObject(**query_obj) for query_obj in queries]
|
||||
self.force = force
|
||||
|
|
|
@ -17,7 +17,9 @@
|
|||
from typing import Dict, List, Optional, Set, Type, TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session, subqueryload
|
||||
from sqlalchemy.orm import subqueryload
|
||||
|
||||
from superset.extensions import db
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# pylint: disable=unused-import
|
||||
|
@ -43,20 +45,20 @@ class ConnectorRegistry:
|
|||
|
||||
@classmethod
|
||||
def get_datasource(
|
||||
cls, datasource_type: str, datasource_id: int, session: Session
|
||||
cls, datasource_type: str, datasource_id: int
|
||||
) -> "BaseDatasource":
|
||||
return (
|
||||
session.query(cls.sources[datasource_type])
|
||||
db.session.query(cls.sources[datasource_type])
|
||||
.filter_by(id=datasource_id)
|
||||
.one()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_all_datasources(cls, session: Session) -> List["BaseDatasource"]:
|
||||
def get_all_datasources(cls) -> List["BaseDatasource"]:
|
||||
datasources: List["BaseDatasource"] = []
|
||||
for source_type in ConnectorRegistry.sources:
|
||||
source_class = ConnectorRegistry.sources[source_type]
|
||||
qry = session.query(source_class)
|
||||
qry = db.session.query(source_class)
|
||||
qry = source_class.default_query(qry)
|
||||
datasources.extend(qry.all())
|
||||
return datasources
|
||||
|
@ -64,7 +66,6 @@ class ConnectorRegistry:
|
|||
@classmethod
|
||||
def get_datasource_by_name( # pylint: disable=too-many-arguments
|
||||
cls,
|
||||
session: Session,
|
||||
datasource_type: str,
|
||||
datasource_name: str,
|
||||
schema: str,
|
||||
|
@ -72,21 +73,17 @@ class ConnectorRegistry:
|
|||
) -> Optional["BaseDatasource"]:
|
||||
datasource_class = ConnectorRegistry.sources[datasource_type]
|
||||
return datasource_class.get_datasource_by_name(
|
||||
session, datasource_name, schema, database_name
|
||||
datasource_name, schema, database_name
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def query_datasources_by_permissions( # pylint: disable=invalid-name
|
||||
cls,
|
||||
session: Session,
|
||||
database: "Database",
|
||||
permissions: Set[str],
|
||||
schema_perms: Set[str],
|
||||
cls, database: "Database", permissions: Set[str], schema_perms: Set[str],
|
||||
) -> List["BaseDatasource"]:
|
||||
# TODO(bogdan): add unit test
|
||||
datasource_class = ConnectorRegistry.sources[database.type]
|
||||
return (
|
||||
session.query(datasource_class)
|
||||
db.session.query(datasource_class)
|
||||
.filter_by(database_id=database.id)
|
||||
.filter(
|
||||
or_(
|
||||
|
@ -99,12 +96,12 @@ class ConnectorRegistry:
|
|||
|
||||
@classmethod
|
||||
def get_eager_datasource(
|
||||
cls, session: Session, datasource_type: str, datasource_id: int
|
||||
cls, datasource_type: str, datasource_id: int
|
||||
) -> "BaseDatasource":
|
||||
"""Returns datasource with columns and metrics."""
|
||||
datasource_class = ConnectorRegistry.sources[datasource_type]
|
||||
return (
|
||||
session.query(datasource_class)
|
||||
db.session.query(datasource_class)
|
||||
.options(
|
||||
subqueryload(datasource_class.columns),
|
||||
subqueryload(datasource_class.metrics),
|
||||
|
@ -115,13 +112,9 @@ class ConnectorRegistry:
|
|||
|
||||
@classmethod
|
||||
def query_datasources_by_name(
|
||||
cls,
|
||||
session: Session,
|
||||
database: "Database",
|
||||
datasource_name: str,
|
||||
schema: Optional[str] = None,
|
||||
cls, database: "Database", datasource_name: str, schema: Optional[str] = None,
|
||||
) -> List["BaseDatasource"]:
|
||||
datasource_class = ConnectorRegistry.sources[database.type]
|
||||
return datasource_class.query_datasources_by_name(
|
||||
session, database, datasource_name, schema=schema
|
||||
database, datasource_name, schema=schema
|
||||
)
|
||||
|
|
|
@ -45,7 +45,7 @@ from sqlalchemy import (
|
|||
UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.ext.hybrid import hybrid_property
|
||||
from sqlalchemy.orm import backref, relationship, Session
|
||||
from sqlalchemy.orm import backref, relationship
|
||||
from sqlalchemy.sql import expression
|
||||
from sqlalchemy_utils import EncryptedType
|
||||
|
||||
|
@ -223,9 +223,8 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
|
|||
Fetches metadata for the specified datasources and
|
||||
merges to the Superset database
|
||||
"""
|
||||
session = db.session
|
||||
ds_list = (
|
||||
session.query(DruidDatasource)
|
||||
db.session.query(DruidDatasource)
|
||||
.filter(DruidDatasource.cluster_id == self.id)
|
||||
.filter(DruidDatasource.datasource_name.in_(datasource_names))
|
||||
)
|
||||
|
@ -234,8 +233,8 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
|
|||
datasource = ds_map.get(ds_name, None)
|
||||
if not datasource:
|
||||
datasource = DruidDatasource(datasource_name=ds_name)
|
||||
with session.no_autoflush:
|
||||
session.add(datasource)
|
||||
with db.session.no_autoflush:
|
||||
db.session.add(datasource)
|
||||
flasher(_("Adding new datasource [{}]").format(ds_name), "success")
|
||||
ds_map[ds_name] = datasource
|
||||
elif refresh_all:
|
||||
|
@ -245,7 +244,7 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
|
|||
continue
|
||||
datasource.cluster = self
|
||||
datasource.merge_flag = merge_flag
|
||||
session.flush()
|
||||
db.session.flush()
|
||||
|
||||
# Prepare multithreaded executation
|
||||
pool = ThreadPool()
|
||||
|
@ -259,7 +258,7 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
|
|||
cols = metadata[i]
|
||||
if cols:
|
||||
col_objs_list = (
|
||||
session.query(DruidColumn)
|
||||
db.session.query(DruidColumn)
|
||||
.filter(DruidColumn.datasource_id == datasource.id)
|
||||
.filter(DruidColumn.column_name.in_(cols.keys()))
|
||||
)
|
||||
|
@ -272,15 +271,15 @@ class DruidCluster(Model, AuditMixinNullable, ImportMixin):
|
|||
col_obj = DruidColumn(
|
||||
datasource_id=datasource.id, column_name=col
|
||||
)
|
||||
with session.no_autoflush:
|
||||
session.add(col_obj)
|
||||
with db.session.no_autoflush:
|
||||
db.session.add(col_obj)
|
||||
col_obj.type = cols[col]["type"]
|
||||
col_obj.datasource = datasource
|
||||
if col_obj.type == "STRING":
|
||||
col_obj.groupby = True
|
||||
col_obj.filterable = True
|
||||
datasource.refresh_metrics()
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
|
||||
@hybrid_property
|
||||
def perm(self) -> str:
|
||||
|
@ -390,7 +389,7 @@ class DruidColumn(Model, BaseColumn):
|
|||
.first()
|
||||
)
|
||||
|
||||
return import_datasource.import_simple_obj(db.session, i_column, lookup_obj)
|
||||
return import_datasource.import_simple_obj(i_column, lookup_obj)
|
||||
|
||||
|
||||
class DruidMetric(Model, BaseMetric):
|
||||
|
@ -459,7 +458,7 @@ class DruidMetric(Model, BaseMetric):
|
|||
.first()
|
||||
)
|
||||
|
||||
return import_datasource.import_simple_obj(db.session, i_metric, lookup_obj)
|
||||
return import_datasource.import_simple_obj(i_metric, lookup_obj)
|
||||
|
||||
|
||||
druiddatasource_user = Table(
|
||||
|
@ -635,7 +634,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
return db.session.query(DruidCluster).filter_by(id=d.cluster_id).first()
|
||||
|
||||
return import_datasource.import_datasource(
|
||||
db.session, i_datasource, lookup_cluster, lookup_datasource, import_time
|
||||
i_datasource, lookup_cluster, lookup_datasource, import_time
|
||||
)
|
||||
|
||||
def latest_metadata(self) -> Optional[Dict[str, Any]]:
|
||||
|
@ -705,9 +704,10 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
refresh: bool = True,
|
||||
) -> None:
|
||||
"""Merges the ds config from druid_config into one stored in the db."""
|
||||
session = db.session
|
||||
datasource = (
|
||||
session.query(cls).filter_by(datasource_name=druid_config["name"]).first()
|
||||
db.session.query(cls)
|
||||
.filter_by(datasource_name=druid_config["name"])
|
||||
.first()
|
||||
)
|
||||
# Create a new datasource.
|
||||
if not datasource:
|
||||
|
@ -718,13 +718,13 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
changed_by_fk=user.id,
|
||||
created_by_fk=user.id,
|
||||
)
|
||||
session.add(datasource)
|
||||
db.session.add(datasource)
|
||||
elif not refresh:
|
||||
return
|
||||
|
||||
dimensions = druid_config["dimensions"]
|
||||
col_objs = (
|
||||
session.query(DruidColumn)
|
||||
db.session.query(DruidColumn)
|
||||
.filter(DruidColumn.datasource_id == datasource.id)
|
||||
.filter(DruidColumn.column_name.in_(dimensions))
|
||||
)
|
||||
|
@ -741,10 +741,10 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
type="STRING",
|
||||
datasource=datasource,
|
||||
)
|
||||
session.add(col_obj)
|
||||
db.session.add(col_obj)
|
||||
# Import Druid metrics
|
||||
metric_objs = (
|
||||
session.query(DruidMetric)
|
||||
db.session.query(DruidMetric)
|
||||
.filter(DruidMetric.datasource_id == datasource.id)
|
||||
.filter(
|
||||
DruidMetric.metric_name.in_(
|
||||
|
@ -777,8 +777,8 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
% druid_config["name"]
|
||||
),
|
||||
)
|
||||
session.add(metric_obj)
|
||||
session.commit()
|
||||
db.session.add(metric_obj)
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def time_offset(granularity: Granularity) -> int:
|
||||
|
@ -788,10 +788,10 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
|
||||
@classmethod
|
||||
def get_datasource_by_name(
|
||||
cls, session: Session, datasource_name: str, schema: str, database_name: str
|
||||
cls, datasource_name: str, schema: str, database_name: str
|
||||
) -> Optional["DruidDatasource"]:
|
||||
query = (
|
||||
session.query(cls)
|
||||
db.session.query(cls)
|
||||
.join(DruidCluster)
|
||||
.filter(cls.datasource_name == datasource_name)
|
||||
.filter(DruidCluster.cluster_name == database_name)
|
||||
|
@ -1724,11 +1724,7 @@ class DruidDatasource(Model, BaseDatasource):
|
|||
|
||||
@classmethod
|
||||
def query_datasources_by_name(
|
||||
cls,
|
||||
session: Session,
|
||||
database: Database,
|
||||
datasource_name: str,
|
||||
schema: Optional[str] = None,
|
||||
cls, database: Database, datasource_name: str, schema: Optional[str] = None,
|
||||
) -> List["DruidDatasource"]:
|
||||
return []
|
||||
|
||||
|
|
|
@ -365,11 +365,10 @@ class Druid(BaseSupersetView):
|
|||
self, refresh_all: bool = True
|
||||
) -> FlaskResponse:
|
||||
"""endpoint that refreshes druid datasources metadata"""
|
||||
session = db.session()
|
||||
DruidCluster = ConnectorRegistry.sources[ # pylint: disable=invalid-name
|
||||
"druid"
|
||||
].cluster_class
|
||||
for cluster in session.query(DruidCluster).all():
|
||||
for cluster in db.session.query(DruidCluster).all():
|
||||
cluster_name = cluster.cluster_name
|
||||
valid_cluster = True
|
||||
try:
|
||||
|
@ -391,7 +390,7 @@ class Druid(BaseSupersetView):
|
|||
),
|
||||
"info",
|
||||
)
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
return redirect("/druiddatasourcemodelview/list/")
|
||||
|
||||
@has_access
|
||||
|
|
|
@ -41,7 +41,7 @@ from sqlalchemy import (
|
|||
Text,
|
||||
)
|
||||
from sqlalchemy.exc import CompileError
|
||||
from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty, Session
|
||||
from sqlalchemy.orm import backref, Query, relationship, RelationshipProperty
|
||||
from sqlalchemy.orm.exc import NoResultFound
|
||||
from sqlalchemy.schema import UniqueConstraint
|
||||
from sqlalchemy.sql import column, ColumnElement, literal_column, table, text
|
||||
|
@ -255,7 +255,7 @@ class TableColumn(Model, BaseColumn):
|
|||
.first()
|
||||
)
|
||||
|
||||
return import_datasource.import_simple_obj(db.session, i_column, lookup_obj)
|
||||
return import_datasource.import_simple_obj(i_column, lookup_obj)
|
||||
|
||||
def dttm_sql_literal(
|
||||
self,
|
||||
|
@ -375,7 +375,7 @@ class SqlMetric(Model, BaseMetric):
|
|||
.first()
|
||||
)
|
||||
|
||||
return import_datasource.import_simple_obj(db.session, i_metric, lookup_obj)
|
||||
return import_datasource.import_simple_obj(i_metric, lookup_obj)
|
||||
|
||||
|
||||
sqlatable_user = Table(
|
||||
|
@ -503,15 +503,11 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
|
|||
|
||||
@classmethod
|
||||
def get_datasource_by_name(
|
||||
cls,
|
||||
session: Session,
|
||||
datasource_name: str,
|
||||
schema: Optional[str],
|
||||
database_name: str,
|
||||
cls, datasource_name: str, schema: Optional[str], database_name: str,
|
||||
) -> Optional["SqlaTable"]:
|
||||
schema = schema or None
|
||||
query = (
|
||||
session.query(cls)
|
||||
db.session.query(cls)
|
||||
.join(Database)
|
||||
.filter(cls.table_name == datasource_name)
|
||||
.filter(Database.database_name == database_name)
|
||||
|
@ -1296,24 +1292,15 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
|
|||
)
|
||||
|
||||
return import_datasource.import_datasource(
|
||||
db.session,
|
||||
i_datasource,
|
||||
lookup_database,
|
||||
lookup_sqlatable,
|
||||
import_time,
|
||||
database_id,
|
||||
i_datasource, lookup_database, lookup_sqlatable, import_time, database_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def query_datasources_by_name(
|
||||
cls,
|
||||
session: Session,
|
||||
database: Database,
|
||||
datasource_name: str,
|
||||
schema: Optional[str] = None,
|
||||
cls, database: Database, datasource_name: str, schema: Optional[str] = None,
|
||||
) -> List["SqlaTable"]:
|
||||
query = (
|
||||
session.query(cls)
|
||||
db.session.query(cls)
|
||||
.filter_by(database_id=database.id)
|
||||
.filter_by(table_name=datasource_name)
|
||||
)
|
||||
|
|
|
@ -99,9 +99,7 @@ class DashboardDAO(BaseDAO):
|
|||
except KeyError:
|
||||
pass
|
||||
|
||||
session = db.session()
|
||||
current_slices = session.query(Slice).filter(Slice.id.in_(slice_ids)).all()
|
||||
|
||||
current_slices = db.session.query(Slice).filter(Slice.id.in_(slice_ids)).all()
|
||||
dashboard.slices = current_slices
|
||||
|
||||
# update slice names. this assumes user has permissions to update the slice
|
||||
|
@ -111,8 +109,8 @@ class DashboardDAO(BaseDAO):
|
|||
new_name = slice_id_to_name[slc.id]
|
||||
if slc.slice_name != new_name:
|
||||
slc.slice_name = new_name
|
||||
session.merge(slc)
|
||||
session.flush()
|
||||
db.session.merge(slc)
|
||||
db.session.flush()
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ from sqlalchemy import (
|
|||
UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.engine.base import Connection
|
||||
from sqlalchemy.orm import relationship, sessionmaker, subqueryload
|
||||
from sqlalchemy.orm import relationship, subqueryload
|
||||
from sqlalchemy.orm.mapper import Mapper
|
||||
|
||||
from superset import app, ConnectorRegistry, db, is_feature_enabled, security_manager
|
||||
|
@ -62,18 +62,17 @@ config = app.config
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def copy_dashboard(mapper: Mapper, connection: Connection, target: "Dashboard") -> None:
|
||||
# pylint: disable=unused-argument
|
||||
def copy_dashboard( # pylint: disable=unused-argument
|
||||
mapper: Mapper, connection: Connection, target: "Dashboard"
|
||||
) -> None:
|
||||
dashboard_id = config["DASHBOARD_TEMPLATE_ID"]
|
||||
if dashboard_id is None:
|
||||
return
|
||||
|
||||
session_class = sessionmaker(autoflush=False)
|
||||
session = session_class(bind=connection)
|
||||
new_user = session.query(User).filter_by(id=target.id).first()
|
||||
new_user = db.session.query(User).filter_by(id=target.id).first()
|
||||
|
||||
# copy template dashboard to user
|
||||
template = session.query(Dashboard).filter_by(id=int(dashboard_id)).first()
|
||||
template = db.session.query(Dashboard).filter_by(id=int(dashboard_id)).first()
|
||||
dashboard = Dashboard(
|
||||
dashboard_title=template.dashboard_title,
|
||||
position_json=template.position_json,
|
||||
|
@ -83,15 +82,15 @@ def copy_dashboard(mapper: Mapper, connection: Connection, target: "Dashboard")
|
|||
slices=template.slices,
|
||||
owners=[new_user],
|
||||
)
|
||||
session.add(dashboard)
|
||||
session.commit()
|
||||
db.session.add(dashboard)
|
||||
db.session.commit()
|
||||
|
||||
# set dashboard as the welcome dashboard
|
||||
extra_attributes = UserAttribute(
|
||||
user_id=target.id, welcome_dashboard_id=dashboard.id
|
||||
)
|
||||
session.add(extra_attributes)
|
||||
session.commit()
|
||||
db.session.add(extra_attributes)
|
||||
db.session.commit()
|
||||
|
||||
|
||||
sqla.event.listen(User, "after_insert", copy_dashboard)
|
||||
|
@ -307,7 +306,6 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
|
|||
logger.info(
|
||||
"Started import of the dashboard: %s", dashboard_to_import.to_json()
|
||||
)
|
||||
session = db.session
|
||||
logger.info("Dashboard has %d slices", len(dashboard_to_import.slices))
|
||||
# copy slices object as Slice.import_slice will mutate the slice
|
||||
# and will remove the existing dashboard - slice association
|
||||
|
@ -324,7 +322,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
|
|||
i_params_dict = dashboard_to_import.params_dict
|
||||
remote_id_slice_map = {
|
||||
slc.params_dict["remote_id"]: slc
|
||||
for slc in session.query(Slice).all()
|
||||
for slc in db.session.query(Slice).all()
|
||||
if "remote_id" in slc.params_dict
|
||||
}
|
||||
for slc in slices:
|
||||
|
@ -375,7 +373,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
|
|||
|
||||
# override the dashboard
|
||||
existing_dashboard = None
|
||||
for dash in session.query(Dashboard).all():
|
||||
for dash in db.session.query(Dashboard).all():
|
||||
if (
|
||||
"remote_id" in dash.params_dict
|
||||
and dash.params_dict["remote_id"] == dashboard_to_import.id
|
||||
|
@ -402,7 +400,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
|
|||
)
|
||||
|
||||
new_slices = (
|
||||
session.query(Slice)
|
||||
db.session.query(Slice)
|
||||
.filter(Slice.id.in_(old_to_new_slc_id_dict.values()))
|
||||
.all()
|
||||
)
|
||||
|
@ -410,12 +408,12 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
|
|||
if existing_dashboard:
|
||||
existing_dashboard.override(dashboard_to_import)
|
||||
existing_dashboard.slices = new_slices
|
||||
session.flush()
|
||||
db.session.flush()
|
||||
return existing_dashboard.id
|
||||
|
||||
dashboard_to_import.slices = new_slices
|
||||
session.add(dashboard_to_import)
|
||||
session.flush()
|
||||
db.session.add(dashboard_to_import)
|
||||
db.session.flush()
|
||||
return dashboard_to_import.id # type: ignore
|
||||
|
||||
@classmethod
|
||||
|
@ -457,7 +455,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes
|
|||
eager_datasources = []
|
||||
for datasource_id, datasource_type in datasource_ids:
|
||||
eager_datasource = ConnectorRegistry.get_eager_datasource(
|
||||
db.session, datasource_type, datasource_id
|
||||
datasource_type, datasource_id
|
||||
)
|
||||
copied_datasource = eager_datasource.copy()
|
||||
copied_datasource.alter_params(
|
||||
|
|
|
@ -34,9 +34,9 @@ from flask_appbuilder.models.mixins import AuditMixin
|
|||
from flask_appbuilder.security.sqla.models import User
|
||||
from sqlalchemy import and_, or_, UniqueConstraint
|
||||
from sqlalchemy.ext.declarative import declared_attr
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm.exc import MultipleResultsFound
|
||||
|
||||
from superset.extensions import db
|
||||
from superset.utils.core import QueryStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -127,7 +127,6 @@ class ImportMixin:
|
|||
@classmethod
|
||||
def import_from_dict( # pylint: disable=too-many-arguments,too-many-branches,too-many-locals
|
||||
cls,
|
||||
session: Session,
|
||||
dict_rep: Dict[Any, Any],
|
||||
parent: Optional[Any] = None,
|
||||
recursive: bool = True,
|
||||
|
@ -178,7 +177,7 @@ class ImportMixin:
|
|||
|
||||
# Check if object already exists in DB, break if more than one is found
|
||||
try:
|
||||
obj_query = session.query(cls).filter(and_(*filters))
|
||||
obj_query = db.session.query(cls).filter(and_(*filters))
|
||||
obj = obj_query.one_or_none()
|
||||
except MultipleResultsFound as ex:
|
||||
logger.error(
|
||||
|
@ -196,7 +195,7 @@ class ImportMixin:
|
|||
logger.info("Importing new %s %s", obj.__tablename__, str(obj))
|
||||
if cls.export_parent and parent:
|
||||
setattr(obj, cls.export_parent, parent)
|
||||
session.add(obj)
|
||||
db.session.add(obj)
|
||||
else:
|
||||
is_new_obj = False
|
||||
logger.info("Updating %s %s", obj.__tablename__, str(obj))
|
||||
|
@ -214,7 +213,7 @@ class ImportMixin:
|
|||
for c_obj in new_children.get(child, []):
|
||||
added.append(
|
||||
child_class.import_from_dict(
|
||||
session=session, dict_rep=c_obj, parent=obj, sync=sync
|
||||
dict_rep=c_obj, parent=obj, sync=sync
|
||||
)
|
||||
)
|
||||
# If children should get synced, delete the ones that did not
|
||||
|
@ -228,11 +227,11 @@ class ImportMixin:
|
|||
for k in back_refs.keys()
|
||||
]
|
||||
to_delete = set(
|
||||
session.query(child_class).filter(and_(*delete_filters))
|
||||
db.session.query(child_class).filter(and_(*delete_filters))
|
||||
).difference(set(added))
|
||||
for o in to_delete:
|
||||
logger.info("Deleting %s %s", child, str(obj))
|
||||
session.delete(o)
|
||||
db.session.delete(o)
|
||||
|
||||
return obj
|
||||
|
||||
|
|
|
@ -300,7 +300,6 @@ class Slice(
|
|||
:returns: The resulting id for the imported slice
|
||||
:rtype: int
|
||||
"""
|
||||
session = db.session
|
||||
make_transient(slc_to_import)
|
||||
slc_to_import.dashboards = []
|
||||
slc_to_import.alter_params(remote_id=slc_to_import.id, import_time=import_time)
|
||||
|
@ -309,7 +308,6 @@ class Slice(
|
|||
slc_to_import.reset_ownership()
|
||||
params = slc_to_import.params_dict
|
||||
datasource = ConnectorRegistry.get_datasource_by_name(
|
||||
session,
|
||||
slc_to_import.datasource_type,
|
||||
params["datasource_name"],
|
||||
params["schema"],
|
||||
|
@ -318,11 +316,11 @@ class Slice(
|
|||
slc_to_import.datasource_id = datasource.id # type: ignore
|
||||
if slc_to_override:
|
||||
slc_to_override.override(slc_to_import)
|
||||
session.flush()
|
||||
db.session.flush()
|
||||
return slc_to_override.id
|
||||
session.add(slc_to_import)
|
||||
db.session.add(slc_to_import)
|
||||
logger.info("Final slice: %s", str(slc_to_import.to_json()))
|
||||
session.flush()
|
||||
db.session.flush()
|
||||
return slc_to_import.id
|
||||
|
||||
@property
|
||||
|
|
|
@ -22,10 +22,11 @@ from typing import List, Optional, TYPE_CHECKING, Union
|
|||
from flask_appbuilder import Model
|
||||
from sqlalchemy import Column, Enum, ForeignKey, Integer, String
|
||||
from sqlalchemy.engine.base import Connection
|
||||
from sqlalchemy.orm import relationship, Session, sessionmaker
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.orm.exc import NoResultFound
|
||||
from sqlalchemy.orm.mapper import Mapper
|
||||
|
||||
from superset.extensions import db
|
||||
from superset.models.helpers import AuditMixinNullable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -34,8 +35,6 @@ if TYPE_CHECKING:
|
|||
from superset.models.slice import Slice # pylint: disable=unused-import
|
||||
from superset.models.sql_lab import Query # pylint: disable=unused-import
|
||||
|
||||
Session = sessionmaker(autoflush=False)
|
||||
|
||||
|
||||
class TagTypes(enum.Enum):
|
||||
|
||||
|
@ -88,13 +87,13 @@ class TaggedObject(Model, AuditMixinNullable):
|
|||
tag = relationship("Tag", backref="objects")
|
||||
|
||||
|
||||
def get_tag(name: str, session: Session, type_: TagTypes) -> Tag:
|
||||
def get_tag(name: str, type_: TagTypes) -> Tag:
|
||||
try:
|
||||
tag = session.query(Tag).filter_by(name=name, type=type_).one()
|
||||
tag = db.session.query(Tag).filter_by(name=name, type=type_).one()
|
||||
except NoResultFound:
|
||||
tag = Tag(name=name, type=type_)
|
||||
session.add(tag)
|
||||
session.commit()
|
||||
db.session.add(tag)
|
||||
db.session.commit()
|
||||
|
||||
return tag
|
||||
|
||||
|
@ -122,52 +121,43 @@ class ObjectUpdater:
|
|||
raise NotImplementedError("Subclass should implement `get_owners_ids`")
|
||||
|
||||
@classmethod
|
||||
def _add_owners(
|
||||
cls, session: Session, target: Union["Dashboard", "FavStar", "Slice"]
|
||||
) -> None:
|
||||
def _add_owners(cls, target: Union["Dashboard", "FavStar", "Slice"]) -> None:
|
||||
for owner_id in cls.get_owners_ids(target):
|
||||
name = "owner:{0}".format(owner_id)
|
||||
tag = get_tag(name, session, TagTypes.owner)
|
||||
tag = get_tag(name, TagTypes.owner)
|
||||
tagged_object = TaggedObject(
|
||||
tag_id=tag.id, object_id=target.id, object_type=cls.object_type
|
||||
)
|
||||
session.add(tagged_object)
|
||||
db.session.add(tagged_object)
|
||||
|
||||
@classmethod
|
||||
def after_insert(
|
||||
def after_insert( # pylint: disable=unused-argument
|
||||
cls,
|
||||
mapper: Mapper,
|
||||
connection: Connection,
|
||||
target: Union["Dashboard", "FavStar", "Slice"],
|
||||
) -> None:
|
||||
# pylint: disable=unused-argument
|
||||
session = Session(bind=connection)
|
||||
|
||||
# add `owner:` tags
|
||||
cls._add_owners(session, target)
|
||||
cls._add_owners(target)
|
||||
|
||||
# add `type:` tags
|
||||
tag = get_tag("type:{0}".format(cls.object_type), session, TagTypes.type)
|
||||
tag = get_tag("type:{0}".format(cls.object_type), TagTypes.type)
|
||||
tagged_object = TaggedObject(
|
||||
tag_id=tag.id, object_id=target.id, object_type=cls.object_type
|
||||
)
|
||||
session.add(tagged_object)
|
||||
|
||||
session.commit()
|
||||
db.session.add(tagged_object)
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def after_update(
|
||||
def after_update( # pylint: disable=unused-argument
|
||||
cls,
|
||||
mapper: Mapper,
|
||||
connection: Connection,
|
||||
target: Union["Dashboard", "FavStar", "Slice"],
|
||||
) -> None:
|
||||
# pylint: disable=unused-argument
|
||||
session = Session(bind=connection)
|
||||
|
||||
# delete current `owner:` tags
|
||||
query = (
|
||||
session.query(TaggedObject.id)
|
||||
db.session.query(TaggedObject.id)
|
||||
.join(Tag)
|
||||
.filter(
|
||||
TaggedObject.object_type == cls.object_type,
|
||||
|
@ -176,32 +166,28 @@ class ObjectUpdater:
|
|||
)
|
||||
)
|
||||
ids = [row[0] for row in query]
|
||||
session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
|
||||
db.session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
# add `owner:` tags
|
||||
cls._add_owners(session, target)
|
||||
|
||||
session.commit()
|
||||
cls._add_owners(target)
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def after_delete(
|
||||
def after_delete( # pylint: disable=unused-argument
|
||||
cls,
|
||||
mapper: Mapper,
|
||||
connection: Connection,
|
||||
target: Union["Dashboard", "FavStar", "Slice"],
|
||||
) -> None:
|
||||
# pylint: disable=unused-argument
|
||||
session = Session(bind=connection)
|
||||
|
||||
# delete row from `tagged_objects`
|
||||
session.query(TaggedObject).filter(
|
||||
db.session.query(TaggedObject).filter(
|
||||
TaggedObject.object_type == cls.object_type,
|
||||
TaggedObject.object_id == target.id,
|
||||
).delete()
|
||||
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
|
||||
|
||||
class ChartUpdater(ObjectUpdater):
|
||||
|
@ -233,31 +219,26 @@ class QueryUpdater(ObjectUpdater):
|
|||
|
||||
class FavStarUpdater:
|
||||
@classmethod
|
||||
def after_insert(
|
||||
def after_insert( # pylint: disable=unused-argument
|
||||
cls, mapper: Mapper, connection: Connection, target: "FavStar"
|
||||
) -> None:
|
||||
# pylint: disable=unused-argument
|
||||
session = Session(bind=connection)
|
||||
name = "favorited_by:{0}".format(target.user_id)
|
||||
tag = get_tag(name, session, TagTypes.favorited_by)
|
||||
tag = get_tag(name, TagTypes.favorited_by)
|
||||
tagged_object = TaggedObject(
|
||||
tag_id=tag.id,
|
||||
object_id=target.obj_id,
|
||||
object_type=get_object_type(target.class_name),
|
||||
)
|
||||
session.add(tagged_object)
|
||||
|
||||
session.commit()
|
||||
db.session.add(tagged_object)
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def after_delete(
|
||||
cls, mapper: Mapper, connection: Connection, target: "FavStar"
|
||||
def after_delete( # pylint: disable=unused-argument
|
||||
cls, mapper: Mapper, connection: Connection, target: "FavStar",
|
||||
) -> None:
|
||||
# pylint: disable=unused-argument
|
||||
session = Session(bind=connection)
|
||||
name = "favorited_by:{0}".format(target.user_id)
|
||||
query = (
|
||||
session.query(TaggedObject.id)
|
||||
db.session.query(TaggedObject.id)
|
||||
.join(Tag)
|
||||
.filter(
|
||||
TaggedObject.object_id == target.obj_id,
|
||||
|
@ -266,8 +247,8 @@ class FavStarUpdater:
|
|||
)
|
||||
)
|
||||
ids = [row[0] for row in query]
|
||||
session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
|
||||
db.session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
|
|
|
@ -507,7 +507,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
user_perms = self.user_view_menu_names("datasource_access")
|
||||
schema_perms = self.user_view_menu_names("schema_access")
|
||||
user_datasources = ConnectorRegistry.query_datasources_by_permissions(
|
||||
self.get_session, database, user_perms, schema_perms
|
||||
database, user_perms, schema_perms
|
||||
)
|
||||
if schema:
|
||||
names = {d.table_name for d in user_datasources if d.schema == schema}
|
||||
|
@ -568,7 +568,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
self.add_permission_view_menu(view_menu, perm)
|
||||
|
||||
logger.info("Creating missing datasource permissions.")
|
||||
datasources = ConnectorRegistry.get_all_datasources(self.get_session)
|
||||
datasources = ConnectorRegistry.get_all_datasources()
|
||||
for datasource in datasources:
|
||||
merge_pv("datasource_access", datasource.get_perm())
|
||||
merge_pv("schema_access", datasource.get_schema_perm())
|
||||
|
@ -901,7 +901,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
|
||||
if not (schema_perm and self.can_access("schema_access", schema_perm)):
|
||||
datasources = SqlaTable.query_datasources_by_name(
|
||||
self.get_session, database, table_.table, schema=table_.schema
|
||||
database, table_.table, schema=table_.schema
|
||||
)
|
||||
|
||||
# Access to any datasource is suffice.
|
||||
|
|
|
@ -132,8 +132,7 @@ def session_scope(nullpool: bool) -> Iterator[Session]:
|
|||
)
|
||||
if nullpool:
|
||||
engine = sqlalchemy.create_engine(database_uri, poolclass=NullPool)
|
||||
session_class = sessionmaker()
|
||||
session_class.configure(bind=engine)
|
||||
session_class = sessionmaker(bind=engine)
|
||||
session = session_class()
|
||||
else:
|
||||
session = db.session()
|
||||
|
|
|
@ -134,8 +134,7 @@ class DummyStrategy(Strategy):
|
|||
name = "dummy"
|
||||
|
||||
def get_urls(self) -> List[str]:
|
||||
session = db.create_scoped_session()
|
||||
charts = session.query(Slice).all()
|
||||
charts = db.session.query(Slice).all()
|
||||
|
||||
return [get_url(chart) for chart in charts]
|
||||
|
||||
|
@ -167,10 +166,9 @@ class TopNDashboardsStrategy(Strategy):
|
|||
|
||||
def get_urls(self) -> List[str]:
|
||||
urls = []
|
||||
session = db.create_scoped_session()
|
||||
|
||||
records = (
|
||||
session.query(Log.dashboard_id, func.count(Log.dashboard_id))
|
||||
db.session.query(Log.dashboard_id, func.count(Log.dashboard_id))
|
||||
.filter(and_(Log.dashboard_id.isnot(None), Log.dttm >= self.since))
|
||||
.group_by(Log.dashboard_id)
|
||||
.order_by(func.count(Log.dashboard_id).desc())
|
||||
|
@ -178,7 +176,9 @@ class TopNDashboardsStrategy(Strategy):
|
|||
.all()
|
||||
)
|
||||
dash_ids = [record.dashboard_id for record in records]
|
||||
dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
|
||||
dashboards = (
|
||||
db.session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
|
||||
)
|
||||
for dashboard in dashboards:
|
||||
for chart in dashboard.slices:
|
||||
form_data_with_filters = get_form_data(chart.id, dashboard)
|
||||
|
@ -211,14 +211,13 @@ class DashboardTagsStrategy(Strategy):
|
|||
|
||||
def get_urls(self) -> List[str]:
|
||||
urls = []
|
||||
session = db.create_scoped_session()
|
||||
|
||||
tags = session.query(Tag).filter(Tag.name.in_(self.tags)).all()
|
||||
tags = db.session.query(Tag).filter(Tag.name.in_(self.tags)).all()
|
||||
tag_ids = [tag.id for tag in tags]
|
||||
|
||||
# add dashboards that are tagged
|
||||
tagged_objects = (
|
||||
session.query(TaggedObject)
|
||||
db.session.query(TaggedObject)
|
||||
.filter(
|
||||
and_(
|
||||
TaggedObject.object_type == "dashboard",
|
||||
|
@ -228,14 +227,16 @@ class DashboardTagsStrategy(Strategy):
|
|||
.all()
|
||||
)
|
||||
dash_ids = [tagged_object.object_id for tagged_object in tagged_objects]
|
||||
tagged_dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids))
|
||||
tagged_dashboards = db.session.query(Dashboard).filter(
|
||||
Dashboard.id.in_(dash_ids)
|
||||
)
|
||||
for dashboard in tagged_dashboards:
|
||||
for chart in dashboard.slices:
|
||||
urls.append(get_url(chart))
|
||||
|
||||
# add charts that are tagged
|
||||
tagged_objects = (
|
||||
session.query(TaggedObject)
|
||||
db.session.query(TaggedObject)
|
||||
.filter(
|
||||
and_(
|
||||
TaggedObject.object_type == "chart",
|
||||
|
@ -245,7 +246,7 @@ class DashboardTagsStrategy(Strategy):
|
|||
.all()
|
||||
)
|
||||
chart_ids = [tagged_object.object_id for tagged_object in tagged_objects]
|
||||
tagged_charts = session.query(Slice).filter(Slice.id.in_(chart_ids))
|
||||
tagged_charts = db.session.query(Slice).filter(Slice.id.in_(chart_ids))
|
||||
for chart in tagged_charts:
|
||||
urls.append(get_url(chart))
|
||||
|
||||
|
|
|
@ -48,7 +48,6 @@ from flask_login import login_user
|
|||
from retry.api import retry_call
|
||||
from selenium.common.exceptions import WebDriverException
|
||||
from selenium.webdriver import chrome, firefox
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.http import parse_cookie
|
||||
|
||||
from superset import app, db, security_manager, thumbnail_cache
|
||||
|
@ -543,8 +542,7 @@ def schedule_alert_query( # pylint: disable=unused-argument
|
|||
is_test_alert: Optional[bool] = False,
|
||||
) -> None:
|
||||
model_cls = get_scheduler_model(report_type)
|
||||
dbsession = db.create_scoped_session()
|
||||
schedule = dbsession.query(model_cls).get(schedule_id)
|
||||
schedule = db.session.query(model_cls).get(schedule_id)
|
||||
|
||||
# The user may have disabled the schedule. If so, ignore this
|
||||
if not schedule or not schedule.active:
|
||||
|
@ -556,7 +554,7 @@ def schedule_alert_query( # pylint: disable=unused-argument
|
|||
deliver_alert(schedule, recipients)
|
||||
return
|
||||
|
||||
if run_alert_query(schedule, dbsession):
|
||||
if run_alert_query(schedule):
|
||||
# deliver_dashboard OR deliver_slice
|
||||
return
|
||||
else:
|
||||
|
@ -614,7 +612,7 @@ def deliver_alert(alert: Alert, recipients: Optional[str] = None) -> None:
|
|||
_deliver_email(recipients, deliver_as_group, subject, body, data, images)
|
||||
|
||||
|
||||
def run_alert_query(alert: Alert, dbsession: Session) -> Optional[bool]:
|
||||
def run_alert_query(alert: Alert) -> Optional[bool]:
|
||||
"""
|
||||
Execute alert.sql and return value if any rows are returned
|
||||
"""
|
||||
|
@ -666,7 +664,7 @@ def run_alert_query(alert: Alert, dbsession: Session) -> Optional[bool]:
|
|||
state=state,
|
||||
)
|
||||
)
|
||||
dbsession.commit()
|
||||
db.session.commit()
|
||||
|
||||
return None
|
||||
|
||||
|
@ -706,8 +704,7 @@ def schedule_window(
|
|||
if not model_cls:
|
||||
return None
|
||||
|
||||
dbsession = db.create_scoped_session()
|
||||
schedules = dbsession.query(model_cls).filter(model_cls.active.is_(True))
|
||||
schedules = db.session.query(model_cls).filter(model_cls.active.is_(True))
|
||||
|
||||
for schedule in schedules:
|
||||
logging.info("Processing schedule %s", schedule)
|
||||
|
|
|
@ -22,10 +22,10 @@ from io import BytesIO
|
|||
from typing import Any, Dict, Optional
|
||||
|
||||
from flask_babel import lazy_gettext as _
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
|
||||
from superset.exceptions import DashboardImportException
|
||||
from superset.extensions import db
|
||||
from superset.models.dashboard import Dashboard
|
||||
from superset.models.slice import Slice
|
||||
|
||||
|
@ -71,7 +71,6 @@ def decode_dashboards( # pylint: disable=too-many-return-statements
|
|||
|
||||
|
||||
def import_dashboards(
|
||||
session: Session,
|
||||
data_stream: BytesIO,
|
||||
database_id: Optional[int] = None,
|
||||
import_time: Optional[int] = None,
|
||||
|
@ -84,16 +83,16 @@ def import_dashboards(
|
|||
raise DashboardImportException(_("No data in file"))
|
||||
for table in data["datasources"]:
|
||||
type(table).import_obj(table, database_id, import_time=import_time)
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
for dashboard in data["dashboards"]:
|
||||
Dashboard.import_obj(dashboard, import_time=import_time)
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def export_dashboards(session: Session) -> str:
|
||||
def export_dashboards() -> str:
|
||||
"""Returns all dashboards metadata as a json dump"""
|
||||
logger.info("Starting export")
|
||||
dashboards = session.query(Dashboard)
|
||||
dashboards = db.session.query(Dashboard)
|
||||
dashboard_ids = []
|
||||
for dashboard in dashboards:
|
||||
dashboard_ids.append(dashboard.id)
|
||||
|
|
|
@ -17,9 +17,8 @@
|
|||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from superset.connectors.druid.models import DruidCluster
|
||||
from superset.extensions import db
|
||||
from superset.models.core import Database
|
||||
|
||||
DATABASES_KEY = "databases"
|
||||
|
@ -44,11 +43,11 @@ def export_schema_to_dict(back_references: bool) -> Dict[str, Any]:
|
|||
|
||||
|
||||
def export_to_dict(
|
||||
session: Session, recursive: bool, back_references: bool, include_defaults: bool
|
||||
recursive: bool, back_references: bool, include_defaults: bool
|
||||
) -> Dict[str, Any]:
|
||||
"""Exports databases and druid clusters to a dictionary"""
|
||||
logger.info("Starting export")
|
||||
dbs = session.query(Database)
|
||||
dbs = db.session.query(Database)
|
||||
databases = [
|
||||
database.export_to_dict(
|
||||
recursive=recursive,
|
||||
|
@ -58,7 +57,7 @@ def export_to_dict(
|
|||
for database in dbs
|
||||
]
|
||||
logger.info("Exported %d %s", len(databases), DATABASES_KEY)
|
||||
cls = session.query(DruidCluster)
|
||||
cls = db.session.query(DruidCluster)
|
||||
clusters = [
|
||||
cluster.export_to_dict(
|
||||
recursive=recursive,
|
||||
|
@ -76,22 +75,20 @@ def export_to_dict(
|
|||
return data
|
||||
|
||||
|
||||
def import_from_dict(
|
||||
session: Session, data: Dict[str, Any], sync: Optional[List[str]] = None
|
||||
) -> None:
|
||||
def import_from_dict(data: Dict[str, Any], sync: Optional[List[str]] = None) -> None:
|
||||
"""Imports databases and druid clusters from dictionary"""
|
||||
if not sync:
|
||||
sync = []
|
||||
if isinstance(data, dict):
|
||||
logger.info("Importing %d %s", len(data.get(DATABASES_KEY, [])), DATABASES_KEY)
|
||||
for database in data.get(DATABASES_KEY, []):
|
||||
Database.import_from_dict(session, database, sync=sync)
|
||||
Database.import_from_dict(database, sync=sync)
|
||||
|
||||
logger.info(
|
||||
"Importing %d %s", len(data.get(DRUID_CLUSTERS_KEY, [])), DRUID_CLUSTERS_KEY
|
||||
)
|
||||
for datasource in data.get(DRUID_CLUSTERS_KEY, []):
|
||||
DruidCluster.import_from_dict(session, datasource, sync=sync)
|
||||
session.commit()
|
||||
DruidCluster.import_from_dict(datasource, sync=sync)
|
||||
db.session.commit()
|
||||
else:
|
||||
logger.info("Supplied object is not a dictionary.")
|
||||
|
|
|
@ -18,14 +18,14 @@ import logging
|
|||
from typing import Callable, Optional
|
||||
|
||||
from flask_appbuilder import Model
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm.session import make_transient
|
||||
|
||||
from superset.extensions import db
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def import_datasource( # pylint: disable=too-many-arguments
|
||||
session: Session,
|
||||
i_datasource: Model,
|
||||
lookup_database: Callable[[Model], Model],
|
||||
lookup_datasource: Callable[[Model], Model],
|
||||
|
@ -52,11 +52,11 @@ def import_datasource( # pylint: disable=too-many-arguments
|
|||
|
||||
if datasource:
|
||||
datasource.override(i_datasource)
|
||||
session.flush()
|
||||
db.session.flush()
|
||||
else:
|
||||
datasource = i_datasource.copy()
|
||||
session.add(datasource)
|
||||
session.flush()
|
||||
db.session.add(datasource)
|
||||
db.session.flush()
|
||||
|
||||
for metric in i_datasource.metrics:
|
||||
new_m = metric.copy()
|
||||
|
@ -81,13 +81,11 @@ def import_datasource( # pylint: disable=too-many-arguments
|
|||
imported_c = i_datasource.column_class.import_obj(new_c)
|
||||
if imported_c.column_name not in [c.column_name for c in datasource.columns]:
|
||||
datasource.columns.append(imported_c)
|
||||
session.flush()
|
||||
db.session.flush()
|
||||
return datasource.id
|
||||
|
||||
|
||||
def import_simple_obj(
|
||||
session: Session, i_obj: Model, lookup_obj: Callable[[Model], Model]
|
||||
) -> Model:
|
||||
def import_simple_obj(i_obj: Model, lookup_obj: Callable[[Model], Model]) -> Model:
|
||||
make_transient(i_obj)
|
||||
i_obj.id = None
|
||||
i_obj.table = None
|
||||
|
@ -97,9 +95,9 @@ def import_simple_obj(
|
|||
i_obj.table = None
|
||||
if existing_column:
|
||||
existing_column.override(i_obj)
|
||||
session.flush()
|
||||
db.session.flush()
|
||||
return existing_column
|
||||
|
||||
session.add(i_obj)
|
||||
session.flush()
|
||||
db.session.add(i_obj)
|
||||
db.session.flush()
|
||||
return i_obj
|
||||
|
|
|
@ -487,8 +487,7 @@ def check_ownership(obj: Any, raise_if_false: bool = True) -> bool:
|
|||
roles = [r.name for r in get_user_roles()]
|
||||
if "Admin" in roles:
|
||||
return True
|
||||
scoped_session = db.create_scoped_session()
|
||||
orig_obj = scoped_session.query(obj.__class__).filter_by(id=obj.id).first()
|
||||
orig_obj = db.session.query(obj.__class__).filter_by(id=obj.id).first()
|
||||
|
||||
# Making a list of owners that works across ORM models
|
||||
owners: List[User] = []
|
||||
|
|
|
@ -20,7 +20,7 @@ from flask_appbuilder import expose, has_access
|
|||
from flask_appbuilder.models.sqla.interface import SQLAInterface
|
||||
from flask_babel import lazy_gettext as _
|
||||
|
||||
from superset import app, db
|
||||
from superset import app
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.constants import RouteMethod
|
||||
from superset.models.slice import Slice
|
||||
|
@ -56,7 +56,7 @@ class SliceModelView(
|
|||
def add(self) -> FlaskResponse:
|
||||
datasources = [
|
||||
{"value": str(d.id) + "__" + d.type, "label": repr(d)}
|
||||
for d in ConnectorRegistry.get_all_datasources(db.session)
|
||||
for d in ConnectorRegistry.get_all_datasources()
|
||||
]
|
||||
return self.render_template(
|
||||
"superset/add_slice.html",
|
||||
|
|
|
@ -40,7 +40,6 @@ from sqlalchemy.exc import (
|
|||
OperationalError,
|
||||
SQLAlchemyError,
|
||||
)
|
||||
from sqlalchemy.orm.session import Session
|
||||
from werkzeug.urls import Href
|
||||
|
||||
import superset.models.core as models
|
||||
|
@ -164,7 +163,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
sorted(
|
||||
[
|
||||
datasource.short_data
|
||||
for datasource in ConnectorRegistry.get_all_datasources(db.session)
|
||||
for datasource in ConnectorRegistry.get_all_datasources()
|
||||
if datasource.short_data.get("name")
|
||||
],
|
||||
key=lambda datasource: datasource["name"],
|
||||
|
@ -203,7 +202,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
)
|
||||
db_ds_names.add(fullname)
|
||||
|
||||
existing_datasources = ConnectorRegistry.get_all_datasources(db.session)
|
||||
existing_datasources = ConnectorRegistry.get_all_datasources()
|
||||
datasources = [d for d in existing_datasources if d.full_name in db_ds_names]
|
||||
role = security_manager.find_role(role_name)
|
||||
# remove all permissions
|
||||
|
@ -270,15 +269,15 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
@has_access
|
||||
@expose("/approve")
|
||||
def approve(self) -> FlaskResponse: # pylint: disable=too-many-locals,no-self-use
|
||||
def clean_fulfilled_requests(session: Session) -> None:
|
||||
for dar in session.query(DAR).all():
|
||||
def clean_fulfilled_requests() -> None:
|
||||
for dar in db.session.query(DAR).all():
|
||||
datasource = ConnectorRegistry.get_datasource(
|
||||
dar.datasource_type, dar.datasource_id, session
|
||||
dar.datasource_type, dar.datasource_id
|
||||
)
|
||||
if not datasource or security_manager.can_access_datasource(datasource):
|
||||
# datasource does not exist anymore
|
||||
session.delete(dar)
|
||||
session.commit()
|
||||
db.session.delete(dar)
|
||||
db.session.commit()
|
||||
|
||||
datasource_type = request.args["datasource_type"]
|
||||
datasource_id = request.args["datasource_id"]
|
||||
|
@ -286,10 +285,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
role_to_grant = request.args.get("role_to_grant")
|
||||
role_to_extend = request.args.get("role_to_extend")
|
||||
|
||||
session = db.session
|
||||
datasource = ConnectorRegistry.get_datasource(
|
||||
datasource_type, datasource_id, session
|
||||
)
|
||||
datasource = ConnectorRegistry.get_datasource(datasource_type, datasource_id)
|
||||
|
||||
if not datasource:
|
||||
flash(DATASOURCE_MISSING_ERR, "alert")
|
||||
|
@ -301,7 +297,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
return json_error_response(USER_MISSING_ERR)
|
||||
|
||||
requests = (
|
||||
session.query(DAR)
|
||||
db.session.query(DAR)
|
||||
.filter(
|
||||
DAR.datasource_id == datasource_id,
|
||||
DAR.datasource_type == datasource_type,
|
||||
|
@ -361,13 +357,13 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
app.config,
|
||||
)
|
||||
flash(msg, "info")
|
||||
clean_fulfilled_requests(session)
|
||||
clean_fulfilled_requests()
|
||||
else:
|
||||
flash(__("You have no permission to approve this request"), "danger")
|
||||
return redirect("/accessrequestsmodelview/list/")
|
||||
for request_ in requests:
|
||||
session.delete(request_)
|
||||
session.commit()
|
||||
db.session.delete(request_)
|
||||
db.session.commit()
|
||||
return redirect("/accessrequestsmodelview/list/")
|
||||
|
||||
@has_access
|
||||
|
@ -548,7 +544,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
database_id = request.form.get("db_id")
|
||||
try:
|
||||
dashboard_import_export.import_dashboards(
|
||||
db.session, import_file.stream, database_id
|
||||
import_file.stream, database_id
|
||||
)
|
||||
success = True
|
||||
except DatabaseNotFound as ex:
|
||||
|
@ -630,7 +626,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
return redirect(error_redirect)
|
||||
|
||||
datasource = ConnectorRegistry.get_datasource(
|
||||
cast(str, datasource_type), datasource_id, db.session
|
||||
cast(str, datasource_type), datasource_id
|
||||
)
|
||||
if not datasource:
|
||||
flash(DATASOURCE_MISSING_ERR, "danger")
|
||||
|
@ -749,9 +745,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
:raises SupersetSecurityException: If the user cannot access the resource
|
||||
"""
|
||||
# TODO: Cache endpoint by user, datasource and column
|
||||
datasource = ConnectorRegistry.get_datasource(
|
||||
datasource_type, datasource_id, db.session
|
||||
)
|
||||
datasource = ConnectorRegistry.get_datasource(datasource_type, datasource_id)
|
||||
if not datasource:
|
||||
return json_error_response(DATASOURCE_MISSING_ERR)
|
||||
|
||||
|
@ -1015,10 +1009,9 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
self, dashboard_id: int
|
||||
) -> FlaskResponse:
|
||||
"""Copy dashboard"""
|
||||
session = db.session()
|
||||
data = json.loads(request.form["data"])
|
||||
dash = models.Dashboard()
|
||||
original_dash = session.query(Dashboard).get(dashboard_id)
|
||||
original_dash = db.session.query(Dashboard).get(dashboard_id)
|
||||
|
||||
dash.owners = [g.user] if g.user else []
|
||||
dash.dashboard_title = data["dashboard_title"]
|
||||
|
@ -1029,8 +1022,8 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
for slc in original_dash.slices:
|
||||
new_slice = slc.clone()
|
||||
new_slice.owners = [g.user] if g.user else []
|
||||
session.add(new_slice)
|
||||
session.flush()
|
||||
db.session.add(new_slice)
|
||||
db.session.flush()
|
||||
new_slice.dashboards.append(dash)
|
||||
old_to_new_slice_ids[slc.id] = new_slice.id
|
||||
|
||||
|
@ -1046,10 +1039,9 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
dash.params = original_dash.params
|
||||
|
||||
DashboardDAO.set_dash_metadata(dash, data, old_to_new_slice_ids)
|
||||
session.add(dash)
|
||||
session.commit()
|
||||
db.session.add(dash)
|
||||
db.session.commit()
|
||||
dash_json = json.dumps(dash.data)
|
||||
session.close()
|
||||
return json_success(dash_json)
|
||||
|
||||
@api
|
||||
|
@ -1059,14 +1051,12 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
self, dashboard_id: int
|
||||
) -> FlaskResponse:
|
||||
"""Save a dashboard's metadata"""
|
||||
session = db.session()
|
||||
dash = session.query(Dashboard).get(dashboard_id)
|
||||
dash = db.session.query(Dashboard).get(dashboard_id)
|
||||
check_ownership(dash, raise_if_false=True)
|
||||
data = json.loads(request.form["data"])
|
||||
DashboardDAO.set_dash_metadata(dash, data)
|
||||
session.merge(dash)
|
||||
session.commit()
|
||||
session.close()
|
||||
db.session.merge(dash)
|
||||
db.session.commit()
|
||||
return json_success(json.dumps({"status": "SUCCESS"}))
|
||||
|
||||
@api
|
||||
|
@ -1077,14 +1067,12 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
) -> FlaskResponse:
|
||||
"""Add and save slices to a dashboard"""
|
||||
data = json.loads(request.form["data"])
|
||||
session = db.session()
|
||||
dash = session.query(Dashboard).get(dashboard_id)
|
||||
dash = db.session.query(Dashboard).get(dashboard_id)
|
||||
check_ownership(dash, raise_if_false=True)
|
||||
new_slices = session.query(Slice).filter(Slice.id.in_(data["slice_ids"]))
|
||||
new_slices = db.session.query(Slice).filter(Slice.id.in_(data["slice_ids"]))
|
||||
dash.slices += new_slices
|
||||
session.merge(dash)
|
||||
session.commit()
|
||||
session.close()
|
||||
db.session.merge(dash)
|
||||
db.session.commit()
|
||||
return "SLICES ADDED"
|
||||
|
||||
@api
|
||||
|
@ -1431,7 +1419,6 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
|
||||
Note for slices a force refresh occurs.
|
||||
"""
|
||||
session = db.session()
|
||||
slice_id = request.args.get("slice_id")
|
||||
dashboard_id = request.args.get("dashboard_id")
|
||||
table_name = request.args.get("table_name")
|
||||
|
@ -1446,14 +1433,14 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
status=400,
|
||||
)
|
||||
if slice_id:
|
||||
slices = session.query(Slice).filter_by(id=slice_id).all()
|
||||
slices = db.session.query(Slice).filter_by(id=slice_id).all()
|
||||
if not slices:
|
||||
return json_error_response(
|
||||
__("Chart %(id)s not found", id=slice_id), status=404
|
||||
)
|
||||
elif table_name and db_name:
|
||||
table = (
|
||||
session.query(SqlaTable)
|
||||
db.session.query(SqlaTable)
|
||||
.join(models.Database)
|
||||
.filter(
|
||||
models.Database.database_name == db_name
|
||||
|
@ -1470,7 +1457,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
status=404,
|
||||
)
|
||||
slices = (
|
||||
session.query(Slice)
|
||||
db.session.query(Slice)
|
||||
.filter_by(datasource_id=table.id, datasource_type=table.type)
|
||||
.all()
|
||||
)
|
||||
|
@ -1513,17 +1500,16 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
self, class_name: str, obj_id: int, action: str
|
||||
) -> FlaskResponse:
|
||||
"""Toggle favorite stars on Slices and Dashboard"""
|
||||
session = db.session()
|
||||
FavStar = models.FavStar
|
||||
count = 0
|
||||
favs = (
|
||||
session.query(FavStar)
|
||||
db.session.query(FavStar)
|
||||
.filter_by(class_name=class_name, obj_id=obj_id, user_id=g.user.get_id())
|
||||
.all()
|
||||
)
|
||||
if action == "select":
|
||||
if not favs:
|
||||
session.add(
|
||||
db.session.add(
|
||||
FavStar(
|
||||
class_name=class_name,
|
||||
obj_id=obj_id,
|
||||
|
@ -1534,10 +1520,10 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
count = 1
|
||||
elif action == "unselect":
|
||||
for fav in favs:
|
||||
session.delete(fav)
|
||||
db.session.delete(fav)
|
||||
else:
|
||||
count = len(favs)
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
return json_success(json.dumps({"count": count}))
|
||||
|
||||
@api
|
||||
|
@ -1550,12 +1536,13 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
logger.warning(
|
||||
"This API endpoint is deprecated and will be removed in version 1.0.0"
|
||||
)
|
||||
session = db.session()
|
||||
Role = ab_models.Role
|
||||
dash = (
|
||||
session.query(Dashboard).filter(Dashboard.id == dashboard_id).one_or_none()
|
||||
db.session.query(Dashboard)
|
||||
.filter(Dashboard.id == dashboard_id)
|
||||
.one_or_none()
|
||||
)
|
||||
admin_role = session.query(Role).filter(Role.name == "Admin").one_or_none()
|
||||
admin_role = db.session.query(Role).filter(Role.name == "Admin").one_or_none()
|
||||
|
||||
if request.method == "GET":
|
||||
if dash:
|
||||
|
@ -1574,7 +1561,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
)
|
||||
|
||||
dash.published = str(request.form["published"]).lower() == "true"
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
return json_success(json.dumps({"published": dash.published}))
|
||||
|
||||
@has_access
|
||||
|
@ -1583,8 +1570,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
self, dashboard_id_or_slug: str
|
||||
) -> FlaskResponse:
|
||||
"""Server side rendering for a dashboard"""
|
||||
session = db.session()
|
||||
qry = session.query(Dashboard)
|
||||
qry = db.session.query(Dashboard)
|
||||
if dashboard_id_or_slug.isdigit():
|
||||
qry = qry.filter_by(id=int(dashboard_id_or_slug))
|
||||
else:
|
||||
|
@ -2042,8 +2028,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
"SQL validation does not support template parameters", status=400
|
||||
)
|
||||
|
||||
session = db.session()
|
||||
mydb = session.query(models.Database).filter_by(id=database_id).one_or_none()
|
||||
mydb = db.session.query(models.Database).filter_by(id=database_id).one_or_none()
|
||||
if not mydb:
|
||||
return json_error_response(
|
||||
"Database with id {} is missing.".format(database_id), status=400
|
||||
|
@ -2092,7 +2077,6 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
|
||||
@staticmethod
|
||||
def _sql_json_async( # pylint: disable=too-many-arguments
|
||||
session: Session,
|
||||
rendered_query: str,
|
||||
query: Query,
|
||||
expand_data: bool,
|
||||
|
@ -2101,7 +2085,6 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
"""
|
||||
Send SQL JSON query to celery workers.
|
||||
|
||||
:param session: SQLAlchemy session object
|
||||
:param rendered_query: the rendered query to perform by workers
|
||||
:param query: The query (SQLAlchemy) object
|
||||
:return: A Flask Response
|
||||
|
@ -2132,7 +2115,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
)
|
||||
query.status = QueryStatus.FAILED
|
||||
query.error_message = msg
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
return json_error_response("{}".format(msg))
|
||||
resp = json_success(
|
||||
json.dumps(
|
||||
|
@ -2142,12 +2125,11 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
),
|
||||
status=202,
|
||||
)
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
return resp
|
||||
|
||||
@staticmethod
|
||||
def _sql_json_sync(
|
||||
_session: Session,
|
||||
rendered_query: str,
|
||||
query: Query,
|
||||
expand_data: bool,
|
||||
|
@ -2241,8 +2223,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
tab_name: str = cast(str, query_params.get("tab"))
|
||||
status: str = QueryStatus.PENDING if async_flag else QueryStatus.RUNNING
|
||||
|
||||
session = db.session()
|
||||
mydb = session.query(models.Database).get(database_id)
|
||||
mydb = db.session.query(models.Database).get(database_id)
|
||||
if not mydb:
|
||||
return json_error_response("Database with id %i is missing.", database_id)
|
||||
|
||||
|
@ -2273,13 +2254,13 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
client_id=client_id,
|
||||
)
|
||||
try:
|
||||
session.add(query)
|
||||
session.flush()
|
||||
db.session.add(query)
|
||||
db.session.flush()
|
||||
query_id = query.id
|
||||
session.commit() # shouldn't be necessary
|
||||
db.session.commit() # shouldn't be necessary
|
||||
except SQLAlchemyError as ex:
|
||||
logger.error("Errors saving query details %s", str(ex))
|
||||
session.rollback()
|
||||
db.session.rollback()
|
||||
raise Exception(_("Query record was not created as expected."))
|
||||
if not query_id:
|
||||
raise Exception(_("Query record was not created as expected."))
|
||||
|
@ -2290,7 +2271,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
query.raise_for_access()
|
||||
except SupersetSecurityException as ex:
|
||||
query.status = QueryStatus.FAILED
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
return json_errors_response([ex.error], status=403)
|
||||
|
||||
try:
|
||||
|
@ -2323,13 +2304,9 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
|
||||
# Async request.
|
||||
if async_flag:
|
||||
return self._sql_json_async(
|
||||
session, rendered_query, query, expand_data, log_params
|
||||
)
|
||||
return self._sql_json_async(rendered_query, query, expand_data, log_params)
|
||||
# Sync request.
|
||||
return self._sql_json_sync(
|
||||
session, rendered_query, query, expand_data, log_params
|
||||
)
|
||||
return self._sql_json_sync(rendered_query, query, expand_data, log_params)
|
||||
|
||||
@has_access
|
||||
@expose("/csv/<client_id>")
|
||||
|
@ -2398,9 +2375,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
"""
|
||||
|
||||
datasource_id, datasource_type = request.args["datasourceKey"].split("__")
|
||||
datasource = ConnectorRegistry.get_datasource(
|
||||
datasource_type, datasource_id, db.session
|
||||
)
|
||||
datasource = ConnectorRegistry.get_datasource(datasource_type, datasource_id)
|
||||
# Check if datasource exists
|
||||
if not datasource:
|
||||
return json_error_response(DATASOURCE_MISSING_ERR)
|
||||
|
|
|
@ -47,7 +47,7 @@ class Datasource(BaseSupersetView):
|
|||
datasource_type = datasource_dict.get("type")
|
||||
database_id = datasource_dict["database"].get("id")
|
||||
orm_datasource = ConnectorRegistry.get_datasource(
|
||||
datasource_type, datasource_id, db.session
|
||||
datasource_type, datasource_id
|
||||
)
|
||||
orm_datasource.database_id = database_id
|
||||
|
||||
|
@ -82,7 +82,7 @@ class Datasource(BaseSupersetView):
|
|||
def get(self, datasource_type: str, datasource_id: int) -> FlaskResponse:
|
||||
try:
|
||||
orm_datasource = ConnectorRegistry.get_datasource(
|
||||
datasource_type, datasource_id, db.session
|
||||
datasource_type, datasource_id
|
||||
)
|
||||
if not orm_datasource.data:
|
||||
return json_error_response(
|
||||
|
@ -102,7 +102,7 @@ class Datasource(BaseSupersetView):
|
|||
"""Gets column info from the source system"""
|
||||
if datasource_type == "druid":
|
||||
datasource = ConnectorRegistry.get_datasource(
|
||||
datasource_type, datasource_id, db.session
|
||||
datasource_type, datasource_id
|
||||
)
|
||||
elif datasource_type == "table":
|
||||
database = (
|
||||
|
|
|
@ -105,9 +105,7 @@ def get_viz(
|
|||
form_data: FormData, datasource_type: str, datasource_id: int, force: bool = False
|
||||
) -> BaseViz:
|
||||
viz_type = form_data.get("viz_type", "table")
|
||||
datasource = ConnectorRegistry.get_datasource(
|
||||
datasource_type, datasource_id, db.session
|
||||
)
|
||||
datasource = ConnectorRegistry.get_datasource(datasource_type, datasource_id)
|
||||
viz_obj = viz.viz_types[viz_type](datasource, form_data=form_data, force=force)
|
||||
return viz_obj
|
||||
|
||||
|
@ -293,8 +291,7 @@ CONTAINER_TYPES = ["COLUMN", "GRID", "TABS", "TAB", "ROW"]
|
|||
def get_dashboard_extra_filters(
|
||||
slice_id: int, dashboard_id: int
|
||||
) -> List[Dict[str, Any]]:
|
||||
session = db.session()
|
||||
dashboard = session.query(Dashboard).filter_by(id=dashboard_id).one_or_none()
|
||||
dashboard = db.session.query(Dashboard).filter_by(id=dashboard_id).one_or_none()
|
||||
|
||||
# is chart in this dashboard?
|
||||
if (
|
||||
|
|
|
@ -71,13 +71,17 @@ DB_ACCESS_ROLE = "db_access_role"
|
|||
SCHEMA_ACCESS_ROLE = "schema_access_role"
|
||||
|
||||
|
||||
def create_access_request(session, ds_type, ds_name, role_name, user_name):
|
||||
def create_access_request(ds_type, ds_name, role_name, user_name):
|
||||
ds_class = ConnectorRegistry.sources[ds_type]
|
||||
# TODO: generalize datasource names
|
||||
if ds_type == "table":
|
||||
ds = session.query(ds_class).filter(ds_class.table_name == ds_name).first()
|
||||
ds = db.session.query(ds_class).filter(ds_class.table_name == ds_name).first()
|
||||
else:
|
||||
ds = session.query(ds_class).filter(ds_class.datasource_name == ds_name).first()
|
||||
ds = (
|
||||
db.session.query(ds_class)
|
||||
.filter(ds_class.datasource_name == ds_name)
|
||||
.first()
|
||||
)
|
||||
ds_perm_view = security_manager.find_permission_view_menu(
|
||||
"datasource_access", ds.perm
|
||||
)
|
||||
|
@ -89,8 +93,8 @@ def create_access_request(session, ds_type, ds_name, role_name, user_name):
|
|||
datasource_type=ds_type,
|
||||
created_by_fk=security_manager.find_user(username=user_name).id,
|
||||
)
|
||||
session.add(access_request)
|
||||
session.commit()
|
||||
db.session.add(access_request)
|
||||
db.session.commit()
|
||||
return access_request
|
||||
|
||||
|
||||
|
@ -126,7 +130,6 @@ class TestRequestAccess(SupersetTestCase):
|
|||
override_me = security_manager.find_role("override_me")
|
||||
override_me.permissions = []
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
def test_override_role_permissions_is_admin_only(self):
|
||||
self.logout()
|
||||
|
@ -211,7 +214,6 @@ class TestRequestAccess(SupersetTestCase):
|
|||
)
|
||||
|
||||
def test_clean_requests_after_role_extend(self):
|
||||
session = db.session
|
||||
|
||||
# Case 1. Gamma and gamma2 requested test_role1 on energy_usage access
|
||||
# Gamma already has role test_role1
|
||||
|
@ -221,12 +223,10 @@ class TestRequestAccess(SupersetTestCase):
|
|||
# gamma2 and gamma request table_role on energy usage
|
||||
if app.config["ENABLE_ACCESS_REQUEST"]:
|
||||
access_request1 = create_access_request(
|
||||
session, "table", "random_time_series", TEST_ROLE_1, "gamma2"
|
||||
"table", "random_time_series", TEST_ROLE_1, "gamma2"
|
||||
)
|
||||
ds_1_id = access_request1.datasource_id
|
||||
create_access_request(
|
||||
session, "table", "random_time_series", TEST_ROLE_1, "gamma"
|
||||
)
|
||||
create_access_request("table", "random_time_series", TEST_ROLE_1, "gamma")
|
||||
access_requests = self.get_access_requests("gamma", "table", ds_1_id)
|
||||
self.assertTrue(access_requests)
|
||||
# gamma gets test_role1
|
||||
|
@ -244,22 +244,20 @@ class TestRequestAccess(SupersetTestCase):
|
|||
gamma_user.roles.remove(security_manager.find_role("test_role1"))
|
||||
|
||||
def test_clean_requests_after_alpha_grant(self):
|
||||
session = db.session
|
||||
|
||||
# Case 2. Two access requests from gamma and gamma2
|
||||
# Gamma becomes alpha, gamma2 gets granted
|
||||
# Check if request by gamma has been deleted
|
||||
|
||||
access_request1 = create_access_request(
|
||||
session, "table", "birth_names", TEST_ROLE_1, "gamma"
|
||||
"table", "birth_names", TEST_ROLE_1, "gamma"
|
||||
)
|
||||
create_access_request(session, "table", "birth_names", TEST_ROLE_2, "gamma2")
|
||||
create_access_request("table", "birth_names", TEST_ROLE_2, "gamma2")
|
||||
ds_1_id = access_request1.datasource_id
|
||||
# gamma becomes alpha
|
||||
alpha_role = security_manager.find_role("Alpha")
|
||||
gamma_user = security_manager.find_user(username="gamma")
|
||||
gamma_user.roles.append(alpha_role)
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
access_requests = self.get_access_requests("gamma", "table", ds_1_id)
|
||||
self.assertTrue(access_requests)
|
||||
self.client.get(
|
||||
|
@ -270,23 +268,21 @@ class TestRequestAccess(SupersetTestCase):
|
|||
|
||||
gamma_user = security_manager.find_user(username="gamma")
|
||||
gamma_user.roles.remove(security_manager.find_role("Alpha"))
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
|
||||
def test_clean_requests_after_db_grant(self):
|
||||
session = db.session
|
||||
|
||||
# Case 3. Two access requests from gamma and gamma2
|
||||
# Gamma gets database access, gamma2 access request granted
|
||||
# Check if request by gamma has been deleted
|
||||
|
||||
gamma_user = security_manager.find_user(username="gamma")
|
||||
access_request1 = create_access_request(
|
||||
session, "table", "energy_usage", TEST_ROLE_1, "gamma"
|
||||
"table", "energy_usage", TEST_ROLE_1, "gamma"
|
||||
)
|
||||
create_access_request(session, "table", "energy_usage", TEST_ROLE_2, "gamma2")
|
||||
create_access_request("table", "energy_usage", TEST_ROLE_2, "gamma2")
|
||||
ds_1_id = access_request1.datasource_id
|
||||
# gamma gets granted database access
|
||||
database = session.query(models.Database).first()
|
||||
database = db.session.query(models.Database).first()
|
||||
|
||||
security_manager.add_permission_view_menu("database_access", database.perm)
|
||||
ds_perm_view = security_manager.find_permission_view_menu(
|
||||
|
@ -296,7 +292,7 @@ class TestRequestAccess(SupersetTestCase):
|
|||
security_manager.find_role(DB_ACCESS_ROLE), ds_perm_view
|
||||
)
|
||||
gamma_user.roles.append(security_manager.find_role(DB_ACCESS_ROLE))
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
access_requests = self.get_access_requests("gamma", "table", ds_1_id)
|
||||
self.assertTrue(access_requests)
|
||||
# gamma2 request gets fulfilled
|
||||
|
@ -308,25 +304,21 @@ class TestRequestAccess(SupersetTestCase):
|
|||
self.assertFalse(access_requests)
|
||||
gamma_user = security_manager.find_user(username="gamma")
|
||||
gamma_user.roles.remove(security_manager.find_role(DB_ACCESS_ROLE))
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
|
||||
def test_clean_requests_after_schema_grant(self):
|
||||
session = db.session
|
||||
|
||||
# Case 4. Two access requests from gamma and gamma2
|
||||
# Gamma gets schema access, gamma2 access request granted
|
||||
# Check if request by gamma has been deleted
|
||||
|
||||
gamma_user = security_manager.find_user(username="gamma")
|
||||
access_request1 = create_access_request(
|
||||
session, "table", "wb_health_population", TEST_ROLE_1, "gamma"
|
||||
)
|
||||
create_access_request(
|
||||
session, "table", "wb_health_population", TEST_ROLE_2, "gamma2"
|
||||
"table", "wb_health_population", TEST_ROLE_1, "gamma"
|
||||
)
|
||||
create_access_request("table", "wb_health_population", TEST_ROLE_2, "gamma2")
|
||||
ds_1_id = access_request1.datasource_id
|
||||
ds = (
|
||||
session.query(SqlaTable)
|
||||
db.session.query(SqlaTable)
|
||||
.filter_by(table_name="wb_health_population")
|
||||
.first()
|
||||
)
|
||||
|
@ -340,7 +332,7 @@ class TestRequestAccess(SupersetTestCase):
|
|||
security_manager.find_role(SCHEMA_ACCESS_ROLE), schema_perm_view
|
||||
)
|
||||
gamma_user.roles.append(security_manager.find_role(SCHEMA_ACCESS_ROLE))
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
# gamma2 request gets fulfilled
|
||||
self.client.get(
|
||||
EXTEND_ROLE_REQUEST.format("table", ds_1_id, "gamma2", TEST_ROLE_2)
|
||||
|
@ -351,25 +343,24 @@ class TestRequestAccess(SupersetTestCase):
|
|||
gamma_user.roles.remove(security_manager.find_role(SCHEMA_ACCESS_ROLE))
|
||||
|
||||
ds = (
|
||||
session.query(SqlaTable)
|
||||
db.session.query(SqlaTable)
|
||||
.filter_by(table_name="wb_health_population")
|
||||
.first()
|
||||
)
|
||||
ds.schema = None
|
||||
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
|
||||
@mock.patch("superset.utils.core.send_mime_email")
|
||||
def test_approve(self, mock_send_mime):
|
||||
if app.config["ENABLE_ACCESS_REQUEST"]:
|
||||
session = db.session
|
||||
TEST_ROLE_NAME = "table_role"
|
||||
security_manager.add_role(TEST_ROLE_NAME)
|
||||
|
||||
# Case 1. Grant new role to the user.
|
||||
|
||||
access_request1 = create_access_request(
|
||||
session, "table", "unicode_test", TEST_ROLE_NAME, "gamma"
|
||||
"table", "unicode_test", TEST_ROLE_NAME, "gamma"
|
||||
)
|
||||
ds_1_id = access_request1.datasource_id
|
||||
self.get_resp(
|
||||
|
@ -404,7 +395,7 @@ class TestRequestAccess(SupersetTestCase):
|
|||
# Case 2. Extend the role to have access to the table
|
||||
|
||||
access_request2 = create_access_request(
|
||||
session, "table", "energy_usage", TEST_ROLE_NAME, "gamma"
|
||||
"table", "energy_usage", TEST_ROLE_NAME, "gamma"
|
||||
)
|
||||
ds_2_id = access_request2.datasource_id
|
||||
energy_usage_perm = access_request2.datasource.perm
|
||||
|
@ -448,7 +439,7 @@ class TestRequestAccess(SupersetTestCase):
|
|||
|
||||
security_manager.add_role("druid_role")
|
||||
access_request3 = create_access_request(
|
||||
session, "druid", "druid_ds_1", "druid_role", "gamma"
|
||||
"druid", "druid_ds_1", "druid_role", "gamma"
|
||||
)
|
||||
self.get_resp(
|
||||
GRANT_ROLE_REQUEST.format(
|
||||
|
@ -463,7 +454,7 @@ class TestRequestAccess(SupersetTestCase):
|
|||
# Case 4. Extend the role to have access to the druid datasource
|
||||
|
||||
access_request4 = create_access_request(
|
||||
session, "druid", "druid_ds_2", "druid_role", "gamma"
|
||||
"druid", "druid_ds_2", "druid_role", "gamma"
|
||||
)
|
||||
druid_ds_2_perm = access_request4.datasource.perm
|
||||
|
||||
|
@ -483,19 +474,18 @@ class TestRequestAccess(SupersetTestCase):
|
|||
gamma_user = security_manager.find_user(username="gamma")
|
||||
gamma_user.roles.remove(security_manager.find_role("druid_role"))
|
||||
gamma_user.roles.remove(security_manager.find_role(TEST_ROLE_NAME))
|
||||
session.delete(security_manager.find_role("druid_role"))
|
||||
session.delete(security_manager.find_role(TEST_ROLE_NAME))
|
||||
session.commit()
|
||||
db.session.delete(security_manager.find_role("druid_role"))
|
||||
db.session.delete(security_manager.find_role(TEST_ROLE_NAME))
|
||||
db.session.commit()
|
||||
|
||||
def test_request_access(self):
|
||||
if app.config["ENABLE_ACCESS_REQUEST"]:
|
||||
session = db.session
|
||||
self.logout()
|
||||
self.login(username="gamma")
|
||||
gamma_user = security_manager.find_user(username="gamma")
|
||||
security_manager.add_role("dummy_role")
|
||||
gamma_user.roles.append(security_manager.find_role("dummy_role"))
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
|
||||
ACCESS_REQUEST = (
|
||||
"/superset/request_access?"
|
||||
|
@ -511,7 +501,7 @@ class TestRequestAccess(SupersetTestCase):
|
|||
# Request table access, there are no roles have this table.
|
||||
|
||||
table1 = (
|
||||
session.query(SqlaTable)
|
||||
db.session.query(SqlaTable)
|
||||
.filter_by(table_name="random_time_series")
|
||||
.first()
|
||||
)
|
||||
|
@ -526,7 +516,7 @@ class TestRequestAccess(SupersetTestCase):
|
|||
# Request access, roles exist that contains the table.
|
||||
# add table to the existing roles
|
||||
table3 = (
|
||||
session.query(SqlaTable).filter_by(table_name="energy_usage").first()
|
||||
db.session.query(SqlaTable).filter_by(table_name="energy_usage").first()
|
||||
)
|
||||
table_3_id = table3.id
|
||||
table3_perm = table3.perm
|
||||
|
@ -545,7 +535,7 @@ class TestRequestAccess(SupersetTestCase):
|
|||
"datasource_access", table3_perm
|
||||
),
|
||||
)
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
|
||||
self.get_resp(ACCESS_REQUEST.format("table", table_3_id, "go"))
|
||||
access_request3 = self.get_access_requests("gamma", "table", table_3_id)
|
||||
|
@ -559,7 +549,7 @@ class TestRequestAccess(SupersetTestCase):
|
|||
|
||||
# Request druid access, there are no roles have this table.
|
||||
druid_ds_4 = (
|
||||
session.query(DruidDatasource)
|
||||
db.session.query(DruidDatasource)
|
||||
.filter_by(datasource_name="druid_ds_1")
|
||||
.first()
|
||||
)
|
||||
|
@ -574,7 +564,7 @@ class TestRequestAccess(SupersetTestCase):
|
|||
# Case 5. Roles exist that contains the druid datasource.
|
||||
# add druid ds to the existing roles
|
||||
druid_ds_5 = (
|
||||
session.query(DruidDatasource)
|
||||
db.session.query(DruidDatasource)
|
||||
.filter_by(datasource_name="druid_ds_2")
|
||||
.first()
|
||||
)
|
||||
|
@ -595,7 +585,7 @@ class TestRequestAccess(SupersetTestCase):
|
|||
"datasource_access", druid_ds_5_perm
|
||||
),
|
||||
)
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
|
||||
self.get_resp(ACCESS_REQUEST.format("druid", druid_ds_5_id, "go"))
|
||||
access_request5 = self.get_access_requests("gamma", "druid", druid_ds_5_id)
|
||||
|
@ -610,7 +600,7 @@ class TestRequestAccess(SupersetTestCase):
|
|||
# cleanup
|
||||
gamma_user = security_manager.find_user(username="gamma")
|
||||
gamma_user.roles.remove(security_manager.find_role("dummy_role"))
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -32,112 +32,118 @@ logging.basicConfig(level=logging.INFO)
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.yield_fixture(scope="module")
|
||||
def setup_database():
|
||||
def setup_module():
|
||||
with app.app_context():
|
||||
slice_id = db.session.query(Slice).all()[0].id
|
||||
database_id = utils.get_example_database().id
|
||||
|
||||
alert1 = Alert(
|
||||
id=1,
|
||||
label="alert_1",
|
||||
active=True,
|
||||
crontab="*/1 * * * *",
|
||||
sql="SELECT 0",
|
||||
alert_type="email",
|
||||
slice_id=slice_id,
|
||||
database_id=database_id,
|
||||
)
|
||||
alert2 = Alert(
|
||||
id=2,
|
||||
label="alert_2",
|
||||
active=True,
|
||||
crontab="*/1 * * * *",
|
||||
sql="SELECT 55",
|
||||
alert_type="email",
|
||||
slice_id=slice_id,
|
||||
database_id=database_id,
|
||||
)
|
||||
alert3 = Alert(
|
||||
id=3,
|
||||
label="alert_3",
|
||||
active=False,
|
||||
crontab="*/1 * * * *",
|
||||
sql="UPDATE 55",
|
||||
alert_type="email",
|
||||
slice_id=slice_id,
|
||||
database_id=database_id,
|
||||
)
|
||||
alert4 = Alert(id=4, active=False, label="alert_4", database_id=-1)
|
||||
alert5 = Alert(id=5, active=False, label="alert_5", database_id=database_id)
|
||||
alerts = [
|
||||
Alert(
|
||||
id=1,
|
||||
label="alert_1",
|
||||
active=True,
|
||||
crontab="*/1 * * * *",
|
||||
sql="SELECT 0",
|
||||
alert_type="email",
|
||||
slice_id=slice_id,
|
||||
database_id=database_id,
|
||||
),
|
||||
Alert(
|
||||
id=2,
|
||||
label="alert_2",
|
||||
active=True,
|
||||
crontab="*/1 * * * *",
|
||||
sql="SELECT 55",
|
||||
alert_type="email",
|
||||
slice_id=slice_id,
|
||||
database_id=database_id,
|
||||
),
|
||||
Alert(
|
||||
id=3,
|
||||
label="alert_3",
|
||||
active=False,
|
||||
crontab="*/1 * * * *",
|
||||
sql="UPDATE 55",
|
||||
alert_type="email",
|
||||
slice_id=slice_id,
|
||||
database_id=database_id,
|
||||
),
|
||||
Alert(id=4, active=False, label="alert_4", database_id=-1),
|
||||
Alert(id=5, active=False, label="alert_5", database_id=database_id),
|
||||
]
|
||||
|
||||
for num in range(1, 6):
|
||||
eval(f"db.session.add(alert{num})")
|
||||
db.session.bulk_save_objects(alerts)
|
||||
db.session.commit()
|
||||
yield db.session
|
||||
|
||||
|
||||
def teardown_module():
|
||||
with app.app_context():
|
||||
db.session.query(AlertLog).delete()
|
||||
db.session.query(Alert).delete()
|
||||
|
||||
|
||||
@patch("superset.tasks.schedules.deliver_alert")
|
||||
@patch("superset.tasks.schedules.logging.Logger.error")
|
||||
def test_run_alert_query(mock_error, mock_deliver, setup_database):
|
||||
database = setup_database
|
||||
run_alert_query(database.query(Alert).filter_by(id=1).one(), database)
|
||||
alert1 = database.query(Alert).filter_by(id=1).one()
|
||||
assert mock_deliver.call_count == 0
|
||||
assert len(alert1.logs) == 1
|
||||
assert alert1.logs[0].alert_id == 1
|
||||
assert alert1.logs[0].state == "pass"
|
||||
def test_run_alert_query(mock_error, mock_deliver_alert):
|
||||
with app.app_context():
|
||||
run_alert_query(db.session.query(Alert).filter_by(id=1).one())
|
||||
alert1 = db.session.query(Alert).filter_by(id=1).one()
|
||||
assert mock_deliver_alert.call_count == 0
|
||||
assert len(alert1.logs) == 1
|
||||
assert alert1.logs[0].alert_id == 1
|
||||
assert alert1.logs[0].state == "pass"
|
||||
|
||||
run_alert_query(database.query(Alert).filter_by(id=2).one(), database)
|
||||
alert2 = database.query(Alert).filter_by(id=2).one()
|
||||
assert mock_deliver.call_count == 1
|
||||
assert len(alert2.logs) == 1
|
||||
assert alert2.logs[0].alert_id == 2
|
||||
assert alert2.logs[0].state == "trigger"
|
||||
run_alert_query(db.session.query(Alert).filter_by(id=2).one())
|
||||
alert2 = db.session.query(Alert).filter_by(id=2).one()
|
||||
assert mock_deliver_alert.call_count == 1
|
||||
assert len(alert2.logs) == 1
|
||||
assert alert2.logs[0].alert_id == 2
|
||||
assert alert2.logs[0].state == "trigger"
|
||||
|
||||
run_alert_query(database.query(Alert).filter_by(id=3).one(), database)
|
||||
alert3 = database.query(Alert).filter_by(id=3).one()
|
||||
assert mock_deliver.call_count == 1
|
||||
assert mock_error.call_count == 2
|
||||
assert len(alert3.logs) == 1
|
||||
assert alert3.logs[0].alert_id == 3
|
||||
assert alert3.logs[0].state == "error"
|
||||
run_alert_query(db.session.query(Alert).filter_by(id=3).one())
|
||||
alert3 = db.session.query(Alert).filter_by(id=3).one()
|
||||
assert mock_deliver_alert.call_count == 1
|
||||
assert mock_error.call_count == 2
|
||||
assert len(alert3.logs) == 1
|
||||
assert alert3.logs[0].alert_id == 3
|
||||
assert alert3.logs[0].state == "error"
|
||||
|
||||
run_alert_query(database.query(Alert).filter_by(id=4).one(), database)
|
||||
assert mock_deliver.call_count == 1
|
||||
assert mock_error.call_count == 3
|
||||
run_alert_query(db.session.query(Alert).filter_by(id=4).one())
|
||||
assert mock_deliver_alert.call_count == 1
|
||||
assert mock_error.call_count == 3
|
||||
|
||||
run_alert_query(database.query(Alert).filter_by(id=5).one(), database)
|
||||
assert mock_deliver.call_count == 1
|
||||
assert mock_error.call_count == 4
|
||||
run_alert_query(db.session.query(Alert).filter_by(id=5).one())
|
||||
assert mock_deliver_alert.call_count == 1
|
||||
assert mock_error.call_count == 4
|
||||
|
||||
|
||||
@patch("superset.tasks.schedules.deliver_alert")
|
||||
@patch("superset.tasks.schedules.run_alert_query")
|
||||
def test_schedule_alert_query(mock_run_alert, mock_deliver_alert, setup_database):
|
||||
database = setup_database
|
||||
active_alert = database.query(Alert).filter_by(id=1).one()
|
||||
inactive_alert = database.query(Alert).filter_by(id=3).one()
|
||||
def test_schedule_alert_query(mock_run_alert, mock_deliver_alert):
|
||||
with app.app_context():
|
||||
active_alert = db.session.query(Alert).filter_by(id=1).one()
|
||||
inactive_alert = db.session.query(Alert).filter_by(id=3).one()
|
||||
|
||||
# Test that inactive alerts are no processed
|
||||
schedule_alert_query(report_type=ScheduleType.alert, schedule_id=inactive_alert.id)
|
||||
assert mock_run_alert.call_count == 0
|
||||
assert mock_deliver_alert.call_count == 0
|
||||
# Test that inactive alerts are no processed
|
||||
schedule_alert_query(
|
||||
report_type=ScheduleType.alert, schedule_id=inactive_alert.id
|
||||
)
|
||||
assert mock_run_alert.call_count == 0
|
||||
assert mock_deliver_alert.call_count == 0
|
||||
|
||||
# Test that active alerts with no recipients passed in are processed regularly
|
||||
schedule_alert_query(report_type=ScheduleType.alert, schedule_id=active_alert.id)
|
||||
assert mock_run_alert.call_count == 1
|
||||
assert mock_deliver_alert.call_count == 0
|
||||
# Test that active alerts with no recipients passed in are processed regularly
|
||||
schedule_alert_query(
|
||||
report_type=ScheduleType.alert, schedule_id=active_alert.id
|
||||
)
|
||||
assert mock_run_alert.call_count == 1
|
||||
assert mock_deliver_alert.call_count == 0
|
||||
|
||||
# Test that active alerts sent as a test are delivered immediately
|
||||
schedule_alert_query(
|
||||
report_type=ScheduleType.alert,
|
||||
schedule_id=active_alert.id,
|
||||
recipients="testing@email.com",
|
||||
is_test_alert=True,
|
||||
)
|
||||
assert mock_run_alert.call_count == 1
|
||||
assert mock_deliver_alert.call_count == 1
|
||||
# Test that active alerts sent as a test are delivered immediately
|
||||
schedule_alert_query(
|
||||
report_type=ScheduleType.alert,
|
||||
schedule_id=active_alert.id,
|
||||
recipients="testing@email.com",
|
||||
is_test_alert=True,
|
||||
)
|
||||
assert mock_run_alert.call_count == 1
|
||||
assert mock_deliver_alert.call_count == 1
|
||||
|
|
|
@ -25,7 +25,6 @@ import pandas as pd
|
|||
from flask import Response
|
||||
from flask_appbuilder.security.sqla import models as ab_models
|
||||
from flask_testing import TestCase
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from tests.test_app import app
|
||||
from superset.sql_parse import CtasMethod
|
||||
|
@ -103,24 +102,25 @@ class SupersetTestCase(TestCase):
|
|||
# create druid cluster and druid datasources
|
||||
|
||||
with app.app_context():
|
||||
session = db.session
|
||||
cluster = (
|
||||
session.query(DruidCluster).filter_by(cluster_name="druid_test").first()
|
||||
db.session.query(DruidCluster)
|
||||
.filter_by(cluster_name="druid_test")
|
||||
.first()
|
||||
)
|
||||
if not cluster:
|
||||
cluster = DruidCluster(cluster_name="druid_test")
|
||||
session.add(cluster)
|
||||
session.commit()
|
||||
db.session.add(cluster)
|
||||
db.session.commit()
|
||||
|
||||
druid_datasource1 = DruidDatasource(
|
||||
datasource_name="druid_ds_1", cluster=cluster
|
||||
)
|
||||
session.add(druid_datasource1)
|
||||
db.session.add(druid_datasource1)
|
||||
druid_datasource2 = DruidDatasource(
|
||||
datasource_name="druid_ds_2", cluster=cluster
|
||||
)
|
||||
session.add(druid_datasource2)
|
||||
session.commit()
|
||||
db.session.add(druid_datasource2)
|
||||
db.session.commit()
|
||||
|
||||
@staticmethod
|
||||
def get_table_by_id(table_id: int) -> SqlaTable:
|
||||
|
@ -134,25 +134,23 @@ class SupersetTestCase(TestCase):
|
|||
except ImportError:
|
||||
return False
|
||||
|
||||
def get_or_create(self, cls, criteria, session, **kwargs):
|
||||
obj = session.query(cls).filter_by(**criteria).first()
|
||||
def get_or_create(self, cls, criteria, **kwargs):
|
||||
obj = db.session.query(cls).filter_by(**criteria).first()
|
||||
if not obj:
|
||||
obj = cls(**criteria)
|
||||
obj.__dict__.update(**kwargs)
|
||||
session.add(obj)
|
||||
session.commit()
|
||||
db.session.add(obj)
|
||||
db.session.commit()
|
||||
return obj
|
||||
|
||||
def login(self, username="admin", password="general"):
|
||||
resp = self.get_resp("/login/", data=dict(username=username, password=password))
|
||||
self.assertNotIn("User confirmation needed", resp)
|
||||
|
||||
def get_slice(
|
||||
self, slice_name: str, session: Session, expunge_from_session: bool = True
|
||||
) -> Slice:
|
||||
slc = session.query(Slice).filter_by(slice_name=slice_name).one()
|
||||
def get_slice(self, slice_name: str, expunge_from_session: bool = True) -> Slice:
|
||||
slc = db.session.query(Slice).filter_by(slice_name=slice_name).one()
|
||||
if expunge_from_session:
|
||||
session.expunge_all()
|
||||
db.session.expunge_all()
|
||||
return slc
|
||||
|
||||
@staticmethod
|
||||
|
@ -301,7 +299,6 @@ class SupersetTestCase(TestCase):
|
|||
return self.get_or_create(
|
||||
cls=models.Database,
|
||||
criteria={"database_name": database_name},
|
||||
session=db.session,
|
||||
sqlalchemy_uri="sqlite:///:memory:",
|
||||
id=db_id,
|
||||
extra=extra,
|
||||
|
@ -323,7 +320,6 @@ class SupersetTestCase(TestCase):
|
|||
return self.get_or_create(
|
||||
cls=models.Database,
|
||||
criteria={"database_name": database_name},
|
||||
session=db.session,
|
||||
sqlalchemy_uri="presto://user@host:8080/hive",
|
||||
id=db_id,
|
||||
)
|
||||
|
|
|
@ -97,15 +97,13 @@ CTAS_SCHEMA_NAME = "sqllab_test_db"
|
|||
|
||||
class TestCelery(SupersetTestCase):
|
||||
def get_query_by_name(self, sql):
|
||||
session = db.session
|
||||
query = session.query(Query).filter_by(sql=sql).first()
|
||||
session.close()
|
||||
query = db.session.query(Query).filter_by(sql=sql).first()
|
||||
db.session.close()
|
||||
return query
|
||||
|
||||
def get_query_by_id(self, id):
|
||||
session = db.session
|
||||
query = session.query(Query).filter_by(id=id).first()
|
||||
session.close()
|
||||
query = db.session.query(Query).filter_by(id=id).first()
|
||||
db.session.close()
|
||||
return query
|
||||
|
||||
@classmethod
|
||||
|
|
|
@ -57,9 +57,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
|
|||
for owner in owners:
|
||||
user = db.session.query(security_manager.user_model).get(owner)
|
||||
obj_owners.append(user)
|
||||
datasource = ConnectorRegistry.get_datasource(
|
||||
datasource_type, datasource_id, db.session
|
||||
)
|
||||
datasource = ConnectorRegistry.get_datasource(datasource_type, datasource_id)
|
||||
slice = Slice(
|
||||
slice_name=slice_name,
|
||||
datasource_id=datasource.id,
|
||||
|
|
|
@ -99,7 +99,7 @@ class TestCore(SupersetTestCase):
|
|||
|
||||
def test_slice_endpoint(self):
|
||||
self.login(username="admin")
|
||||
slc = self.get_slice("Girls", db.session)
|
||||
slc = self.get_slice("Girls")
|
||||
resp = self.get_resp("/superset/slice/{}/".format(slc.id))
|
||||
assert "Time Column" in resp
|
||||
assert "List Roles" in resp
|
||||
|
@ -113,7 +113,7 @@ class TestCore(SupersetTestCase):
|
|||
|
||||
def test_viz_cache_key(self):
|
||||
self.login(username="admin")
|
||||
slc = self.get_slice("Girls", db.session)
|
||||
slc = self.get_slice("Girls")
|
||||
|
||||
viz = slc.viz
|
||||
qobj = viz.query_obj()
|
||||
|
@ -229,7 +229,7 @@ class TestCore(SupersetTestCase):
|
|||
def test_save_slice(self):
|
||||
self.login(username="admin")
|
||||
slice_name = f"Energy Sankey"
|
||||
slice_id = self.get_slice(slice_name, db.session).id
|
||||
slice_id = self.get_slice(slice_name).id
|
||||
copy_name_prefix = "Test Sankey"
|
||||
copy_name = f"{copy_name_prefix}[save]{random.random()}"
|
||||
tbl_id = self.table_ids.get("energy_usage")
|
||||
|
@ -295,7 +295,7 @@ class TestCore(SupersetTestCase):
|
|||
def test_filter_endpoint(self):
|
||||
self.login(username="admin")
|
||||
slice_name = "Energy Sankey"
|
||||
slice_id = self.get_slice(slice_name, db.session).id
|
||||
slice_id = self.get_slice(slice_name).id
|
||||
db.session.commit()
|
||||
tbl_id = self.table_ids.get("energy_usage")
|
||||
table = db.session.query(SqlaTable).filter(SqlaTable.id == tbl_id)
|
||||
|
@ -315,9 +315,7 @@ class TestCore(SupersetTestCase):
|
|||
def test_slice_data(self):
|
||||
# slice data should have some required attributes
|
||||
self.login(username="admin")
|
||||
slc = self.get_slice(
|
||||
slice_name="Girls", session=db.session, expunge_from_session=False
|
||||
)
|
||||
slc = self.get_slice(slice_name="Girls", expunge_from_session=False)
|
||||
slc_data_attributes = slc.data.keys()
|
||||
assert "changed_on" in slc_data_attributes
|
||||
assert "modified" in slc_data_attributes
|
||||
|
@ -368,9 +366,7 @@ class TestCore(SupersetTestCase):
|
|||
self.assertEqual(data, [])
|
||||
|
||||
# make user owner of slice and verify that endpoint returns said slice
|
||||
slc = self.get_slice(
|
||||
slice_name=slice_name, session=db.session, expunge_from_session=False
|
||||
)
|
||||
slc = self.get_slice(slice_name=slice_name, expunge_from_session=False)
|
||||
slc.owners = [user]
|
||||
db.session.merge(slc)
|
||||
db.session.commit()
|
||||
|
@ -381,9 +377,7 @@ class TestCore(SupersetTestCase):
|
|||
self.assertEqual(data[0]["title"], slice_name)
|
||||
|
||||
# remove ownership and ensure user no longer gets slice
|
||||
slc = self.get_slice(
|
||||
slice_name=slice_name, session=db.session, expunge_from_session=False
|
||||
)
|
||||
slc = self.get_slice(slice_name=slice_name, expunge_from_session=False)
|
||||
slc.owners = []
|
||||
db.session.merge(slc)
|
||||
db.session.commit()
|
||||
|
@ -561,7 +555,7 @@ class TestCore(SupersetTestCase):
|
|||
db.session.commit()
|
||||
|
||||
def test_warm_up_cache(self):
|
||||
slc = self.get_slice("Girls", db.session)
|
||||
slc = self.get_slice("Girls")
|
||||
data = self.get_json_resp("/superset/warm_up_cache?slice_id={}".format(slc.id))
|
||||
self.assertEqual(
|
||||
data, [{"slice_id": slc.id, "viz_error": None, "viz_status": "success"}]
|
||||
|
@ -769,7 +763,7 @@ class TestCore(SupersetTestCase):
|
|||
|
||||
def test_user_profile(self, username="admin"):
|
||||
self.login(username=username)
|
||||
slc = self.get_slice("Girls", db.session)
|
||||
slc = self.get_slice("Girls")
|
||||
|
||||
# Setting some faves
|
||||
url = f"/superset/favstar/Slice/{slc.id}/select/"
|
||||
|
|
|
@ -178,12 +178,11 @@ class TestDatabaseApi(SupersetTestCase):
|
|||
"""
|
||||
Database API: Test get select star with datasource access
|
||||
"""
|
||||
session = db.session
|
||||
table = SqlaTable(
|
||||
schema="main", table_name="ab_permission", database=get_example_database()
|
||||
)
|
||||
session.add(table)
|
||||
session.commit()
|
||||
db.session.add(table)
|
||||
db.session.commit()
|
||||
|
||||
tmp_table_perm = security_manager.find_permission_view_menu(
|
||||
"datasource_access", table.get_perm()
|
||||
|
|
|
@ -156,7 +156,7 @@ class TestDatasetApi(SupersetTestCase):
|
|||
"template_params": None,
|
||||
}
|
||||
for key, value in expected_result.items():
|
||||
self.assertEqual(response["result"][key], expected_result[key])
|
||||
self.assertEqual(response["result"][key], value)
|
||||
self.assertEqual(len(response["result"]["columns"]), 8)
|
||||
self.assertEqual(len(response["result"]["metrics"]), 2)
|
||||
|
||||
|
@ -717,10 +717,7 @@ class TestDatasetApi(SupersetTestCase):
|
|||
)
|
||||
|
||||
cli_export = export_to_dict(
|
||||
session=db.session,
|
||||
recursive=True,
|
||||
back_references=False,
|
||||
include_defaults=False,
|
||||
recursive=True, back_references=False, include_defaults=False,
|
||||
)
|
||||
cli_export_tables = cli_export["databases"][0]["tables"]
|
||||
expected_response = []
|
||||
|
|
|
@ -47,14 +47,13 @@ class TestDictImportExport(SupersetTestCase):
|
|||
def delete_imports(cls):
|
||||
with app.app_context():
|
||||
# Imported data clean up
|
||||
session = db.session
|
||||
for table in session.query(SqlaTable):
|
||||
for table in db.session.query(SqlaTable):
|
||||
if DBREF in table.params_dict:
|
||||
session.delete(table)
|
||||
for datasource in session.query(DruidDatasource):
|
||||
db.session.delete(table)
|
||||
for datasource in db.session.query(DruidDatasource):
|
||||
if DBREF in datasource.params_dict:
|
||||
session.delete(datasource)
|
||||
session.commit()
|
||||
db.session.delete(datasource)
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
@ -90,9 +89,7 @@ class TestDictImportExport(SupersetTestCase):
|
|||
|
||||
def create_druid_datasource(self, name, id=0, cols_names=[], metric_names=[]):
|
||||
cluster_name = "druid_test"
|
||||
cluster = self.get_or_create(
|
||||
DruidCluster, {"cluster_name": cluster_name}, db.session
|
||||
)
|
||||
cluster = self.get_or_create(DruidCluster, {"cluster_name": cluster_name})
|
||||
|
||||
name = "{0}{1}".format(NAME_PREFIX, name)
|
||||
params = {DBREF: id, "database_name": cluster_name}
|
||||
|
@ -159,7 +156,7 @@ class TestDictImportExport(SupersetTestCase):
|
|||
|
||||
def test_import_table_no_metadata(self):
|
||||
table, dict_table = self.create_table("pure_table", id=ID_PREFIX + 1)
|
||||
new_table = SqlaTable.import_from_dict(db.session, dict_table)
|
||||
new_table = SqlaTable.import_from_dict(dict_table)
|
||||
db.session.commit()
|
||||
imported_id = new_table.id
|
||||
imported = self.get_table_by_id(imported_id)
|
||||
|
@ -173,7 +170,7 @@ class TestDictImportExport(SupersetTestCase):
|
|||
cols_names=["col1"],
|
||||
metric_names=["metric1"],
|
||||
)
|
||||
imported_table = SqlaTable.import_from_dict(db.session, dict_table)
|
||||
imported_table = SqlaTable.import_from_dict(dict_table)
|
||||
db.session.commit()
|
||||
imported = self.get_table_by_id(imported_table.id)
|
||||
self.assert_table_equals(table, imported)
|
||||
|
@ -189,7 +186,7 @@ class TestDictImportExport(SupersetTestCase):
|
|||
cols_names=["c1", "c2"],
|
||||
metric_names=["m1", "m2"],
|
||||
)
|
||||
imported_table = SqlaTable.import_from_dict(db.session, dict_table)
|
||||
imported_table = SqlaTable.import_from_dict(dict_table)
|
||||
db.session.commit()
|
||||
imported = self.get_table_by_id(imported_table.id)
|
||||
self.assert_table_equals(table, imported)
|
||||
|
@ -199,7 +196,7 @@ class TestDictImportExport(SupersetTestCase):
|
|||
table, dict_table = self.create_table(
|
||||
"table_override", id=ID_PREFIX + 3, cols_names=["col1"], metric_names=["m1"]
|
||||
)
|
||||
imported_table = SqlaTable.import_from_dict(db.session, dict_table)
|
||||
imported_table = SqlaTable.import_from_dict(dict_table)
|
||||
db.session.commit()
|
||||
table_over, dict_table_over = self.create_table(
|
||||
"table_override",
|
||||
|
@ -207,7 +204,7 @@ class TestDictImportExport(SupersetTestCase):
|
|||
cols_names=["new_col1", "col2", "col3"],
|
||||
metric_names=["new_metric1"],
|
||||
)
|
||||
imported_over_table = SqlaTable.import_from_dict(db.session, dict_table_over)
|
||||
imported_over_table = SqlaTable.import_from_dict(dict_table_over)
|
||||
db.session.commit()
|
||||
|
||||
imported_over = self.get_table_by_id(imported_over_table.id)
|
||||
|
@ -227,7 +224,7 @@ class TestDictImportExport(SupersetTestCase):
|
|||
table, dict_table = self.create_table(
|
||||
"table_override", id=ID_PREFIX + 3, cols_names=["col1"], metric_names=["m1"]
|
||||
)
|
||||
imported_table = SqlaTable.import_from_dict(db.session, dict_table)
|
||||
imported_table = SqlaTable.import_from_dict(dict_table)
|
||||
db.session.commit()
|
||||
table_over, dict_table_over = self.create_table(
|
||||
"table_override",
|
||||
|
@ -236,7 +233,7 @@ class TestDictImportExport(SupersetTestCase):
|
|||
metric_names=["new_metric1"],
|
||||
)
|
||||
imported_over_table = SqlaTable.import_from_dict(
|
||||
session=db.session, dict_rep=dict_table_over, sync=["metrics", "columns"]
|
||||
dict_rep=dict_table_over, sync=["metrics", "columns"]
|
||||
)
|
||||
db.session.commit()
|
||||
|
||||
|
@ -260,7 +257,7 @@ class TestDictImportExport(SupersetTestCase):
|
|||
cols_names=["new_col1", "col2", "col3"],
|
||||
metric_names=["new_metric1"],
|
||||
)
|
||||
imported_table = SqlaTable.import_from_dict(db.session, dict_table)
|
||||
imported_table = SqlaTable.import_from_dict(dict_table)
|
||||
db.session.commit()
|
||||
copy_table, dict_copy_table = self.create_table(
|
||||
"copy_cat",
|
||||
|
@ -268,7 +265,7 @@ class TestDictImportExport(SupersetTestCase):
|
|||
cols_names=["new_col1", "col2", "col3"],
|
||||
metric_names=["new_metric1"],
|
||||
)
|
||||
imported_copy_table = SqlaTable.import_from_dict(db.session, dict_copy_table)
|
||||
imported_copy_table = SqlaTable.import_from_dict(dict_copy_table)
|
||||
db.session.commit()
|
||||
self.assertEqual(imported_table.id, imported_copy_table.id)
|
||||
self.assert_table_equals(copy_table, self.get_table_by_id(imported_table.id))
|
||||
|
@ -281,10 +278,7 @@ class TestDictImportExport(SupersetTestCase):
|
|||
self.delete_fake_db()
|
||||
|
||||
cli_export = export_to_dict(
|
||||
session=db.session,
|
||||
recursive=True,
|
||||
back_references=False,
|
||||
include_defaults=False,
|
||||
recursive=True, back_references=False, include_defaults=False,
|
||||
)
|
||||
self.get_resp("/login/", data=dict(username="admin", password="general"))
|
||||
resp = self.get_resp(
|
||||
|
@ -303,7 +297,7 @@ class TestDictImportExport(SupersetTestCase):
|
|||
datasource, dict_datasource = self.create_druid_datasource(
|
||||
"pure_druid", id=ID_PREFIX + 1
|
||||
)
|
||||
imported_cluster = DruidDatasource.import_from_dict(db.session, dict_datasource)
|
||||
imported_cluster = DruidDatasource.import_from_dict(dict_datasource)
|
||||
db.session.commit()
|
||||
imported = self.get_datasource(imported_cluster.id)
|
||||
self.assert_datasource_equals(datasource, imported)
|
||||
|
@ -315,7 +309,7 @@ class TestDictImportExport(SupersetTestCase):
|
|||
cols_names=["col1"],
|
||||
metric_names=["metric1"],
|
||||
)
|
||||
imported_cluster = DruidDatasource.import_from_dict(db.session, dict_datasource)
|
||||
imported_cluster = DruidDatasource.import_from_dict(dict_datasource)
|
||||
db.session.commit()
|
||||
imported = self.get_datasource(imported_cluster.id)
|
||||
self.assert_datasource_equals(datasource, imported)
|
||||
|
@ -331,7 +325,7 @@ class TestDictImportExport(SupersetTestCase):
|
|||
cols_names=["c1", "c2"],
|
||||
metric_names=["m1", "m2"],
|
||||
)
|
||||
imported_cluster = DruidDatasource.import_from_dict(db.session, dict_datasource)
|
||||
imported_cluster = DruidDatasource.import_from_dict(dict_datasource)
|
||||
db.session.commit()
|
||||
imported = self.get_datasource(imported_cluster.id)
|
||||
self.assert_datasource_equals(datasource, imported)
|
||||
|
@ -340,7 +334,7 @@ class TestDictImportExport(SupersetTestCase):
|
|||
datasource, dict_datasource = self.create_druid_datasource(
|
||||
"druid_override", id=ID_PREFIX + 3, cols_names=["col1"], metric_names=["m1"]
|
||||
)
|
||||
imported_cluster = DruidDatasource.import_from_dict(db.session, dict_datasource)
|
||||
imported_cluster = DruidDatasource.import_from_dict(dict_datasource)
|
||||
db.session.commit()
|
||||
table_over, table_over_dict = self.create_druid_datasource(
|
||||
"druid_override",
|
||||
|
@ -348,9 +342,7 @@ class TestDictImportExport(SupersetTestCase):
|
|||
cols_names=["new_col1", "col2", "col3"],
|
||||
metric_names=["new_metric1"],
|
||||
)
|
||||
imported_over_cluster = DruidDatasource.import_from_dict(
|
||||
db.session, table_over_dict
|
||||
)
|
||||
imported_over_cluster = DruidDatasource.import_from_dict(table_over_dict)
|
||||
db.session.commit()
|
||||
imported_over = self.get_datasource(imported_over_cluster.id)
|
||||
self.assertEqual(imported_cluster.id, imported_over.id)
|
||||
|
@ -366,7 +358,7 @@ class TestDictImportExport(SupersetTestCase):
|
|||
datasource, dict_datasource = self.create_druid_datasource(
|
||||
"druid_override", id=ID_PREFIX + 3, cols_names=["col1"], metric_names=["m1"]
|
||||
)
|
||||
imported_cluster = DruidDatasource.import_from_dict(db.session, dict_datasource)
|
||||
imported_cluster = DruidDatasource.import_from_dict(dict_datasource)
|
||||
db.session.commit()
|
||||
table_over, table_over_dict = self.create_druid_datasource(
|
||||
"druid_override",
|
||||
|
@ -375,7 +367,7 @@ class TestDictImportExport(SupersetTestCase):
|
|||
metric_names=["new_metric1"],
|
||||
)
|
||||
imported_over_cluster = DruidDatasource.import_from_dict(
|
||||
session=db.session, dict_rep=table_over_dict, sync=["metrics", "columns"]
|
||||
dict_rep=table_over_dict, sync=["metrics", "columns"]
|
||||
) # syncing metrics and columns
|
||||
db.session.commit()
|
||||
imported_over = self.get_datasource(imported_over_cluster.id)
|
||||
|
@ -395,9 +387,7 @@ class TestDictImportExport(SupersetTestCase):
|
|||
cols_names=["new_col1", "col2", "col3"],
|
||||
metric_names=["new_metric1"],
|
||||
)
|
||||
imported = DruidDatasource.import_from_dict(
|
||||
session=db.session, dict_rep=dict_datasource
|
||||
)
|
||||
imported = DruidDatasource.import_from_dict(dict_rep=dict_datasource)
|
||||
db.session.commit()
|
||||
copy_datasource, dict_cp_datasource = self.create_druid_datasource(
|
||||
"copy_cat",
|
||||
|
@ -405,7 +395,7 @@ class TestDictImportExport(SupersetTestCase):
|
|||
cols_names=["new_col1", "col2", "col3"],
|
||||
metric_names=["new_metric1"],
|
||||
)
|
||||
imported_copy = DruidDatasource.import_from_dict(db.session, dict_cp_datasource)
|
||||
imported_copy = DruidDatasource.import_from_dict(dict_cp_datasource)
|
||||
db.session.commit()
|
||||
|
||||
self.assertEqual(imported.id, imported_copy.id)
|
||||
|
|
|
@ -212,9 +212,7 @@ class TestDruid(SupersetTestCase):
|
|||
def test_druid_sync_from_config(self):
|
||||
CLUSTER_NAME = "new_druid"
|
||||
self.login()
|
||||
cluster = self.get_or_create(
|
||||
DruidCluster, {"cluster_name": CLUSTER_NAME}, db.session
|
||||
)
|
||||
cluster = self.get_or_create(DruidCluster, {"cluster_name": CLUSTER_NAME})
|
||||
|
||||
db.session.merge(cluster)
|
||||
db.session.commit()
|
||||
|
@ -302,15 +300,12 @@ class TestDruid(SupersetTestCase):
|
|||
@unittest.skipUnless(app.config["DRUID_IS_ACTIVE"], "DRUID_IS_ACTIVE is false")
|
||||
def test_filter_druid_datasource(self):
|
||||
CLUSTER_NAME = "new_druid"
|
||||
cluster = self.get_or_create(
|
||||
DruidCluster, {"cluster_name": CLUSTER_NAME}, db.session
|
||||
)
|
||||
cluster = self.get_or_create(DruidCluster, {"cluster_name": CLUSTER_NAME})
|
||||
db.session.merge(cluster)
|
||||
|
||||
gamma_ds = self.get_or_create(
|
||||
DruidDatasource,
|
||||
{"datasource_name": "datasource_for_gamma", "cluster": cluster},
|
||||
db.session,
|
||||
)
|
||||
gamma_ds.cluster = cluster
|
||||
db.session.merge(gamma_ds)
|
||||
|
@ -318,7 +313,6 @@ class TestDruid(SupersetTestCase):
|
|||
no_gamma_ds = self.get_or_create(
|
||||
DruidDatasource,
|
||||
{"datasource_name": "datasource_not_for_gamma", "cluster": cluster},
|
||||
db.session,
|
||||
)
|
||||
no_gamma_ds.cluster = cluster
|
||||
db.session.merge(no_gamma_ds)
|
||||
|
|
|
@ -46,20 +46,19 @@ class TestImportExport(SupersetTestCase):
|
|||
def delete_imports(cls):
|
||||
with app.app_context():
|
||||
# Imported data clean up
|
||||
session = db.session
|
||||
for slc in session.query(Slice):
|
||||
for slc in db.session.query(Slice):
|
||||
if "remote_id" in slc.params_dict:
|
||||
session.delete(slc)
|
||||
for dash in session.query(Dashboard):
|
||||
db.session.delete(slc)
|
||||
for dash in db.session.query(Dashboard):
|
||||
if "remote_id" in dash.params_dict:
|
||||
session.delete(dash)
|
||||
for table in session.query(SqlaTable):
|
||||
db.session.delete(dash)
|
||||
for table in db.session.query(SqlaTable):
|
||||
if "remote_id" in table.params_dict:
|
||||
session.delete(table)
|
||||
for datasource in session.query(DruidDatasource):
|
||||
db.session.delete(table)
|
||||
for datasource in db.session.query(DruidDatasource):
|
||||
if "remote_id" in datasource.params_dict:
|
||||
session.delete(datasource)
|
||||
session.commit()
|
||||
db.session.delete(datasource)
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
|
@ -126,9 +125,7 @@ class TestImportExport(SupersetTestCase):
|
|||
|
||||
def create_druid_datasource(self, name, id=0, cols_names=[], metric_names=[]):
|
||||
cluster_name = "druid_test"
|
||||
cluster = self.get_or_create(
|
||||
DruidCluster, {"cluster_name": cluster_name}, db.session
|
||||
)
|
||||
cluster = self.get_or_create(DruidCluster, {"cluster_name": cluster_name})
|
||||
|
||||
params = {"remote_id": id, "database_name": cluster_name}
|
||||
datasource = DruidDatasource(
|
||||
|
|
|
@ -83,7 +83,6 @@ class TestQueryContext(SupersetTestCase):
|
|||
datasource = ConnectorRegistry.get_datasource(
|
||||
datasource_type=payload["datasource"]["type"],
|
||||
datasource_id=payload["datasource"]["id"],
|
||||
session=db.session,
|
||||
)
|
||||
description_original = datasource.description
|
||||
datasource.description = "temporary description"
|
||||
|
|
|
@ -69,9 +69,8 @@ class TestRolePermission(SupersetTestCase):
|
|||
"""Testing export role permissions."""
|
||||
|
||||
def setUp(self):
|
||||
session = db.session
|
||||
security_manager.add_role(SCHEMA_ACCESS_ROLE)
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
|
||||
ds = (
|
||||
db.session.query(SqlaTable)
|
||||
|
@ -82,7 +81,7 @@ class TestRolePermission(SupersetTestCase):
|
|||
ds.schema_perm = ds.get_schema_perm()
|
||||
|
||||
ds_slices = (
|
||||
session.query(Slice)
|
||||
db.session.query(Slice)
|
||||
.filter_by(datasource_type="table")
|
||||
.filter_by(datasource_id=ds.id)
|
||||
.all()
|
||||
|
@ -92,12 +91,11 @@ class TestRolePermission(SupersetTestCase):
|
|||
create_schema_perm("[examples].[temp_schema]")
|
||||
gamma_user = security_manager.find_user(username="gamma")
|
||||
gamma_user.roles.append(security_manager.find_role(SCHEMA_ACCESS_ROLE))
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
|
||||
def tearDown(self):
|
||||
session = db.session
|
||||
ds = (
|
||||
session.query(SqlaTable)
|
||||
db.session.query(SqlaTable)
|
||||
.filter_by(table_name="wb_health_population")
|
||||
.first()
|
||||
)
|
||||
|
@ -105,7 +103,7 @@ class TestRolePermission(SupersetTestCase):
|
|||
ds.schema = None
|
||||
ds.schema_perm = None
|
||||
ds_slices = (
|
||||
session.query(Slice)
|
||||
db.session.query(Slice)
|
||||
.filter_by(datasource_type="table")
|
||||
.filter_by(datasource_id=ds.id)
|
||||
.all()
|
||||
|
@ -114,21 +112,20 @@ class TestRolePermission(SupersetTestCase):
|
|||
s.schema_perm = None
|
||||
|
||||
delete_schema_perm(schema_perm)
|
||||
session.delete(security_manager.find_role(SCHEMA_ACCESS_ROLE))
|
||||
session.commit()
|
||||
db.session.delete(security_manager.find_role(SCHEMA_ACCESS_ROLE))
|
||||
db.session.commit()
|
||||
|
||||
def test_set_perm_sqla_table(self):
|
||||
session = db.session
|
||||
table = SqlaTable(
|
||||
schema="tmp_schema",
|
||||
table_name="tmp_perm_table",
|
||||
database=get_example_database(),
|
||||
)
|
||||
session.add(table)
|
||||
session.commit()
|
||||
db.session.add(table)
|
||||
db.session.commit()
|
||||
|
||||
stored_table = (
|
||||
session.query(SqlaTable).filter_by(table_name="tmp_perm_table").one()
|
||||
db.session.query(SqlaTable).filter_by(table_name="tmp_perm_table").one()
|
||||
)
|
||||
self.assertEqual(
|
||||
stored_table.perm, f"[examples].[tmp_perm_table](id:{stored_table.id})"
|
||||
|
@ -147,9 +144,9 @@ class TestRolePermission(SupersetTestCase):
|
|||
|
||||
# table name change
|
||||
stored_table.table_name = "tmp_perm_table_v2"
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
stored_table = (
|
||||
session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one()
|
||||
db.session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one()
|
||||
)
|
||||
self.assertEqual(
|
||||
stored_table.perm, f"[examples].[tmp_perm_table_v2](id:{stored_table.id})"
|
||||
|
@ -169,9 +166,9 @@ class TestRolePermission(SupersetTestCase):
|
|||
|
||||
# schema name change
|
||||
stored_table.schema = "tmp_schema_v2"
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
stored_table = (
|
||||
session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one()
|
||||
db.session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one()
|
||||
)
|
||||
self.assertEqual(
|
||||
stored_table.perm, f"[examples].[tmp_perm_table_v2](id:{stored_table.id})"
|
||||
|
@ -191,13 +188,13 @@ class TestRolePermission(SupersetTestCase):
|
|||
|
||||
# database change
|
||||
new_db = Database(sqlalchemy_uri="some_uri", database_name="tmp_db")
|
||||
session.add(new_db)
|
||||
db.session.add(new_db)
|
||||
stored_table.database = (
|
||||
session.query(Database).filter_by(database_name="tmp_db").one()
|
||||
db.session.query(Database).filter_by(database_name="tmp_db").one()
|
||||
)
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
stored_table = (
|
||||
session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one()
|
||||
db.session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one()
|
||||
)
|
||||
self.assertEqual(
|
||||
stored_table.perm, f"[tmp_db].[tmp_perm_table_v2](id:{stored_table.id})"
|
||||
|
@ -217,9 +214,9 @@ class TestRolePermission(SupersetTestCase):
|
|||
|
||||
# no schema
|
||||
stored_table.schema = None
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
stored_table = (
|
||||
session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one()
|
||||
db.session.query(SqlaTable).filter_by(table_name="tmp_perm_table_v2").one()
|
||||
)
|
||||
self.assertEqual(
|
||||
stored_table.perm, f"[tmp_db].[tmp_perm_table_v2](id:{stored_table.id})"
|
||||
|
@ -231,26 +228,25 @@ class TestRolePermission(SupersetTestCase):
|
|||
)
|
||||
self.assertIsNone(stored_table.schema_perm)
|
||||
|
||||
session.delete(new_db)
|
||||
session.delete(stored_table)
|
||||
session.commit()
|
||||
db.session.delete(new_db)
|
||||
db.session.delete(stored_table)
|
||||
db.session.commit()
|
||||
|
||||
def test_set_perm_druid_datasource(self):
|
||||
session = db.session
|
||||
druid_cluster = (
|
||||
session.query(DruidCluster).filter_by(cluster_name="druid_test").one()
|
||||
db.session.query(DruidCluster).filter_by(cluster_name="druid_test").one()
|
||||
)
|
||||
datasource = DruidDatasource(
|
||||
datasource_name="tmp_datasource",
|
||||
cluster=druid_cluster,
|
||||
cluster_id=druid_cluster.id,
|
||||
)
|
||||
session.add(datasource)
|
||||
session.commit()
|
||||
db.session.add(datasource)
|
||||
db.session.commit()
|
||||
|
||||
# store without a schema
|
||||
stored_datasource = (
|
||||
session.query(DruidDatasource)
|
||||
db.session.query(DruidDatasource)
|
||||
.filter_by(datasource_name="tmp_datasource")
|
||||
.one()
|
||||
)
|
||||
|
@ -267,7 +263,7 @@ class TestRolePermission(SupersetTestCase):
|
|||
|
||||
# store with a schema
|
||||
stored_datasource.datasource_name = "tmp_schema.tmp_datasource"
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
self.assertEqual(
|
||||
stored_datasource.perm,
|
||||
f"[druid_test].[tmp_schema.tmp_datasource](id:{stored_datasource.id})",
|
||||
|
@ -284,16 +280,15 @@ class TestRolePermission(SupersetTestCase):
|
|||
)
|
||||
)
|
||||
|
||||
session.delete(stored_datasource)
|
||||
session.commit()
|
||||
db.session.delete(stored_datasource)
|
||||
db.session.commit()
|
||||
|
||||
def test_set_perm_druid_cluster(self):
|
||||
session = db.session
|
||||
cluster = DruidCluster(cluster_name="tmp_druid_cluster")
|
||||
session.add(cluster)
|
||||
db.session.add(cluster)
|
||||
|
||||
stored_cluster = (
|
||||
session.query(DruidCluster)
|
||||
db.session.query(DruidCluster)
|
||||
.filter_by(cluster_name="tmp_druid_cluster")
|
||||
.one()
|
||||
)
|
||||
|
@ -307,7 +302,7 @@ class TestRolePermission(SupersetTestCase):
|
|||
)
|
||||
|
||||
stored_cluster.cluster_name = "tmp_druid_cluster2"
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
self.assertEqual(
|
||||
stored_cluster.perm, f"[tmp_druid_cluster2].(id:{stored_cluster.id})"
|
||||
)
|
||||
|
@ -317,18 +312,17 @@ class TestRolePermission(SupersetTestCase):
|
|||
)
|
||||
)
|
||||
|
||||
session.delete(stored_cluster)
|
||||
session.commit()
|
||||
db.session.delete(stored_cluster)
|
||||
db.session.commit()
|
||||
|
||||
def test_set_perm_database(self):
|
||||
session = db.session
|
||||
database = Database(
|
||||
database_name="tmp_database", sqlalchemy_uri="sqlite://test"
|
||||
)
|
||||
session.add(database)
|
||||
db.session.add(database)
|
||||
|
||||
stored_db = (
|
||||
session.query(Database).filter_by(database_name="tmp_database").one()
|
||||
db.session.query(Database).filter_by(database_name="tmp_database").one()
|
||||
)
|
||||
self.assertEqual(stored_db.perm, f"[tmp_database].(id:{stored_db.id})")
|
||||
self.assertIsNotNone(
|
||||
|
@ -338,9 +332,9 @@ class TestRolePermission(SupersetTestCase):
|
|||
)
|
||||
|
||||
stored_db.database_name = "tmp_database2"
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
stored_db = (
|
||||
session.query(Database).filter_by(database_name="tmp_database2").one()
|
||||
db.session.query(Database).filter_by(database_name="tmp_database2").one()
|
||||
)
|
||||
self.assertEqual(stored_db.perm, f"[tmp_database2].(id:{stored_db.id})")
|
||||
self.assertIsNotNone(
|
||||
|
@ -349,8 +343,8 @@ class TestRolePermission(SupersetTestCase):
|
|||
)
|
||||
)
|
||||
|
||||
session.delete(stored_db)
|
||||
session.commit()
|
||||
db.session.delete(stored_db)
|
||||
db.session.commit()
|
||||
|
||||
def test_hybrid_perm_druid_cluster(self):
|
||||
cluster = DruidCluster(cluster_name="tmp_druid_cluster3")
|
||||
|
@ -400,14 +394,13 @@ class TestRolePermission(SupersetTestCase):
|
|||
db.session.commit()
|
||||
|
||||
def test_set_perm_slice(self):
|
||||
session = db.session
|
||||
database = Database(
|
||||
database_name="tmp_database", sqlalchemy_uri="sqlite://test"
|
||||
)
|
||||
table = SqlaTable(table_name="tmp_perm_table", database=database)
|
||||
session.add(database)
|
||||
session.add(table)
|
||||
session.commit()
|
||||
db.session.add(database)
|
||||
db.session.add(table)
|
||||
db.session.commit()
|
||||
|
||||
# no schema permission
|
||||
slice = Slice(
|
||||
|
@ -416,10 +409,10 @@ class TestRolePermission(SupersetTestCase):
|
|||
datasource_name="tmp_perm_table",
|
||||
slice_name="slice_name",
|
||||
)
|
||||
session.add(slice)
|
||||
session.commit()
|
||||
db.session.add(slice)
|
||||
db.session.commit()
|
||||
|
||||
slice = session.query(Slice).filter_by(slice_name="slice_name").one()
|
||||
slice = db.session.query(Slice).filter_by(slice_name="slice_name").one()
|
||||
self.assertEqual(slice.perm, table.perm)
|
||||
self.assertEqual(slice.perm, f"[tmp_database].[tmp_perm_table](id:{table.id})")
|
||||
self.assertEqual(slice.schema_perm, table.schema_perm)
|
||||
|
@ -427,7 +420,7 @@ class TestRolePermission(SupersetTestCase):
|
|||
|
||||
table.schema = "tmp_perm_schema"
|
||||
table.table_name = "tmp_perm_table_v2"
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
# TODO(bogdan): modify slice permissions on the table update.
|
||||
self.assertNotEquals(slice.perm, table.perm)
|
||||
self.assertEqual(slice.perm, f"[tmp_database].[tmp_perm_table](id:{table.id})")
|
||||
|
@ -440,7 +433,7 @@ class TestRolePermission(SupersetTestCase):
|
|||
|
||||
# updating slice refreshes the permissions
|
||||
slice.slice_name = "slice_name_v2"
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
self.assertEqual(slice.perm, table.perm)
|
||||
self.assertEqual(
|
||||
slice.perm, f"[tmp_database].[tmp_perm_table_v2](id:{table.id})"
|
||||
|
@ -448,11 +441,10 @@ class TestRolePermission(SupersetTestCase):
|
|||
self.assertEqual(slice.schema_perm, table.schema_perm)
|
||||
self.assertEqual(slice.schema_perm, "[tmp_database].[tmp_perm_schema]")
|
||||
|
||||
session.delete(slice)
|
||||
session.delete(table)
|
||||
session.delete(database)
|
||||
|
||||
session.commit()
|
||||
db.session.delete(slice)
|
||||
db.session.delete(table)
|
||||
db.session.delete(database)
|
||||
db.session.commit()
|
||||
|
||||
# TODO test slice permission
|
||||
|
||||
|
@ -532,11 +524,11 @@ class TestRolePermission(SupersetTestCase):
|
|||
self.assertNotIn("Girl Name Cloud", data) # birth_names slice, no access
|
||||
|
||||
def test_sqllab_gamma_user_schema_access_to_sqllab(self):
|
||||
session = db.session
|
||||
|
||||
example_db = session.query(Database).filter_by(database_name="examples").one()
|
||||
example_db = (
|
||||
db.session.query(Database).filter_by(database_name="examples").one()
|
||||
)
|
||||
example_db.expose_in_sqllab = True
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
|
||||
arguments = {
|
||||
"keys": ["none"],
|
||||
|
@ -959,12 +951,10 @@ class TestRowLevelSecurity(SupersetTestCase):
|
|||
rls_entry = None
|
||||
|
||||
def setUp(self):
|
||||
session = db.session
|
||||
|
||||
# Create the RowLevelSecurityFilter
|
||||
self.rls_entry = RowLevelSecurityFilter()
|
||||
self.rls_entry.tables.extend(
|
||||
session.query(SqlaTable)
|
||||
db.session.query(SqlaTable)
|
||||
.filter(SqlaTable.table_name.in_(["energy_usage", "unicode_test"]))
|
||||
.all()
|
||||
)
|
||||
|
@ -974,13 +964,11 @@ class TestRowLevelSecurity(SupersetTestCase):
|
|||
) # db.session.query(Role).filter_by(name="Gamma").first())
|
||||
self.rls_entry.roles.append(security_manager.find_role("Alpha"))
|
||||
db.session.add(self.rls_entry)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
def tearDown(self):
|
||||
session = db.session
|
||||
session.delete(self.rls_entry)
|
||||
session.commit()
|
||||
db.session.delete(self.rls_entry)
|
||||
db.session.commit()
|
||||
|
||||
# Do another test to make sure it doesn't alter another query
|
||||
def test_rls_filter_alters_query(self):
|
||||
|
|
|
@ -55,7 +55,6 @@ class TestSqlLab(SupersetTestCase):
|
|||
self.logout()
|
||||
db.session.query(Query).delete()
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
def test_sql_json(self):
|
||||
self.login("admin")
|
||||
|
@ -433,7 +432,6 @@ class TestSqlLab(SupersetTestCase):
|
|||
Test query api with can_access_all_queries perm added to
|
||||
gamma and make sure all queries show up.
|
||||
"""
|
||||
session = db.session
|
||||
|
||||
# Add all_query_access perm to Gamma user
|
||||
all_queries_view = security_manager.find_permission_view_menu(
|
||||
|
@ -443,7 +441,7 @@ class TestSqlLab(SupersetTestCase):
|
|||
security_manager.add_permission_role(
|
||||
security_manager.find_role("gamma_sqllab"), all_queries_view
|
||||
)
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
|
||||
# Test search_queries for Admin user
|
||||
self.run_some_queries()
|
||||
|
@ -460,7 +458,7 @@ class TestSqlLab(SupersetTestCase):
|
|||
security_manager.find_role("gamma_sqllab"), all_queries_view
|
||||
)
|
||||
|
||||
session.commit()
|
||||
db.session.commit()
|
||||
|
||||
def test_query_admin_can_access_all_queries(self) -> None:
|
||||
"""
|
||||
|
|
|
@ -194,7 +194,7 @@ class TestCacheWarmUp(SupersetTestCase):
|
|||
db.session.commit()
|
||||
|
||||
def test_dashboard_tags(self):
|
||||
tag1 = get_tag("tag1", db.session, TagTypes.custom)
|
||||
tag1 = get_tag("tag1", TagTypes.custom)
|
||||
# delete first to make test idempotent
|
||||
self.reset_tag(tag1)
|
||||
|
||||
|
@ -204,7 +204,7 @@ class TestCacheWarmUp(SupersetTestCase):
|
|||
self.assertEqual(result, expected)
|
||||
|
||||
# tag dashboard 'births' with `tag1`
|
||||
tag1 = get_tag("tag1", db.session, TagTypes.custom)
|
||||
tag1 = get_tag("tag1", TagTypes.custom)
|
||||
dash = self.get_dash_by_slug("births")
|
||||
tag1_urls = sorted([f"{URL_PREFIX}{slc.url}" for slc in dash.slices])
|
||||
tagged_object = TaggedObject(
|
||||
|
@ -216,7 +216,7 @@ class TestCacheWarmUp(SupersetTestCase):
|
|||
self.assertEqual(sorted(strategy.get_urls()), tag1_urls)
|
||||
|
||||
strategy = DashboardTagsStrategy(["tag2"])
|
||||
tag2 = get_tag("tag2", db.session, TagTypes.custom)
|
||||
tag2 = get_tag("tag2", TagTypes.custom)
|
||||
self.reset_tag(tag2)
|
||||
|
||||
result = sorted(strategy.get_urls())
|
||||
|
|
Loading…
Reference in New Issue