fix: add back database lookup from sip 68 revert (#22129)

This commit is contained in:
Elizabeth Thompson 2022-11-15 11:57:03 -08:00 committed by GitHub
parent e23efefc46
commit 6f6cb1839e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 64 additions and 24 deletions

View File

@ -80,11 +80,13 @@ from superset.common.db_query_status import QueryStatus
from superset.common.utils.time_range_utils import get_since_until_from_time_range from superset.common.utils.time_range_utils import get_since_until_from_time_range
from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric
from superset.connectors.sqla.utils import ( from superset.connectors.sqla.utils import (
find_cached_objects_in_session,
get_columns_description, get_columns_description,
get_physical_table_metadata, get_physical_table_metadata,
get_virtual_table_metadata, get_virtual_table_metadata,
validate_adhoc_subquery, validate_adhoc_subquery,
) )
from superset.datasets.models import Dataset as NewDataset
from superset.db_engine_specs.base import BaseEngineSpec, CTE_ALIAS, TimestampExpression from superset.db_engine_specs.base import BaseEngineSpec, CTE_ALIAS, TimestampExpression
from superset.exceptions import ( from superset.exceptions import (
AdvancedDataTypeResponseError, AdvancedDataTypeResponseError,
@ -2088,6 +2090,21 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
# table is updated. This busts the cache key for all charts that use the table. # table is updated. This busts the cache key for all charts that use the table.
session.execute(update(SqlaTable).where(SqlaTable.id == target.table.id)) session.execute(update(SqlaTable).where(SqlaTable.id == target.table.id))
# TODO: This shadow writing is deprecated
# if table itself has changed, shadow-writing will happen in `after_update` anyway
if target.table not in session.dirty:
dataset: NewDataset = (
session.query(NewDataset)
.filter_by(uuid=target.table.uuid)
.one_or_none()
)
# Update shadow dataset and columns
# did we find the dataset?
if not dataset:
# if dataset is not found create a new copy
target.table.write_shadow_dataset()
return
@staticmethod @staticmethod
def after_insert( def after_insert(
mapper: Mapper, mapper: Mapper,
@ -2099,6 +2116,9 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
""" """
security_manager.dataset_after_insert(mapper, connection, sqla_table) security_manager.dataset_after_insert(mapper, connection, sqla_table)
# TODO: deprecated
sqla_table.write_shadow_dataset()
@staticmethod @staticmethod
def after_delete( def after_delete(
mapper: Mapper, mapper: Mapper,
@ -2117,11 +2137,53 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
sqla_table: "SqlaTable", sqla_table: "SqlaTable",
) -> None: ) -> None:
""" """
Update dataset permissions after update Update dataset permissions
""" """
# set permissions # set permissions
security_manager.dataset_after_update(mapper, connection, sqla_table) security_manager.dataset_after_update(mapper, connection, sqla_table)
# TODO: the shadow writing is deprecated
inspector = inspect(sqla_table)
session = inspector.session
# double-check that ``UPDATE``s are actually pending (this method is called even
# for instances that have no net changes to their column-based attributes)
if not session.is_modified(sqla_table, include_collections=True):
return
# find the dataset from the known instance list first
# (it could be either from a previous query or newly created)
dataset = next(
find_cached_objects_in_session(
session, NewDataset, uuids=[sqla_table.uuid]
),
None,
)
# if not found, pull from database
if not dataset:
dataset = (
session.query(NewDataset).filter_by(uuid=sqla_table.uuid).one_or_none()
)
if not dataset:
sqla_table.write_shadow_dataset()
return
def write_shadow_dataset(
self: "SqlaTable",
) -> None:
"""
This method is deprecated
"""
session = inspect(self).session
# most of the write_shadow_dataset functionality has been removed
# but leaving this portion in
# to remove later because it is adding a Database relationship to the session
# and there is some functionality that depends on this
if self.database_id and (
not self.database or self.database.id != self.database_id
):
self.database = session.query(Database).filter_by(id=self.database_id).one()
sa.event.listen(SqlaTable, "before_update", SqlaTable.before_update) sa.event.listen(SqlaTable, "before_update", SqlaTable.before_update)
sa.event.listen(SqlaTable, "after_update", SqlaTable.after_update) sa.event.listen(SqlaTable, "after_update", SqlaTable.after_update)

View File

@ -15,11 +15,9 @@
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
import logging import logging
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional
from flask_appbuilder.models.sqla.interface import SQLAInterface
from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import joinedload
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.dao.base import BaseDAO from superset.dao.base import BaseDAO
@ -37,26 +35,6 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
model_cls = SqlaTable model_cls = SqlaTable
base_filter = DatasourceFilter base_filter = DatasourceFilter
@classmethod
def find_by_ids(cls, model_ids: Union[List[str], List[int]]) -> List[SqlaTable]:
"""
Find a List of models by a list of ids, if defined applies `base_filter`
"""
id_col = getattr(SqlaTable, cls.id_column_name, None)
if id_col is None:
return []
# the joinedload option ensures that the database is
# available in the session later and not lazy loaded
query = (
db.session.query(SqlaTable)
.options(joinedload(SqlaTable.database))
.filter(id_col.in_(model_ids))
)
data_model = SQLAInterface(SqlaTable, db.session)
query = DatasourceFilter(cls.id_column_name, data_model).apply(query, None)
return query.all()
@staticmethod @staticmethod
def get_database_by_id(database_id: int) -> Optional[Database]: def get_database_by_id(database_id: int) -> Optional[Database]:
try: try: