Make timestamp expression native SQLAlchemy element (#7131)

* Add native sqla component for time expressions

* Add unit tests and remove old tests

* Remove redundant _grains_dict method

* Clarify time_grain logic

* Add docstrings and typing

* Fix flake8 errors

* Add missing typings

* Rename to TimestampExpression

* Remove redundant tests

* Fix broken reference to db.database_name due to refactor
This commit is contained in:
Ville Brofeldt 2019-05-30 08:28:37 +03:00 committed by GitHub
parent fc3b043462
commit 34407e8962
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 128 additions and 121 deletions

View File

@ -18,6 +18,7 @@
from collections import namedtuple, OrderedDict from collections import namedtuple, OrderedDict
from datetime import datetime from datetime import datetime
import logging import logging
from typing import Optional, Union
from flask import escape, Markup from flask import escape, Markup
from flask_appbuilder import Model from flask_appbuilder import Model
@ -32,11 +33,12 @@ from sqlalchemy.exc import CompileError
from sqlalchemy.orm import backref, relationship from sqlalchemy.orm import backref, relationship
from sqlalchemy.schema import UniqueConstraint from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import column, literal_column, table, text from sqlalchemy.sql import column, literal_column, table, text
from sqlalchemy.sql.expression import TextAsFrom from sqlalchemy.sql.expression import Label, TextAsFrom
import sqlparse import sqlparse
from superset import app, db, security_manager from superset import app, db, security_manager
from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric
from superset.db_engine_specs import TimestampExpression
from superset.jinja_context import get_template_processor from superset.jinja_context import get_template_processor
from superset.models.annotations import Annotation from superset.models.annotations import Annotation
from superset.models.core import Database from superset.models.core import Database
@ -140,8 +142,14 @@ class TableColumn(Model, BaseColumn):
l.append(col <= text(self.dttm_sql_literal(end_dttm, is_epoch_in_utc))) l.append(col <= text(self.dttm_sql_literal(end_dttm, is_epoch_in_utc)))
return and_(*l) return and_(*l)
def get_timestamp_expression(self, time_grain): def get_timestamp_expression(self, time_grain: Optional[str]) \
"""Getting the time component of the query""" -> Union[TimestampExpression, Label]:
"""
Return a SQLAlchemy Core element representation of self to be used in a query.
:param time_grain: Optional time grain, e.g. P1Y
:return: A TimeExpression object wrapped in a Label if supported by db
"""
label = utils.DTTM_ALIAS label = utils.DTTM_ALIAS
db = self.table.database db = self.table.database
@ -150,16 +158,12 @@ class TableColumn(Model, BaseColumn):
if not self.expression and not time_grain and not is_epoch: if not self.expression and not time_grain and not is_epoch:
sqla_col = column(self.column_name, type_=DateTime) sqla_col = column(self.column_name, type_=DateTime)
return self.table.make_sqla_column_compatible(sqla_col, label) return self.table.make_sqla_column_compatible(sqla_col, label)
grain = None if self.expression:
if time_grain: col = literal_column(self.expression)
grain = db.grains_dict().get(time_grain) else:
if not grain: col = column(self.column_name)
raise NotImplementedError( time_expr = db.db_engine_spec.get_timestamp_expr(col, pdf, time_grain)
f'No grain spec for {time_grain} for database {db.database_name}') return self.table.make_sqla_column_compatible(time_expr, label)
col = db.db_engine_spec.get_timestamp_column(self.expression, self.column_name)
expr = db.db_engine_spec.get_time_expr(col, pdf, time_grain, grain)
sqla_col = literal_column(expr, type_=DateTime)
return self.table.make_sqla_column_compatible(sqla_col, label)
@classmethod @classmethod
def import_obj(cls, i_column): def import_obj(cls, i_column):

View File

@ -36,19 +36,20 @@ import os
import re import re
import textwrap import textwrap
import time import time
from typing import List, Tuple from typing import Dict, List, Optional, Tuple
from urllib import parse from urllib import parse
from flask import g from flask import g
from flask_babel import lazy_gettext as _ from flask_babel import lazy_gettext as _
import pandas import pandas
import sqlalchemy as sqla import sqlalchemy as sqla
from sqlalchemy import Column, select, types from sqlalchemy import Column, DateTime, select, types
from sqlalchemy.engine import create_engine from sqlalchemy.engine import create_engine
from sqlalchemy.engine.base import Engine from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.result import RowProxy from sqlalchemy.engine.result import RowProxy
from sqlalchemy.engine.url import make_url from sqlalchemy.engine.url import make_url
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import quoted_name, text from sqlalchemy.sql import quoted_name, text
from sqlalchemy.sql.expression import ColumnClause from sqlalchemy.sql.expression import ColumnClause
from sqlalchemy.sql.expression import TextAsFrom from sqlalchemy.sql.expression import TextAsFrom
@ -90,6 +91,24 @@ builtin_time_grains = {
} }
class TimestampExpression(ColumnClause):
def __init__(self, expr: str, col: ColumnClause, **kwargs):
"""Sqlalchemy class that can be can be used to render native column elements
respeting engine-specific quoting rules as part of a string-based expression.
:param expr: Sql expression with '{col}' denoting the locations where the col
object will be rendered.
:param col: the target column
"""
super().__init__(expr, **kwargs)
self.col = col
@compiles(TimestampExpression)
def compile_timegrain_expression(element: TimestampExpression, compiler, **kw):
return element.name.replace('{col}', compiler.process(element.col, **kw))
def _create_time_grains_tuple(time_grains, time_grain_functions, blacklist): def _create_time_grains_tuple(time_grains, time_grain_functions, blacklist):
ret_list = [] ret_list = []
blacklist = blacklist if blacklist else [] blacklist = blacklist if blacklist else []
@ -112,7 +131,7 @@ class BaseEngineSpec(object):
"""Abstract class for database engine specific configurations""" """Abstract class for database engine specific configurations"""
engine = 'base' # str as defined in sqlalchemy.engine.engine engine = 'base' # str as defined in sqlalchemy.engine.engine
time_grain_functions: dict = {} time_grain_functions: Dict[Optional[str], str] = {}
time_groupby_inline = False time_groupby_inline = False
limit_method = LimitMethod.FORCE_LIMIT limit_method = LimitMethod.FORCE_LIMIT
time_secondary_columns = False time_secondary_columns = False
@ -125,16 +144,31 @@ class BaseEngineSpec(object):
try_remove_schema_from_table_name = True try_remove_schema_from_table_name = True
@classmethod @classmethod
def get_time_expr(cls, expr, pdf, time_grain, grain): def get_timestamp_expr(cls, col: ColumnClause, pdf: Optional[str],
time_grain: Optional[str]) -> TimestampExpression:
"""
Construct a TimeExpression to be used in a SQLAlchemy query.
:param col: Target column for the TimeExpression
:param pdf: date format (seconds or milliseconds)
:param time_grain: time grain, e.g. P1Y for 1 year
:return: TimestampExpression object
"""
if time_grain:
time_expr = cls.time_grain_functions.get(time_grain)
if not time_expr:
raise NotImplementedError(
f'No grain spec for {time_grain} for database {cls.engine}')
else:
time_expr = '{col}'
# if epoch, translate to DATE using db specific conf # if epoch, translate to DATE using db specific conf
if pdf == 'epoch_s': if pdf == 'epoch_s':
expr = cls.epoch_to_dttm().format(col=expr) time_expr = time_expr.replace('{col}', cls.epoch_to_dttm())
elif pdf == 'epoch_ms': elif pdf == 'epoch_ms':
expr = cls.epoch_ms_to_dttm().format(col=expr) time_expr = time_expr.replace('{col}', cls.epoch_ms_to_dttm())
if grain: return TimestampExpression(time_expr, col, type_=DateTime)
expr = grain.function.format(col=expr)
return expr
@classmethod @classmethod
def get_time_grains(cls): def get_time_grains(cls):
@ -489,13 +523,6 @@ class BaseEngineSpec(object):
label = label[:cls.max_column_name_length] label = label[:cls.max_column_name_length]
return label return label
@staticmethod
def get_timestamp_column(expression, column_name):
"""Return the expression if defined, otherwise return column_name. Some
engines require forcing quotes around column name, in which case this method
can be overridden."""
return expression or column_name
class PostgresBaseEngineSpec(BaseEngineSpec): class PostgresBaseEngineSpec(BaseEngineSpec):
""" Abstract class for Postgres 'like' databases """ """ Abstract class for Postgres 'like' databases """
@ -543,16 +570,6 @@ class PostgresEngineSpec(PostgresBaseEngineSpec):
tables.extend(inspector.get_foreign_table_names(schema)) tables.extend(inspector.get_foreign_table_names(schema))
return sorted(tables) return sorted(tables)
@staticmethod
def get_timestamp_column(expression, column_name):
"""Postgres is unable to identify mixed case column names unless they
are quoted."""
if expression:
return expression
elif column_name.lower() != column_name:
return f'"{column_name}"'
return column_name
class SnowflakeEngineSpec(PostgresBaseEngineSpec): class SnowflakeEngineSpec(PostgresBaseEngineSpec):
engine = 'snowflake' engine = 'snowflake'
@ -794,7 +811,7 @@ class MySQLEngineSpec(BaseEngineSpec):
'INTERVAL DAYOFWEEK(DATE_SUB({col}, INTERVAL 1 DAY)) - 1 DAY))', 'INTERVAL DAYOFWEEK(DATE_SUB({col}, INTERVAL 1 DAY)) - 1 DAY))',
} }
type_code_map: dict = {} # loaded from get_datatype only if needed type_code_map: Dict[int, str] = {} # loaded from get_datatype only if needed
@classmethod @classmethod
def convert_dttm(cls, target_type, dttm): def convert_dttm(cls, target_type, dttm):
@ -1812,20 +1829,21 @@ class PinotEngineSpec(BaseEngineSpec):
inner_joins = False inner_joins = False
supports_column_aliases = False supports_column_aliases = False
_time_grain_to_datetimeconvert = { # Pinot does its own conversion below
time_grain_functions: Dict[Optional[str], str] = {
'PT1S': '1:SECONDS', 'PT1S': '1:SECONDS',
'PT1M': '1:MINUTES', 'PT1M': '1:MINUTES',
'PT1H': '1:HOURS', 'PT1H': '1:HOURS',
'P1D': '1:DAYS', 'P1D': '1:DAYS',
'P1Y': '1:YEARS', 'P1W': '1:WEEKS',
'P1M': '1:MONTHS', 'P1M': '1:MONTHS',
'P0.25Y': '3:MONTHS',
'P1Y': '1:YEARS',
} }
# Pinot does its own conversion below
time_grain_functions = {k: None for k in _time_grain_to_datetimeconvert.keys()}
@classmethod @classmethod
def get_time_expr(cls, expr, pdf, time_grain, grain): def get_timestamp_expr(cls, col: ColumnClause, pdf: Optional[str],
time_grain: Optional[str]) -> TimestampExpression:
is_epoch = pdf in ('epoch_s', 'epoch_ms') is_epoch = pdf in ('epoch_s', 'epoch_ms')
if not is_epoch: if not is_epoch:
raise NotImplementedError('Pinot currently only supports epochs') raise NotImplementedError('Pinot currently only supports epochs')
@ -1834,11 +1852,12 @@ class PinotEngineSpec(BaseEngineSpec):
# We are not really converting any time units, just bucketing them. # We are not really converting any time units, just bucketing them.
seconds_or_ms = 'MILLISECONDS' if pdf == 'epoch_ms' else 'SECONDS' seconds_or_ms = 'MILLISECONDS' if pdf == 'epoch_ms' else 'SECONDS'
tf = f'1:{seconds_or_ms}:EPOCH' tf = f'1:{seconds_or_ms}:EPOCH'
granularity = cls._time_grain_to_datetimeconvert.get(time_grain) granularity = cls.time_grain_functions.get(time_grain)
if not granularity: if not granularity:
raise NotImplementedError('No pinot grain spec for ' + str(time_grain)) raise NotImplementedError('No pinot grain spec for ' + str(time_grain))
# In pinot the output is a string since there is no timestamp column like pg # In pinot the output is a string since there is no timestamp column like pg
return f'DATETIMECONVERT({expr}, "{tf}", "{tf}", "{granularity}")' time_expr = f'DATETIMECONVERT({{col}}, "{tf}", "{tf}", "{granularity}")'
return TimestampExpression(time_expr, col)
@classmethod @classmethod
def make_select_compatible(cls, groupby_exprs, select_exprs): def make_select_compatible(cls, groupby_exprs, select_exprs):

