[mypy] Enforcing typing for db_engine_specs (#9138)

This commit is contained in:
John Bodley 2020-02-17 23:08:11 -08:00 committed by GitHub
parent 3149d8ebc0
commit 9f5f8e5d92
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 173 additions and 104 deletions

View File

@ -52,3 +52,8 @@ order_by_type = false
[mypy]
ignore_missing_imports = true
no_implicit_optional = true
[mypy-superset.db_engine_specs.*]
check_untyped_defs = true
disallow_untyped_calls = true
disallow_untyped_defs = true

View File

@ -30,16 +30,23 @@ from sqlalchemy import column, DateTime, select
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.interfaces import Compiled, Dialect
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import URL
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm import Session
from sqlalchemy.sql import quoted_name, text
from sqlalchemy.sql.expression import ColumnClause, ColumnElement, Select, TextAsFrom
from sqlalchemy.types import TypeEngine
from wtforms.form import Form
from superset import app, sql_parse
from superset.models.sql_lab import Query
from superset.utils import core as utils
if TYPE_CHECKING:
# prevent circular imports
from superset.connectors.sqla.models import ( # pylint: disable=unused-import
TableColumn,
)
from superset.models.core import Database # pylint: disable=unused-import
@ -77,7 +84,7 @@ builtin_time_grains: Dict[Optional[str], str] = {
class TimestampExpression(
ColumnClause
): # pylint: disable=abstract-method,too-many-ancestors,too-few-public-methods
def __init__(self, expr: str, col: ColumnClause, **kwargs):
def __init__(self, expr: str, col: ColumnClause, **kwargs: Any) -> None:
"""Sqlalchemy class that can be can be used to render native column elements
respeting engine-specific quoting rules as part of a string-based expression.
@ -89,7 +96,7 @@ class TimestampExpression(
self.col = col
@property
def _constructor(self):
def _constructor(self) -> ColumnClause:
# Needed to ensure that the column label is rendered correctly when
# proxied to the outer query.
# See https://github.com/sqlalchemy/sqlalchemy/issues/4730
@ -98,9 +105,9 @@ class TimestampExpression(
@compiles(TimestampExpression)
def compile_timegrain_expression(
element: TimestampExpression, compiler: Compiled, **kw
element: TimestampExpression, compiler: Compiled, **kwargs: Any
) -> str:
return element.name.replace("{col}", compiler.process(element.col, **kw))
return element.name.replace("{col}", compiler.process(element.col, **kwargs))
class LimitMethod: # pylint: disable=too-few-public-methods
@ -132,7 +139,12 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return False
@classmethod
def get_engine(cls, database, schema=None, source=None):
def get_engine(
cls,
database: "Database",
schema: Optional[str] = None,
source: Optional[str] = None,
) -> Engine:
user_name = utils.get_username()
return database.get_sqla_engine(
schema=schema, nullpool=True, user_name=user_name, source=source
@ -217,7 +229,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return select_exprs
@classmethod
def fetch_data(cls, cursor, limit: int) -> List[Tuple]:
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
"""
:param cursor: Cursor instance
@ -246,7 +258,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return columns, data, []
@classmethod
def alter_new_orm_column(cls, orm_col):
def alter_new_orm_column(cls, orm_col: "TableColumn") -> None:
"""Allow altering default column attributes when first detected/added
For instance special column like `__time` for Druid can be
@ -290,7 +302,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def extra_table_metadata(
cls, database, table_name: str, schema_name: str
cls, database: "Database", table_name: str, schema_name: str
) -> Dict[str, Any]:
"""
Returns engine-specific table metadata
@ -304,7 +316,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return {}
@classmethod
def apply_limit_to_sql(cls, sql: str, limit: int, database) -> str:
def apply_limit_to_sql(cls, sql: str, limit: int, database: "Database") -> str:
"""
Alters the SQL statement to apply a LIMIT clause
@ -351,7 +363,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return parsed_query.get_query_with_new_limit(limit)
@staticmethod
def csv_to_df(**kwargs) -> pd.DataFrame:
def csv_to_df(**kwargs: Any) -> pd.DataFrame:
""" Read csv into Pandas DataFrame
:param kwargs: params to be passed to DataFrame.read_csv
:return: Pandas DataFrame containing data from csv
@ -363,7 +375,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return df
@classmethod
def df_to_sql(cls, df: pd.DataFrame, **kwargs): # pylint: disable=invalid-name
def df_to_sql( # pylint: disable=invalid-name
cls, df: pd.DataFrame, **kwargs: Any
) -> None:
""" Upload data from a Pandas DataFrame to a database. For
regular engines this calls the DataFrame.to_sql() method. Can be
overridden for engines that don't work well with to_sql(), e.g.
@ -374,7 +388,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
df.to_sql(**kwargs)
@classmethod
def create_table_from_csv(cls, form, database) -> None:
def create_table_from_csv(cls, form: Form, database: "Database") -> None:
"""
Create table from contents of a csv. Note: this method does not create
metadata for the table.
@ -437,7 +451,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def get_all_datasource_names(
cls, database, datasource_type: str
cls, database: "Database", datasource_type: str
) -> List[utils.DatasourceName]:
"""Returns a list of all tables or views in database.
@ -472,7 +486,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return all_datasources
@classmethod
def handle_cursor(cls, cursor, query, session):
def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None:
"""Handle a live cursor between the execute and fetchall calls
The flow works without this method doing anything, but it allows
@ -486,13 +500,14 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return f"{cls.engine} error: {cls._extract_error_message(e)}"
@classmethod
def _extract_error_message(cls, e: Exception) -> str:
def _extract_error_message(cls, e: Exception) -> Optional[str]:
"""Extract error message for queries"""
return utils.error_msg_from_exception(e)
@classmethod
def adjust_database_uri(cls, uri, selected_schema: Optional[str]):
"""Based on a URI and selected schema, return a new URI
def adjust_database_uri(cls, uri: URL, selected_schema: Optional[str]) -> None:
"""
Mutate the database component of the SQLAlchemy URI.
The URI here represents the URI as entered when saving the database,
``selected_schema`` is the schema currently active presumably in
@ -509,11 +524,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
Some database drivers like presto accept '{catalog}/{schema}' in
the database component of the URL, that can be handled here.
"""
# TODO: All overrides mutate input uri; should be renamed or refactored
return uri
pass
@classmethod
def patch(cls):
def patch(cls) -> None:
"""
TODO: Improve docstring and refactor implementation in Hive
"""
@ -580,7 +594,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
cls,
table_name: str,
schema: Optional[str],
database,
database: "Database",
query: Select,
columns: Optional[List] = None,
) -> Optional[Select]:
@ -599,13 +613,13 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return None
@classmethod
def _get_fields(cls, cols):
return [column(c.get("name")) for c in cols]
def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[Any]:
return [column(c["name"]) for c in cols]
@classmethod
def select_star( # pylint: disable=too-many-arguments,too-many-locals
cls,
database,
database: "Database",
table_name: str,
engine: Engine,
schema: Optional[str] = None,
@ -629,7 +643,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:param cols: Columns to include in query
:return: SQL query
"""
fields = "*"
fields: Union[str, List[Any]] = "*"
cols = cols or []
if (show_cols or latest_partition) and not cols:
cols = database.get_columns(table_name, schema)
@ -659,7 +673,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def estimate_statement_cost(
cls, statement: str, database, cursor, user_name: str
cls, statement: str, database: "Database", cursor: Any, user_name: str
) -> Dict[str, Any]:
"""
Generate a SQL query that estimates the cost of a given statement.
@ -686,7 +700,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def estimate_query_cost(
cls, database, schema: str, sql: str, source: Optional[str] = None
cls, database: "Database", schema: str, sql: str, source: Optional[str] = None
) -> List[Dict[str, str]]:
"""
Estimate the cost of a multiple statement SQL query.
@ -718,8 +732,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def modify_url_for_impersonation(
cls, url, impersonate_user: bool, username: Optional[str]
):
cls, url: URL, impersonate_user: bool, username: Optional[str]
) -> None:
"""
Modify the SQL Alchemy URL object with the user to impersonate if applicable.
:param url: SQLAlchemy URL object
@ -745,7 +759,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return {}
@classmethod
def execute(cls, cursor, query: str, **kwargs):
def execute(cls, cursor: Any, query: str, **kwargs: Any) -> None:
"""
Execute a SQL query

View File

@ -17,13 +17,17 @@
import hashlib
import re
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
import pandas as pd
from sqlalchemy import literal_column
from sqlalchemy.sql.expression import ColumnClause
from superset.db_engine_specs.base import BaseEngineSpec
if TYPE_CHECKING:
from superset.models.core import Database # pylint: disable=unused-import
class BigQueryEngineSpec(BaseEngineSpec):
"""Engine spec for Google's BigQuery
@ -69,7 +73,7 @@ class BigQueryEngineSpec(BaseEngineSpec):
return None
@classmethod
def fetch_data(cls, cursor, limit: int) -> List[Tuple]:
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
data = super(BigQueryEngineSpec, cls).fetch_data(cursor, limit)
if data and type(data[0]).__name__ == "Row":
data = [r.values() for r in data] # type: ignore
@ -112,7 +116,7 @@ class BigQueryEngineSpec(BaseEngineSpec):
@classmethod
def extra_table_metadata(
cls, database, table_name: str, schema_name: str
cls, database: "Database", table_name: str, schema_name: str
) -> Dict[str, Any]:
indexes = database.get_indexes(table_name, schema_name)
if not indexes:
@ -133,7 +137,7 @@ class BigQueryEngineSpec(BaseEngineSpec):
}
@classmethod
def _get_fields(cls, cols):
def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[ColumnClause]:
"""
BigQuery dialect requires us to not use backtick in the fieldname which are
nested.
@ -143,8 +147,7 @@ class BigQueryEngineSpec(BaseEngineSpec):
column names in the result.
"""
return [
literal_column(c.get("name")).label(c.get("name").replace(".", "__"))
for c in cols
literal_column(c["name"]).label(c["name"].replace(".", "__")) for c in cols
]
@classmethod
@ -156,7 +159,7 @@ class BigQueryEngineSpec(BaseEngineSpec):
return "TIMESTAMP_MILLIS({col})"
@classmethod
def df_to_sql(cls, df: pd.DataFrame, **kwargs):
def df_to_sql(cls, df: pd.DataFrame, **kwargs: Any) -> None:
"""
Upload data from a Pandas DataFrame to BigQuery. Calls
`DataFrame.to_gbq()` which requires `pandas_gbq` to be installed.

View File

@ -18,6 +18,8 @@ from datetime import datetime
from typing import Optional
from urllib import parse
from sqlalchemy.engine.url import URL
from superset.db_engine_specs.base import BaseEngineSpec
@ -59,7 +61,6 @@ class DrillEngineSpec(BaseEngineSpec):
return None
@classmethod
def adjust_database_uri(cls, uri, selected_schema):
def adjust_database_uri(cls, uri: URL, selected_schema: Optional[str]) -> None:
if selected_schema:
uri.database = parse.quote(selected_schema, safe="")
return uri

View File

@ -14,8 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import TYPE_CHECKING
from superset.db_engine_specs.base import BaseEngineSpec
if TYPE_CHECKING:
from superset.connectors.sqla.models import ( # pylint: disable=unused-import
TableColumn,
)
class DruidEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
"""Engine spec for Druid.io"""
@ -37,6 +44,6 @@ class DruidEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
}
@classmethod
def alter_new_orm_column(cls, orm_col):
def alter_new_orm_column(cls, orm_col: "TableColumn") -> None:
if orm_col.column_name == "__time":
orm_col.is_dttm = True

