perf: improve perf in SIP-68 migration (#19416)

* chore: improve perf in SIP-68 migration

* Small fixes

* Create tables referenced in SQL

* Update logic in SqlaTable as well

* Fix unit tests
This commit is contained in:
Beto Dealmeida 2022-03-29 22:33:15 -07:00 committed by GitHub
parent 0968f86584
commit 63b5e2e4fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 364 additions and 41 deletions

View File

@ -102,7 +102,10 @@ def find_models(module: ModuleType) -> List[Type[Model]]:
while tables:
table = tables.pop()
seen.add(table)
model = getattr(Base.classes, table)
try:
model = getattr(Base.classes, table)
except AttributeError:
continue
model.__tablename__ = table
models.append(model)

View File

@ -111,6 +111,7 @@ setup(
"slackclient==2.5.0", # PINNED! slack changes file upload api in the future versions
"sqlalchemy>=1.3.16, <1.4, !=1.3.21",
"sqlalchemy-utils>=0.37.8, <0.38",
"sqloxide==0.1.15",
"sqlparse==0.3.0", # PINNED! see https://github.com/andialbrecht/sqlparse/issues/562
"tabulate==0.8.9",
# needed to support Literal (3.8) and TypeGuard (3.10)

View File

@ -78,6 +78,7 @@ from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetr
from superset.connectors.sqla.utils import (
get_physical_table_metadata,
get_virtual_table_metadata,
load_or_create_tables,
validate_adhoc_subquery,
)
from superset.datasets.models import Dataset as NewDataset
@ -2242,7 +2243,10 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
if column.is_active is False:
continue
extra_json = json.loads(column.extra or "{}")
try:
extra_json = json.loads(column.extra or "{}")
except json.decoder.JSONDecodeError:
extra_json = {}
for attr in {"groupby", "filterable", "verbose_name", "python_date_format"}:
value = getattr(column, attr)
if value:
@ -2269,7 +2273,10 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
# create metrics
for metric in dataset.metrics:
extra_json = json.loads(metric.extra or "{}")
try:
extra_json = json.loads(metric.extra or "{}")
except json.decoder.JSONDecodeError:
extra_json = {}
for attr in {"verbose_name", "metric_type", "d3format"}:
value = getattr(metric, attr)
if value:
@ -2300,8 +2307,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
)
# physical dataset
tables = []
if dataset.sql is None:
if not dataset.sql:
physical_columns = [column for column in columns if column.is_physical]
# create table
@ -2314,7 +2320,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
is_managed_externally=dataset.is_managed_externally,
external_url=dataset.external_url,
)
tables.append(table)
tables = [table]
# virtual dataset
else:
@ -2325,18 +2331,14 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
# find referenced tables
parsed = ParsedQuery(dataset.sql)
referenced_tables = parsed.tables
# predicate for finding the referenced tables
predicate = or_(
*[
and_(
NewTable.schema == (table.schema or dataset.schema),
NewTable.name == table.table,
)
for table in referenced_tables
]
tables = load_or_create_tables(
session,
dataset.database_id,
dataset.schema,
referenced_tables,
conditional_quote,
engine,
)
tables = session.query(NewTable).filter(predicate).all()
# create the new dataset
new_dataset = NewDataset(
@ -2345,7 +2347,7 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
expression=dataset.sql or conditional_quote(dataset.table_name),
tables=tables,
columns=columns,
is_physical=dataset.sql is None,
is_physical=not dataset.sql,
is_managed_externally=dataset.is_managed_externally,
external_url=dataset.external_url,
)

View File

@ -15,13 +15,17 @@
# specific language governing permissions and limitations
# under the License.
from contextlib import closing
from typing import Dict, List, Optional, TYPE_CHECKING
from typing import Callable, Dict, List, Optional, Set, TYPE_CHECKING
import sqlparse
from flask_babel import lazy_gettext as _
from sqlalchemy import and_, inspect, or_
from sqlalchemy.engine import Engine
from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.orm import Session
from sqlalchemy.sql.type_api import TypeEngine
from superset.columns.models import Column as NewColumn
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import (
SupersetGenericDBErrorException,
@ -29,12 +33,16 @@ from superset.exceptions import (
)
from superset.models.core import Database
from superset.result_set import SupersetResultSet
from superset.sql_parse import has_table_query, ParsedQuery
from superset.sql_parse import has_table_query, ParsedQuery, Table
from superset.tables.models import Table as NewTable
if TYPE_CHECKING:
from superset.connectors.sqla.models import SqlaTable
TEMPORAL_TYPES = {"DATETIME", "DATE", "TIME", "TIMEDELTA"}
def get_physical_table_metadata(
database: Database,
table_name: str,
@ -151,3 +159,78 @@ def validate_adhoc_subquery(raw_sql: str) -> None:
)
)
return
def load_or_create_tables( # pylint: disable=too-many-arguments
session: Session,
database_id: int,
default_schema: Optional[str],
tables: Set[Table],
conditional_quote: Callable[[str], str],
engine: Engine,
) -> List[NewTable]:
"""
Load or create new table model instances.
"""
if not tables:
return []
# set the default schema in tables that don't have it
if default_schema:
fixed_tables = list(tables)
for i, table in enumerate(fixed_tables):
if table.schema is None:
fixed_tables[i] = Table(table.table, default_schema, table.catalog)
tables = set(fixed_tables)
# load existing tables
predicate = or_(
*[
and_(
NewTable.database_id == database_id,
NewTable.schema == table.schema,
NewTable.name == table.table,
)
for table in tables
]
)
new_tables = session.query(NewTable).filter(predicate).all()
# add missing tables
existing = {(table.schema, table.name) for table in new_tables}
for table in tables:
if (table.schema, table.table) not in existing:
try:
inspector = inspect(engine)
column_metadata = inspector.get_columns(
table.table, schema=table.schema
)
except Exception: # pylint: disable=broad-except
continue
columns = [
NewColumn(
name=column["name"],
type=str(column["type"]),
expression=conditional_quote(column["name"]),
is_temporal=column["type"].python_type.__name__.upper()
in TEMPORAL_TYPES,
is_aggregation=False,
is_physical=True,
is_spatial=False,
is_partition=False,
is_increase_desired=True,
)
for column in column_metadata
]
new_tables.append(
NewTable(
name=table.table,
schema=table.schema,
catalog=None,
database_id=database_id,
columns=columns,
)
)
existing.add((table.schema, table.table))
return new_tables

View File

@ -14,10 +14,39 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Iterator, Optional, Set
from alembic import op
from sqlalchemy import engine_from_config
from sqlalchemy.engine import reflection
from sqlalchemy.exc import NoSuchTableError
from sqloxide import parse_sql
from superset.sql_parse import ParsedQuery, Table
logger = logging.getLogger("alembic")
# mapping between sqloxide and SQLAlchemy dialects
sqloxide_dialects = {
"ansi": {"trino", "trinonative", "presto"},
"hive": {"hive", "databricks"},
"ms": {"mssql"},
"mysql": {"mysql"},
"postgres": {
"cockroachdb",
"hana",
"netezza",
"postgres",
"postgresql",
"redshift",
"vertica",
},
"snowflake": {"snowflake"},
"sqlite": {"sqlite", "gsheets", "shillelagh"},
"clickhouse": {"clickhouse"},
}
def table_has_column(table: str, column: str) -> bool:
@ -38,3 +67,40 @@ def table_has_column(table: str, column: str) -> bool:
return any(col["name"] == column for col in insp.get_columns(table))
except NoSuchTableError:
return False
def find_nodes_by_key(element: Any, target: str) -> Iterator[Any]:
"""
Find all nodes in a SQL tree matching a given key.
"""
if isinstance(element, list):
for child in element:
yield from find_nodes_by_key(child, target)
elif isinstance(element, dict):
for key, value in element.items():
if key == target:
yield value
else:
yield from find_nodes_by_key(value, target)
def extract_table_references(sql_text: str, sqla_dialect: str) -> Set[Table]:
"""
Return all the dependencies from a SQL sql_text.
"""
dialect = "generic"
for dialect, sqla_dialects in sqloxide_dialects.items():
if sqla_dialect in sqla_dialects:
break
try:
tree = parse_sql(sql_text, dialect=dialect)
except Exception: # pylint: disable=broad-except
logger.warning("Unable to parse query with sqloxide: %s", sql_text)
# fallback to sqlparse
parsed = ParsedQuery(sql_text)
return parsed.tables
return {
Table(*[part["value"] for part in table["name"][::-1]])
for table in find_nodes_by_key(tree, "Table")
}

View File

@ -25,7 +25,7 @@ Create Date: 2021-11-11 16:41:53.266965
"""
import json
from typing import List
from typing import Callable, List, Optional, Set
from uuid import uuid4
import sqlalchemy as sa
@ -40,7 +40,9 @@ from sqlalchemy_utils import UUIDType
from superset import app, db
from superset.connectors.sqla.models import ADDITIVE_METRIC_TYPES
from superset.extensions import encrypted_field_factory
from superset.sql_parse import ParsedQuery
from superset.migrations.shared.utils import extract_table_references
from superset.models.core import Database as OriginalDatabase
from superset.sql_parse import Table
# revision identifiers, used by Alembic.
revision = "b8d3a24d9131"
@ -228,6 +230,85 @@ class NewDataset(Base):
external_url = sa.Column(sa.Text, nullable=True)
TEMPORAL_TYPES = {"DATETIME", "DATE", "TIME", "TIMEDELTA"}
def load_or_create_tables(
session: Session,
database_id: int,
default_schema: Optional[str],
tables: Set[Table],
conditional_quote: Callable[[str], str],
) -> List[NewTable]:
"""
Load or create new table model instances.
"""
if not tables:
return []
# set the default schema in tables that don't have it
if default_schema:
tables = list(tables)
for i, table in enumerate(tables):
if table.schema is None:
tables[i] = Table(table.table, default_schema, table.catalog)
# load existing tables
predicate = or_(
*[
and_(
NewTable.database_id == database_id,
NewTable.schema == table.schema,
NewTable.name == table.table,
)
for table in tables
]
)
new_tables = session.query(NewTable).filter(predicate).all()
# use original database model to get the engine
engine = (
session.query(OriginalDatabase)
.filter_by(id=database_id)
.one()
.get_sqla_engine(default_schema)
)
inspector = inspect(engine)
# add missing tables
existing = {(table.schema, table.name) for table in new_tables}
for table in tables:
if (table.schema, table.table) not in existing:
column_metadata = inspector.get_columns(table.table, schema=table.schema)
columns = [
NewColumn(
name=column["name"],
type=str(column["type"]),
expression=conditional_quote(column["name"]),
is_temporal=column["type"].python_type.__name__.upper()
in TEMPORAL_TYPES,
is_aggregation=False,
is_physical=True,
is_spatial=False,
is_partition=False,
is_increase_desired=True,
)
for column in column_metadata
]
new_tables.append(
NewTable(
name=table.table,
schema=table.schema,
catalog=None,
database_id=database_id,
columns=columns,
)
)
existing.add((table.schema, table.table))
return new_tables
def after_insert(target: SqlaTable) -> None: # pylint: disable=too-many-locals
"""
Copy old datasets to the new models.
@ -253,7 +334,10 @@ def after_insert(target: SqlaTable) -> None: # pylint: disable=too-many-locals
if column.is_active is False:
continue
extra_json = json.loads(column.extra or "{}")
try:
extra_json = json.loads(column.extra or "{}")
except json.decoder.JSONDecodeError:
extra_json = {}
for attr in {"groupby", "filterable", "verbose_name", "python_date_format"}:
value = getattr(column, attr)
if value:
@ -279,7 +363,10 @@ def after_insert(target: SqlaTable) -> None: # pylint: disable=too-many-locals
# create metrics
for metric in target.metrics:
extra_json = json.loads(metric.extra or "{}")
try:
extra_json = json.loads(metric.extra or "{}")
except json.decoder.JSONDecodeError:
extra_json = {}
for attr in {"verbose_name", "metric_type", "d3format"}:
value = getattr(metric, attr)
if value:
@ -309,8 +396,7 @@ def after_insert(target: SqlaTable) -> None: # pylint: disable=too-many-locals
)
# physical dataset
tables = []
if target.sql is None:
if not target.sql:
physical_columns = [column for column in columns if column.is_physical]
# create table
@ -323,7 +409,7 @@ def after_insert(target: SqlaTable) -> None: # pylint: disable=too-many-locals
is_managed_externally=target.is_managed_externally,
external_url=target.external_url,
)
tables.append(table)
tables = [table]
# virtual dataset
else:
@ -332,20 +418,14 @@ def after_insert(target: SqlaTable) -> None: # pylint: disable=too-many-locals
column.is_physical = False
# find referenced tables
parsed = ParsedQuery(target.sql)
referenced_tables = parsed.tables
# predicate for finding the referenced tables
predicate = or_(
*[
and_(
NewTable.schema == (table.schema or target.schema),
NewTable.name == table.table,
)
for table in referenced_tables
]
referenced_tables = extract_table_references(target.sql, dialect_class.name)
tables = load_or_create_tables(
session,
target.database_id,
target.schema,
referenced_tables,
conditional_quote,
)
tables = session.query(NewTable).filter(predicate).all()
# create the new dataset
dataset = NewDataset(
@ -354,7 +434,7 @@ def after_insert(target: SqlaTable) -> None: # pylint: disable=too-many-locals
expression=target.sql or conditional_quote(target.table_name),
tables=tables,
columns=columns,
is_physical=target.sql is None,
is_physical=not target.sql,
is_managed_externally=target.is_managed_externally,
external_url=target.external_url,
)

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,56 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-outside-toplevel, unused-argument
"""
Test the SIP-68 migration.
"""
from pytest_mock import MockerFixture
from superset.sql_parse import Table
def test_extract_table_references(mocker: MockerFixture, app_context: None) -> None:
"""
Test the ``extract_table_references`` helper function.
"""
from superset.migrations.shared.utils import extract_table_references
assert extract_table_references("SELECT 1", "trino") == set()
assert extract_table_references("SELECT 1 FROM some_table", "trino") == {
Table(table="some_table", schema=None, catalog=None)
}
assert extract_table_references(
"SELECT 1 FROM some_catalog.some_schema.some_table", "trino"
) == {Table(table="some_table", schema="some_schema", catalog="some_catalog")}
assert extract_table_references(
"SELECT * FROM some_table JOIN other_table ON some_table.id = other_table.id",
"trino",
) == {
Table(table="some_table", schema=None, catalog=None),
Table(table="other_table", schema=None, catalog=None),
}
# test falling back to sqlparse
logger = mocker.patch("superset.migrations.shared.utils.logger")
sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table"
assert extract_table_references(
sql,
"trino",
) == {Table(table="other_table", schema=None, catalog=None)}
logger.warning.assert_called_with("Unable to parse query with sqloxide: %s", sql)