View File

@ -1029,21 +1029,13 @@ class Database(Model, AuditMixinNullable, ImportMixin):
"""Defines time granularity database-specific expressions. """Defines time granularity database-specific expressions.
The idea here is to make it easy for users to change the time grain The idea here is to make it easy for users to change the time grain
form a datetime (maybe the source grain is arbitrary timestamps, daily from a datetime (maybe the source grain is arbitrary timestamps, daily
or 5 minutes increments) to another, "truncated" datetime. Since or 5 minutes increments) to another, "truncated" datetime. Since
each database has slightly different but similar datetime functions, each database has slightly different but similar datetime functions,
this allows a mapping between database engines and actual functions. this allows a mapping between database engines and actual functions.
""" """
return self.db_engine_spec.get_time_grains() return self.db_engine_spec.get_time_grains()
def grains_dict(self):
"""Allowing to lookup grain by either label or duration
For backward compatibility"""
d = {grain.duration: grain for grain in self.grains()}
d.update({grain.label: grain for grain in self.grains()})
return d
def get_extra(self): def get_extra(self):
extra = {} extra = {}
if self.extra: if self.extra:

View File

@ -17,15 +17,16 @@
import inspect import inspect
from unittest import mock from unittest import mock
from sqlalchemy import column, select, table from sqlalchemy import column, literal_column, select, table
from sqlalchemy.dialects.mssql import pymssql from sqlalchemy.dialects import mssql, oracle, postgresql
from sqlalchemy.engine.result import RowProxy from sqlalchemy.engine.result import RowProxy
from sqlalchemy.types import String, UnicodeText from sqlalchemy.types import String, UnicodeText
from superset import db_engine_specs from superset import db_engine_specs
from superset.db_engine_specs import ( from superset.db_engine_specs import (
BaseEngineSpec, BQEngineSpec, HiveEngineSpec, MssqlEngineSpec, BaseEngineSpec, BQEngineSpec, HiveEngineSpec, MssqlEngineSpec,
MySQLEngineSpec, OracleEngineSpec, PrestoEngineSpec, MySQLEngineSpec, OracleEngineSpec, PinotEngineSpec, PostgresEngineSpec,
PrestoEngineSpec,
) )
from superset.models.core import Database from superset.models.core import Database
from .base_tests import SupersetTestCase from .base_tests import SupersetTestCase
@ -451,7 +452,7 @@ class DbEngineSpecsTestCase(SupersetTestCase):
assert_type('NTEXT', UnicodeText) assert_type('NTEXT', UnicodeText)
def test_mssql_where_clause_n_prefix(self): def test_mssql_where_clause_n_prefix(self):
dialect = pymssql.dialect() dialect = mssql.dialect()
spec = MssqlEngineSpec spec = MssqlEngineSpec
str_col = column('col', type_=spec.get_sqla_column_type('VARCHAR(10)')) str_col = column('col', type_=spec.get_sqla_column_type('VARCHAR(10)'))
unicode_col = column('unicode_col', type_=spec.get_sqla_column_type('NTEXT')) unicode_col = column('unicode_col', type_=spec.get_sqla_column_type('NTEXT'))
@ -462,7 +463,9 @@ class DbEngineSpecsTestCase(SupersetTestCase):
where(unicode_col == 'abc') where(unicode_col == 'abc')
query = str(sel.compile(dialect=dialect, compile_kwargs={'literal_binds': True})) query = str(sel.compile(dialect=dialect, compile_kwargs={'literal_binds': True}))
query_expected = "SELECT col, unicode_col \nFROM tbl \nWHERE col = 'abc' AND unicode_col = N'abc'" # noqa query_expected = 'SELECT col, unicode_col \n' \
'FROM tbl \n' \
"WHERE col = 'abc' AND unicode_col = N'abc'"
self.assertEqual(query, query_expected) self.assertEqual(query, query_expected)
def test_get_table_names(self): def test_get_table_names(self):
@ -483,3 +486,51 @@ class DbEngineSpecsTestCase(SupersetTestCase):
pg_result = db_engine_specs.PostgresEngineSpec.get_table_names( pg_result = db_engine_specs.PostgresEngineSpec.get_table_names(
schema='schema', inspector=inspector) schema='schema', inspector=inspector)
self.assertListEqual(pg_result_expected, pg_result) self.assertListEqual(pg_result_expected, pg_result)
def test_pg_time_expression_literal_no_grain(self):
col = literal_column('COALESCE(a, b)')
expr = PostgresEngineSpec.get_timestamp_expr(col, None, None)
result = str(expr.compile(dialect=postgresql.dialect()))
self.assertEqual(result, 'COALESCE(a, b)')
def test_pg_time_expression_literal_1y_grain(self):
col = literal_column('COALESCE(a, b)')
expr = PostgresEngineSpec.get_timestamp_expr(col, None, 'P1Y')
result = str(expr.compile(dialect=postgresql.dialect()))
self.assertEqual(result, "DATE_TRUNC('year', COALESCE(a, b))")
def test_pg_time_expression_lower_column_no_grain(self):
col = column('lower_case')
expr = PostgresEngineSpec.get_timestamp_expr(col, None, None)
result = str(expr.compile(dialect=postgresql.dialect()))
self.assertEqual(result, 'lower_case')
def test_pg_time_expression_lower_case_column_sec_1y_grain(self):
col = column('lower_case')
expr = PostgresEngineSpec.get_timestamp_expr(col, 'epoch_s', 'P1Y')
result = str(expr.compile(dialect=postgresql.dialect()))
self.assertEqual(result, "DATE_TRUNC('year', (timestamp 'epoch' + lower_case * interval '1 second'))") # noqa
def test_pg_time_expression_mixed_case_column_1y_grain(self):
col = column('MixedCase')
expr = PostgresEngineSpec.get_timestamp_expr(col, None, 'P1Y')
result = str(expr.compile(dialect=postgresql.dialect()))
self.assertEqual(result, "DATE_TRUNC('year', \"MixedCase\")")
def test_mssql_time_expression_mixed_case_column_1y_grain(self):
col = column('MixedCase')
expr = MssqlEngineSpec.get_timestamp_expr(col, None, 'P1Y')
result = str(expr.compile(dialect=mssql.dialect()))
self.assertEqual(result, 'DATEADD(year, DATEDIFF(year, 0, [MixedCase]), 0)')
def test_oracle_time_expression_reserved_keyword_1m_grain(self):
col = column('decimal')
expr = OracleEngineSpec.get_timestamp_expr(col, None, 'P1M')
result = str(expr.compile(dialect=oracle.dialect()))
self.assertEqual(result, "TRUNC(CAST(\"decimal\" as DATE), 'MONTH')")
def test_pinot_time_expression_sec_1m_grain(self):
col = column('tstamp')
expr = PinotEngineSpec.get_timestamp_expr(col, 'epoch_s', 'P1M')
result = str(expr.compile())
self.assertEqual(result, 'DATETIMECONVERT(tstamp, "1:SECONDS:EPOCH", "1:SECONDS:EPOCH", "1:MONTHS")') # noqa

