mirror of https://github.com/apache/superset.git
chore: switching out ConnectorRegistry references for DatasourceDAO (#20380)
* rename and move dao file * Update dao.py * add cachekey * Update __init__.py * change reference in query context test * add utils ref * more ref changes * add helpers * add todo in dashboard.py * add cachekey * circular import error in dar.py * push rest of refs * fix linting * fix more linting * update enum * remove references for connector registry * big reafctor * take value * fix * test to see if removing value works * delete connectregistry * address concerns * address comments * fix merge conflicts * address concern II * address concern II * fix test Co-authored-by: Phillip Kelley-Dotson <pkelleydotson@yahoo.com>
This commit is contained in:
parent
c79b0d62d0
commit
e3e37cb68f
|
@ -19,7 +19,6 @@ from flask import current_app, Flask
|
|||
from werkzeug.local import LocalProxy
|
||||
|
||||
from superset.app import create_app
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.extensions import (
|
||||
appbuilder,
|
||||
cache_manager,
|
||||
|
|
|
@ -25,7 +25,7 @@ from marshmallow.exceptions import ValidationError
|
|||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from superset.cachekeys.schemas import CacheInvalidationRequestSchema
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.extensions import cache_manager, db, event_logger
|
||||
from superset.models.cache import CacheKey
|
||||
from superset.views.base_api import BaseSupersetModelRestApi, statsd_metrics
|
||||
|
@ -83,13 +83,13 @@ class CacheRestApi(BaseSupersetModelRestApi):
|
|||
return self.response_400(message=str(error))
|
||||
datasource_uids = set(datasources.get("datasource_uids", []))
|
||||
for ds in datasources.get("datasources", []):
|
||||
ds_obj = ConnectorRegistry.get_datasource_by_name(
|
||||
ds_obj = SqlaTable.get_datasource_by_name(
|
||||
session=db.session,
|
||||
datasource_type=ds.get("datasource_type"),
|
||||
datasource_name=ds.get("datasource_name"),
|
||||
schema=ds.get("schema"),
|
||||
database_name=ds.get("database_name"),
|
||||
)
|
||||
|
||||
if ds_obj:
|
||||
datasource_uids.add(ds_obj.uid)
|
||||
|
||||
|
|
|
@ -25,9 +25,10 @@ from superset.commands.exceptions import (
|
|||
OwnersNotFoundValidationError,
|
||||
RolesNotFoundValidationError,
|
||||
)
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.datasets.commands.exceptions import DatasetNotFoundError
|
||||
from superset.dao.exceptions import DatasourceNotFound
|
||||
from superset.datasource.dao import DatasourceDAO
|
||||
from superset.extensions import db, security_manager
|
||||
from superset.utils.core import DatasourceType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.connectors.base.models import BaseDatasource
|
||||
|
@ -79,8 +80,8 @@ def populate_roles(role_ids: Optional[List[int]] = None) -> List[Role]:
|
|||
|
||||
def get_datasource_by_id(datasource_id: int, datasource_type: str) -> BaseDatasource:
|
||||
try:
|
||||
return ConnectorRegistry.get_datasource(
|
||||
datasource_type, datasource_id, db.session
|
||||
return DatasourceDAO.get_datasource(
|
||||
db.session, DatasourceType(datasource_type), datasource_id
|
||||
)
|
||||
except DatasetNotFoundError as ex:
|
||||
except DatasourceNotFound as ex:
|
||||
raise DatasourceNotFoundValidationError() from ex
|
||||
|
|
|
@ -22,8 +22,8 @@ from superset import app, db
|
|||
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
|
||||
from superset.common.query_context import QueryContext
|
||||
from superset.common.query_object_factory import QueryObjectFactory
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.utils.core import DatasourceDict
|
||||
from superset.datasource.dao import DatasourceDAO
|
||||
from superset.utils.core import DatasourceDict, DatasourceType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.connectors.base.models import BaseDatasource
|
||||
|
@ -32,7 +32,7 @@ config = app.config
|
|||
|
||||
|
||||
def create_query_object_factory() -> QueryObjectFactory:
|
||||
return QueryObjectFactory(config, ConnectorRegistry(), db.session)
|
||||
return QueryObjectFactory(config, DatasourceDAO(), db.session)
|
||||
|
||||
|
||||
class QueryContextFactory: # pylint: disable=too-few-public-methods
|
||||
|
@ -82,6 +82,6 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
|
|||
|
||||
# pylint: disable=no-self-use
|
||||
def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource:
|
||||
return ConnectorRegistry.get_datasource(
|
||||
str(datasource["type"]), int(datasource["id"]), db.session
|
||||
return DatasourceDAO.get_datasource(
|
||||
db.session, DatasourceType(datasource["type"]), int(datasource["id"])
|
||||
)
|
||||
|
|
|
@ -21,29 +21,29 @@ from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
|
|||
|
||||
from superset.common.chart_data import ChartDataResultType
|
||||
from superset.common.query_object import QueryObject
|
||||
from superset.utils.core import apply_max_row_limit, DatasourceDict
|
||||
from superset.utils.core import apply_max_row_limit, DatasourceDict, DatasourceType
|
||||
from superset.utils.date_parser import get_since_until
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from superset import ConnectorRegistry
|
||||
from superset.connectors.base.models import BaseDatasource
|
||||
from superset.datasource.dao import DatasourceDAO
|
||||
|
||||
|
||||
class QueryObjectFactory: # pylint: disable=too-few-public-methods
|
||||
_config: Dict[str, Any]
|
||||
_connector_registry: ConnectorRegistry
|
||||
_datasource_dao: DatasourceDAO
|
||||
_session_maker: sessionmaker
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_configurations: Dict[str, Any],
|
||||
connector_registry: ConnectorRegistry,
|
||||
_datasource_dao: DatasourceDAO,
|
||||
session_maker: sessionmaker,
|
||||
):
|
||||
self._config = app_configurations
|
||||
self._connector_registry = connector_registry
|
||||
self._datasource_dao = _datasource_dao
|
||||
self._session_maker = session_maker
|
||||
|
||||
def create( # pylint: disable=too-many-arguments
|
||||
|
@ -75,8 +75,10 @@ class QueryObjectFactory: # pylint: disable=too-few-public-methods
|
|||
)
|
||||
|
||||
def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource:
|
||||
return self._connector_registry.get_datasource(
|
||||
str(datasource["type"]), int(datasource["id"]), self._session_maker()
|
||||
return self._datasource_dao.get_datasource(
|
||||
datasource_type=DatasourceType(datasource["type"]),
|
||||
datasource_id=int(datasource["id"]),
|
||||
session=self._session_maker(),
|
||||
)
|
||||
|
||||
def _process_extras( # pylint: disable=no-self-use
|
||||
|
|
|
@ -1,164 +0,0 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Dict, List, Optional, Set, Type, TYPE_CHECKING
|
||||
|
||||
from flask_babel import _
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session, subqueryload
|
||||
from sqlalchemy.orm.exc import NoResultFound
|
||||
|
||||
from superset.datasets.commands.exceptions import DatasetNotFoundError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections import OrderedDict
|
||||
|
||||
from superset.connectors.base.models import BaseDatasource
|
||||
from superset.models.core import Database
|
||||
|
||||
|
||||
class ConnectorRegistry:
|
||||
"""Central Registry for all available datasource engines"""
|
||||
|
||||
sources: Dict[str, Type["BaseDatasource"]] = {}
|
||||
|
||||
@classmethod
|
||||
def register_sources(cls, datasource_config: "OrderedDict[str, List[str]]") -> None:
|
||||
for module_name, class_names in datasource_config.items():
|
||||
class_names = [str(s) for s in class_names]
|
||||
module_obj = __import__(module_name, fromlist=class_names)
|
||||
for class_name in class_names:
|
||||
source_class = getattr(module_obj, class_name)
|
||||
cls.sources[source_class.type] = source_class
|
||||
|
||||
@classmethod
|
||||
def get_datasource(
|
||||
cls, datasource_type: str, datasource_id: int, session: Session
|
||||
) -> "BaseDatasource":
|
||||
"""Safely get a datasource instance, raises `DatasetNotFoundError` if
|
||||
`datasource_type` is not registered or `datasource_id` does not
|
||||
exist."""
|
||||
if datasource_type not in cls.sources:
|
||||
raise DatasetNotFoundError()
|
||||
|
||||
datasource = (
|
||||
session.query(cls.sources[datasource_type])
|
||||
.filter_by(id=datasource_id)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
if not datasource:
|
||||
raise DatasetNotFoundError()
|
||||
|
||||
return datasource
|
||||
|
||||
@classmethod
|
||||
def get_all_datasources(cls, session: Session) -> List["BaseDatasource"]:
|
||||
datasources: List["BaseDatasource"] = []
|
||||
for source_class in ConnectorRegistry.sources.values():
|
||||
qry = session.query(source_class)
|
||||
qry = source_class.default_query(qry)
|
||||
datasources.extend(qry.all())
|
||||
return datasources
|
||||
|
||||
@classmethod
|
||||
def get_datasource_by_id(
|
||||
cls, session: Session, datasource_id: int
|
||||
) -> "BaseDatasource":
|
||||
"""
|
||||
Find a datasource instance based on the unique id.
|
||||
|
||||
:param session: Session to use
|
||||
:param datasource_id: unique id of datasource
|
||||
:return: Datasource corresponding to the id
|
||||
:raises NoResultFound: if no datasource is found corresponding to the id
|
||||
"""
|
||||
for datasource_class in ConnectorRegistry.sources.values():
|
||||
try:
|
||||
return (
|
||||
session.query(datasource_class)
|
||||
.filter(datasource_class.id == datasource_id)
|
||||
.one()
|
||||
)
|
||||
except NoResultFound:
|
||||
# proceed to next datasource type
|
||||
pass
|
||||
raise NoResultFound(_("Datasource id not found: %(id)s", id=datasource_id))
|
||||
|
||||
@classmethod
|
||||
def get_datasource_by_name( # pylint: disable=too-many-arguments
|
||||
cls,
|
||||
session: Session,
|
||||
datasource_type: str,
|
||||
datasource_name: str,
|
||||
schema: str,
|
||||
database_name: str,
|
||||
) -> Optional["BaseDatasource"]:
|
||||
datasource_class = ConnectorRegistry.sources[datasource_type]
|
||||
return datasource_class.get_datasource_by_name(
|
||||
session, datasource_name, schema, database_name
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def query_datasources_by_permissions( # pylint: disable=invalid-name
|
||||
cls,
|
||||
session: Session,
|
||||
database: "Database",
|
||||
permissions: Set[str],
|
||||
schema_perms: Set[str],
|
||||
) -> List["BaseDatasource"]:
|
||||
# TODO(bogdan): add unit test
|
||||
datasource_class = ConnectorRegistry.sources[database.type]
|
||||
return (
|
||||
session.query(datasource_class)
|
||||
.filter_by(database_id=database.id)
|
||||
.filter(
|
||||
or_(
|
||||
datasource_class.perm.in_(permissions),
|
||||
datasource_class.schema_perm.in_(schema_perms),
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_eager_datasource(
|
||||
cls, session: Session, datasource_type: str, datasource_id: int
|
||||
) -> "BaseDatasource":
|
||||
"""Returns datasource with columns and metrics."""
|
||||
datasource_class = ConnectorRegistry.sources[datasource_type]
|
||||
return (
|
||||
session.query(datasource_class)
|
||||
.options(
|
||||
subqueryload(datasource_class.columns),
|
||||
subqueryload(datasource_class.metrics),
|
||||
)
|
||||
.filter_by(id=datasource_id)
|
||||
.one()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def query_datasources_by_name(
|
||||
cls,
|
||||
session: Session,
|
||||
database: "Database",
|
||||
datasource_name: str,
|
||||
schema: Optional[str] = None,
|
||||
) -> List["BaseDatasource"]:
|
||||
datasource_class = ConnectorRegistry.sources[database.type]
|
||||
return datasource_class.query_datasources_by_name(
|
||||
session, database, datasource_name, schema=schema
|
||||
)
|
|
@ -31,6 +31,7 @@ from typing import (
|
|||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
|
@ -1990,6 +1991,48 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||
query = query.filter_by(schema=schema)
|
||||
return query.all()
|
||||
|
||||
@classmethod
|
||||
def query_datasources_by_permissions( # pylint: disable=invalid-name
|
||||
cls,
|
||||
session: Session,
|
||||
database: Database,
|
||||
permissions: Set[str],
|
||||
schema_perms: Set[str],
|
||||
) -> List["SqlaTable"]:
|
||||
# TODO(hughhhh): add unit test
|
||||
return (
|
||||
session.query(cls)
|
||||
.filter_by(database_id=database.id)
|
||||
.filter(
|
||||
or_(
|
||||
SqlaTable.perm.in_(permissions),
|
||||
SqlaTable.schema_perm.in_(schema_perms),
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_eager_sqlatable_datasource(
|
||||
cls, session: Session, datasource_id: int
|
||||
) -> "SqlaTable":
|
||||
"""Returns SqlaTable with columns and metrics."""
|
||||
return (
|
||||
session.query(cls)
|
||||
.options(
|
||||
sa.orm.subqueryload(cls.columns),
|
||||
sa.orm.subqueryload(cls.metrics),
|
||||
)
|
||||
.filter_by(id=datasource_id)
|
||||
.one()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_all_datasources(cls, session: Session) -> List["SqlaTable"]:
|
||||
qry = session.query(cls)
|
||||
qry = cls.default_query(qry)
|
||||
return qry.all()
|
||||
|
||||
@staticmethod
|
||||
def default_query(qry: Query) -> Query:
|
||||
return qry.filter_by(is_sqllab_view=False)
|
||||
|
|
|
@ -1,147 +0,0 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Set, Type, Union
|
||||
|
||||
from flask_babel import _
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy.orm import Session, subqueryload
|
||||
from sqlalchemy.orm.exc import NoResultFound
|
||||
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.dao.base import BaseDAO
|
||||
from superset.dao.exceptions import DatasourceNotFound, DatasourceTypeNotSupportedError
|
||||
from superset.datasets.commands.exceptions import DatasetNotFoundError
|
||||
from superset.datasets.models import Dataset
|
||||
from superset.models.core import Database
|
||||
from superset.models.sql_lab import Query, SavedQuery
|
||||
from superset.tables.models import Table
|
||||
from superset.utils.core import DatasourceType
|
||||
|
||||
Datasource = Union[Dataset, SqlaTable, Table, Query, SavedQuery]
|
||||
|
||||
|
||||
class DatasourceDAO(BaseDAO):
|
||||
|
||||
sources: Dict[DatasourceType, Type[Datasource]] = {
|
||||
DatasourceType.TABLE: SqlaTable,
|
||||
DatasourceType.QUERY: Query,
|
||||
DatasourceType.SAVEDQUERY: SavedQuery,
|
||||
DatasourceType.DATASET: Dataset,
|
||||
DatasourceType.SLTABLE: Table,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_datasource(
|
||||
cls, session: Session, datasource_type: DatasourceType, datasource_id: int
|
||||
) -> Datasource:
|
||||
if datasource_type not in cls.sources:
|
||||
raise DatasourceTypeNotSupportedError()
|
||||
|
||||
datasource = (
|
||||
session.query(cls.sources[datasource_type])
|
||||
.filter_by(id=datasource_id)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
if not datasource:
|
||||
raise DatasourceNotFound()
|
||||
|
||||
return datasource
|
||||
|
||||
@classmethod
|
||||
def get_all_sqlatables_datasources(cls, session: Session) -> List[Datasource]:
|
||||
source_class = DatasourceDAO.sources[DatasourceType.TABLE]
|
||||
qry = session.query(source_class)
|
||||
qry = source_class.default_query(qry)
|
||||
return qry.all()
|
||||
|
||||
@classmethod
|
||||
def get_datasource_by_name( # pylint: disable=too-many-arguments
|
||||
cls,
|
||||
session: Session,
|
||||
datasource_type: DatasourceType,
|
||||
datasource_name: str,
|
||||
database_name: str,
|
||||
schema: str,
|
||||
) -> Optional[Datasource]:
|
||||
datasource_class = DatasourceDAO.sources[datasource_type]
|
||||
if isinstance(datasource_class, SqlaTable):
|
||||
return datasource_class.get_datasource_by_name(
|
||||
session, datasource_name, schema, database_name
|
||||
)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def query_datasources_by_permissions( # pylint: disable=invalid-name
|
||||
cls,
|
||||
session: Session,
|
||||
database: Database,
|
||||
permissions: Set[str],
|
||||
schema_perms: Set[str],
|
||||
) -> List[Datasource]:
|
||||
# TODO(hughhhh): add unit test
|
||||
datasource_class = DatasourceDAO.sources[DatasourceType[database.type]]
|
||||
if not isinstance(datasource_class, SqlaTable):
|
||||
return []
|
||||
|
||||
return (
|
||||
session.query(datasource_class)
|
||||
.filter_by(database_id=database.id)
|
||||
.filter(
|
||||
or_(
|
||||
datasource_class.perm.in_(permissions),
|
||||
datasource_class.schema_perm.in_(schema_perms),
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_eager_datasource(
|
||||
cls, session: Session, datasource_type: str, datasource_id: int
|
||||
) -> Optional[Datasource]:
|
||||
"""Returns datasource with columns and metrics."""
|
||||
datasource_class = DatasourceDAO.sources[DatasourceType[datasource_type]]
|
||||
if not isinstance(datasource_class, SqlaTable):
|
||||
return None
|
||||
return (
|
||||
session.query(datasource_class)
|
||||
.options(
|
||||
subqueryload(datasource_class.columns),
|
||||
subqueryload(datasource_class.metrics),
|
||||
)
|
||||
.filter_by(id=datasource_id)
|
||||
.one()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def query_datasources_by_name(
|
||||
cls,
|
||||
session: Session,
|
||||
database: Database,
|
||||
datasource_name: str,
|
||||
schema: Optional[str] = None,
|
||||
) -> List[Datasource]:
|
||||
datasource_class = DatasourceDAO.sources[DatasourceType[database.type]]
|
||||
if not isinstance(datasource_class, SqlaTable):
|
||||
return []
|
||||
|
||||
return datasource_class.query_datasources_by_name(
|
||||
session, database, datasource_name, schema=schema
|
||||
)
|
|
@ -60,6 +60,7 @@ class DatasourceTypeNotSupportedError(DAOException):
|
|||
DAO datasource query source type is not supported
|
||||
"""
|
||||
|
||||
status = 422
|
||||
message = "DAO datasource query source type is not supported"
|
||||
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ from typing import Any, Dict, Optional
|
|||
from flask_babel import lazy_gettext as _
|
||||
from sqlalchemy.orm import make_transient, Session
|
||||
|
||||
from superset import ConnectorRegistry, db
|
||||
from superset import db
|
||||
from superset.commands.base import BaseCommand
|
||||
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
|
||||
from superset.datasets.commands.importers.v0 import import_dataset
|
||||
|
@ -63,12 +63,11 @@ def import_chart(
|
|||
slc_to_import = slc_to_import.copy()
|
||||
slc_to_import.reset_ownership()
|
||||
params = slc_to_import.params_dict
|
||||
datasource = ConnectorRegistry.get_datasource_by_name(
|
||||
session,
|
||||
slc_to_import.datasource_type,
|
||||
params["datasource_name"],
|
||||
params["schema"],
|
||||
params["database_name"],
|
||||
datasource = SqlaTable.get_datasource_by_name(
|
||||
session=session,
|
||||
datasource_name=params["datasource_name"],
|
||||
database_name=params["database_name"],
|
||||
schema=params["schema"],
|
||||
)
|
||||
slc_to_import.datasource_id = datasource.id # type: ignore
|
||||
if slc_to_override:
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
|
@ -0,0 +1,62 @@
|
|||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import Dict, Type, Union
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.dao.base import BaseDAO
|
||||
from superset.dao.exceptions import DatasourceNotFound, DatasourceTypeNotSupportedError
|
||||
from superset.datasets.models import Dataset
|
||||
from superset.models.sql_lab import Query, SavedQuery
|
||||
from superset.tables.models import Table
|
||||
from superset.utils.core import DatasourceType
|
||||
|
||||
Datasource = Union[Dataset, SqlaTable, Table, Query, SavedQuery]
|
||||
|
||||
|
||||
class DatasourceDAO(BaseDAO):
|
||||
|
||||
sources: Dict[Union[DatasourceType, str], Type[Datasource]] = {
|
||||
DatasourceType.TABLE: SqlaTable,
|
||||
DatasourceType.QUERY: Query,
|
||||
DatasourceType.SAVEDQUERY: SavedQuery,
|
||||
DatasourceType.DATASET: Dataset,
|
||||
DatasourceType.SLTABLE: Table,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_datasource(
|
||||
cls,
|
||||
session: Session,
|
||||
datasource_type: Union[DatasourceType, str],
|
||||
datasource_id: int,
|
||||
) -> Datasource:
|
||||
if datasource_type not in cls.sources:
|
||||
raise DatasourceTypeNotSupportedError()
|
||||
|
||||
datasource = (
|
||||
session.query(cls.sources[datasource_type])
|
||||
.filter_by(id=datasource_id)
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
if not datasource:
|
||||
raise DatasourceNotFound()
|
||||
|
||||
return datasource
|
|
@ -23,7 +23,7 @@ from typing import Any, Dict, List, Set
|
|||
from urllib import request
|
||||
|
||||
from superset import app, db
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.models.slice import Slice
|
||||
|
||||
BASE_URL = "https://github.com/apache-superset/examples-data/blob/master/"
|
||||
|
@ -32,7 +32,7 @@ misc_dash_slices: Set[str] = set() # slices assembled in a 'Misc Chart' dashboa
|
|||
|
||||
|
||||
def get_table_connector_registry() -> Any:
|
||||
return ConnectorRegistry.sources["table"]
|
||||
return SqlaTable
|
||||
|
||||
|
||||
def get_examples_folder() -> str:
|
||||
|
|
|
@ -27,6 +27,7 @@ from superset.extensions import cache_manager
|
|||
from superset.key_value.utils import get_owner, random_key
|
||||
from superset.temporary_cache.commands.exceptions import TemporaryCacheCreateFailedError
|
||||
from superset.temporary_cache.utils import cache_key
|
||||
from superset.utils.core import DatasourceType
|
||||
from superset.utils.schema import validate_json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -56,7 +57,7 @@ class CreateFormDataCommand(BaseCommand):
|
|||
state: TemporaryExploreState = {
|
||||
"owner": get_owner(actor),
|
||||
"datasource_id": datasource_id,
|
||||
"datasource_type": datasource_type,
|
||||
"datasource_type": DatasourceType(datasource_type),
|
||||
"chart_id": chart_id,
|
||||
"form_data": form_data,
|
||||
}
|
||||
|
|
|
@ -18,10 +18,12 @@ from typing import Optional
|
|||
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from superset.utils.core import DatasourceType
|
||||
|
||||
|
||||
class TemporaryExploreState(TypedDict):
|
||||
owner: Optional[int]
|
||||
datasource_id: int
|
||||
datasource_type: str
|
||||
datasource_type: DatasourceType
|
||||
chart_id: Optional[int]
|
||||
form_data: str
|
||||
|
|
|
@ -32,6 +32,7 @@ from superset.temporary_cache.commands.exceptions import (
|
|||
TemporaryCacheUpdateFailedError,
|
||||
)
|
||||
from superset.temporary_cache.utils import cache_key
|
||||
from superset.utils.core import DatasourceType
|
||||
from superset.utils.schema import validate_json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -75,7 +76,7 @@ class UpdateFormDataCommand(BaseCommand, ABC):
|
|||
new_state: TemporaryExploreState = {
|
||||
"owner": owner,
|
||||
"datasource_id": datasource_id,
|
||||
"datasource_type": datasource_type,
|
||||
"datasource_type": DatasourceType(datasource_type),
|
||||
"chart_id": chart_id,
|
||||
"form_data": form_data,
|
||||
}
|
||||
|
|
|
@ -28,7 +28,6 @@ from flask_babel import gettext as __, lazy_gettext as _
|
|||
from flask_compress import Compress
|
||||
from werkzeug.middleware.proxy_fix import ProxyFix
|
||||
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.constants import CHANGE_ME_SECRET_KEY
|
||||
from superset.extensions import (
|
||||
_event_logger,
|
||||
|
@ -473,7 +472,11 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
|
|||
# Registering sources
|
||||
module_datasource_map = self.config["DEFAULT_MODULE_DS_MAP"]
|
||||
module_datasource_map.update(self.config["ADDITIONAL_MODULE_DS_MAP"])
|
||||
ConnectorRegistry.register_sources(module_datasource_map)
|
||||
|
||||
# todo(hughhhh): fully remove the datasource config register
|
||||
for module_name, class_names in module_datasource_map.items():
|
||||
class_names = [str(s) for s in class_names]
|
||||
__import__(module_name, fromlist=class_names)
|
||||
|
||||
def configure_cache(self) -> None:
|
||||
cache_manager.init_app(self.superset_app)
|
||||
|
|
|
@ -46,10 +46,11 @@ from sqlalchemy.orm.session import object_session
|
|||
from sqlalchemy.sql import join, select
|
||||
from sqlalchemy.sql.elements import BinaryExpression
|
||||
|
||||
from superset import app, ConnectorRegistry, db, is_feature_enabled, security_manager
|
||||
from superset import app, db, is_feature_enabled, security_manager
|
||||
from superset.common.request_contexed_based import is_user_admin
|
||||
from superset.connectors.base.models import BaseDatasource
|
||||
from superset.connectors.sqla.models import SqlMetric, TableColumn
|
||||
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
|
||||
from superset.datasource.dao import DatasourceDAO
|
||||
from superset.extensions import cache_manager
|
||||
from superset.models.filter_set import FilterSet
|
||||
from superset.models.helpers import AuditMixinNullable, ImportExportMixin
|
||||
|
@ -407,16 +408,18 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin):
|
|||
id_ = target.get("datasetId")
|
||||
if id_ is None:
|
||||
continue
|
||||
datasource = ConnectorRegistry.get_datasource_by_id(session, id_)
|
||||
datasource = DatasourceDAO.get_datasource(
|
||||
session, utils.DatasourceType.TABLE, id_
|
||||
)
|
||||
datasource_ids.add((datasource.id, datasource.type))
|
||||
|
||||
copied_dashboard.alter_params(remote_id=dashboard_id)
|
||||
copied_dashboards.append(copied_dashboard)
|
||||
|
||||
eager_datasources = []
|
||||
for datasource_id, datasource_type in datasource_ids:
|
||||
eager_datasource = ConnectorRegistry.get_eager_datasource(
|
||||
db.session, datasource_type, datasource_id
|
||||
for datasource_id, _ in datasource_ids:
|
||||
eager_datasource = SqlaTable.get_eager_sqlatable_datasource(
|
||||
db.session, datasource_id
|
||||
)
|
||||
copied_datasource = eager_datasource.copy()
|
||||
copied_datasource.alter_params(
|
||||
|
|
|
@ -21,7 +21,6 @@ from flask_appbuilder import Model
|
|||
from sqlalchemy import Column, Integer, String
|
||||
|
||||
from superset import app, db, security_manager
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.models.helpers import AuditMixinNullable
|
||||
from superset.utils.memoized import memoized
|
||||
|
||||
|
@ -44,7 +43,10 @@ class DatasourceAccessRequest(Model, AuditMixinNullable):
|
|||
|
||||
@property
|
||||
def cls_model(self) -> Type["BaseDatasource"]:
|
||||
return ConnectorRegistry.sources[self.datasource_type]
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset.datasource.dao import DatasourceDAO
|
||||
|
||||
return DatasourceDAO.sources[self.datasource_type]
|
||||
|
||||
@property
|
||||
def username(self) -> Markup:
|
||||
|
|
|
@ -39,7 +39,7 @@ from sqlalchemy.engine.base import Connection
|
|||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.orm.mapper import Mapper
|
||||
|
||||
from superset import ConnectorRegistry, db, is_feature_enabled, security_manager
|
||||
from superset import db, is_feature_enabled, security_manager
|
||||
from superset.legacy import update_time_range
|
||||
from superset.models.helpers import AuditMixinNullable, ImportExportMixin
|
||||
from superset.models.tags import ChartUpdater
|
||||
|
@ -126,7 +126,10 @@ class Slice( # pylint: disable=too-many-public-methods
|
|||
|
||||
@property
|
||||
def cls_model(self) -> Type["BaseDatasource"]:
|
||||
return ConnectorRegistry.sources[self.datasource_type]
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset.datasource.dao import DatasourceDAO
|
||||
|
||||
return DatasourceDAO.sources[self.datasource_type]
|
||||
|
||||
@property
|
||||
def datasource(self) -> Optional["BaseDatasource"]:
|
||||
|
|
|
@ -61,7 +61,6 @@ from sqlalchemy.orm.mapper import Mapper
|
|||
from sqlalchemy.orm.query import Query as SqlaQuery
|
||||
|
||||
from superset import sql_parse
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.constants import RouteMethod
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import (
|
||||
|
@ -471,23 +470,25 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
user_perms = self.user_view_menu_names("datasource_access")
|
||||
schema_perms = self.user_view_menu_names("schema_access")
|
||||
user_datasources = set()
|
||||
for datasource_class in ConnectorRegistry.sources.values():
|
||||
user_datasources.update(
|
||||
self.get_session.query(datasource_class)
|
||||
.filter(
|
||||
or_(
|
||||
datasource_class.perm.in_(user_perms),
|
||||
datasource_class.schema_perm.in_(schema_perms),
|
||||
)
|
||||
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
||||
user_datasources.update(
|
||||
self.get_session.query(SqlaTable)
|
||||
.filter(
|
||||
or_(
|
||||
SqlaTable.perm.in_(user_perms),
|
||||
SqlaTable.schema_perm.in_(schema_perms),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# group all datasources by database
|
||||
all_datasources = ConnectorRegistry.get_all_datasources(self.get_session)
|
||||
datasources_by_database: Dict["Database", Set["BaseDatasource"]] = defaultdict(
|
||||
set
|
||||
)
|
||||
session = self.get_session
|
||||
all_datasources = SqlaTable.get_all_datasources(session)
|
||||
datasources_by_database: Dict["Database", Set["SqlaTable"]] = defaultdict(set)
|
||||
for datasource in all_datasources:
|
||||
datasources_by_database[datasource.database].add(datasource)
|
||||
|
||||
|
@ -599,6 +600,8 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
:param schema: The fallback SQL schema if not present in the table name
|
||||
:returns: The list of accessible SQL tables w/ schema
|
||||
"""
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
||||
if self.can_access_database(database):
|
||||
return datasource_names
|
||||
|
@ -610,7 +613,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
|
||||
user_perms = self.user_view_menu_names("datasource_access")
|
||||
schema_perms = self.user_view_menu_names("schema_access")
|
||||
user_datasources = ConnectorRegistry.query_datasources_by_permissions(
|
||||
user_datasources = SqlaTable.query_datasources_by_permissions(
|
||||
self.get_session, database, user_perms, schema_perms
|
||||
)
|
||||
if schema:
|
||||
|
@ -660,6 +663,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
"""
|
||||
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.models import core as models
|
||||
|
||||
logger.info("Fetching a set of all perms to lookup which ones are missing")
|
||||
|
@ -668,13 +672,13 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
if pv.permission and pv.view_menu:
|
||||
all_pvs.add((pv.permission.name, pv.view_menu.name))
|
||||
|
||||
def merge_pv(view_menu: str, perm: str) -> None:
|
||||
def merge_pv(view_menu: str, perm: Optional[str]) -> None:
|
||||
"""Create permission view menu only if it doesn't exist"""
|
||||
if view_menu and perm and (view_menu, perm) not in all_pvs:
|
||||
self.add_permission_view_menu(view_menu, perm)
|
||||
|
||||
logger.info("Creating missing datasource permissions.")
|
||||
datasources = ConnectorRegistry.get_all_datasources(self.get_session)
|
||||
datasources = SqlaTable.get_all_datasources(self.get_session)
|
||||
for datasource in datasources:
|
||||
merge_pv("datasource_access", datasource.get_perm())
|
||||
merge_pv("schema_access", datasource.get_schema_perm())
|
||||
|
|
|
@ -61,7 +61,6 @@ from superset.charts.dao import ChartDAO
|
|||
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
|
||||
from superset.common.db_query_status import QueryStatus
|
||||
from superset.connectors.base.models import BaseDatasource
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.connectors.sqla.models import (
|
||||
AnnotationDatasource,
|
||||
SqlaTable,
|
||||
|
@ -77,6 +76,7 @@ from superset.databases.dao import DatabaseDAO
|
|||
from superset.databases.filters import DatabaseFilter
|
||||
from superset.databases.utils import make_url_safe
|
||||
from superset.datasets.commands.exceptions import DatasetNotFoundError
|
||||
from superset.datasource.dao import DatasourceDAO
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import (
|
||||
CacheLoadError,
|
||||
|
@ -129,7 +129,11 @@ from superset.tasks.async_queries import load_explore_json_into_cache
|
|||
from superset.utils import core as utils, csv
|
||||
from superset.utils.async_query_manager import AsyncQueryTokenException
|
||||
from superset.utils.cache import etag_cache
|
||||
from superset.utils.core import apply_max_row_limit, ReservedUrlParameters
|
||||
from superset.utils.core import (
|
||||
apply_max_row_limit,
|
||||
DatasourceType,
|
||||
ReservedUrlParameters,
|
||||
)
|
||||
from superset.utils.dates import now_as_float
|
||||
from superset.utils.decorators import check_dashboard_access
|
||||
from superset.views.base import (
|
||||
|
@ -250,7 +254,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
)
|
||||
db_ds_names.add(fullname)
|
||||
|
||||
existing_datasources = ConnectorRegistry.get_all_datasources(db.session)
|
||||
existing_datasources = SqlaTable.get_all_datasources(db.session)
|
||||
datasources = [d for d in existing_datasources if d.full_name in db_ds_names]
|
||||
role = security_manager.find_role(role_name)
|
||||
# remove all permissions
|
||||
|
@ -282,7 +286,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
datasource_id = request.args.get("datasource_id")
|
||||
datasource_type = request.args.get("datasource_type")
|
||||
if datasource_id and datasource_type:
|
||||
ds_class = ConnectorRegistry.sources.get(datasource_type)
|
||||
ds_class = DatasourceDAO.sources.get(datasource_type)
|
||||
datasource = (
|
||||
db.session.query(ds_class).filter_by(id=int(datasource_id)).one()
|
||||
)
|
||||
|
@ -319,10 +323,8 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
def approve(self) -> FlaskResponse: # pylint: disable=too-many-locals,no-self-use
|
||||
def clean_fulfilled_requests(session: Session) -> None:
|
||||
for dar in session.query(DAR).all():
|
||||
datasource = ConnectorRegistry.get_datasource(
|
||||
dar.datasource_type,
|
||||
dar.datasource_id,
|
||||
session,
|
||||
datasource = DatasourceDAO.get_datasource(
|
||||
session, DatasourceType(dar.datasource_type), dar.datasource_id
|
||||
)
|
||||
if not datasource or security_manager.can_access_datasource(datasource):
|
||||
# Dataset does not exist anymore
|
||||
|
@ -336,8 +338,8 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
role_to_extend = request.args.get("role_to_extend")
|
||||
|
||||
session = db.session
|
||||
datasource = ConnectorRegistry.get_datasource(
|
||||
datasource_type, datasource_id, session
|
||||
datasource = DatasourceDAO.get_datasource(
|
||||
session, DatasourceType(datasource_type), int(datasource_id)
|
||||
)
|
||||
|
||||
if not datasource:
|
||||
|
@ -639,7 +641,6 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
datasource_id, datasource_type = get_datasource_info(
|
||||
datasource_id, datasource_type, form_data
|
||||
)
|
||||
|
||||
force = request.args.get("force") == "true"
|
||||
|
||||
# TODO: support CSV, SQL query and other non-JSON types
|
||||
|
@ -809,8 +810,10 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
datasource: Optional[BaseDatasource] = None
|
||||
if datasource_id is not None:
|
||||
try:
|
||||
datasource = ConnectorRegistry.get_datasource(
|
||||
cast(str, datasource_type), datasource_id, db.session
|
||||
datasource = DatasourceDAO.get_datasource(
|
||||
db.session,
|
||||
DatasourceType(cast(str, datasource_type)),
|
||||
datasource_id,
|
||||
)
|
||||
except DatasetNotFoundError:
|
||||
pass
|
||||
|
@ -948,10 +951,8 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
:raises SupersetSecurityException: If the user cannot access the resource
|
||||
"""
|
||||
# TODO: Cache endpoint by user, datasource and column
|
||||
datasource = ConnectorRegistry.get_datasource(
|
||||
datasource_type,
|
||||
datasource_id,
|
||||
db.session,
|
||||
datasource = DatasourceDAO.get_datasource(
|
||||
db.session, DatasourceType(datasource_type), datasource_id
|
||||
)
|
||||
if not datasource:
|
||||
return json_error_response(DATASOURCE_MISSING_ERR)
|
||||
|
@ -1920,8 +1921,8 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
|
||||
if config["ENABLE_ACCESS_REQUEST"]:
|
||||
for datasource in dashboard.datasources:
|
||||
datasource = ConnectorRegistry.get_datasource(
|
||||
datasource_type=datasource.type,
|
||||
datasource = DatasourceDAO.get_datasource(
|
||||
datasource_type=DatasourceType(datasource.type),
|
||||
datasource_id=datasource.id,
|
||||
session=db.session(),
|
||||
)
|
||||
|
@ -2537,10 +2538,8 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
"""
|
||||
|
||||
datasource_id, datasource_type = request.args["datasourceKey"].split("__")
|
||||
datasource = ConnectorRegistry.get_datasource(
|
||||
datasource_type,
|
||||
datasource_id,
|
||||
db.session,
|
||||
datasource = DatasourceDAO.get_datasource(
|
||||
db.session, DatasourceType(datasource_type), int(datasource_id)
|
||||
)
|
||||
# Check if datasource exists
|
||||
if not datasource:
|
||||
|
|
|
@ -29,16 +29,18 @@ from sqlalchemy.orm.exc import NoResultFound
|
|||
|
||||
from superset import db, event_logger
|
||||
from superset.commands.utils import populate_owners
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.connectors.sqla.utils import get_physical_table_metadata
|
||||
from superset.datasets.commands.exceptions import (
|
||||
DatasetForbiddenError,
|
||||
DatasetNotFoundError,
|
||||
)
|
||||
from superset.datasource.dao import DatasourceDAO
|
||||
from superset.exceptions import SupersetException, SupersetSecurityException
|
||||
from superset.extensions import security_manager
|
||||
from superset.models.core import Database
|
||||
from superset.superset_typing import FlaskResponse
|
||||
from superset.utils.core import DatasourceType
|
||||
from superset.views.base import (
|
||||
api,
|
||||
BaseSupersetView,
|
||||
|
@ -74,8 +76,8 @@ class Datasource(BaseSupersetView):
|
|||
datasource_id = datasource_dict.get("id")
|
||||
datasource_type = datasource_dict.get("type")
|
||||
database_id = datasource_dict["database"].get("id")
|
||||
orm_datasource = ConnectorRegistry.get_datasource(
|
||||
datasource_type, datasource_id, db.session
|
||||
orm_datasource = DatasourceDAO.get_datasource(
|
||||
db.session, DatasourceType(datasource_type), datasource_id
|
||||
)
|
||||
orm_datasource.database_id = database_id
|
||||
|
||||
|
@ -117,8 +119,8 @@ class Datasource(BaseSupersetView):
|
|||
@api
|
||||
@handle_api_exception
|
||||
def get(self, datasource_type: str, datasource_id: int) -> FlaskResponse:
|
||||
datasource = ConnectorRegistry.get_datasource(
|
||||
datasource_type, datasource_id, db.session
|
||||
datasource = DatasourceDAO.get_datasource(
|
||||
db.session, DatasourceType(datasource_type), datasource_id
|
||||
)
|
||||
return self.json_response(sanitize_datasource_data(datasource.data))
|
||||
|
||||
|
@ -130,8 +132,10 @@ class Datasource(BaseSupersetView):
|
|||
self, datasource_type: str, datasource_id: int
|
||||
) -> FlaskResponse:
|
||||
"""Gets column info from the source system"""
|
||||
datasource = ConnectorRegistry.get_datasource(
|
||||
datasource_type, datasource_id, db.session
|
||||
datasource = DatasourceDAO.get_datasource(
|
||||
db.session,
|
||||
DatasourceType(datasource_type),
|
||||
datasource_id,
|
||||
)
|
||||
try:
|
||||
external_metadata = datasource.external_metadata()
|
||||
|
@ -153,9 +157,8 @@ class Datasource(BaseSupersetView):
|
|||
except ValidationError as err:
|
||||
return json_error_response(str(err), status=400)
|
||||
|
||||
datasource = ConnectorRegistry.get_datasource_by_name(
|
||||
datasource = SqlaTable.get_datasource_by_name(
|
||||
session=db.session,
|
||||
datasource_type=params["datasource_type"],
|
||||
database_name=params["database_name"],
|
||||
schema=params["schema_name"],
|
||||
datasource_name=params["table_name"],
|
||||
|
|
|
@ -32,7 +32,7 @@ from sqlalchemy.orm.exc import NoResultFound
|
|||
import superset.models.core as models
|
||||
from superset import app, dataframe, db, result_set, viz
|
||||
from superset.common.db_query_status import QueryStatus
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.datasource.dao import DatasourceDAO
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import (
|
||||
CacheLoadError,
|
||||
|
@ -47,6 +47,7 @@ from superset.models.dashboard import Dashboard
|
|||
from superset.models.slice import Slice
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.superset_typing import FormData
|
||||
from superset.utils.core import DatasourceType
|
||||
from superset.utils.decorators import stats_timing
|
||||
from superset.viz import BaseViz
|
||||
|
||||
|
@ -127,8 +128,10 @@ def get_viz(
|
|||
force_cached: bool = False,
|
||||
) -> BaseViz:
|
||||
viz_type = form_data.get("viz_type", "table")
|
||||
datasource = ConnectorRegistry.get_datasource(
|
||||
datasource_type, datasource_id, db.session
|
||||
datasource = DatasourceDAO.get_datasource(
|
||||
db.session,
|
||||
DatasourceType(datasource_type),
|
||||
datasource_id,
|
||||
)
|
||||
viz_obj = viz.viz_types[viz_type](
|
||||
datasource, form_data=form_data, force=force, force_cached=force_cached
|
||||
|
|
|
@ -39,7 +39,6 @@ from tests.integration_tests.fixtures.energy_dashboard import (
|
|||
)
|
||||
from tests.integration_tests.test_app import app # isort:skip
|
||||
from superset import db, security_manager
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.models import core as models
|
||||
from superset.models.datasource_access_request import DatasourceAccessRequest
|
||||
|
@ -90,12 +89,12 @@ SCHEMA_ACCESS_ROLE = "schema_access_role"
|
|||
|
||||
|
||||
def create_access_request(session, ds_type, ds_name, role_name, username):
|
||||
ds_class = ConnectorRegistry.sources[ds_type]
|
||||
# TODO: generalize datasource names
|
||||
if ds_type == "table":
|
||||
ds = session.query(ds_class).filter(ds_class.table_name == ds_name).first()
|
||||
ds = session.query(SqlaTable).filter(SqlaTable.table_name == ds_name).first()
|
||||
else:
|
||||
ds = session.query(ds_class).filter(ds_class.datasource_name == ds_name).first()
|
||||
# This function will only work for ds_type == "table"
|
||||
raise NotImplementedError()
|
||||
ds_perm_view = security_manager.find_permission_view_menu(
|
||||
"datasource_access", ds.perm
|
||||
)
|
||||
|
@ -449,49 +448,6 @@ class TestRequestAccess(SupersetTestCase):
|
|||
TEST_ROLE = security_manager.find_role(TEST_ROLE_NAME)
|
||||
self.assertIn(perm_view, TEST_ROLE.permissions)
|
||||
|
||||
# Case 3. Grant new role to the user to access the druid datasource.
|
||||
|
||||
security_manager.add_role("druid_role")
|
||||
access_request3 = create_access_request(
|
||||
session, "druid", "druid_ds_1", "druid_role", "gamma"
|
||||
)
|
||||
self.get_resp(
|
||||
GRANT_ROLE_REQUEST.format(
|
||||
"druid", access_request3.datasource_id, "gamma", "druid_role"
|
||||
)
|
||||
)
|
||||
|
||||
# user was granted table_role
|
||||
user_roles = [r.name for r in security_manager.find_user("gamma").roles]
|
||||
self.assertIn("druid_role", user_roles)
|
||||
|
||||
# Case 4. Extend the role to have access to the druid datasource
|
||||
|
||||
access_request4 = create_access_request(
|
||||
session, "druid", "druid_ds_2", "druid_role", "gamma"
|
||||
)
|
||||
druid_ds_2_perm = access_request4.datasource.perm
|
||||
|
||||
self.client.get(
|
||||
EXTEND_ROLE_REQUEST.format(
|
||||
"druid", access_request4.datasource_id, "gamma", "druid_role"
|
||||
)
|
||||
)
|
||||
# druid_role was extended to grant access to the druid_access_ds_2
|
||||
druid_role = security_manager.find_role("druid_role")
|
||||
perm_view = security_manager.find_permission_view_menu(
|
||||
"datasource_access", druid_ds_2_perm
|
||||
)
|
||||
self.assertIn(perm_view, druid_role.permissions)
|
||||
|
||||
# cleanup
|
||||
gamma_user = security_manager.find_user(username="gamma")
|
||||
gamma_user.roles.remove(security_manager.find_role("druid_role"))
|
||||
gamma_user.roles.remove(security_manager.find_role(TEST_ROLE_NAME))
|
||||
session.delete(security_manager.find_role("druid_role"))
|
||||
session.delete(security_manager.find_role(TEST_ROLE_NAME))
|
||||
session.commit()
|
||||
|
||||
def test_request_access(self):
|
||||
if app.config["ENABLE_ACCESS_REQUEST"]:
|
||||
session = db.session
|
||||
|
|
|
@ -21,7 +21,7 @@ from typing import Any, Dict, List, Optional
|
|||
|
||||
from pandas import DataFrame
|
||||
|
||||
from superset import ConnectorRegistry, db
|
||||
from superset import db
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.models.core import Database
|
||||
from superset.models.dashboard import Dashboard
|
||||
|
@ -35,9 +35,8 @@ def get_table(
|
|||
schema: Optional[str] = None,
|
||||
):
|
||||
schema = schema or get_example_default_schema()
|
||||
table_source = ConnectorRegistry.sources["table"]
|
||||
return (
|
||||
db.session.query(table_source)
|
||||
db.session.query(SqlaTable)
|
||||
.filter_by(database_id=database.id, schema=schema, table_name=table_name)
|
||||
.one_or_none()
|
||||
)
|
||||
|
@ -54,8 +53,7 @@ def create_table_metadata(
|
|||
|
||||
table = get_table(table_name, database, schema)
|
||||
if not table:
|
||||
table_source = ConnectorRegistry.sources["table"]
|
||||
table = table_source(schema=schema, table_name=table_name)
|
||||
table = SqlaTable(schema=schema, table_name=table_name)
|
||||
if fetch_values_predicate:
|
||||
table.fetch_values_predicate = fetch_values_predicate
|
||||
table.database = database
|
||||
|
|
|
@ -22,12 +22,13 @@ from unittest import mock
|
|||
import prison
|
||||
import pytest
|
||||
|
||||
from superset import app, ConnectorRegistry, db
|
||||
from superset import app, db
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.dao.exceptions import DatasourceNotFound, DatasourceTypeNotSupportedError
|
||||
from superset.datasets.commands.exceptions import DatasetNotFoundError
|
||||
from superset.exceptions import SupersetGenericDBErrorException
|
||||
from superset.models.core import Database
|
||||
from superset.utils.core import get_example_default_schema
|
||||
from superset.utils.core import DatasourceType, get_example_default_schema
|
||||
from superset.utils.database import get_example_database
|
||||
from tests.integration_tests.base_tests import db_insert_temp_object, SupersetTestCase
|
||||
from tests.integration_tests.fixtures.birth_names_dashboard import (
|
||||
|
@ -256,9 +257,10 @@ class TestDatasource(SupersetTestCase):
|
|||
|
||||
pytest.raises(
|
||||
SupersetGenericDBErrorException,
|
||||
lambda: ConnectorRegistry.get_datasource(
|
||||
"table", tbl.id, db.session
|
||||
).external_metadata(),
|
||||
lambda: db.session.query(SqlaTable)
|
||||
.filter_by(id=tbl.id)
|
||||
.one_or_none()
|
||||
.external_metadata(),
|
||||
)
|
||||
|
||||
resp = self.client.get(url)
|
||||
|
@ -385,21 +387,30 @@ class TestDatasource(SupersetTestCase):
|
|||
app.config["DATASET_HEALTH_CHECK"] = my_check
|
||||
self.login(username="admin")
|
||||
tbl = self.get_table(name="birth_names")
|
||||
datasource = ConnectorRegistry.get_datasource("table", tbl.id, db.session)
|
||||
datasource = db.session.query(SqlaTable).filter_by(id=tbl.id).one_or_none()
|
||||
assert datasource.health_check_message == "Warning message!"
|
||||
app.config["DATASET_HEALTH_CHECK"] = None
|
||||
|
||||
def test_get_datasource_failed(self):
|
||||
from superset.datasource.dao import DatasourceDAO
|
||||
|
||||
pytest.raises(
|
||||
DatasetNotFoundError,
|
||||
lambda: ConnectorRegistry.get_datasource("table", 9999999, db.session),
|
||||
DatasourceNotFound,
|
||||
lambda: DatasourceDAO.get_datasource(db.session, "table", 9999999),
|
||||
)
|
||||
|
||||
self.login(username="admin")
|
||||
resp = self.get_json_resp("/datasource/get/table/500000/", raise_on_error=False)
|
||||
self.assertEqual(resp.get("error"), "Datasource does not exist")
|
||||
|
||||
def test_get_datasource_invalid_datasource_failed(self):
|
||||
from superset.datasource.dao import DatasourceDAO
|
||||
|
||||
pytest.raises(
|
||||
DatasourceTypeNotSupportedError,
|
||||
lambda: DatasourceDAO.get_datasource(db.session, "druid", 9999999),
|
||||
)
|
||||
|
||||
self.login(username="admin")
|
||||
resp = self.get_json_resp("/datasource/get/druid/500000/", raise_on_error=False)
|
||||
self.assertEqual(resp.get("error"), "Dataset does not exist")
|
||||
|
||||
resp = self.get_json_resp(
|
||||
"/datasource/get/invalid-datasource-type/500000/", raise_on_error=False
|
||||
)
|
||||
self.assertEqual(resp.get("error"), "Dataset does not exist")
|
||||
self.assertEqual(resp.get("error"), "'druid' is not a valid DatasourceType")
|
||||
|
|
|
@ -26,6 +26,7 @@ from superset.datasets.commands.exceptions import DatasetAccessDeniedError
|
|||
from superset.explore.form_data.commands.state import TemporaryExploreState
|
||||
from superset.extensions import cache_manager
|
||||
from superset.models.slice import Slice
|
||||
from superset.utils.core import DatasourceType
|
||||
from tests.integration_tests.base_tests import login
|
||||
from tests.integration_tests.fixtures.client import client
|
||||
from tests.integration_tests.fixtures.world_bank_dashboard import (
|
||||
|
@ -392,7 +393,7 @@ def test_delete_not_owner(client, chart_id: int, datasource: SqlaTable, admin_id
|
|||
entry: TemporaryExploreState = {
|
||||
"owner": another_owner,
|
||||
"datasource_id": datasource.id,
|
||||
"datasource_type": datasource.type,
|
||||
"datasource_type": DatasourceType(datasource.type),
|
||||
"chart_id": chart_id,
|
||||
"form_data": INITIAL_FORM_DATA,
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@ from typing import Callable, List, Optional
|
|||
|
||||
import pytest
|
||||
|
||||
from superset import ConnectorRegistry, db
|
||||
from superset import db
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.models.core import Database
|
||||
from superset.models.dashboard import Dashboard
|
||||
|
@ -95,14 +95,11 @@ def _create_table(
|
|||
|
||||
def _cleanup(dash_id: int, slices_ids: List[int]) -> None:
|
||||
schema = get_example_default_schema()
|
||||
|
||||
table_id = (
|
||||
datasource = (
|
||||
db.session.query(SqlaTable)
|
||||
.filter_by(table_name="birth_names", schema=schema)
|
||||
.one()
|
||||
.id
|
||||
)
|
||||
datasource = ConnectorRegistry.get_datasource("table", table_id, db.session)
|
||||
columns = [column for column in datasource.columns]
|
||||
metrics = [metric for metric in datasource.metrics]
|
||||
|
||||
|
|
|
@ -82,7 +82,6 @@ def _create_energy_table():
|
|||
table.metrics.append(
|
||||
SqlMetric(metric_name="sum__value", expression=f"SUM({col})")
|
||||
)
|
||||
|
||||
db.session.merge(table)
|
||||
db.session.commit()
|
||||
table.fetch_metadata()
|
||||
|
|
|
@ -16,7 +16,8 @@
|
|||
# under the License.
|
||||
from typing import List, Optional
|
||||
|
||||
from superset import ConnectorRegistry, db, security_manager
|
||||
from superset import db, security_manager
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.models.slice import Slice
|
||||
|
||||
|
||||
|
@ -43,8 +44,8 @@ class InsertChartMixin:
|
|||
for owner in owners:
|
||||
user = db.session.query(security_manager.user_model).get(owner)
|
||||
obj_owners.append(user)
|
||||
datasource = ConnectorRegistry.get_datasource(
|
||||
datasource_type, datasource_id, db.session
|
||||
datasource = (
|
||||
db.session.query(SqlaTable).filter_by(id=datasource_id).one_or_none()
|
||||
)
|
||||
slice = Slice(
|
||||
cache_timeout=cache_timeout,
|
||||
|
|
|
@ -26,10 +26,15 @@ from superset.charts.schemas import ChartDataQueryContextSchema
|
|||
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
|
||||
from superset.common.query_context import QueryContext
|
||||
from superset.common.query_object import QueryObject
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.connectors.sqla.models import SqlMetric
|
||||
from superset.datasource.dao import DatasourceDAO
|
||||
from superset.extensions import cache_manager
|
||||
from superset.utils.core import AdhocMetricExpressionType, backend, QueryStatus
|
||||
from superset.utils.core import (
|
||||
AdhocMetricExpressionType,
|
||||
backend,
|
||||
DatasourceType,
|
||||
QueryStatus,
|
||||
)
|
||||
from tests.integration_tests.base_tests import SupersetTestCase
|
||||
from tests.integration_tests.fixtures.birth_names_dashboard import (
|
||||
load_birth_names_dashboard_with_slices,
|
||||
|
@ -132,10 +137,10 @@ class TestQueryContext(SupersetTestCase):
|
|||
cache_key_original = query_context.query_cache_key(query_object)
|
||||
|
||||
# make temporary change and revert it to refresh the changed_on property
|
||||
datasource = ConnectorRegistry.get_datasource(
|
||||
datasource_type=payload["datasource"]["type"],
|
||||
datasource_id=payload["datasource"]["id"],
|
||||
datasource = DatasourceDAO.get_datasource(
|
||||
session=db.session,
|
||||
datasource_type=DatasourceType(payload["datasource"]["type"]),
|
||||
datasource_id=payload["datasource"]["id"],
|
||||
)
|
||||
description_original = datasource.description
|
||||
datasource.description = "temporary description"
|
||||
|
@ -156,10 +161,10 @@ class TestQueryContext(SupersetTestCase):
|
|||
payload = get_query_context("birth_names")
|
||||
|
||||
# make temporary change and revert it to refresh the changed_on property
|
||||
datasource = ConnectorRegistry.get_datasource(
|
||||
datasource_type=payload["datasource"]["type"],
|
||||
datasource_id=payload["datasource"]["id"],
|
||||
datasource = DatasourceDAO.get_datasource(
|
||||
session=db.session,
|
||||
datasource_type=DatasourceType(payload["datasource"]["type"]),
|
||||
datasource_id=payload["datasource"]["id"],
|
||||
)
|
||||
|
||||
datasource.metrics.append(SqlMetric(metric_name="foo", expression="select 1;"))
|
||||
|
|
|
@ -28,10 +28,11 @@ import prison
|
|||
import pytest
|
||||
|
||||
from flask import current_app
|
||||
from superset.datasource.dao import DatasourceDAO
|
||||
|
||||
from superset.models.dashboard import Dashboard
|
||||
|
||||
from superset import app, appbuilder, db, security_manager, viz, ConnectorRegistry
|
||||
from superset import app, appbuilder, db, security_manager, viz
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
|
@ -990,7 +991,7 @@ class TestDatasources(SupersetTestCase):
|
|||
mock_get_session.query.return_value.filter.return_value.all.return_value = []
|
||||
|
||||
with mock.patch.object(
|
||||
ConnectorRegistry, "get_all_datasources"
|
||||
SqlaTable, "get_all_datasources"
|
||||
) as mock_get_all_datasources:
|
||||
mock_get_all_datasources.return_value = [
|
||||
Datasource("database1", "schema1", "table1"),
|
||||
|
@ -1018,7 +1019,7 @@ class TestDatasources(SupersetTestCase):
|
|||
mock_get_session.query.return_value.filter.return_value.all.return_value = []
|
||||
|
||||
with mock.patch.object(
|
||||
ConnectorRegistry, "get_all_datasources"
|
||||
SqlaTable, "get_all_datasources"
|
||||
) as mock_get_all_datasources:
|
||||
mock_get_all_datasources.return_value = [
|
||||
Datasource("database1", "schema1", "table1"),
|
||||
|
@ -1046,7 +1047,7 @@ class TestDatasources(SupersetTestCase):
|
|||
]
|
||||
|
||||
with mock.patch.object(
|
||||
ConnectorRegistry, "get_all_datasources"
|
||||
SqlaTable, "get_all_datasources"
|
||||
) as mock_get_all_datasources:
|
||||
mock_get_all_datasources.return_value = [
|
||||
Datasource("database1", "schema1", "table1"),
|
||||
|
|
|
@ -103,7 +103,7 @@ def test_get_datasource_sqlatable(
|
|||
app_context: None, session_with_data: Session
|
||||
) -> None:
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.dao.datasource.dao import DatasourceDAO
|
||||
from superset.datasource.dao import DatasourceDAO
|
||||
|
||||
result = DatasourceDAO.get_datasource(
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
|
@ -117,7 +117,7 @@ def test_get_datasource_sqlatable(
|
|||
|
||||
|
||||
def test_get_datasource_query(app_context: None, session_with_data: Session) -> None:
|
||||
from superset.dao.datasource.dao import DatasourceDAO
|
||||
from superset.datasource.dao import DatasourceDAO
|
||||
from superset.models.sql_lab import Query
|
||||
|
||||
result = DatasourceDAO.get_datasource(
|
||||
|
@ -131,7 +131,7 @@ def test_get_datasource_query(app_context: None, session_with_data: Session) ->
|
|||
def test_get_datasource_saved_query(
|
||||
app_context: None, session_with_data: Session
|
||||
) -> None:
|
||||
from superset.dao.datasource.dao import DatasourceDAO
|
||||
from superset.datasource.dao import DatasourceDAO
|
||||
from superset.models.sql_lab import SavedQuery
|
||||
|
||||
result = DatasourceDAO.get_datasource(
|
||||
|
@ -145,7 +145,7 @@ def test_get_datasource_saved_query(
|
|||
|
||||
|
||||
def test_get_datasource_sl_table(app_context: None, session_with_data: Session) -> None:
|
||||
from superset.dao.datasource.dao import DatasourceDAO
|
||||
from superset.datasource.dao import DatasourceDAO
|
||||
from superset.tables.models import Table
|
||||
|
||||
# todo(hugh): This will break once we remove the dual write
|
||||
|
@ -163,8 +163,8 @@ def test_get_datasource_sl_table(app_context: None, session_with_data: Session)
|
|||
def test_get_datasource_sl_dataset(
|
||||
app_context: None, session_with_data: Session
|
||||
) -> None:
|
||||
from superset.dao.datasource.dao import DatasourceDAO
|
||||
from superset.datasets.models import Dataset
|
||||
from superset.datasource.dao import DatasourceDAO
|
||||
|
||||
# todo(hugh): This will break once we remove the dual write
|
||||
# update the datsource_id=1 and this will pass again
|
||||
|
@ -178,10 +178,35 @@ def test_get_datasource_sl_dataset(
|
|||
assert isinstance(result, Dataset)
|
||||
|
||||
|
||||
def test_get_all_sqlatables_datasources(
|
||||
def test_get_datasource_w_str_param(
|
||||
app_context: None, session_with_data: Session
|
||||
) -> None:
|
||||
from superset.dao.datasource.dao import DatasourceDAO
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.datasets.models import Dataset
|
||||
from superset.datasource.dao import DatasourceDAO
|
||||
from superset.tables.models import Table
|
||||
|
||||
result = DatasourceDAO.get_all_sqlatables_datasources(session=session_with_data)
|
||||
assert isinstance(
|
||||
DatasourceDAO.get_datasource(
|
||||
datasource_type="table",
|
||||
datasource_id=1,
|
||||
session=session_with_data,
|
||||
),
|
||||
SqlaTable,
|
||||
)
|
||||
|
||||
assert isinstance(
|
||||
DatasourceDAO.get_datasource(
|
||||
datasource_type="sl_table",
|
||||
datasource_id=1,
|
||||
session=session_with_data,
|
||||
),
|
||||
Table,
|
||||
)
|
||||
|
||||
|
||||
def test_get_all_datasources(app_context: None, session_with_data: Session) -> None:
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
||||
result = SqlaTable.get_all_datasources(session=session_with_data)
|
||||
assert len(result) == 1
|
Loading…
Reference in New Issue