From 21c5b26fc819aa6531b17d6fc83cc3cc849389a8 Mon Sep 17 00:00:00 2001 From: "Hugh A. Miles II" Date: Fri, 13 May 2022 12:28:57 -0400 Subject: [PATCH] feat(sip-68): Add DatasourceDAO class to manage querying different datasources easier (#20030) * restart * update with enums * address concerns * remove any --- superset/dao/datasource/dao.py | 147 +++++++++++++++++++ superset/dao/exceptions.py | 12 ++ tests/unit_tests/dao/datasource_test.py | 185 ++++++++++++++++++++++++ 3 files changed, 344 insertions(+) create mode 100644 superset/dao/datasource/dao.py create mode 100644 tests/unit_tests/dao/datasource_test.py diff --git a/superset/dao/datasource/dao.py b/superset/dao/datasource/dao.py new file mode 100644 index 0000000000..8b4845db3c --- /dev/null +++ b/superset/dao/datasource/dao.py @@ -0,0 +1,147 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from enum import Enum +from typing import Any, Dict, List, Optional, Set, Type, Union + +from flask_babel import _ +from sqlalchemy import or_ +from sqlalchemy.orm import Session, subqueryload +from sqlalchemy.orm.exc import NoResultFound + +from superset.connectors.sqla.models import SqlaTable +from superset.dao.base import BaseDAO +from superset.dao.exceptions import DatasourceNotFound, DatasourceTypeNotSupportedError +from superset.datasets.commands.exceptions import DatasetNotFoundError +from superset.datasets.models import Dataset +from superset.models.core import Database +from superset.models.sql_lab import Query, SavedQuery +from superset.tables.models import Table +from superset.utils.core import DatasourceType + +Datasource = Union[Dataset, SqlaTable, Table, Query, SavedQuery] + + +class DatasourceDAO(BaseDAO): + + sources: Dict[DatasourceType, Type[Datasource]] = { + DatasourceType.SQLATABLE: SqlaTable, + DatasourceType.QUERY: Query, + DatasourceType.SAVEDQUERY: SavedQuery, + DatasourceType.DATASET: Dataset, + DatasourceType.TABLE: Table, + } + + @classmethod + def get_datasource( + cls, session: Session, datasource_type: DatasourceType, datasource_id: int + ) -> Datasource: + if datasource_type not in cls.sources: + raise DatasourceTypeNotSupportedError() + + datasource = ( + session.query(cls.sources[datasource_type]) + .filter_by(id=datasource_id) + .one_or_none() + ) + + if not datasource: + raise DatasourceNotFound() + + return datasource + + @classmethod + def get_all_sqlatables_datasources(cls, session: Session) -> List[Datasource]: + source_class = DatasourceDAO.sources[DatasourceType.SQLATABLE] + qry = session.query(source_class) + qry = source_class.default_query(qry) + return qry.all() + + @classmethod + def get_datasource_by_name( # pylint: disable=too-many-arguments + cls, + session: Session, + datasource_type: DatasourceType, + datasource_name: str, + database_name: str, + schema: str, + ) -> Optional[Datasource]: + datasource_class = DatasourceDAO.sources[datasource_type] + if isinstance(datasource_class, SqlaTable): + return datasource_class.get_datasource_by_name( + session, datasource_name, schema, database_name + ) + return None + + @classmethod + def query_datasources_by_permissions( # pylint: disable=invalid-name + cls, + session: Session, + database: Database, + permissions: Set[str], + schema_perms: Set[str], + ) -> List[Datasource]: + # TODO(hughhhh): add unit test + datasource_class = DatasourceDAO.sources[DatasourceType[database.type]] + if not isinstance(datasource_class, SqlaTable): + return [] + + return ( + session.query(datasource_class) + .filter_by(database_id=database.id) + .filter( + or_( + datasource_class.perm.in_(permissions), + datasource_class.schema_perm.in_(schema_perms), + ) + ) + .all() + ) + + @classmethod + def get_eager_datasource( + cls, session: Session, datasource_type: str, datasource_id: int + ) -> Optional[Datasource]: + """Returns datasource with columns and metrics.""" + datasource_class = DatasourceDAO.sources[DatasourceType[datasource_type]] + if not isinstance(datasource_class, SqlaTable): + return None + return ( + session.query(datasource_class) + .options( + subqueryload(datasource_class.columns), + subqueryload(datasource_class.metrics), + ) + .filter_by(id=datasource_id) + .one() + ) + + @classmethod + def query_datasources_by_name( + cls, + session: Session, + database: Database, + datasource_name: str, + schema: Optional[str] = None, + ) -> List[Datasource]: + datasource_class = DatasourceDAO.sources[DatasourceType[database.type]] + if not isinstance(datasource_class, SqlaTable): + return [] + + return datasource_class.query_datasources_by_name( + session, database, datasource_name, schema=schema + ) diff --git a/superset/dao/exceptions.py b/superset/dao/exceptions.py index 822b23982e..9b5624bd5d 100644 --- a/superset/dao/exceptions.py +++ b/superset/dao/exceptions.py @@ -53,3 +53,15 @@ class DAOConfigError(DAOException): """ message = "DAO is not configured correctly missing model definition" + + +class DatasourceTypeNotSupportedError(DAOException): + """ + DAO datasource query source type is not supported + """ + + message = "DAO datasource query source type is not supported" + + +class DatasourceNotFound(DAOException): + message = "Datasource does not exist" diff --git a/tests/unit_tests/dao/datasource_test.py b/tests/unit_tests/dao/datasource_test.py new file mode 100644 index 0000000000..dd0db265e7 --- /dev/null +++ b/tests/unit_tests/dao/datasource_test.py @@ -0,0 +1,185 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Iterator + +import pytest +from sqlalchemy.orm.session import Session + +from superset.utils.core import DatasourceType + + +@pytest.fixture +def session_with_data(session: Session) -> Iterator[Session]: + from superset.columns.models import Column + from superset.connectors.sqla.models import SqlaTable, TableColumn + from superset.datasets.models import Dataset + from superset.models.core import Database + from superset.models.sql_lab import Query, SavedQuery + from superset.tables.models import Table + + engine = session.get_bind() + SqlaTable.metadata.create_all(engine) # pylint: disable=no-member + + db = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + + columns = [ + TableColumn(column_name="a", type="INTEGER"), + ] + + sqla_table = SqlaTable( + table_name="my_sqla_table", + columns=columns, + metrics=[], + database=db, + ) + + query_obj = Query( + client_id="foo", + database=db, + tab_name="test_tab", + sql_editor_id="test_editor_id", + sql="select * from bar", + select_sql="select * from bar", + executed_sql="select * from bar", + limit=100, + select_as_cta=False, + rows=100, + error_message="none", + results_key="abc", + ) + + saved_query = SavedQuery(database=db, sql="select * from foo") + + table = Table( + name="my_table", + schema="my_schema", + catalog="my_catalog", + database=db, + columns=[], + ) + + dataset = Dataset( + database=table.database, + name="positions", + expression=""" +SELECT array_agg(array[longitude,latitude]) AS position +FROM my_catalog.my_schema.my_table +""", + tables=[table], + columns=[ + Column( + name="position", + expression="array_agg(array[longitude,latitude])", + ), + ], + ) + + session.add(dataset) + session.add(table) + session.add(saved_query) + session.add(query_obj) + session.add(db) + session.add(sqla_table) + session.flush() + yield session + + +def test_get_datasource_sqlatable( + app_context: None, session_with_data: Session +) -> None: + from superset.connectors.sqla.models import SqlaTable + from superset.dao.datasource.dao import DatasourceDAO + + result = DatasourceDAO.get_datasource( + datasource_type=DatasourceType.SQLATABLE, + datasource_id=1, + session=session_with_data, + ) + + assert 1 == result.id + assert "my_sqla_table" == result.table_name + assert isinstance(result, SqlaTable) + + +def test_get_datasource_query(app_context: None, session_with_data: Session) -> None: + from superset.dao.datasource.dao import DatasourceDAO + from superset.models.sql_lab import Query + + result = DatasourceDAO.get_datasource( + datasource_type=DatasourceType.QUERY, datasource_id=1, session=session_with_data + ) + + assert result.id == 1 + assert isinstance(result, Query) + + +def test_get_datasource_saved_query( + app_context: None, session_with_data: Session +) -> None: + from superset.dao.datasource.dao import DatasourceDAO + from superset.models.sql_lab import SavedQuery + + result = DatasourceDAO.get_datasource( + datasource_type=DatasourceType.SAVEDQUERY, + datasource_id=1, + session=session_with_data, + ) + + assert result.id == 1 + assert isinstance(result, SavedQuery) + + +def test_get_datasource_sl_table(app_context: None, session_with_data: Session) -> None: + from superset.dao.datasource.dao import DatasourceDAO + from superset.tables.models import Table + + # todo(hugh): This will break once we remove the dual write + # update the datsource_id=1 and this will pass again + result = DatasourceDAO.get_datasource( + datasource_type=DatasourceType.TABLE, datasource_id=2, session=session_with_data + ) + + assert result.id == 2 + assert isinstance(result, Table) + + +def test_get_datasource_sl_dataset( + app_context: None, session_with_data: Session +) -> None: + from superset.dao.datasource.dao import DatasourceDAO + from superset.datasets.models import Dataset + + # todo(hugh): This will break once we remove the dual write + # update the datsource_id=1 and this will pass again + result = DatasourceDAO.get_datasource( + datasource_type=DatasourceType.DATASET, + datasource_id=2, + session=session_with_data, + ) + + assert result.id == 2 + assert isinstance(result, Dataset) + + +def test_get_all_sqlatables_datasources( + app_context: None, session_with_data: Session +) -> None: + from superset.dao.datasource.dao import DatasourceDAO + + result = DatasourceDAO.get_all_sqlatables_datasources(session=session_with_data) + assert len(result) == 1