mirror of https://github.com/apache/superset.git
[mypy] Enforcing typing for db_engine_specs (#9138)
This commit is contained in:
parent
3149d8ebc0
commit
9f5f8e5d92
|
@ -52,3 +52,8 @@ order_by_type = false
|
|||
[mypy]
|
||||
ignore_missing_imports = true
|
||||
no_implicit_optional = true
|
||||
|
||||
[mypy-superset.db_engine_specs.*]
|
||||
check_untyped_defs = true
|
||||
disallow_untyped_calls = true
|
||||
disallow_untyped_defs = true
|
||||
|
|
|
@ -30,16 +30,23 @@ from sqlalchemy import column, DateTime, select
|
|||
from sqlalchemy.engine.base import Engine
|
||||
from sqlalchemy.engine.interfaces import Compiled, Dialect
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.engine.url import URL
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import quoted_name, text
|
||||
from sqlalchemy.sql.expression import ColumnClause, ColumnElement, Select, TextAsFrom
|
||||
from sqlalchemy.types import TypeEngine
|
||||
from wtforms.form import Form
|
||||
|
||||
from superset import app, sql_parse
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.utils import core as utils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# prevent circular imports
|
||||
from superset.connectors.sqla.models import ( # pylint: disable=unused-import
|
||||
TableColumn,
|
||||
)
|
||||
from superset.models.core import Database # pylint: disable=unused-import
|
||||
|
||||
|
||||
|
@ -77,7 +84,7 @@ builtin_time_grains: Dict[Optional[str], str] = {
|
|||
class TimestampExpression(
|
||||
ColumnClause
|
||||
): # pylint: disable=abstract-method,too-many-ancestors,too-few-public-methods
|
||||
def __init__(self, expr: str, col: ColumnClause, **kwargs):
|
||||
def __init__(self, expr: str, col: ColumnClause, **kwargs: Any) -> None:
|
||||
"""Sqlalchemy class that can be can be used to render native column elements
|
||||
respeting engine-specific quoting rules as part of a string-based expression.
|
||||
|
||||
|
@ -89,7 +96,7 @@ class TimestampExpression(
|
|||
self.col = col
|
||||
|
||||
@property
|
||||
def _constructor(self):
|
||||
def _constructor(self) -> ColumnClause:
|
||||
# Needed to ensure that the column label is rendered correctly when
|
||||
# proxied to the outer query.
|
||||
# See https://github.com/sqlalchemy/sqlalchemy/issues/4730
|
||||
|
@ -98,9 +105,9 @@ class TimestampExpression(
|
|||
|
||||
@compiles(TimestampExpression)
|
||||
def compile_timegrain_expression(
|
||||
element: TimestampExpression, compiler: Compiled, **kw
|
||||
element: TimestampExpression, compiler: Compiled, **kwargs: Any
|
||||
) -> str:
|
||||
return element.name.replace("{col}", compiler.process(element.col, **kw))
|
||||
return element.name.replace("{col}", compiler.process(element.col, **kwargs))
|
||||
|
||||
|
||||
class LimitMethod: # pylint: disable=too-few-public-methods
|
||||
|
@ -132,7 +139,12 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
return False
|
||||
|
||||
@classmethod
|
||||
def get_engine(cls, database, schema=None, source=None):
|
||||
def get_engine(
|
||||
cls,
|
||||
database: "Database",
|
||||
schema: Optional[str] = None,
|
||||
source: Optional[str] = None,
|
||||
) -> Engine:
|
||||
user_name = utils.get_username()
|
||||
return database.get_sqla_engine(
|
||||
schema=schema, nullpool=True, user_name=user_name, source=source
|
||||
|
@ -217,7 +229,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
return select_exprs
|
||||
|
||||
@classmethod
|
||||
def fetch_data(cls, cursor, limit: int) -> List[Tuple]:
|
||||
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
|
||||
"""
|
||||
|
||||
:param cursor: Cursor instance
|
||||
|
@ -246,7 +258,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
return columns, data, []
|
||||
|
||||
@classmethod
|
||||
def alter_new_orm_column(cls, orm_col):
|
||||
def alter_new_orm_column(cls, orm_col: "TableColumn") -> None:
|
||||
"""Allow altering default column attributes when first detected/added
|
||||
|
||||
For instance special column like `__time` for Druid can be
|
||||
|
@ -290,7 +302,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
|
||||
@classmethod
|
||||
def extra_table_metadata(
|
||||
cls, database, table_name: str, schema_name: str
|
||||
cls, database: "Database", table_name: str, schema_name: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Returns engine-specific table metadata
|
||||
|
@ -304,7 +316,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
return {}
|
||||
|
||||
@classmethod
|
||||
def apply_limit_to_sql(cls, sql: str, limit: int, database) -> str:
|
||||
def apply_limit_to_sql(cls, sql: str, limit: int, database: "Database") -> str:
|
||||
"""
|
||||
Alters the SQL statement to apply a LIMIT clause
|
||||
|
||||
|
@ -351,7 +363,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
return parsed_query.get_query_with_new_limit(limit)
|
||||
|
||||
@staticmethod
|
||||
def csv_to_df(**kwargs) -> pd.DataFrame:
|
||||
def csv_to_df(**kwargs: Any) -> pd.DataFrame:
|
||||
""" Read csv into Pandas DataFrame
|
||||
:param kwargs: params to be passed to DataFrame.read_csv
|
||||
:return: Pandas DataFrame containing data from csv
|
||||
|
@ -363,7 +375,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
return df
|
||||
|
||||
@classmethod
|
||||
def df_to_sql(cls, df: pd.DataFrame, **kwargs): # pylint: disable=invalid-name
|
||||
def df_to_sql( # pylint: disable=invalid-name
|
||||
cls, df: pd.DataFrame, **kwargs: Any
|
||||
) -> None:
|
||||
""" Upload data from a Pandas DataFrame to a database. For
|
||||
regular engines this calls the DataFrame.to_sql() method. Can be
|
||||
overridden for engines that don't work well with to_sql(), e.g.
|
||||
|
@ -374,7 +388,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
df.to_sql(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def create_table_from_csv(cls, form, database) -> None:
|
||||
def create_table_from_csv(cls, form: Form, database: "Database") -> None:
|
||||
"""
|
||||
Create table from contents of a csv. Note: this method does not create
|
||||
metadata for the table.
|
||||
|
@ -437,7 +451,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
|
||||
@classmethod
|
||||
def get_all_datasource_names(
|
||||
cls, database, datasource_type: str
|
||||
cls, database: "Database", datasource_type: str
|
||||
) -> List[utils.DatasourceName]:
|
||||
"""Returns a list of all tables or views in database.
|
||||
|
||||
|
@ -472,7 +486,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
return all_datasources
|
||||
|
||||
@classmethod
|
||||
def handle_cursor(cls, cursor, query, session):
|
||||
def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None:
|
||||
"""Handle a live cursor between the execute and fetchall calls
|
||||
|
||||
The flow works without this method doing anything, but it allows
|
||||
|
@ -486,13 +500,14 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
return f"{cls.engine} error: {cls._extract_error_message(e)}"
|
||||
|
||||
@classmethod
|
||||
def _extract_error_message(cls, e: Exception) -> str:
|
||||
def _extract_error_message(cls, e: Exception) -> Optional[str]:
|
||||
"""Extract error message for queries"""
|
||||
return utils.error_msg_from_exception(e)
|
||||
|
||||
@classmethod
|
||||
def adjust_database_uri(cls, uri, selected_schema: Optional[str]):
|
||||
"""Based on a URI and selected schema, return a new URI
|
||||
def adjust_database_uri(cls, uri: URL, selected_schema: Optional[str]) -> None:
|
||||
"""
|
||||
Mutate the database component of the SQLAlchemy URI.
|
||||
|
||||
The URI here represents the URI as entered when saving the database,
|
||||
``selected_schema`` is the schema currently active presumably in
|
||||
|
@ -509,11 +524,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
Some database drivers like presto accept '{catalog}/{schema}' in
|
||||
the database component of the URL, that can be handled here.
|
||||
"""
|
||||
# TODO: All overrides mutate input uri; should be renamed or refactored
|
||||
return uri
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def patch(cls):
|
||||
def patch(cls) -> None:
|
||||
"""
|
||||
TODO: Improve docstring and refactor implementation in Hive
|
||||
"""
|
||||
|
@ -580,7 +594,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
cls,
|
||||
table_name: str,
|
||||
schema: Optional[str],
|
||||
database,
|
||||
database: "Database",
|
||||
query: Select,
|
||||
columns: Optional[List] = None,
|
||||
) -> Optional[Select]:
|
||||
|
@ -599,13 +613,13 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
return None
|
||||
|
||||
@classmethod
|
||||
def _get_fields(cls, cols):
|
||||
return [column(c.get("name")) for c in cols]
|
||||
def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[Any]:
|
||||
return [column(c["name"]) for c in cols]
|
||||
|
||||
@classmethod
|
||||
def select_star( # pylint: disable=too-many-arguments,too-many-locals
|
||||
cls,
|
||||
database,
|
||||
database: "Database",
|
||||
table_name: str,
|
||||
engine: Engine,
|
||||
schema: Optional[str] = None,
|
||||
|
@ -629,7 +643,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
:param cols: Columns to include in query
|
||||
:return: SQL query
|
||||
"""
|
||||
fields = "*"
|
||||
fields: Union[str, List[Any]] = "*"
|
||||
cols = cols or []
|
||||
if (show_cols or latest_partition) and not cols:
|
||||
cols = database.get_columns(table_name, schema)
|
||||
|
@ -659,7 +673,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
|
||||
@classmethod
|
||||
def estimate_statement_cost(
|
||||
cls, statement: str, database, cursor, user_name: str
|
||||
cls, statement: str, database: "Database", cursor: Any, user_name: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate a SQL query that estimates the cost of a given statement.
|
||||
|
@ -686,7 +700,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
|
||||
@classmethod
|
||||
def estimate_query_cost(
|
||||
cls, database, schema: str, sql: str, source: Optional[str] = None
|
||||
cls, database: "Database", schema: str, sql: str, source: Optional[str] = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Estimate the cost of a multiple statement SQL query.
|
||||
|
@ -718,8 +732,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
|
||||
@classmethod
|
||||
def modify_url_for_impersonation(
|
||||
cls, url, impersonate_user: bool, username: Optional[str]
|
||||
):
|
||||
cls, url: URL, impersonate_user: bool, username: Optional[str]
|
||||
) -> None:
|
||||
"""
|
||||
Modify the SQL Alchemy URL object with the user to impersonate if applicable.
|
||||
:param url: SQLAlchemy URL object
|
||||
|
@ -745,7 +759,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
return {}
|
||||
|
||||
@classmethod
|
||||
def execute(cls, cursor, query: str, **kwargs):
|
||||
def execute(cls, cursor: Any, query: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Execute a SQL query
|
||||
|
||||
|
|
|
@ -17,13 +17,17 @@
|
|||
import hashlib
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
import pandas as pd
|
||||
from sqlalchemy import literal_column
|
||||
from sqlalchemy.sql.expression import ColumnClause
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.models.core import Database # pylint: disable=unused-import
|
||||
|
||||
|
||||
class BigQueryEngineSpec(BaseEngineSpec):
|
||||
"""Engine spec for Google's BigQuery
|
||||
|
@ -69,7 +73,7 @@ class BigQueryEngineSpec(BaseEngineSpec):
|
|||
return None
|
||||
|
||||
@classmethod
|
||||
def fetch_data(cls, cursor, limit: int) -> List[Tuple]:
|
||||
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
|
||||
data = super(BigQueryEngineSpec, cls).fetch_data(cursor, limit)
|
||||
if data and type(data[0]).__name__ == "Row":
|
||||
data = [r.values() for r in data] # type: ignore
|
||||
|
@ -112,7 +116,7 @@ class BigQueryEngineSpec(BaseEngineSpec):
|
|||
|
||||
@classmethod
|
||||
def extra_table_metadata(
|
||||
cls, database, table_name: str, schema_name: str
|
||||
cls, database: "Database", table_name: str, schema_name: str
|
||||
) -> Dict[str, Any]:
|
||||
indexes = database.get_indexes(table_name, schema_name)
|
||||
if not indexes:
|
||||
|
@ -133,7 +137,7 @@ class BigQueryEngineSpec(BaseEngineSpec):
|
|||
}
|
||||
|
||||
@classmethod
|
||||
def _get_fields(cls, cols):
|
||||
def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[ColumnClause]:
|
||||
"""
|
||||
BigQuery dialect requires us to not use backtick in the fieldname which are
|
||||
nested.
|
||||
|
@ -143,8 +147,7 @@ class BigQueryEngineSpec(BaseEngineSpec):
|
|||
column names in the result.
|
||||
"""
|
||||
return [
|
||||
literal_column(c.get("name")).label(c.get("name").replace(".", "__"))
|
||||
for c in cols
|
||||
literal_column(c["name"]).label(c["name"].replace(".", "__")) for c in cols
|
||||
]
|
||||
|
||||
@classmethod
|
||||
|
@ -156,7 +159,7 @@ class BigQueryEngineSpec(BaseEngineSpec):
|
|||
return "TIMESTAMP_MILLIS({col})"
|
||||
|
||||
@classmethod
|
||||
def df_to_sql(cls, df: pd.DataFrame, **kwargs):
|
||||
def df_to_sql(cls, df: pd.DataFrame, **kwargs: Any) -> None:
|
||||
"""
|
||||
Upload data from a Pandas DataFrame to BigQuery. Calls
|
||||
`DataFrame.to_gbq()` which requires `pandas_gbq` to be installed.
|
||||
|
|
|
@ -18,6 +18,8 @@ from datetime import datetime
|
|||
from typing import Optional
|
||||
from urllib import parse
|
||||
|
||||
from sqlalchemy.engine.url import URL
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
|
||||
|
@ -59,7 +61,6 @@ class DrillEngineSpec(BaseEngineSpec):
|
|||
return None
|
||||
|
||||
@classmethod
|
||||
def adjust_database_uri(cls, uri, selected_schema):
|
||||
def adjust_database_uri(cls, uri: URL, selected_schema: Optional[str]) -> None:
|
||||
if selected_schema:
|
||||
uri.database = parse.quote(selected_schema, safe="")
|
||||
return uri
|
||||
|
|
|
@ -14,8 +14,15 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.connectors.sqla.models import ( # pylint: disable=unused-import
|
||||
TableColumn,
|
||||
)
|
||||
|
||||
|
||||
class DruidEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
|
||||
"""Engine spec for Druid.io"""
|
||||
|
@ -37,6 +44,6 @@ class DruidEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
|
|||
}
|
||||
|
||||
@classmethod
|
||||
def alter_new_orm_column(cls, orm_col):
|
||||
def alter_new_orm_column(cls, orm_col: "TableColumn") -> None:
|
||||
if orm_col.column_name == "__time":
|
||||
orm_col.is_dttm = True
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import List, Tuple
|
||||
from typing import Any, List, Tuple
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
|
@ -39,7 +39,7 @@ class ExasolEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
|
|||
}
|
||||
|
||||
@classmethod
|
||||
def fetch_data(cls, cursor, limit: int) -> List[Tuple]:
|
||||
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
|
||||
data = super().fetch_data(cursor, limit)
|
||||
# Lists of `pyodbc.Row` need to be unpacked further
|
||||
return cls.pyodbc_rows_to_tuples(data)
|
||||
|
|
|
@ -22,15 +22,19 @@ from datetime import datetime
|
|||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from urllib import parse
|
||||
|
||||
import pandas as pd
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy.engine.base import Engine
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.engine.url import make_url
|
||||
from sqlalchemy.engine.url import make_url, URL
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.expression import ColumnClause, Select
|
||||
from wtforms.form import Form
|
||||
|
||||
from superset import app, cache, conf
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.db_engine_specs.presto import PrestoEngineSpec
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.utils import core as utils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -67,7 +71,7 @@ class HiveEngineSpec(PrestoEngineSpec):
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def patch(cls):
|
||||
def patch(cls) -> None:
|
||||
from pyhive import hive # pylint: disable=no-name-in-module
|
||||
from superset.db_engines import hive as patched_hive
|
||||
from TCLIService import (
|
||||
|
@ -83,12 +87,12 @@ class HiveEngineSpec(PrestoEngineSpec):
|
|||
|
||||
@classmethod
|
||||
def get_all_datasource_names(
|
||||
cls, database, datasource_type: str
|
||||
cls, database: "Database", datasource_type: str
|
||||
) -> List[utils.DatasourceName]:
|
||||
return BaseEngineSpec.get_all_datasource_names(database, datasource_type)
|
||||
|
||||
@classmethod
|
||||
def fetch_data(cls, cursor, limit: int) -> List[Tuple]:
|
||||
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
|
||||
import pyhive
|
||||
from TCLIService import ttypes
|
||||
|
||||
|
@ -102,11 +106,11 @@ class HiveEngineSpec(PrestoEngineSpec):
|
|||
|
||||
@classmethod
|
||||
def create_table_from_csv( # pylint: disable=too-many-locals
|
||||
cls, form, database
|
||||
cls, form: Form, database: "Database"
|
||||
) -> None:
|
||||
"""Uploads a csv file and creates a superset datasource in Hive."""
|
||||
|
||||
def convert_to_hive_type(col_type):
|
||||
def convert_to_hive_type(col_type: str) -> str:
|
||||
"""maps tableschema's types to hive types"""
|
||||
tableschema_to_hive_types = {
|
||||
"boolean": "BOOLEAN",
|
||||
|
@ -192,13 +196,14 @@ class HiveEngineSpec(PrestoEngineSpec):
|
|||
return None
|
||||
|
||||
@classmethod
|
||||
def adjust_database_uri(cls, uri, selected_schema=None):
|
||||
def adjust_database_uri(
|
||||
cls, uri: URL, selected_schema: Optional[str] = None
|
||||
) -> None:
|
||||
if selected_schema:
|
||||
uri.database = parse.quote(selected_schema, safe="")
|
||||
return uri
|
||||
|
||||
@classmethod
|
||||
def _extract_error_message(cls, e):
|
||||
def _extract_error_message(cls, e: Exception) -> str:
|
||||
msg = str(e)
|
||||
match = re.search(r'errorMessage="(.*?)(?<!\\)"', msg)
|
||||
if match:
|
||||
|
@ -206,10 +211,10 @@ class HiveEngineSpec(PrestoEngineSpec):
|
|||
return msg
|
||||
|
||||
@classmethod
|
||||
def progress(cls, log_lines):
|
||||
def progress(cls, log_lines: List[str]) -> int:
|
||||
total_jobs = 1 # assuming there's at least 1 job
|
||||
current_job = 1
|
||||
stages = {}
|
||||
stages: Dict[int, float] = {}
|
||||
for line in log_lines:
|
||||
match = cls.jobs_stats_r.match(line)
|
||||
if match:
|
||||
|
@ -237,15 +242,17 @@ class HiveEngineSpec(PrestoEngineSpec):
|
|||
return int(progress)
|
||||
|
||||
@classmethod
|
||||
def get_tracking_url(cls, log_lines):
|
||||
def get_tracking_url(cls, log_lines: List[str]) -> Optional[str]:
|
||||
lkp = "Tracking URL = "
|
||||
for line in log_lines:
|
||||
if lkp in line:
|
||||
return line.split(lkp)[1]
|
||||
return None
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def handle_cursor(cls, cursor, query, session): # pylint: disable=too-many-locals
|
||||
def handle_cursor( # pylint: disable=too-many-locals
|
||||
cls, cursor: Any, query: Query, session: Session
|
||||
) -> None:
|
||||
"""Updates progress information"""
|
||||
from pyhive import hive # pylint: disable=no-name-in-module
|
||||
|
||||
|
@ -310,7 +317,7 @@ class HiveEngineSpec(PrestoEngineSpec):
|
|||
cls,
|
||||
table_name: str,
|
||||
schema: Optional[str],
|
||||
database,
|
||||
database: "Database",
|
||||
query: Select,
|
||||
columns: Optional[List] = None,
|
||||
) -> Optional[Select]:
|
||||
|
@ -335,12 +342,14 @@ class HiveEngineSpec(PrestoEngineSpec):
|
|||
return BaseEngineSpec._get_fields(cols) # pylint: disable=protected-access
|
||||
|
||||
@classmethod
|
||||
def latest_sub_partition(cls, table_name, schema, database, **kwargs):
|
||||
def latest_sub_partition(
|
||||
cls, table_name: str, schema: Optional[str], database: "Database", **kwargs: Any
|
||||
) -> str:
|
||||
# TODO(bogdan): implement`
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _latest_partition_from_df(cls, df) -> Optional[List[str]]:
|
||||
def _latest_partition_from_df(cls, df: pd.DataFrame) -> Optional[List[str]]:
|
||||
"""Hive partitions look like ds={partition name}"""
|
||||
if not df.empty:
|
||||
return [df.ix[:, 0].max().split("=")[1]]
|
||||
|
@ -348,14 +357,19 @@ class HiveEngineSpec(PrestoEngineSpec):
|
|||
|
||||
@classmethod
|
||||
def _partition_query( # pylint: disable=too-many-arguments
|
||||
cls, table_name, database, limit=0, order_by=None, filters=None
|
||||
):
|
||||
cls,
|
||||
table_name: str,
|
||||
database: "Database",
|
||||
limit: int = 0,
|
||||
order_by: Optional[List[Tuple[str, bool]]] = None,
|
||||
filters: Optional[Dict[Any, Any]] = None,
|
||||
) -> str:
|
||||
return f"SHOW PARTITIONS {table_name}"
|
||||
|
||||
@classmethod
|
||||
def select_star( # pylint: disable=too-many-arguments
|
||||
cls,
|
||||
database,
|
||||
database: "Database",
|
||||
table_name: str,
|
||||
engine: Engine,
|
||||
schema: Optional[str] = None,
|
||||
|
@ -381,8 +395,8 @@ class HiveEngineSpec(PrestoEngineSpec):
|
|||
|
||||
@classmethod
|
||||
def modify_url_for_impersonation(
|
||||
cls, url, impersonate_user: bool, username: Optional[str]
|
||||
):
|
||||
cls, url: URL, impersonate_user: bool, username: Optional[str]
|
||||
) -> None:
|
||||
"""
|
||||
Modify the SQL Alchemy URL object with the user to impersonate if applicable.
|
||||
:param url: SQLAlchemy URL object
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
# under the License.
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
from sqlalchemy.engine.interfaces import Dialect
|
||||
from sqlalchemy.types import String, TypeEngine, UnicodeText
|
||||
|
@ -46,7 +46,7 @@ class MssqlEngineSpec(BaseEngineSpec):
|
|||
}
|
||||
|
||||
@classmethod
|
||||
def epoch_to_dttm(cls):
|
||||
def epoch_to_dttm(cls) -> str:
|
||||
return "dateadd(S, {col}, '1970-01-01')"
|
||||
|
||||
@classmethod
|
||||
|
@ -61,7 +61,7 @@ class MssqlEngineSpec(BaseEngineSpec):
|
|||
return None
|
||||
|
||||
@classmethod
|
||||
def fetch_data(cls, cursor, limit: int) -> List[Tuple]:
|
||||
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
|
||||
data = super().fetch_data(cursor, limit)
|
||||
# Lists of `pyodbc.Row` need to be unpacked further
|
||||
return cls.pyodbc_rows_to_tuples(data)
|
||||
|
|
|
@ -19,6 +19,7 @@ from typing import Any, Dict, Optional
|
|||
from urllib import parse
|
||||
|
||||
from sqlalchemy.engine.interfaces import Dialect
|
||||
from sqlalchemy.engine.url import URL
|
||||
from sqlalchemy.types import TypeEngine
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
@ -59,10 +60,11 @@ class MySQLEngineSpec(BaseEngineSpec):
|
|||
return None
|
||||
|
||||
@classmethod
|
||||
def adjust_database_uri(cls, uri, selected_schema=None):
|
||||
def adjust_database_uri(
|
||||
cls, uri: URL, selected_schema: Optional[str] = None
|
||||
) -> None:
|
||||
if selected_schema:
|
||||
uri.database = parse.quote(selected_schema, safe="")
|
||||
return uri
|
||||
|
||||
@classmethod
|
||||
def get_datatype(cls, type_code: Any) -> Optional[str]:
|
||||
|
@ -86,7 +88,7 @@ class MySQLEngineSpec(BaseEngineSpec):
|
|||
return "from_unixtime({col})"
|
||||
|
||||
@classmethod
|
||||
def _extract_error_message(cls, e):
|
||||
def _extract_error_message(cls, e: Exception) -> str:
|
||||
"""Extract error message for queries"""
|
||||
message = str(e)
|
||||
try:
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Tuple, TYPE_CHECKING
|
||||
from typing import Any, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from pytz import _FixedOffset # type: ignore
|
||||
from sqlalchemy.dialects.postgresql.base import PGInspector
|
||||
|
@ -51,7 +51,7 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
|
|||
}
|
||||
|
||||
@classmethod
|
||||
def fetch_data(cls, cursor, limit: int) -> List[Tuple]:
|
||||
def fetch_data(cls, cursor: Any, limit: int) -> List[Tuple]:
|
||||
cursor.tzinfo_factory = FixedOffsetTimezone
|
||||
if not cursor.description:
|
||||
return []
|
||||
|
|
|
@ -25,16 +25,20 @@ from distutils.version import StrictVersion
|
|||
from typing import Any, cast, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from urllib import parse
|
||||
|
||||
import pandas as pd
|
||||
import simplejson as json
|
||||
from sqlalchemy import Column, literal_column
|
||||
from sqlalchemy.engine.base import Engine
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.engine.result import RowProxy
|
||||
from sqlalchemy.engine.url import URL
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.expression import ColumnClause, Select
|
||||
|
||||
from superset import app, cache, is_feature_enabled, security_manager
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.exceptions import SupersetTemplateException
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.models.sql_types.presto_sql_types import type_map as presto_type_map
|
||||
from superset.sql_parse import ParsedQuery
|
||||
from superset.utils import core as utils
|
||||
|
@ -392,7 +396,7 @@ class PrestoEngineSpec(BaseEngineSpec):
|
|||
@classmethod
|
||||
def select_star( # pylint: disable=too-many-arguments
|
||||
cls,
|
||||
database,
|
||||
database: "Database",
|
||||
table_name: str,
|
||||
engine: Engine,
|
||||
schema: Optional[str] = None,
|
||||
|
@ -428,7 +432,7 @@ class PrestoEngineSpec(BaseEngineSpec):
|
|||
|
||||
@classmethod
|
||||
def estimate_statement_cost( # pylint: disable=too-many-locals
|
||||
cls, statement: str, database, cursor, user_name: str
|
||||
cls, statement: str, database: "Database", cursor: Any, user_name: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run a SQL query that estimates the cost of a given statement.
|
||||
|
@ -510,7 +514,9 @@ class PrestoEngineSpec(BaseEngineSpec):
|
|||
return cost
|
||||
|
||||
@classmethod
|
||||
def adjust_database_uri(cls, uri, selected_schema=None):
|
||||
def adjust_database_uri(
|
||||
cls, uri: URL, selected_schema: Optional[str] = None
|
||||
) -> None:
|
||||
database = uri.database
|
||||
if selected_schema and database:
|
||||
selected_schema = parse.quote(selected_schema, safe="")
|
||||
|
@ -519,7 +525,6 @@ class PrestoEngineSpec(BaseEngineSpec):
|
|||
else:
|
||||
database += "/" + selected_schema
|
||||
uri.database = database
|
||||
return uri
|
||||
|
||||
@classmethod
|
||||
def convert_dttm(cls, target_type: str, dttm: datetime) -> Optional[str]:
|
||||
|
@ -536,7 +541,7 @@ class PrestoEngineSpec(BaseEngineSpec):
|
|||
|
||||
@classmethod
|
||||
def get_all_datasource_names(
|
||||
cls, database, datasource_type: str
|
||||
cls, database: "Database", datasource_type: str
|
||||
) -> List[utils.DatasourceName]:
|
||||
datasource_df = database.get_df(
|
||||
"SELECT table_schema, table_name FROM INFORMATION_SCHEMA.{}S "
|
||||
|
@ -656,7 +661,7 @@ class PrestoEngineSpec(BaseEngineSpec):
|
|||
|
||||
@classmethod
|
||||
def extra_table_metadata(
|
||||
cls, database, table_name: str, schema_name: str
|
||||
cls, database: "Database", table_name: str, schema_name: str
|
||||
) -> Dict[str, Any]:
|
||||
metadata = {}
|
||||
|
||||
|
@ -670,10 +675,12 @@ class PrestoEngineSpec(BaseEngineSpec):
|
|||
col_names, latest_parts = cls.latest_partition(
|
||||
table_name, schema_name, database, show_first=True
|
||||
)
|
||||
latest_parts = latest_parts or tuple([None] * len(col_names))
|
||||
|
||||
if not latest_parts:
|
||||
latest_parts = tuple([None] * len(col_names)) # type: ignore
|
||||
metadata["partitions"] = {
|
||||
"cols": cols,
|
||||
"latest": dict(zip(col_names, latest_parts)),
|
||||
"latest": dict(zip(col_names, latest_parts)), # type: ignore
|
||||
"partitionQuery": pql,
|
||||
}
|
||||
|
||||
|
@ -685,7 +692,9 @@ class PrestoEngineSpec(BaseEngineSpec):
|
|||
return metadata
|
||||
|
||||
@classmethod
|
||||
def get_create_view(cls, database, schema: str, table: str) -> Optional[str]:
|
||||
def get_create_view(
|
||||
cls, database: "Database", schema: str, table: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Return a CREATE VIEW statement, or `None` if not a view.
|
||||
|
||||
|
@ -712,7 +721,7 @@ class PrestoEngineSpec(BaseEngineSpec):
|
|||
return rows[0][0]
|
||||
|
||||
@classmethod
|
||||
def handle_cursor(cls, cursor, query, session):
|
||||
def handle_cursor(cls, cursor: Any, query: Query, session: Session) -> None:
|
||||
"""Updates progress information"""
|
||||
query_id = query.id
|
||||
logger.info(f"Query {query_id}: Polling the cursor for progress")
|
||||
|
@ -753,13 +762,13 @@ class PrestoEngineSpec(BaseEngineSpec):
|
|||
polled = cursor.poll()
|
||||
|
||||
@classmethod
|
||||
def _extract_error_message(cls, e):
|
||||
def _extract_error_message(cls, e: Exception) -> Optional[str]:
|
||||
if (
|
||||
hasattr(e, "orig")
|
||||
and type(e.orig).__name__ == "DatabaseError"
|
||||
and isinstance(e.orig[0], dict)
|
||||
and type(e.orig).__name__ == "DatabaseError" # type: ignore
|
||||
and isinstance(e.orig[0], dict) # type: ignore
|
||||
):
|
||||
error_dict = e.orig[0]
|
||||
error_dict = e.orig[0] # type: ignore
|
||||
return "{} at {}: {}".format(
|
||||
error_dict.get("errorName"),
|
||||
error_dict.get("errorLocation"),
|
||||
|
@ -772,8 +781,13 @@ class PrestoEngineSpec(BaseEngineSpec):
|
|||
|
||||
@classmethod
|
||||
def _partition_query( # pylint: disable=too-many-arguments,too-many-locals
|
||||
cls, table_name, database, limit=0, order_by=None, filters=None
|
||||
):
|
||||
cls,
|
||||
table_name: str,
|
||||
database: "Database",
|
||||
limit: int = 0,
|
||||
order_by: Optional[List[Tuple[str, bool]]] = None,
|
||||
filters: Optional[Dict[Any, Any]] = None,
|
||||
) -> str:
|
||||
"""Returns a partition query
|
||||
|
||||
:param table_name: the name of the table to get partitions from
|
||||
|
@ -827,7 +841,7 @@ class PrestoEngineSpec(BaseEngineSpec):
|
|||
cls,
|
||||
table_name: str,
|
||||
schema: Optional[str],
|
||||
database,
|
||||
database: "Database",
|
||||
query: Select,
|
||||
columns: Optional[List] = None,
|
||||
) -> Optional[Select]:
|
||||
|
@ -850,7 +864,7 @@ class PrestoEngineSpec(BaseEngineSpec):
|
|||
|
||||
@classmethod
|
||||
def _latest_partition_from_df( # pylint: disable=invalid-name
|
||||
cls, df
|
||||
cls, df: pd.DataFrame
|
||||
) -> Optional[List[str]]:
|
||||
if not df.empty:
|
||||
return df.to_records(index=False)[0].item()
|
||||
|
@ -858,8 +872,12 @@ class PrestoEngineSpec(BaseEngineSpec):
|
|||
|
||||
@classmethod
|
||||
def latest_partition(
|
||||
cls, table_name: str, schema: Optional[str], database, show_first: bool = False
|
||||
):
|
||||
cls,
|
||||
table_name: str,
|
||||
schema: Optional[str],
|
||||
database: "Database",
|
||||
show_first: bool = False,
|
||||
) -> Tuple[List[str], Optional[List[str]]]:
|
||||
"""Returns col name and the latest (max) partition value for a table
|
||||
|
||||
:param table_name: the name of the table
|
||||
|
@ -897,7 +915,9 @@ class PrestoEngineSpec(BaseEngineSpec):
|
|||
return column_names, cls._latest_partition_from_df(df)
|
||||
|
||||
@classmethod
|
||||
def latest_sub_partition(cls, table_name, schema, database, **kwargs):
|
||||
def latest_sub_partition(
|
||||
cls, table_name: str, schema: Optional[str], database: "Database", **kwargs: Any
|
||||
) -> Any:
|
||||
"""Returns the latest (max) partition value for a table
|
||||
|
||||
A filtering criteria should be passed for all fields that are
|
||||
|
|
|
@ -18,6 +18,8 @@ from datetime import datetime
|
|||
from typing import Optional
|
||||
from urllib import parse
|
||||
|
||||
from sqlalchemy.engine.url import URL
|
||||
|
||||
from superset.db_engine_specs.postgres import PostgresBaseEngineSpec
|
||||
|
||||
|
||||
|
@ -47,14 +49,15 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
|||
}
|
||||
|
||||
@classmethod
|
||||
def adjust_database_uri(cls, uri, selected_schema=None):
|
||||
def adjust_database_uri(
|
||||
cls, uri: URL, selected_schema: Optional[str] = None
|
||||
) -> None:
|
||||
database = uri.database
|
||||
if "/" in uri.database:
|
||||
database = uri.database.split("/")[0]
|
||||
if selected_schema:
|
||||
selected_schema = parse.quote(selected_schema, safe="")
|
||||
uri.database = database + "/" + selected_schema
|
||||
return uri
|
||||
|
||||
@classmethod
|
||||
def epoch_to_dttm(cls) -> str:
|
||||
|
|
|
@ -49,7 +49,7 @@ class SqliteEngineSpec(BaseEngineSpec):
|
|||
|
||||
@classmethod
|
||||
def get_all_datasource_names(
|
||||
cls, database, datasource_type: str
|
||||
cls, database: "Database", datasource_type: str
|
||||
) -> List[utils.DatasourceName]:
|
||||
schemas = database.get_all_schema_names(
|
||||
cache=database.schema_cache_enabled,
|
||||
|
|
|
@ -282,7 +282,7 @@ class Database(
|
|||
) -> Engine:
|
||||
extra = self.get_extra()
|
||||
sqlalchemy_url = make_url(self.sqlalchemy_uri_decrypted)
|
||||
sqlalchemy_url = self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema)
|
||||
self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema)
|
||||
effective_username = self.get_effective_user(sqlalchemy_url, user_name)
|
||||
# If using MySQL or Presto for example, will set url.username
|
||||
# If using Hive, will not do anything yet since that relies on a
|
||||
|
|
|
@ -51,7 +51,7 @@ def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
|
|||
|
||||
|
||||
class ParsedQuery:
|
||||
def __init__(self, sql_statement):
|
||||
def __init__(self, sql_statement: str):
|
||||
self.sql: str = sql_statement
|
||||
self._table_names: Set[str] = set()
|
||||
self._alias_names: Set[str] = set()
|
||||
|
|
|
@ -414,7 +414,7 @@ def json_dumps_w_dates(payload):
|
|||
return json.dumps(payload, default=json_int_dttm_ser)
|
||||
|
||||
|
||||
def error_msg_from_exception(e):
|
||||
def error_msg_from_exception(e: Exception) -> str:
|
||||
"""Translate exception into error message
|
||||
|
||||
Database have different ways to handle exception. This function attempts
|
||||
|
@ -430,10 +430,10 @@ def error_msg_from_exception(e):
|
|||
"""
|
||||
msg = ""
|
||||
if hasattr(e, "message"):
|
||||
if isinstance(e.message, dict):
|
||||
msg = e.message.get("message")
|
||||
elif e.message:
|
||||
msg = e.message
|
||||
if isinstance(e.message, dict): # type: ignore
|
||||
msg = e.message.get("message") # type: ignore
|
||||
elif e.message: # type: ignore
|
||||
msg = e.message # type: ignore
|
||||
return msg or str(e)
|
||||
|
||||
|
||||
|
|
|
@ -34,6 +34,6 @@ class FeatureFlagManager:
|
|||
|
||||
return self._feature_flags
|
||||
|
||||
def is_feature_enabled(self, feature):
|
||||
def is_feature_enabled(self, feature) -> bool:
|
||||
"""Utility function for checking whether a feature is turned on"""
|
||||
return self.get_feature_flags().get(feature)
|
||||
|
|
Loading…
Reference in New Issue