View File

@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import List, Tuple
from typing import Any, List, Tuple
from superset.db_engine_specs.base import BaseEngineSpec
@ -39,7 +39,7 @@ class ExasolEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
}
@classmethod
def fetch_data(cls, cursor, limit: int) -> List[Tuple]:
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
data = super().fetch_data(cursor, limit)
# Lists of `pyodbc.Row` need to be unpacked further
return cls.pyodbc_rows_to_tuples(data)

View File

@ -22,15 +22,19 @@ from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from urllib import parse
import pandas as pd
from sqlalchemy import Column
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import make_url
from sqlalchemy.engine.url import make_url, URL
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import ColumnClause, Select
from wtforms.form import Form
from superset import app, cache, conf
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.presto import PrestoEngineSpec
from superset.models.sql_lab import Query
from superset.utils import core as utils
if TYPE_CHECKING:
@ -67,7 +71,7 @@ class HiveEngineSpec(PrestoEngineSpec):
)
@classmethod
def patch(cls):
def patch(cls) -> None:
from pyhive import hive # pylint: disable=no-name-in-module
from superset.db_engines import hive as patched_hive
from TCLIService import (
@ -83,12 +87,12 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod
def get_all_datasource_names(
cls, database, datasource_type: str
cls, database: "Database", datasource_type: str
) -> List[utils.DatasourceName]:
return BaseEngineSpec.get_all_datasource_names(database, datasource_type)
@classmethod
def fetch_data(cls, cursor, limit: int) -> List[Tuple]:
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
import pyhive
from TCLIService import ttypes
@ -102,11 +106,11 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod
def create_table_from_csv( # pylint: disable=too-many-locals
cls, form, database
cls, form: Form, database: "Database"
) -> None:
"""Uploads a csv file and creates a superset datasource in Hive."""
def convert_to_hive_type(col_type):
def convert_to_hive_type(col_type: str) -> str:
"""maps tableschema's types to hive types"""
tableschema_to_hive_types = {
"boolean": "BOOLEAN",
@ -192,13 +196,14 @@ class HiveEngineSpec(PrestoEngineSpec):
return None
@classmethod
def adjust_database_uri(cls, uri, selected_schema=None):
def adjust_database_uri(
cls, uri: URL, selected_schema: Optional[str] = None
) -> None:
if selected_schema:
uri.database = parse.quote(selected_schema, safe="")
return uri
@classmethod
def _extract_error_message(cls, e):
def _extract_error_message(cls, e: Exception) -> str:
msg = str(e)
match = re.search(r'errorMessage="(.*?)(?<!\\)"', msg)
if match:
@ -206,10 +211,10 @@ class HiveEngineSpec(PrestoEngineSpec):
return msg
@classmethod
def progress(cls, log_lines):
def progress(cls, log_lines: List[str]) -> int:
total_jobs = 1 # assuming there's at least 1 job
current_job = 1
stages = {}
stages: Dict[int, float] = {}
for line in log_lines:
match = cls.jobs_stats_r.match(line)
if match:
@ -237,15 +242,17 @@ class HiveEngineSpec(PrestoEngineSpec):
return int(progress)
@classmethod
def get_tracking_url(cls, log_lines):
def get_tracking_url(cls, log_lines: List[str]) -> Optional[str]:
lkp = "Tracking URL = "
for line in log_lines:
if lkp in line:
return line.split(lkp)[1]
return None
return None
@classmethod
def handle_cursor(cls, cursor, query, session): # pylint: disable=too-many-locals
def handle_cursor( # pylint: disable=too-many-locals
cls, cursor: Any, query: Query, session: Session
) -> None:
"""Updates progress information"""
from pyhive import hive # pylint: disable=no-name-in-module
@ -310,7 +317,7 @@ class HiveEngineSpec(PrestoEngineSpec):
cls,
table_name: str,
schema: Optional[str],
database,
database: "Database",
query: Select,
columns: Optional[List] = None,
) -> Optional[Select]:
@ -335,12 +342,14 @@ class HiveEngineSpec(PrestoEngineSpec):
return BaseEngineSpec._get_fields(cols) # pylint: disable=protected-access
@classmethod
def latest_sub_partition(cls, table_name, schema, database, **kwargs):
def latest_sub_partition(
cls, table_name: str, schema: Optional[str], database: "Database", **kwargs: Any
) -> str:
# TODO(bogdan): implement`
pass
@classmethod
def _latest_partition_from_df(cls, df) -> Optional[List[str]]:
def _latest_partition_from_df(cls, df: pd.DataFrame) -> Optional[List[str]]:
"""Hive partitions look like ds={partition name}"""
if not df.empty:
return [df.ix[:, 0].max().split("=")[1]]
@ -348,14 +357,19 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod
def _partition_query( # pylint: disable=too-many-arguments
cls, table_name, database, limit=0, order_by=None, filters=None
):
cls,
table_name: str,
database: "Database",
limit: int = 0,
order_by: Optional[List[Tuple[str, bool]]] = None,
filters: Optional[Dict[Any, Any]] = None,
) -> str:
return f"SHOW PARTITIONS {table_name}"
@classmethod
def select_star( # pylint: disable=too-many-arguments
cls,
database,
database: "Database",
table_name: str,
engine: Engine,
schema: Optional[str] = None,
@ -381,8 +395,8 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod
def modify_url_for_impersonation(
cls, url, impersonate_user: bool, username: Optional[str]
):
cls, url: URL, impersonate_user: bool, username: Optional[str]
) -> None:
"""
Modify the SQL Alchemy URL object with the user to impersonate if applicable.
:param url: SQLAlchemy URL object

View File

@ -16,7 +16,7 @@
# under the License.
import re
from datetime import datetime
from typing import List, Optional, Tuple
from typing import Any, List, Optional, Tuple
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.types import String, TypeEngine, UnicodeText
@ -46,7 +46,7 @@ class MssqlEngineSpec(BaseEngineSpec):
}
@classmethod
def epoch_to_dttm(cls):
def epoch_to_dttm(cls) -> str:
return "dateadd(S, {col}, '1970-01-01')"
@classmethod
@ -61,7 +61,7 @@ class MssqlEngineSpec(BaseEngineSpec):
return None
@classmethod
def fetch_data(cls, cursor, limit: int) -> List[Tuple]:
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
data = super().fetch_data(cursor, limit)
# Lists of `pyodbc.Row` need to be unpacked further
return cls.pyodbc_rows_to_tuples(data)

View File

@ -19,6 +19,7 @@ from typing import Any, Dict, Optional
from urllib import parse
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.engine.url import URL
from sqlalchemy.types import TypeEngine
from superset.db_engine_specs.base import BaseEngineSpec
@ -59,10 +60,11 @@ class MySQLEngineSpec(BaseEngineSpec):
return None
@classmethod
def adjust_database_uri(cls, uri, selected_schema=None):
def adjust_database_uri(
cls, uri: URL, selected_schema: Optional[str] = None
) -> None:
if selected_schema:
uri.database = parse.quote(selected_schema, safe="")
return uri
@classmethod
def get_datatype(cls, type_code: Any) -> Optional[str]:
@ -86,7 +88,7 @@ class MySQLEngineSpec(BaseEngineSpec):
return "from_unixtime({col})"
@classmethod
def _extract_error_message(cls, e):
def _extract_error_message(cls, e: Exception) -> str:
"""Extract error message for queries"""
message = str(e)
try:

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
from typing import List, Optional, Tuple, TYPE_CHECKING
from typing import Any, List, Optional, Tuple, TYPE_CHECKING
from pytz import _FixedOffset # type: ignore
from sqlalchemy.dialects.postgresql.base import PGInspector
@ -51,7 +51,7 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
}
@classmethod
def fetch_data(cls, cursor, limit: int) -> List[Tuple]:
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
cursor.tzinfo_factory = FixedOffsetTimezone
if not cursor.description:
return []

View File

@ -25,16 +25,20 @@ from distutils.version import StrictVersion
from typing import Any, cast, Dict, List, Optional, Tuple, TYPE_CHECKING
from urllib import parse
import pandas as pd
import simplejson as json
from sqlalchemy import Column, literal_column
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.result import RowProxy
from sqlalchemy.engine.url import URL
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import ColumnClause, Select
from superset import app, cache, is_feature_enabled, security_manager
from superset.db_engine_specs.base import BaseEngineSpec
from superset.exceptions import SupersetTemplateException
from superset.models.sql_lab import Query
from superset.models.sql_types.presto_sql_types import type_map as presto_type_map
from superset.sql_parse import ParsedQuery
from superset.utils import core as utils
@ -392,7 +396,7 @@ class PrestoEngineSpec(BaseEngineSpec):
@classmethod
def select_star( # pylint: disable=too-many-arguments
cls,
database,
database: "Database",
table_name: str,
engine: Engine,
schema: Optional[str] = None,
@ -428,7 +432,7 @@ class PrestoEngineSpec(BaseEngineSpec):
@classmethod
def estimate_statement_cost( # pylint: disable=too-many-locals
cls, statement: str, database, cursor, user_name: str
cls, statement: str, database: "Database", cursor: Any, user_name: str
) -> Dict[str, Any]:
"""
Run a SQL query that estimates the cost of a given statement.
@ -510,7 +514,9 @@ class PrestoEngineSpec(BaseEngineSpec):
return cost
@classmethod
def adjust_database_uri(cls, uri, selected_schema=None):
def adjust_database_uri(
cls, uri: URL, selected_schema: Optional[str] = None
) -> None:
database = uri.database
if selected_schema and database:
selected_schema = parse.quote(selected_schema, safe="")
@ -519,7 +525,6 @@ class PrestoEngineSpec(BaseEngineSpec):
else:
database += "/" + selected_schema
uri.database = database
return uri
@classmethod
def convert_dttm(cls, target_type: str, dttm: datetime) -> Optional[str]:
@ -536,7 +541,7 @@ class PrestoEngineSpec(BaseEngineSpec):
@classmethod
def get_all_datasource_names(
cls, database, datasource_type: str
cls, database: "Database", datasource_type: str
) -> List[utils.DatasourceName]:
datasource_df = database.get_df(
"SELECT table_schema, table_name FROM INFORMATION_SCHEMA.{}S "
@ -656,7 +661,7 @@ class PrestoEngineSpec(BaseEngineSpec):
@classmethod
def extra_table_metadata(
cls, database, table_name: str, schema_name: str
cls, database: "Database", table_name: str, schema_name: str
) -> Dict[str, Any]:
metadata = {}
@ -670,10 +675,12 @@ class PrestoEngineSpec(BaseEngineSpec):
col_names, latest_parts = cls.latest_partition(
table_name, schema_name, database, show_first=True
)
latest_parts = latest_parts or tuple([None] * len(col_names))
if not latest_parts:
latest_parts = tuple([None] * len(col_names)) # type: ignore
metadata["partitions"] = {
"cols": cols,
"latest": dict(zip(col_names, latest_parts)),
"latest": dict(zip(col_names, latest_parts)), # type: ignore
"partitionQuery": pql,
}
@ -685,7 +692,9 @@ class PrestoEngineSpec(BaseEngineSpec):
return metadata
@classmethod
def get_create_view(cls, database, schema: str, table: str) -> Optional[str]:
def get_create_view(
cls, database: "Database", schema: str, table: str
) -> Optional[str]:
"""
Return a CREATE VIEW statement, or `None` if not a view.
@ -712,7 +721,7 @@ class PrestoEngineSpec(BaseEngineSpec):
return rows[0][0]
@classmethod
def handle_cursor(cls, cursor, query, session):
def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None:
"""Updates progress information"""
query_id = query.id
logger.info(f"Query {query_id}: Polling the cursor for progress")
@ -753,13 +762,13 @@ class PrestoEngineSpec(BaseEngineSpec):
polled = cursor.poll()
@classmethod
def _extract_error_message(cls, e):
def _extract_error_message(cls, e: Exception) -> Optional[str]:
if (
hasattr(e, "orig")
and type(e.orig).__name__ == "DatabaseError"
and isinstance(e.orig[0], dict)
and type(e.orig).__name__ == "DatabaseError" # type: ignore
and isinstance(e.orig[0], dict) # type: ignore
):
error_dict = e.orig[0]
error_dict = e.orig[0] # type: ignore
return "{} at {}: {}".format(
error_dict.get("errorName"),
error_dict.get("errorLocation"),
@ -772,8 +781,13 @@ class PrestoEngineSpec(BaseEngineSpec):
@classmethod
def _partition_query( # pylint: disable=too-many-arguments,too-many-locals
cls, table_name, database, limit=0, order_by=None, filters=None
):
cls,
table_name: str,
database: "Database",
limit: int = 0,
order_by: Optional[List[Tuple[str, bool]]] = None,
filters: Optional[Dict[Any, Any]] = None,
) -> str:
"""Returns a partition query
:param table_name: the name of the table to get partitions from
@ -827,7 +841,7 @@ class PrestoEngineSpec(BaseEngineSpec):
cls,
table_name: str,
schema: Optional[str],
database,
database: "Database",
query: Select,
columns: Optional[List] = None,
) -> Optional[Select]:
@ -850,7 +864,7 @@ class PrestoEngineSpec(BaseEngineSpec):
@classmethod
def _latest_partition_from_df( # pylint: disable=invalid-name
cls, df
cls, df: pd.DataFrame
) -> Optional[List[str]]:
if not df.empty:
return df.to_records(index=False)[0].item()
@ -858,8 +872,12 @@ class PrestoEngineSpec(BaseEngineSpec):
@classmethod
def latest_partition(
cls, table_name: str, schema: Optional[str], database, show_first: bool = False
):
cls,
table_name: str,
schema: Optional[str],
database: "Database",
show_first: bool = False,
) -> Tuple[List[str], Optional[List[str]]]:
"""Returns col name and the latest (max) partition value for a table
:param table_name: the name of the table
@ -897,7 +915,9 @@ class PrestoEngineSpec(BaseEngineSpec):
return column_names, cls._latest_partition_from_df(df)
@classmethod
def latest_sub_partition(cls, table_name, schema, database, **kwargs):
def latest_sub_partition(
cls, table_name: str, schema: Optional[str], database: "Database", **kwargs: Any
) -> Any:
"""Returns the latest (max) partition value for a table
A filtering criteria should be passed for all fields that are