View File

@ -109,47 +109,6 @@ class DatabaseModelTestCase(SupersetTestCase):
LIMIT 100""") LIMIT 100""")
assert sql.startswith(expected) assert sql.startswith(expected)
def test_grains_dict(self):
uri = 'mysql://root@localhost'
database = Database(sqlalchemy_uri=uri)
d = database.grains_dict()
self.assertEquals(d.get('day').function, 'DATE({col})')
self.assertEquals(d.get('P1D').function, 'DATE({col})')
self.assertEquals(d.get('Time Column').function, '{col}')
def test_postgres_expression_time_grain(self):
uri = 'postgresql+psycopg2://uid:pwd@localhost:5432/superset'
database = Database(sqlalchemy_uri=uri)
pdf, time_grain = '', 'P1D'
expression, column_name = 'COALESCE(lowercase_col, "MixedCaseCol")', ''
grain = database.grains_dict().get(time_grain)
col = database.db_engine_spec.get_timestamp_column(expression, column_name)
grain_expr = database.db_engine_spec.get_time_expr(col, pdf, time_grain, grain)
grain_expr_expected = grain.function.replace('{col}', expression)
self.assertEqual(grain_expr, grain_expr_expected)
def test_postgres_lowercase_col_time_grain(self):
uri = 'postgresql+psycopg2://uid:pwd@localhost:5432/superset'
database = Database(sqlalchemy_uri=uri)
pdf, time_grain = '', 'P1D'
expression, column_name = '', 'lowercase_col'
grain = database.grains_dict().get(time_grain)
col = database.db_engine_spec.get_timestamp_column(expression, column_name)
grain_expr = database.db_engine_spec.get_time_expr(col, pdf, time_grain, grain)
grain_expr_expected = grain.function.replace('{col}', column_name)
self.assertEqual(grain_expr, grain_expr_expected)
def test_postgres_mixedcase_col_time_grain(self):
uri = 'postgresql+psycopg2://uid:pwd@localhost:5432/superset'
database = Database(sqlalchemy_uri=uri)
pdf, time_grain = '', 'P1D'
expression, column_name = '', 'MixedCaseCol'
grain = database.grains_dict().get(time_grain)
col = database.db_engine_spec.get_timestamp_column(expression, column_name)
grain_expr = database.db_engine_spec.get_time_expr(col, pdf, time_grain, grain)
grain_expr_expected = grain.function.replace('{col}', f'"{column_name}"')
self.assertEqual(grain_expr, grain_expr_expected)
def test_single_statement(self): def test_single_statement(self):
main_db = get_main_database(db.session) main_db = get_main_database(db.session)
@ -217,24 +176,6 @@ class SqlaTableModelTestCase(SupersetTestCase):
self.assertEquals(compiled, 'DATE(from_unixtime(DATE_ADD(ds, 1)))') self.assertEquals(compiled, 'DATE(from_unixtime(DATE_ADD(ds, 1)))')
ds_col.expression = prev_ds_expr ds_col.expression = prev_ds_expr
def test_get_timestamp_expression_backward(self):
tbl = self.get_table_by_name('birth_names')
ds_col = tbl.get_column('ds')
ds_col.expression = None
ds_col.python_date_format = None
sqla_literal = ds_col.get_timestamp_expression('day')
compiled = '{}'.format(sqla_literal.compile())
if tbl.database.backend == 'mysql':
self.assertEquals(compiled, 'DATE(ds)')
ds_col.expression = None
ds_col.python_date_format = None
sqla_literal = ds_col.get_timestamp_expression('Time Column')
compiled = '{}'.format(sqla_literal.compile())
if tbl.database.backend == 'mysql':
self.assertEquals(compiled, 'ds')
def query_with_expr_helper(self, is_timeseries, inner_join=True): def query_with_expr_helper(self, is_timeseries, inner_join=True):
tbl = self.get_table_by_name('birth_names') tbl = self.get_table_by_name('birth_names')
ds_col = tbl.get_column('ds') ds_col = tbl.get_column('ds')