chore: switching out ConnectorRegistry references for DatasourceDAO (#20380)

* rename and move dao file

* Update dao.py

* add cachekey

* Update __init__.py

* change reference in query context test

* add utils ref

* more ref changes

* add helpers

* add todo in dashboard.py

* add cachekey

* circular import error in dar.py

* push rest of refs

* fix linting

* fix more linting

* update enum

* remove references for connector registry

* big reafctor

* take value

* fix

* test to see if removing value works

* delete connectregistry

* address concerns

* address comments

* fix merge conflicts

* address concern II

* address concern II

* fix test

Co-authored-by: Phillip Kelley-Dotson <pkelleydotson@yahoo.com>
This commit is contained in:
Hugh A. Miles II 2022-06-21 13:22:39 +02:00 committed by GitHub
parent c79b0d62d0
commit e3e37cb68f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
34 changed files with 334 additions and 504 deletions

View File

@ -19,7 +19,6 @@ from flask import current_app, Flask
from werkzeug.local import LocalProxy
from superset.app import create_app
from superset.connectors.connector_registry import ConnectorRegistry
from superset.extensions import (
appbuilder,
cache_manager,

View File

@ -25,7 +25,7 @@ from marshmallow.exceptions import ValidationError
from sqlalchemy.exc import SQLAlchemyError
from superset.cachekeys.schemas import CacheInvalidationRequestSchema
from superset.connectors.connector_registry import ConnectorRegistry
from superset.connectors.sqla.models import SqlaTable
from superset.extensions import cache_manager, db, event_logger
from superset.models.cache import CacheKey
from superset.views.base_api import BaseSupersetModelRestApi, statsd_metrics
@ -83,13 +83,13 @@ class CacheRestApi(BaseSupersetModelRestApi):
return self.response_400(message=str(error))
datasource_uids = set(datasources.get("datasource_uids", []))
for ds in datasources.get("datasources", []):
ds_obj = ConnectorRegistry.get_datasource_by_name(
ds_obj = SqlaTable.get_datasource_by_name(
session=db.session,
datasource_type=ds.get("datasource_type"),
datasource_name=ds.get("datasource_name"),
schema=ds.get("schema"),
database_name=ds.get("database_name"),
)
if ds_obj:
datasource_uids.add(ds_obj.uid)

View File

@ -25,9 +25,10 @@ from superset.commands.exceptions import (
OwnersNotFoundValidationError,
RolesNotFoundValidationError,
)
from superset.connectors.connector_registry import ConnectorRegistry
from superset.datasets.commands.exceptions import DatasetNotFoundError
from superset.dao.exceptions import DatasourceNotFound
from superset.datasource.dao import DatasourceDAO
from superset.extensions import db, security_manager
from superset.utils.core import DatasourceType
if TYPE_CHECKING:
from superset.connectors.base.models import BaseDatasource
@ -79,8 +80,8 @@ def populate_roles(role_ids: Optional[List[int]] = None) -> List[Role]:
def get_datasource_by_id(datasource_id: int, datasource_type: str) -> BaseDatasource:
try:
return ConnectorRegistry.get_datasource(
datasource_type, datasource_id, db.session
return DatasourceDAO.get_datasource(
db.session, DatasourceType(datasource_type), datasource_id
)
except DatasetNotFoundError as ex:
except DatasourceNotFound as ex:
raise DatasourceNotFoundValidationError() from ex

View File

@ -22,8 +22,8 @@ from superset import app, db
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.common.query_context import QueryContext
from superset.common.query_object_factory import QueryObjectFactory
from superset.connectors.connector_registry import ConnectorRegistry
from superset.utils.core import DatasourceDict
from superset.datasource.dao import DatasourceDAO
from superset.utils.core import DatasourceDict, DatasourceType
if TYPE_CHECKING:
from superset.connectors.base.models import BaseDatasource
@ -32,7 +32,7 @@ config = app.config
def create_query_object_factory() -> QueryObjectFactory:
return QueryObjectFactory(config, ConnectorRegistry(), db.session)
return QueryObjectFactory(config, DatasourceDAO(), db.session)
class QueryContextFactory: # pylint: disable=too-few-public-methods
@ -82,6 +82,6 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
# pylint: disable=no-self-use
def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource:
return ConnectorRegistry.get_datasource(
str(datasource["type"]), int(datasource["id"]), db.session
return DatasourceDAO.get_datasource(
db.session, DatasourceType(datasource["type"]), int(datasource["id"])
)

View File

@ -21,29 +21,29 @@ from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING
from superset.common.chart_data import ChartDataResultType
from superset.common.query_object import QueryObject
from superset.utils.core import apply_max_row_limit, DatasourceDict
from superset.utils.core import apply_max_row_limit, DatasourceDict, DatasourceType
from superset.utils.date_parser import get_since_until
if TYPE_CHECKING:
from sqlalchemy.orm import sessionmaker
from superset import ConnectorRegistry
from superset.connectors.base.models import BaseDatasource
from superset.datasource.dao import DatasourceDAO
class QueryObjectFactory: # pylint: disable=too-few-public-methods
_config: Dict[str, Any]
_connector_registry: ConnectorRegistry
_datasource_dao: DatasourceDAO
_session_maker: sessionmaker
def __init__(
self,
app_configurations: Dict[str, Any],
connector_registry: ConnectorRegistry,
_datasource_dao: DatasourceDAO,
session_maker: sessionmaker,
):
self._config = app_configurations
self._connector_registry = connector_registry
self._datasource_dao = _datasource_dao
self._session_maker = session_maker
def create( # pylint: disable=too-many-arguments
@ -75,8 +75,10 @@ class QueryObjectFactory: # pylint: disable=too-few-public-methods
)
def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource:
return self._connector_registry.get_datasource(
str(datasource["type"]), int(datasource["id"]), self._session_maker()
return self._datasource_dao.get_datasource(
datasource_type=DatasourceType(datasource["type"]),
datasource_id=int(datasource["id"]),
session=self._session_maker(),
)
def _process_extras( # pylint: disable=no-self-use

View File

@ -1,164 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Dict, List, Optional, Set, Type, TYPE_CHECKING
from flask_babel import _
from sqlalchemy import or_
from sqlalchemy.orm import Session, subqueryload
from sqlalchemy.orm.exc import NoResultFound
from superset.datasets.commands.exceptions import DatasetNotFoundError
if TYPE_CHECKING:
from collections import OrderedDict
from superset.connectors.base.models import BaseDatasource
from superset.models.core import Database
class ConnectorRegistry:
"""Central Registry for all available datasource engines"""
sources: Dict[str, Type["BaseDatasource"]] = {}
@classmethod
def register_sources(cls, datasource_config: "OrderedDict[str, List[str]]") -> None:
for module_name, class_names in datasource_config.items():
class_names = [str(s) for s in class_names]
module_obj = __import__(module_name, fromlist=class_names)
for class_name in class_names:
source_class = getattr(module_obj, class_name)
cls.sources[source_class.type] = source_class
@classmethod
def get_datasource(
cls, datasource_type: str, datasource_id: int, session: Session
) -> "BaseDatasource":
"""Safely get a datasource instance, raises `DatasetNotFoundError` if
`datasource_type` is not registered or `datasource_id` does not
exist."""
if datasource_type not in cls.sources:
raise DatasetNotFoundError()
datasource = (
session.query(cls.sources[datasource_type])
.filter_by(id=datasource_id)
.one_or_none()
)
if not datasource:
raise DatasetNotFoundError()
return datasource
@classmethod
def get_all_datasources(cls, session: Session) -> List["BaseDatasource"]:
datasources: List["BaseDatasource"] = []
for source_class in ConnectorRegistry.sources.values():
qry = session.query(source_class)
qry = source_class.default_query(qry)
datasources.extend(qry.all())
return datasources
@classmethod
def get_datasource_by_id(
cls, session: Session, datasource_id: int
) -> "BaseDatasource":
"""
Find a datasource instance based on the unique id.
:param session: Session to use
:param datasource_id: unique id of datasource
:return: Datasource corresponding to the id
:raises NoResultFound: if no datasource is found corresponding to the id
"""
for datasource_class in ConnectorRegistry.sources.values():
try:
return (
session.query(datasource_class)
.filter(datasource_class.id == datasource_id)
.one()
)
except NoResultFound:
# proceed to next datasource type
pass
raise NoResultFound(_("Datasource id not found: %(id)s", id=datasource_id))
@classmethod
def get_datasource_by_name( # pylint: disable=too-many-arguments
cls,
session: Session,
datasource_type: str,
datasource_name: str,
schema: str,
database_name: str,
) -> Optional["BaseDatasource"]:
datasource_class = ConnectorRegistry.sources[datasource_type]
return datasource_class.get_datasource_by_name(
session, datasource_name, schema, database_name
)
@classmethod
def query_datasources_by_permissions( # pylint: disable=invalid-name
cls,
session: Session,
database: "Database",
permissions: Set[str],
schema_perms: Set[str],
) -> List["BaseDatasource"]:
# TODO(bogdan): add unit test
datasource_class = ConnectorRegistry.sources[database.type]
return (
session.query(datasource_class)
.filter_by(database_id=database.id)
.filter(
or_(
datasource_class.perm.in_(permissions),
datasource_class.schema_perm.in_(schema_perms),
)
)
.all()
)
@classmethod
def get_eager_datasource(
cls, session: Session, datasource_type: str, datasource_id: int
) -> "BaseDatasource":
"""Returns datasource with columns and metrics."""
datasource_class = ConnectorRegistry.sources[datasource_type]
return (
session.query(datasource_class)
.options(
subqueryload(datasource_class.columns),
subqueryload(datasource_class.metrics),
)
.filter_by(id=datasource_id)
.one()
)
@classmethod
def query_datasources_by_name(
cls,
session: Session,
database: "Database",
datasource_name: str,
schema: Optional[str] = None,
) -> List["BaseDatasource"]:
datasource_class = ConnectorRegistry.sources[database.type]
return datasource_class.query_datasources_by_name(
session, database, datasource_name, schema=schema
)

View File

@ -31,6 +31,7 @@ from typing import (
List,
NamedTuple,
Optional,
Set,
Tuple,
Type,
Union,
@ -1990,6 +1991,48 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
query = query.filter_by(schema=schema)
return query.all()
@classmethod
def query_datasources_by_permissions( # pylint: disable=invalid-name
cls,
session: Session,
database: Database,
permissions: Set[str],
schema_perms: Set[str],
) -> List["SqlaTable"]:
# TODO(hughhhh): add unit test
return (
session.query(cls)
.filter_by(database_id=database.id)
.filter(
or_(
SqlaTable.perm.in_(permissions),
SqlaTable.schema_perm.in_(schema_perms),
)
)
.all()
)
@classmethod
def get_eager_sqlatable_datasource(
cls, session: Session, datasource_id: int
) -> "SqlaTable":
"""Returns SqlaTable with columns and metrics."""
return (
session.query(cls)
.options(
sa.orm.subqueryload(cls.columns),
sa.orm.subqueryload(cls.metrics),
)
.filter_by(id=datasource_id)
.one()
)
@classmethod
def get_all_datasources(cls, session: Session) -> List["SqlaTable"]:
qry = session.query(cls)
qry = cls.default_query(qry)
return qry.all()
@staticmethod
def default_query(qry: Query) -> Query:
return qry.filter_by(is_sqllab_view=False)

View File

@ -1,147 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from enum import Enum
from typing import Any, Dict, List, Optional, Set, Type, Union
from flask_babel import _
from sqlalchemy import or_
from sqlalchemy.orm import Session, subqueryload
from sqlalchemy.orm.exc import NoResultFound
from superset.connectors.sqla.models import SqlaTable
from superset.dao.base import BaseDAO
from superset.dao.exceptions import DatasourceNotFound, DatasourceTypeNotSupportedError
from superset.datasets.commands.exceptions import DatasetNotFoundError
from superset.datasets.models import Dataset
from superset.models.core import Database
from superset.models.sql_lab import Query, SavedQuery
from superset.tables.models import Table
from superset.utils.core import DatasourceType
Datasource = Union[Dataset, SqlaTable, Table, Query, SavedQuery]
class DatasourceDAO(BaseDAO):
sources: Dict[DatasourceType, Type[Datasource]] = {
DatasourceType.TABLE: SqlaTable,
DatasourceType.QUERY: Query,
DatasourceType.SAVEDQUERY: SavedQuery,
DatasourceType.DATASET: Dataset,
DatasourceType.SLTABLE: Table,
}
@classmethod
def get_datasource(
cls, session: Session, datasource_type: DatasourceType, datasource_id: int
) -> Datasource:
if datasource_type not in cls.sources:
raise DatasourceTypeNotSupportedError()
datasource = (
session.query(cls.sources[datasource_type])
.filter_by(id=datasource_id)
.one_or_none()
)
if not datasource:
raise DatasourceNotFound()
return datasource
@classmethod
def get_all_sqlatables_datasources(cls, session: Session) -> List[Datasource]:
source_class = DatasourceDAO.sources[DatasourceType.TABLE]
qry = session.query(source_class)
qry = source_class.default_query(qry)
return qry.all()
@classmethod
def get_datasource_by_name( # pylint: disable=too-many-arguments
cls,
session: Session,
datasource_type: DatasourceType,
datasource_name: str,
database_name: str,
schema: str,
) -> Optional[Datasource]:
datasource_class = DatasourceDAO.sources[datasource_type]
if isinstance(datasource_class, SqlaTable):
return datasource_class.get_datasource_by_name(
session, datasource_name, schema, database_name
)
return None
@classmethod
def query_datasources_by_permissions( # pylint: disable=invalid-name
cls,
session: Session,
database: Database,
permissions: Set[str],
schema_perms: Set[str],
) -> List[Datasource]:
# TODO(hughhhh): add unit test
datasource_class = DatasourceDAO.sources[DatasourceType[database.type]]
if not isinstance(datasource_class, SqlaTable):
return []
return (
session.query(datasource_class)
.filter_by(database_id=database.id)
.filter(
or_(
datasource_class.perm.in_(permissions),
datasource_class.schema_perm.in_(schema_perms),
)
)
.all()
)
@classmethod
def get_eager_datasource(
cls, session: Session, datasource_type: str, datasource_id: int
) -> Optional[Datasource]:
"""Returns datasource with columns and metrics."""
datasource_class = DatasourceDAO.sources[DatasourceType[datasource_type]]
if not isinstance(datasource_class, SqlaTable):
return None
return (
session.query(datasource_class)
.options(
subqueryload(datasource_class.columns),
subqueryload(datasource_class.metrics),
)
.filter_by(id=datasource_id)
.one()
)
@classmethod
def query_datasources_by_name(
cls,
session: Session,
database: Database,
datasource_name: str,
schema: Optional[str] = None,
) -> List[Datasource]:
datasource_class = DatasourceDAO.sources[DatasourceType[database.type]]
if not isinstance(datasource_class, SqlaTable):
return []
return datasource_class.query_datasources_by_name(
session, database, datasource_name, schema=schema
)

View File

@ -60,6 +60,7 @@ class DatasourceTypeNotSupportedError(DAOException):
DAO datasource query source type is not supported
"""
status = 422
message = "DAO datasource query source type is not supported"

View File

@ -24,7 +24,7 @@ from typing import Any, Dict, Optional
from flask_babel import lazy_gettext as _
from sqlalchemy.orm import make_transient, Session
from superset import ConnectorRegistry, db
from superset import db
from superset.commands.base import BaseCommand
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.datasets.commands.importers.v0 import import_dataset
@ -63,12 +63,11 @@ def import_chart(
slc_to_import = slc_to_import.copy()
slc_to_import.reset_ownership()
params = slc_to_import.params_dict
datasource = ConnectorRegistry.get_datasource_by_name(
session,
slc_to_import.datasource_type,
params["datasource_name"],
params["schema"],
params["database_name"],
datasource = SqlaTable.get_datasource_by_name(
session=session,
datasource_name=params["datasource_name"],
database_name=params["database_name"],
schema=params["schema"],
)
slc_to_import.datasource_id = datasource.id # type: ignore
if slc_to_override:

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,62 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Dict, Type, Union
from sqlalchemy.orm import Session
from superset.connectors.sqla.models import SqlaTable
from superset.dao.base import BaseDAO
from superset.dao.exceptions import DatasourceNotFound, DatasourceTypeNotSupportedError
from superset.datasets.models import Dataset
from superset.models.sql_lab import Query, SavedQuery
from superset.tables.models import Table
from superset.utils.core import DatasourceType
Datasource = Union[Dataset, SqlaTable, Table, Query, SavedQuery]
class DatasourceDAO(BaseDAO):
sources: Dict[Union[DatasourceType, str], Type[Datasource]] = {
DatasourceType.TABLE: SqlaTable,
DatasourceType.QUERY: Query,
DatasourceType.SAVEDQUERY: SavedQuery,
DatasourceType.DATASET: Dataset,
DatasourceType.SLTABLE: Table,
}
@classmethod
def get_datasource(
cls,
session: Session,
datasource_type: Union[DatasourceType, str],
datasource_id: int,
) -> Datasource:
if datasource_type not in cls.sources:
raise DatasourceTypeNotSupportedError()
datasource = (
session.query(cls.sources[datasource_type])
.filter_by(id=datasource_id)
.one_or_none()
)
if not datasource:
raise DatasourceNotFound()
return datasource

View File

@ -23,7 +23,7 @@ from typing import Any, Dict, List, Set
from urllib import request
from superset import app, db
from superset.connectors.connector_registry import ConnectorRegistry
from superset.connectors.sqla.models import SqlaTable
from superset.models.slice import Slice
BASE_URL = "https://github.com/apache-superset/examples-data/blob/master/"
@ -32,7 +32,7 @@ misc_dash_slices: Set[str] = set() # slices assembled in a 'Misc Chart' dashboa
def get_table_connector_registry() -> Any:
return ConnectorRegistry.sources["table"]
return SqlaTable
def get_examples_folder() -> str:

View File

@ -27,6 +27,7 @@ from superset.extensions import cache_manager
from superset.key_value.utils import get_owner, random_key
from superset.temporary_cache.commands.exceptions import TemporaryCacheCreateFailedError
from superset.temporary_cache.utils import cache_key
from superset.utils.core import DatasourceType
from superset.utils.schema import validate_json
logger = logging.getLogger(__name__)
@ -56,7 +57,7 @@ class CreateFormDataCommand(BaseCommand):
state: TemporaryExploreState = {
"owner": get_owner(actor),
"datasource_id": datasource_id,
"datasource_type": datasource_type,
"datasource_type": DatasourceType(datasource_type),
"chart_id": chart_id,
"form_data": form_data,
}

View File

@ -18,10 +18,12 @@ from typing import Optional
from typing_extensions import TypedDict
from superset.utils.core import DatasourceType
class TemporaryExploreState(TypedDict):
owner: Optional[int]
datasource_id: int
datasource_type: str
datasource_type: DatasourceType
chart_id: Optional[int]
form_data: str

View File

@ -32,6 +32,7 @@ from superset.temporary_cache.commands.exceptions import (
TemporaryCacheUpdateFailedError,
)
from superset.temporary_cache.utils import cache_key
from superset.utils.core import DatasourceType
from superset.utils.schema import validate_json
logger = logging.getLogger(__name__)
@ -75,7 +76,7 @@ class UpdateFormDataCommand(BaseCommand, ABC):
new_state: TemporaryExploreState = {
"owner": owner,
"datasource_id": datasource_id,
"datasource_type": datasource_type,
"datasource_type": DatasourceType(datasource_type),
"chart_id": chart_id,
"form_data": form_data,
}

View File

@ -28,7 +28,6 @@ from flask_babel import gettext as __, lazy_gettext as _
from flask_compress import Compress
from werkzeug.middleware.proxy_fix import ProxyFix
from superset.connectors.connector_registry import ConnectorRegistry
from superset.constants import CHANGE_ME_SECRET_KEY
from superset.extensions import (
_event_logger,
@ -473,7 +472,11 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
# Registering sources
module_datasource_map = self.config["DEFAULT_MODULE_DS_MAP"]
module_datasource_map.update(self.config["ADDITIONAL_MODULE_DS_MAP"])
ConnectorRegistry.register_sources(module_datasource_map)
# todo(hughhhh): fully remove the datasource config register
for module_name, class_names in module_datasource_map.items():
class_names = [str(s) for s in class_names]
__import__(module_name, fromlist=class_names)
def configure_cache(self) -> None:
cache_manager.init_app(self.superset_app)

View File

@ -46,10 +46,11 @@ from sqlalchemy.orm.session import object_session
from sqlalchemy.sql import join, select
from sqlalchemy.sql.elements import BinaryExpression
from superset import app, ConnectorRegistry, db, is_feature_enabled, security_manager
from superset import app, db, is_feature_enabled, security_manager
from superset.common.request_contexed_based import is_user_admin
from superset.connectors.base.models import BaseDatasource
from superset.connectors.sqla.models import SqlMetric, TableColumn
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.datasource.dao import DatasourceDAO
from superset.extensions import cache_manager
from superset.models.filter_set import FilterSet
from superset.models.helpers import AuditMixinNullable, ImportExportMixin
@ -407,16 +408,18 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin):
id_ = target.get("datasetId")
if id_ is None:
continue
datasource = ConnectorRegistry.get_datasource_by_id(session, id_)
datasource = DatasourceDAO.get_datasource(
session, utils.DatasourceType.TABLE, id_
)
datasource_ids.add((datasource.id, datasource.type))
copied_dashboard.alter_params(remote_id=dashboard_id)
copied_dashboards.append(copied_dashboard)
eager_datasources = []
for datasource_id, datasource_type in datasource_ids:
eager_datasource = ConnectorRegistry.get_eager_datasource(
db.session, datasource_type, datasource_id
for datasource_id, _ in datasource_ids:
eager_datasource = SqlaTable.get_eager_sqlatable_datasource(
db.session, datasource_id
)
copied_datasource = eager_datasource.copy()
copied_datasource.alter_params(

View File

@ -21,7 +21,6 @@ from flask_appbuilder import Model
from sqlalchemy import Column, Integer, String
from superset import app, db, security_manager
from superset.connectors.connector_registry import ConnectorRegistry
from superset.models.helpers import AuditMixinNullable
from superset.utils.memoized import memoized
@ -44,7 +43,10 @@ class DatasourceAccessRequest(Model, AuditMixinNullable):
@property
def cls_model(self) -> Type["BaseDatasource"]:
return ConnectorRegistry.sources[self.datasource_type]
# pylint: disable=import-outside-toplevel
from superset.datasource.dao import DatasourceDAO
return DatasourceDAO.sources[self.datasource_type]
@property
def username(self) -> Markup:

View File

@ -39,7 +39,7 @@ from sqlalchemy.engine.base import Connection
from sqlalchemy.orm import relationship
from sqlalchemy.orm.mapper import Mapper
from superset import ConnectorRegistry, db, is_feature_enabled, security_manager
from superset import db, is_feature_enabled, security_manager
from superset.legacy import update_time_range
from superset.models.helpers import AuditMixinNullable, ImportExportMixin
from superset.models.tags import ChartUpdater
@ -126,7 +126,10 @@ class Slice( # pylint: disable=too-many-public-methods
@property
def cls_model(self) -> Type["BaseDatasource"]:
return ConnectorRegistry.sources[self.datasource_type]
# pylint: disable=import-outside-toplevel
from superset.datasource.dao import DatasourceDAO
return DatasourceDAO.sources[self.datasource_type]
@property
def datasource(self) -> Optional["BaseDatasource"]:

View File

@ -61,7 +61,6 @@ from sqlalchemy.orm.mapper import Mapper
from sqlalchemy.orm.query import Query as SqlaQuery
from superset import sql_parse
from superset.connectors.connector_registry import ConnectorRegistry
from superset.constants import RouteMethod
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import (
@ -471,23 +470,25 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
user_perms = self.user_view_menu_names("datasource_access")
schema_perms = self.user_view_menu_names("schema_access")
user_datasources = set()
for datasource_class in ConnectorRegistry.sources.values():
user_datasources.update(
self.get_session.query(datasource_class)
.filter(
or_(
datasource_class.perm.in_(user_perms),
datasource_class.schema_perm.in_(schema_perms),
)
# pylint: disable=import-outside-toplevel
from superset.connectors.sqla.models import SqlaTable
user_datasources.update(
self.get_session.query(SqlaTable)
.filter(
or_(
SqlaTable.perm.in_(user_perms),
SqlaTable.schema_perm.in_(schema_perms),
)
.all()
)
.all()
)
# group all datasources by database
all_datasources = ConnectorRegistry.get_all_datasources(self.get_session)
datasources_by_database: Dict["Database", Set["BaseDatasource"]] = defaultdict(
set
)
session = self.get_session
all_datasources = SqlaTable.get_all_datasources(session)
datasources_by_database: Dict["Database", Set["SqlaTable"]] = defaultdict(set)
for datasource in all_datasources:
datasources_by_database[datasource.database].add(datasource)
@ -599,6 +600,8 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
:param schema: The fallback SQL schema if not present in the table name
:returns: The list of accessible SQL tables w/ schema
"""
# pylint: disable=import-outside-toplevel
from superset.connectors.sqla.models import SqlaTable
if self.can_access_database(database):
return datasource_names
@ -610,7 +613,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
user_perms = self.user_view_menu_names("datasource_access")
schema_perms = self.user_view_menu_names("schema_access")
user_datasources = ConnectorRegistry.query_datasources_by_permissions(
user_datasources = SqlaTable.query_datasources_by_permissions(
self.get_session, database, user_perms, schema_perms
)
if schema:
@ -660,6 +663,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
"""
# pylint: disable=import-outside-toplevel
from superset.connectors.sqla.models import SqlaTable
from superset.models import core as models
logger.info("Fetching a set of all perms to lookup which ones are missing")
@ -668,13 +672,13 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
if pv.permission and pv.view_menu:
all_pvs.add((pv.permission.name, pv.view_menu.name))
def merge_pv(view_menu: str, perm: str) -> None:
def merge_pv(view_menu: str, perm: Optional[str]) -> None:
"""Create permission view menu only if it doesn't exist"""
if view_menu and perm and (view_menu, perm) not in all_pvs:
self.add_permission_view_menu(view_menu, perm)
logger.info("Creating missing datasource permissions.")
datasources = ConnectorRegistry.get_all_datasources(self.get_session)
datasources = SqlaTable.get_all_datasources(self.get_session)
for datasource in datasources:
merge_pv("datasource_access", datasource.get_perm())
merge_pv("schema_access", datasource.get_schema_perm())

View File

@ -61,7 +61,6 @@ from superset.charts.dao import ChartDAO
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.common.db_query_status import QueryStatus
from superset.connectors.base.models import BaseDatasource
from superset.connectors.connector_registry import ConnectorRegistry
from superset.connectors.sqla.models import (
AnnotationDatasource,
SqlaTable,
@ -77,6 +76,7 @@ from superset.databases.dao import DatabaseDAO
from superset.databases.filters import DatabaseFilter
from superset.databases.utils import make_url_safe
from superset.datasets.commands.exceptions import DatasetNotFoundError
from superset.datasource.dao import DatasourceDAO
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import (
CacheLoadError,
@ -129,7 +129,11 @@ from superset.tasks.async_queries import load_explore_json_into_cache
from superset.utils import core as utils, csv
from superset.utils.async_query_manager import AsyncQueryTokenException
from superset.utils.cache import etag_cache
from superset.utils.core import apply_max_row_limit, ReservedUrlParameters
from superset.utils.core import (
apply_max_row_limit,
DatasourceType,
ReservedUrlParameters,
)
from superset.utils.dates import now_as_float
from superset.utils.decorators import check_dashboard_access
from superset.views.base import (
@ -250,7 +254,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
)
db_ds_names.add(fullname)
existing_datasources = ConnectorRegistry.get_all_datasources(db.session)
existing_datasources = SqlaTable.get_all_datasources(db.session)
datasources = [d for d in existing_datasources if d.full_name in db_ds_names]
role = security_manager.find_role(role_name)
# remove all permissions
@ -282,7 +286,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
datasource_id = request.args.get("datasource_id")
datasource_type = request.args.get("datasource_type")
if datasource_id and datasource_type:
ds_class = ConnectorRegistry.sources.get(datasource_type)
ds_class = DatasourceDAO.sources.get(datasource_type)
datasource = (
db.session.query(ds_class).filter_by(id=int(datasource_id)).one()
)
@ -319,10 +323,8 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
def approve(self) -> FlaskResponse: # pylint: disable=too-many-locals,no-self-use
def clean_fulfilled_requests(session: Session) -> None:
for dar in session.query(DAR).all():
datasource = ConnectorRegistry.get_datasource(
dar.datasource_type,
dar.datasource_id,
session,
datasource = DatasourceDAO.get_datasource(
session, DatasourceType(dar.datasource_type), dar.datasource_id
)
if not datasource or security_manager.can_access_datasource(datasource):
# Dataset does not exist anymore
@ -336,8 +338,8 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
role_to_extend = request.args.get("role_to_extend")
session = db.session
datasource = ConnectorRegistry.get_datasource(
datasource_type, datasource_id, session
datasource = DatasourceDAO.get_datasource(
session, DatasourceType(datasource_type), int(datasource_id)
)
if not datasource:
@ -639,7 +641,6 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
datasource_id, datasource_type = get_datasource_info(
datasource_id, datasource_type, form_data
)
force = request.args.get("force") == "true"
# TODO: support CSV, SQL query and other non-JSON types
@ -809,8 +810,10 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
datasource: Optional[BaseDatasource] = None
if datasource_id is not None:
try:
datasource = ConnectorRegistry.get_datasource(
cast(str, datasource_type), datasource_id, db.session
datasource = DatasourceDAO.get_datasource(
db.session,
DatasourceType(cast(str, datasource_type)),
datasource_id,
)
except DatasetNotFoundError:
pass
@ -948,10 +951,8 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
:raises SupersetSecurityException: If the user cannot access the resource
"""
# TODO: Cache endpoint by user, datasource and column
datasource = ConnectorRegistry.get_datasource(
datasource_type,
datasource_id,
db.session,
datasource = DatasourceDAO.get_datasource(
db.session, DatasourceType(datasource_type), datasource_id
)
if not datasource:
return json_error_response(DATASOURCE_MISSING_ERR)
@ -1920,8 +1921,8 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
if config["ENABLE_ACCESS_REQUEST"]:
for datasource in dashboard.datasources:
datasource = ConnectorRegistry.get_datasource(
datasource_type=datasource.type,
datasource = DatasourceDAO.get_datasource(
datasource_type=DatasourceType(datasource.type),
datasource_id=datasource.id,
session=db.session(),
)
@ -2537,10 +2538,8 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
"""
datasource_id, datasource_type = request.args["datasourceKey"].split("__")
datasource = ConnectorRegistry.get_datasource(
datasource_type,
datasource_id,
db.session,
datasource = DatasourceDAO.get_datasource(
db.session, DatasourceType(datasource_type), int(datasource_id)
)
# Check if datasource exists
if not datasource:

View File

@ -29,16 +29,18 @@ from sqlalchemy.orm.exc import NoResultFound
from superset import db, event_logger
from superset.commands.utils import populate_owners
from superset.connectors.connector_registry import ConnectorRegistry
from superset.connectors.sqla.models import SqlaTable
from superset.connectors.sqla.utils import get_physical_table_metadata
from superset.datasets.commands.exceptions import (
DatasetForbiddenError,
DatasetNotFoundError,
)
from superset.datasource.dao import DatasourceDAO
from superset.exceptions import SupersetException, SupersetSecurityException
from superset.extensions import security_manager
from superset.models.core import Database
from superset.superset_typing import FlaskResponse
from superset.utils.core import DatasourceType
from superset.views.base import (
api,
BaseSupersetView,
@ -74,8 +76,8 @@ class Datasource(BaseSupersetView):
datasource_id = datasource_dict.get("id")
datasource_type = datasource_dict.get("type")
database_id = datasource_dict["database"].get("id")
orm_datasource = ConnectorRegistry.get_datasource(
datasource_type, datasource_id, db.session
orm_datasource = DatasourceDAO.get_datasource(
db.session, DatasourceType(datasource_type), datasource_id
)
orm_datasource.database_id = database_id
@ -117,8 +119,8 @@ class Datasource(BaseSupersetView):
@api
@handle_api_exception
def get(self, datasource_type: str, datasource_id: int) -> FlaskResponse:
datasource = ConnectorRegistry.get_datasource(
datasource_type, datasource_id, db.session
datasource = DatasourceDAO.get_datasource(
db.session, DatasourceType(datasource_type), datasource_id
)
return self.json_response(sanitize_datasource_data(datasource.data))
@ -130,8 +132,10 @@ class Datasource(BaseSupersetView):
self, datasource_type: str, datasource_id: int
) -> FlaskResponse:
"""Gets column info from the source system"""
datasource = ConnectorRegistry.get_datasource(
datasource_type, datasource_id, db.session
datasource = DatasourceDAO.get_datasource(
db.session,
DatasourceType(datasource_type),
datasource_id,
)
try:
external_metadata = datasource.external_metadata()
@ -153,9 +157,8 @@ class Datasource(BaseSupersetView):
except ValidationError as err:
return json_error_response(str(err), status=400)
datasource = ConnectorRegistry.get_datasource_by_name(
datasource = SqlaTable.get_datasource_by_name(
session=db.session,
datasource_type=params["datasource_type"],
database_name=params["database_name"],
schema=params["schema_name"],
datasource_name=params["table_name"],

View File

@ -32,7 +32,7 @@ from sqlalchemy.orm.exc import NoResultFound
import superset.models.core as models
from superset import app, dataframe, db, result_set, viz
from superset.common.db_query_status import QueryStatus
from superset.connectors.connector_registry import ConnectorRegistry
from superset.datasource.dao import DatasourceDAO
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import (
CacheLoadError,
@ -47,6 +47,7 @@ from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.models.sql_lab import Query
from superset.superset_typing import FormData
from superset.utils.core import DatasourceType
from superset.utils.decorators import stats_timing
from superset.viz import BaseViz
@ -127,8 +128,10 @@ def get_viz(
force_cached: bool = False,
) -> BaseViz:
viz_type = form_data.get("viz_type", "table")
datasource = ConnectorRegistry.get_datasource(
datasource_type, datasource_id, db.session
datasource = DatasourceDAO.get_datasource(
db.session,
DatasourceType(datasource_type),
datasource_id,
)
viz_obj = viz.viz_types[viz_type](
datasource, form_data=form_data, force=force, force_cached=force_cached

View File

@ -39,7 +39,6 @@ from tests.integration_tests.fixtures.energy_dashboard import (
)
from tests.integration_tests.test_app import app # isort:skip
from superset import db, security_manager
from superset.connectors.connector_registry import ConnectorRegistry
from superset.connectors.sqla.models import SqlaTable
from superset.models import core as models
from superset.models.datasource_access_request import DatasourceAccessRequest
@ -90,12 +89,12 @@ SCHEMA_ACCESS_ROLE = "schema_access_role"
def create_access_request(session, ds_type, ds_name, role_name, username):
ds_class = ConnectorRegistry.sources[ds_type]
# TODO: generalize datasource names
if ds_type == "table":
ds = session.query(ds_class).filter(ds_class.table_name == ds_name).first()
ds = session.query(SqlaTable).filter(SqlaTable.table_name == ds_name).first()
else:
ds = session.query(ds_class).filter(ds_class.datasource_name == ds_name).first()
# This function will only work for ds_type == "table"
raise NotImplementedError()
ds_perm_view = security_manager.find_permission_view_menu(
"datasource_access", ds.perm
)
@ -449,49 +448,6 @@ class TestRequestAccess(SupersetTestCase):
TEST_ROLE = security_manager.find_role(TEST_ROLE_NAME)
self.assertIn(perm_view, TEST_ROLE.permissions)
# Case 3. Grant new role to the user to access the druid datasource.
security_manager.add_role("druid_role")
access_request3 = create_access_request(
session, "druid", "druid_ds_1", "druid_role", "gamma"
)
self.get_resp(
GRANT_ROLE_REQUEST.format(
"druid", access_request3.datasource_id, "gamma", "druid_role"
)
)
# user was granted table_role
user_roles = [r.name for r in security_manager.find_user("gamma").roles]
self.assertIn("druid_role", user_roles)
# Case 4. Extend the role to have access to the druid datasource
access_request4 = create_access_request(
session, "druid", "druid_ds_2", "druid_role", "gamma"
)
druid_ds_2_perm = access_request4.datasource.perm
self.client.get(
EXTEND_ROLE_REQUEST.format(
"druid", access_request4.datasource_id, "gamma", "druid_role"
)
)
# druid_role was extended to grant access to the druid_access_ds_2
druid_role = security_manager.find_role("druid_role")
perm_view = security_manager.find_permission_view_menu(
"datasource_access", druid_ds_2_perm
)
self.assertIn(perm_view, druid_role.permissions)
# cleanup
gamma_user = security_manager.find_user(username="gamma")
gamma_user.roles.remove(security_manager.find_role("druid_role"))
gamma_user.roles.remove(security_manager.find_role(TEST_ROLE_NAME))
session.delete(security_manager.find_role("druid_role"))
session.delete(security_manager.find_role(TEST_ROLE_NAME))
session.commit()
def test_request_access(self):
if app.config["ENABLE_ACCESS_REQUEST"]:
session = db.session

View File

@ -21,7 +21,7 @@ from typing import Any, Dict, List, Optional
from pandas import DataFrame
from superset import ConnectorRegistry, db
from superset import db
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import Database
from superset.models.dashboard import Dashboard
@ -35,9 +35,8 @@ def get_table(
schema: Optional[str] = None,
):
schema = schema or get_example_default_schema()
table_source = ConnectorRegistry.sources["table"]
return (
db.session.query(table_source)
db.session.query(SqlaTable)
.filter_by(database_id=database.id, schema=schema, table_name=table_name)
.one_or_none()
)
@ -54,8 +53,7 @@ def create_table_metadata(
table = get_table(table_name, database, schema)
if not table:
table_source = ConnectorRegistry.sources["table"]
table = table_source(schema=schema, table_name=table_name)
table = SqlaTable(schema=schema, table_name=table_name)
if fetch_values_predicate:
table.fetch_values_predicate = fetch_values_predicate
table.database = database

View File

@ -22,12 +22,13 @@ from unittest import mock
import prison
import pytest
from superset import app, ConnectorRegistry, db
from superset import app, db
from superset.connectors.sqla.models import SqlaTable
from superset.dao.exceptions import DatasourceNotFound, DatasourceTypeNotSupportedError
from superset.datasets.commands.exceptions import DatasetNotFoundError
from superset.exceptions import SupersetGenericDBErrorException
from superset.models.core import Database
from superset.utils.core import get_example_default_schema
from superset.utils.core import DatasourceType, get_example_default_schema
from superset.utils.database import get_example_database
from tests.integration_tests.base_tests import db_insert_temp_object, SupersetTestCase
from tests.integration_tests.fixtures.birth_names_dashboard import (
@ -256,9 +257,10 @@ class TestDatasource(SupersetTestCase):
pytest.raises(
SupersetGenericDBErrorException,
lambda: ConnectorRegistry.get_datasource(
"table", tbl.id, db.session
).external_metadata(),
lambda: db.session.query(SqlaTable)
.filter_by(id=tbl.id)
.one_or_none()
.external_metadata(),
)
resp = self.client.get(url)
@ -385,21 +387,30 @@ class TestDatasource(SupersetTestCase):
app.config["DATASET_HEALTH_CHECK"] = my_check
self.login(username="admin")
tbl = self.get_table(name="birth_names")
datasource = ConnectorRegistry.get_datasource("table", tbl.id, db.session)
datasource = db.session.query(SqlaTable).filter_by(id=tbl.id).one_or_none()
assert datasource.health_check_message == "Warning message!"
app.config["DATASET_HEALTH_CHECK"] = None
def test_get_datasource_failed(self):
from superset.datasource.dao import DatasourceDAO
pytest.raises(
DatasetNotFoundError,
lambda: ConnectorRegistry.get_datasource("table", 9999999, db.session),
DatasourceNotFound,
lambda: DatasourceDAO.get_datasource(db.session, "table", 9999999),
)
self.login(username="admin")
resp = self.get_json_resp("/datasource/get/table/500000/", raise_on_error=False)
self.assertEqual(resp.get("error"), "Datasource does not exist")
def test_get_datasource_invalid_datasource_failed(self):
from superset.datasource.dao import DatasourceDAO
pytest.raises(
DatasourceTypeNotSupportedError,
lambda: DatasourceDAO.get_datasource(db.session, "druid", 9999999),
)
self.login(username="admin")
resp = self.get_json_resp("/datasource/get/druid/500000/", raise_on_error=False)
self.assertEqual(resp.get("error"), "Dataset does not exist")
resp = self.get_json_resp(
"/datasource/get/invalid-datasource-type/500000/", raise_on_error=False
)
self.assertEqual(resp.get("error"), "Dataset does not exist")
self.assertEqual(resp.get("error"), "'druid' is not a valid DatasourceType")

View File

@ -26,6 +26,7 @@ from superset.datasets.commands.exceptions import DatasetAccessDeniedError
from superset.explore.form_data.commands.state import TemporaryExploreState
from superset.extensions import cache_manager
from superset.models.slice import Slice
from superset.utils.core import DatasourceType
from tests.integration_tests.base_tests import login
from tests.integration_tests.fixtures.client import client
from tests.integration_tests.fixtures.world_bank_dashboard import (
@ -392,7 +393,7 @@ def test_delete_not_owner(client, chart_id: int, datasource: SqlaTable, admin_id
entry: TemporaryExploreState = {
"owner": another_owner,
"datasource_id": datasource.id,
"datasource_type": datasource.type,
"datasource_type": DatasourceType(datasource.type),
"chart_id": chart_id,
"form_data": INITIAL_FORM_DATA,
}

View File

@ -18,7 +18,7 @@ from typing import Callable, List, Optional
import pytest
from superset import ConnectorRegistry, db
from superset import db
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import Database
from superset.models.dashboard import Dashboard
@ -95,14 +95,11 @@ def _create_table(
def _cleanup(dash_id: int, slices_ids: List[int]) -> None:
schema = get_example_default_schema()
table_id = (
datasource = (
db.session.query(SqlaTable)
.filter_by(table_name="birth_names", schema=schema)
.one()
.id
)
datasource = ConnectorRegistry.get_datasource("table", table_id, db.session)
columns = [column for column in datasource.columns]
metrics = [metric for metric in datasource.metrics]

View File

@ -82,7 +82,6 @@ def _create_energy_table():
table.metrics.append(
SqlMetric(metric_name="sum__value", expression=f"SUM({col})")
)
db.session.merge(table)
db.session.commit()
table.fetch_metadata()

View File

@ -16,7 +16,8 @@
# under the License.
from typing import List, Optional
from superset import ConnectorRegistry, db, security_manager
from superset import db, security_manager
from superset.connectors.sqla.models import SqlaTable
from superset.models.slice import Slice
@ -43,8 +44,8 @@ class InsertChartMixin:
for owner in owners:
user = db.session.query(security_manager.user_model).get(owner)
obj_owners.append(user)
datasource = ConnectorRegistry.get_datasource(
datasource_type, datasource_id, db.session
datasource = (
db.session.query(SqlaTable).filter_by(id=datasource_id).one_or_none()
)
slice = Slice(
cache_timeout=cache_timeout,

View File

@ -26,10 +26,15 @@ from superset.charts.schemas import ChartDataQueryContextSchema
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.common.query_context import QueryContext
from superset.common.query_object import QueryObject
from superset.connectors.connector_registry import ConnectorRegistry
from superset.connectors.sqla.models import SqlMetric
from superset.datasource.dao import DatasourceDAO
from superset.extensions import cache_manager
from superset.utils.core import AdhocMetricExpressionType, backend, QueryStatus
from superset.utils.core import (
AdhocMetricExpressionType,
backend,
DatasourceType,
QueryStatus,
)
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
@ -132,10 +137,10 @@ class TestQueryContext(SupersetTestCase):
cache_key_original = query_context.query_cache_key(query_object)
# make temporary change and revert it to refresh the changed_on property
datasource = ConnectorRegistry.get_datasource(
datasource_type=payload["datasource"]["type"],
datasource_id=payload["datasource"]["id"],
datasource = DatasourceDAO.get_datasource(
session=db.session,
datasource_type=DatasourceType(payload["datasource"]["type"]),
datasource_id=payload["datasource"]["id"],
)
description_original = datasource.description
datasource.description = "temporary description"
@ -156,10 +161,10 @@ class TestQueryContext(SupersetTestCase):
payload = get_query_context("birth_names")
# make temporary change and revert it to refresh the changed_on property
datasource = ConnectorRegistry.get_datasource(
datasource_type=payload["datasource"]["type"],
datasource_id=payload["datasource"]["id"],
datasource = DatasourceDAO.get_datasource(
session=db.session,
datasource_type=DatasourceType(payload["datasource"]["type"]),
datasource_id=payload["datasource"]["id"],
)
datasource.metrics.append(SqlMetric(metric_name="foo", expression="select 1;"))

View File

@ -28,10 +28,11 @@ import prison
import pytest
from flask import current_app
from superset.datasource.dao import DatasourceDAO
from superset.models.dashboard import Dashboard
from superset import app, appbuilder, db, security_manager, viz, ConnectorRegistry
from superset import app, appbuilder, db, security_manager, viz
from superset.connectors.sqla.models import SqlaTable
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetSecurityException
@ -990,7 +991,7 @@ class TestDatasources(SupersetTestCase):
mock_get_session.query.return_value.filter.return_value.all.return_value = []
with mock.patch.object(
ConnectorRegistry, "get_all_datasources"
SqlaTable, "get_all_datasources"
) as mock_get_all_datasources:
mock_get_all_datasources.return_value = [
Datasource("database1", "schema1", "table1"),
@ -1018,7 +1019,7 @@ class TestDatasources(SupersetTestCase):
mock_get_session.query.return_value.filter.return_value.all.return_value = []
with mock.patch.object(
ConnectorRegistry, "get_all_datasources"
SqlaTable, "get_all_datasources"
) as mock_get_all_datasources:
mock_get_all_datasources.return_value = [
Datasource("database1", "schema1", "table1"),
@ -1046,7 +1047,7 @@ class TestDatasources(SupersetTestCase):
]
with mock.patch.object(
ConnectorRegistry, "get_all_datasources"
SqlaTable, "get_all_datasources"
) as mock_get_all_datasources:
mock_get_all_datasources.return_value = [
Datasource("database1", "schema1", "table1"),

View File

@ -103,7 +103,7 @@ def test_get_datasource_sqlatable(
app_context: None, session_with_data: Session
) -> None:
from superset.connectors.sqla.models import SqlaTable
from superset.dao.datasource.dao import DatasourceDAO
from superset.datasource.dao import DatasourceDAO
result = DatasourceDAO.get_datasource(
datasource_type=DatasourceType.TABLE,
@ -117,7 +117,7 @@ def test_get_datasource_sqlatable(
def test_get_datasource_query(app_context: None, session_with_data: Session) -> None:
from superset.dao.datasource.dao import DatasourceDAO
from superset.datasource.dao import DatasourceDAO
from superset.models.sql_lab import Query
result = DatasourceDAO.get_datasource(
@ -131,7 +131,7 @@ def test_get_datasource_query(app_context: None, session_with_data: Session) ->
def test_get_datasource_saved_query(
app_context: None, session_with_data: Session
) -> None:
from superset.dao.datasource.dao import DatasourceDAO
from superset.datasource.dao import DatasourceDAO
from superset.models.sql_lab import SavedQuery
result = DatasourceDAO.get_datasource(
@ -145,7 +145,7 @@ def test_get_datasource_saved_query(
def test_get_datasource_sl_table(app_context: None, session_with_data: Session) -> None:
from superset.dao.datasource.dao import DatasourceDAO
from superset.datasource.dao import DatasourceDAO
from superset.tables.models import Table
# todo(hugh): This will break once we remove the dual write
@ -163,8 +163,8 @@ def test_get_datasource_sl_table(app_context: None, session_with_data: Session)
def test_get_datasource_sl_dataset(
app_context: None, session_with_data: Session
) -> None:
from superset.dao.datasource.dao import DatasourceDAO
from superset.datasets.models import Dataset
from superset.datasource.dao import DatasourceDAO
# todo(hugh): This will break once we remove the dual write
# update the datsource_id=1 and this will pass again
@ -178,10 +178,35 @@ def test_get_datasource_sl_dataset(
assert isinstance(result, Dataset)
def test_get_all_sqlatables_datasources(
def test_get_datasource_w_str_param(
app_context: None, session_with_data: Session
) -> None:
from superset.dao.datasource.dao import DatasourceDAO
from superset.connectors.sqla.models import SqlaTable
from superset.datasets.models import Dataset
from superset.datasource.dao import DatasourceDAO
from superset.tables.models import Table
result = DatasourceDAO.get_all_sqlatables_datasources(session=session_with_data)
assert isinstance(
DatasourceDAO.get_datasource(
datasource_type="table",
datasource_id=1,
session=session_with_data,
),
SqlaTable,
)
assert isinstance(
DatasourceDAO.get_datasource(
datasource_type="sl_table",
datasource_id=1,
session=session_with_data,
),
Table,
)
def test_get_all_datasources(app_context: None, session_with_data: Session) -> None:
from superset.connectors.sqla.models import SqlaTable
result = SqlaTable.get_all_datasources(session=session_with_data)
assert len(result) == 1