View File

@ -18,6 +18,8 @@ from datetime import datetime
from typing import Optional
from urllib import parse
from sqlalchemy.engine.url import URL
from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
@ -47,14 +49,15 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
}
@classmethod
def adjust_database_uri(cls, uri, selected_schema=None):
def adjust_database_uri(
cls, uri: URL, selected_schema: Optional[str] = None
) -> None:
database = uri.database
if "/" in uri.database:
database = uri.database.split("/")[0]
if selected_schema:
selected_schema = parse.quote(selected_schema, safe="")
uri.database = database + "/" + selected_schema
return uri
@classmethod
def epoch_to_dttm(cls) -> str:

View File

@ -49,7 +49,7 @@ class SqliteEngineSpec(BaseEngineSpec):
@classmethod
def get_all_datasource_names(
cls, database, datasource_type: str
cls, database: "Database", datasource_type: str
) -> List[utils.DatasourceName]:
schemas = database.get_all_schema_names(
cache=database.schema_cache_enabled,

View File

@ -282,7 +282,7 @@ class Database(
) -> Engine:
extra = self.get_extra()
sqlalchemy_url = make_url(self.sqlalchemy_uri_decrypted)
sqlalchemy_url = self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema)
self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema)
effective_username = self.get_effective_user(sqlalchemy_url, user_name)
# If using MySQL or Presto for example, will set url.username
# If using Hive, will not do anything yet since that relies on a

View File

@ -51,7 +51,7 @@ def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
class ParsedQuery:
def __init__(self, sql_statement):
def __init__(self, sql_statement: str):
self.sql: str = sql_statement
self._table_names: Set[str] = set()
self._alias_names: Set[str] = set()

View File

@ -414,7 +414,7 @@ def json_dumps_w_dates(payload):
return json.dumps(payload, default=json_int_dttm_ser)
def error_msg_from_exception(e):
def error_msg_from_exception(e: Exception) -> str:
"""Translate exception into error message
Database have different ways to handle exception. This function attempts
@ -430,10 +430,10 @@ def error_msg_from_exception(e):
"""
msg = ""
if hasattr(e, "message"):
if isinstance(e.message, dict):
msg = e.message.get("message")
elif e.message:
msg = e.message
if isinstance(e.message, dict): # type: ignore
msg = e.message.get("message") # type: ignore
elif e.message: # type: ignore
msg = e.message # type: ignore
return msg or str(e)

View File

@ -34,6 +34,6 @@ class FeatureFlagManager:
return self._feature_flags
def is_feature_enabled(self, feature):
def is_feature_enabled(self, feature) -> bool:
"""Utility function for checking whether a feature is turned on"""
return self.get_feature_flags().get(feature)