Make timestamp expression native SQLAlchemy element (#7131)

* Add native sqla component for time expressions

* Add unit tests and remove old tests

* Remove redundant _grains_dict method

* Clarify time_grain logic

* Add docstrings and typing

* Fix flake8 errors

* Add missing typings

* Rename to TimestampExpression

* Remove redundant tests

* Fix broken reference to db.database_name due to refactor
This commit is contained in:
Ville Brofeldt 2019-05-30 08:28:37 +03:00 committed by GitHub
parent fc3b043462
commit 34407e8962
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 128 additions and 121 deletions

View File

@ -18,6 +18,7 @@
from collections import namedtuple, OrderedDict
from datetime import datetime
import logging
from typing import Optional, Union
from flask import escape, Markup
from flask_appbuilder import Model
@ -32,11 +33,12 @@ from sqlalchemy.exc import CompileError
from sqlalchemy.orm import backref, relationship
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import column, literal_column, table, text
from sqlalchemy.sql.expression import TextAsFrom
from sqlalchemy.sql.expression import Label, TextAsFrom
import sqlparse
from superset import app, db, security_manager
from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric
from superset.db_engine_specs import TimestampExpression
from superset.jinja_context import get_template_processor
from superset.models.annotations import Annotation
from superset.models.core import Database
@ -140,8 +142,14 @@ class TableColumn(Model, BaseColumn):
l.append(col <= text(self.dttm_sql_literal(end_dttm, is_epoch_in_utc)))
return and_(*l)
def get_timestamp_expression(self, time_grain):
"""Getting the time component of the query"""
def get_timestamp_expression(self, time_grain: Optional[str]) \
-> Union[TimestampExpression, Label]:
"""
Return a SQLAlchemy Core element representation of self to be used in a query.
:param time_grain: Optional time grain, e.g. P1Y
:return: A TimeExpression object wrapped in a Label if supported by db
"""
label = utils.DTTM_ALIAS
db = self.table.database
@ -150,16 +158,12 @@ class TableColumn(Model, BaseColumn):
if not self.expression and not time_grain and not is_epoch:
sqla_col = column(self.column_name, type_=DateTime)
return self.table.make_sqla_column_compatible(sqla_col, label)
grain = None
if time_grain:
grain = db.grains_dict().get(time_grain)
if not grain:
raise NotImplementedError(
f'No grain spec for {time_grain} for database {db.database_name}')
col = db.db_engine_spec.get_timestamp_column(self.expression, self.column_name)
expr = db.db_engine_spec.get_time_expr(col, pdf, time_grain, grain)
sqla_col = literal_column(expr, type_=DateTime)
return self.table.make_sqla_column_compatible(sqla_col, label)
if self.expression:
col = literal_column(self.expression)
else:
col = column(self.column_name)
time_expr = db.db_engine_spec.get_timestamp_expr(col, pdf, time_grain)
return self.table.make_sqla_column_compatible(time_expr, label)
@classmethod
def import_obj(cls, i_column):

View File

@ -36,19 +36,20 @@ import os
import re
import textwrap
import time
from typing import List, Tuple
from typing import Dict, List, Optional, Tuple
from urllib import parse
from flask import g
from flask_babel import lazy_gettext as _
import pandas
import sqlalchemy as sqla
from sqlalchemy import Column, select, types
from sqlalchemy import Column, DateTime, select, types
from sqlalchemy.engine import create_engine
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.result import RowProxy
from sqlalchemy.engine.url import make_url
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import quoted_name, text
from sqlalchemy.sql.expression import ColumnClause
from sqlalchemy.sql.expression import TextAsFrom
@ -90,6 +91,24 @@ builtin_time_grains = {
}
class TimestampExpression(ColumnClause):
def __init__(self, expr: str, col: ColumnClause, **kwargs):
"""Sqlalchemy class that can be can be used to render native column elements
respeting engine-specific quoting rules as part of a string-based expression.
:param expr: Sql expression with '{col}' denoting the locations where the col
object will be rendered.
:param col: the target column
"""
super().__init__(expr, **kwargs)
self.col = col
@compiles(TimestampExpression)
def compile_timegrain_expression(element: TimestampExpression, compiler, **kw):
return element.name.replace('{col}', compiler.process(element.col, **kw))
def _create_time_grains_tuple(time_grains, time_grain_functions, blacklist):
ret_list = []
blacklist = blacklist if blacklist else []
@ -112,7 +131,7 @@ class BaseEngineSpec(object):
"""Abstract class for database engine specific configurations"""
engine = 'base' # str as defined in sqlalchemy.engine.engine
time_grain_functions: dict = {}
time_grain_functions: Dict[Optional[str], str] = {}
time_groupby_inline = False
limit_method = LimitMethod.FORCE_LIMIT
time_secondary_columns = False
@ -125,16 +144,31 @@ class BaseEngineSpec(object):
try_remove_schema_from_table_name = True
@classmethod
def get_time_expr(cls, expr, pdf, time_grain, grain):
def get_timestamp_expr(cls, col: ColumnClause, pdf: Optional[str],
time_grain: Optional[str]) -> TimestampExpression:
"""
Construct a TimeExpression to be used in a SQLAlchemy query.
:param col: Target column for the TimeExpression
:param pdf: date format (seconds or milliseconds)
:param time_grain: time grain, e.g. P1Y for 1 year
:return: TimestampExpression object
"""
if time_grain:
time_expr = cls.time_grain_functions.get(time_grain)
if not time_expr:
raise NotImplementedError(
f'No grain spec for {time_grain} for database {cls.engine}')
else:
time_expr = '{col}'
# if epoch, translate to DATE using db specific conf
if pdf == 'epoch_s':
expr = cls.epoch_to_dttm().format(col=expr)
time_expr = time_expr.replace('{col}', cls.epoch_to_dttm())
elif pdf == 'epoch_ms':
expr = cls.epoch_ms_to_dttm().format(col=expr)
time_expr = time_expr.replace('{col}', cls.epoch_ms_to_dttm())
if grain:
expr = grain.function.format(col=expr)
return expr
return TimestampExpression(time_expr, col, type_=DateTime)
@classmethod
def get_time_grains(cls):
@ -489,13 +523,6 @@ class BaseEngineSpec(object):
label = label[:cls.max_column_name_length]
return label
@staticmethod
def get_timestamp_column(expression, column_name):
"""Return the expression if defined, otherwise return column_name. Some
engines require forcing quotes around column name, in which case this method
can be overridden."""
return expression or column_name
class PostgresBaseEngineSpec(BaseEngineSpec):
""" Abstract class for Postgres 'like' databases """
@ -543,16 +570,6 @@ class PostgresEngineSpec(PostgresBaseEngineSpec):
tables.extend(inspector.get_foreign_table_names(schema))
return sorted(tables)
@staticmethod
def get_timestamp_column(expression, column_name):
"""Postgres is unable to identify mixed case column names unless they
are quoted."""
if expression:
return expression
elif column_name.lower() != column_name:
return f'"{column_name}"'
return column_name
class SnowflakeEngineSpec(PostgresBaseEngineSpec):
engine = 'snowflake'
@ -794,7 +811,7 @@ class MySQLEngineSpec(BaseEngineSpec):
'INTERVAL DAYOFWEEK(DATE_SUB({col}, INTERVAL 1 DAY)) - 1 DAY))',
}
type_code_map: dict = {} # loaded from get_datatype only if needed
type_code_map: Dict[int, str] = {} # loaded from get_datatype only if needed
@classmethod
def convert_dttm(cls, target_type, dttm):
@ -1812,20 +1829,21 @@ class PinotEngineSpec(BaseEngineSpec):
inner_joins = False
supports_column_aliases = False
_time_grain_to_datetimeconvert = {
# Pinot does its own conversion below
time_grain_functions: Dict[Optional[str], str] = {
'PT1S': '1:SECONDS',
'PT1M': '1:MINUTES',
'PT1H': '1:HOURS',
'P1D': '1:DAYS',
'P1Y': '1:YEARS',
'P1W': '1:WEEKS',
'P1M': '1:MONTHS',
'P0.25Y': '3:MONTHS',
'P1Y': '1:YEARS',
}
# Pinot does its own conversion below
time_grain_functions = {k: None for k in _time_grain_to_datetimeconvert.keys()}
@classmethod
def get_time_expr(cls, expr, pdf, time_grain, grain):
def get_timestamp_expr(cls, col: ColumnClause, pdf: Optional[str],
time_grain: Optional[str]) -> TimestampExpression:
is_epoch = pdf in ('epoch_s', 'epoch_ms')
if not is_epoch:
raise NotImplementedError('Pinot currently only supports epochs')
@ -1834,11 +1852,12 @@ class PinotEngineSpec(BaseEngineSpec):
# We are not really converting any time units, just bucketing them.
seconds_or_ms = 'MILLISECONDS' if pdf == 'epoch_ms' else 'SECONDS'
tf = f'1:{seconds_or_ms}:EPOCH'
granularity = cls._time_grain_to_datetimeconvert.get(time_grain)
granularity = cls.time_grain_functions.get(time_grain)
if not granularity:
raise NotImplementedError('No pinot grain spec for ' + str(time_grain))
# In pinot the output is a string since there is no timestamp column like pg
return f'DATETIMECONVERT({expr}, "{tf}", "{tf}", "{granularity}")'
time_expr = f'DATETIMECONVERT({{col}}, "{tf}", "{tf}", "{granularity}")'
return TimestampExpression(time_expr, col)
@classmethod
def make_select_compatible(cls, groupby_exprs, select_exprs):

View File

@ -1029,21 +1029,13 @@ class Database(Model, AuditMixinNullable, ImportMixin):
"""Defines time granularity database-specific expressions.
The idea here is to make it easy for users to change the time grain
form a datetime (maybe the source grain is arbitrary timestamps, daily
from a datetime (maybe the source grain is arbitrary timestamps, daily
or 5 minutes increments) to another, "truncated" datetime. Since
each database has slightly different but similar datetime functions,
this allows a mapping between database engines and actual functions.
"""
return self.db_engine_spec.get_time_grains()
def grains_dict(self):
"""Allowing to lookup grain by either label or duration
For backward compatibility"""
d = {grain.duration: grain for grain in self.grains()}
d.update({grain.label: grain for grain in self.grains()})
return d
def get_extra(self):
extra = {}
if self.extra:

View File

@ -17,15 +17,16 @@
import inspect
from unittest import mock
from sqlalchemy import column, select, table
from sqlalchemy.dialects.mssql import pymssql
from sqlalchemy import column, literal_column, select, table
from sqlalchemy.dialects import mssql, oracle, postgresql
from sqlalchemy.engine.result import RowProxy
from sqlalchemy.types import String, UnicodeText
from superset import db_engine_specs
from superset.db_engine_specs import (
BaseEngineSpec, BQEngineSpec, HiveEngineSpec, MssqlEngineSpec,
MySQLEngineSpec, OracleEngineSpec, PrestoEngineSpec,
MySQLEngineSpec, OracleEngineSpec, PinotEngineSpec, PostgresEngineSpec,
PrestoEngineSpec,
)
from superset.models.core import Database
from .base_tests import SupersetTestCase
@ -451,7 +452,7 @@ class DbEngineSpecsTestCase(SupersetTestCase):
assert_type('NTEXT', UnicodeText)
def test_mssql_where_clause_n_prefix(self):
dialect = pymssql.dialect()
dialect = mssql.dialect()
spec = MssqlEngineSpec
str_col = column('col', type_=spec.get_sqla_column_type('VARCHAR(10)'))
unicode_col = column('unicode_col', type_=spec.get_sqla_column_type('NTEXT'))
@ -462,7 +463,9 @@ class DbEngineSpecsTestCase(SupersetTestCase):
where(unicode_col == 'abc')
query = str(sel.compile(dialect=dialect, compile_kwargs={'literal_binds': True}))
query_expected = "SELECT col, unicode_col \nFROM tbl \nWHERE col = 'abc' AND unicode_col = N'abc'" # noqa
query_expected = 'SELECT col, unicode_col \n' \
'FROM tbl \n' \
"WHERE col = 'abc' AND unicode_col = N'abc'"
self.assertEqual(query, query_expected)
def test_get_table_names(self):
@ -483,3 +486,51 @@ class DbEngineSpecsTestCase(SupersetTestCase):
pg_result = db_engine_specs.PostgresEngineSpec.get_table_names(
schema='schema', inspector=inspector)
self.assertListEqual(pg_result_expected, pg_result)
def test_pg_time_expression_literal_no_grain(self):
col = literal_column('COALESCE(a, b)')
expr = PostgresEngineSpec.get_timestamp_expr(col, None, None)
result = str(expr.compile(dialect=postgresql.dialect()))
self.assertEqual(result, 'COALESCE(a, b)')
def test_pg_time_expression_literal_1y_grain(self):
col = literal_column('COALESCE(a, b)')
expr = PostgresEngineSpec.get_timestamp_expr(col, None, 'P1Y')
result = str(expr.compile(dialect=postgresql.dialect()))
self.assertEqual(result, "DATE_TRUNC('year', COALESCE(a, b))")
def test_pg_time_expression_lower_column_no_grain(self):
col = column('lower_case')
expr = PostgresEngineSpec.get_timestamp_expr(col, None, None)
result = str(expr.compile(dialect=postgresql.dialect()))
self.assertEqual(result, 'lower_case')
def test_pg_time_expression_lower_case_column_sec_1y_grain(self):
col = column('lower_case')
expr = PostgresEngineSpec.get_timestamp_expr(col, 'epoch_s', 'P1Y')
result = str(expr.compile(dialect=postgresql.dialect()))
self.assertEqual(result, "DATE_TRUNC('year', (timestamp 'epoch' + lower_case * interval '1 second'))") # noqa
def test_pg_time_expression_mixed_case_column_1y_grain(self):
col = column('MixedCase')
expr = PostgresEngineSpec.get_timestamp_expr(col, None, 'P1Y')
result = str(expr.compile(dialect=postgresql.dialect()))
self.assertEqual(result, "DATE_TRUNC('year', \"MixedCase\")")
def test_mssql_time_expression_mixed_case_column_1y_grain(self):
col = column('MixedCase')
expr = MssqlEngineSpec.get_timestamp_expr(col, None, 'P1Y')
result = str(expr.compile(dialect=mssql.dialect()))
self.assertEqual(result, 'DATEADD(year, DATEDIFF(year, 0, [MixedCase]), 0)')
def test_oracle_time_expression_reserved_keyword_1m_grain(self):
col = column('decimal')
expr = OracleEngineSpec.get_timestamp_expr(col, None, 'P1M')
result = str(expr.compile(dialect=oracle.dialect()))
self.assertEqual(result, "TRUNC(CAST(\"decimal\" as DATE), 'MONTH')")
def test_pinot_time_expression_sec_1m_grain(self):
col = column('tstamp')
expr = PinotEngineSpec.get_timestamp_expr(col, 'epoch_s', 'P1M')
result = str(expr.compile())
self.assertEqual(result, 'DATETIMECONVERT(tstamp, "1:SECONDS:EPOCH", "1:SECONDS:EPOCH", "1:MONTHS")') # noqa

View File

@ -109,47 +109,6 @@ class DatabaseModelTestCase(SupersetTestCase):
LIMIT 100""")
assert sql.startswith(expected)
def test_grains_dict(self):
uri = 'mysql://root@localhost'
database = Database(sqlalchemy_uri=uri)
d = database.grains_dict()
self.assertEquals(d.get('day').function, 'DATE({col})')
self.assertEquals(d.get('P1D').function, 'DATE({col})')
self.assertEquals(d.get('Time Column').function, '{col}')
def test_postgres_expression_time_grain(self):
uri = 'postgresql+psycopg2://uid:pwd@localhost:5432/superset'
database = Database(sqlalchemy_uri=uri)
pdf, time_grain = '', 'P1D'
expression, column_name = 'COALESCE(lowercase_col, "MixedCaseCol")', ''
grain = database.grains_dict().get(time_grain)
col = database.db_engine_spec.get_timestamp_column(expression, column_name)
grain_expr = database.db_engine_spec.get_time_expr(col, pdf, time_grain, grain)
grain_expr_expected = grain.function.replace('{col}', expression)
self.assertEqual(grain_expr, grain_expr_expected)
def test_postgres_lowercase_col_time_grain(self):
uri = 'postgresql+psycopg2://uid:pwd@localhost:5432/superset'
database = Database(sqlalchemy_uri=uri)
pdf, time_grain = '', 'P1D'
expression, column_name = '', 'lowercase_col'
grain = database.grains_dict().get(time_grain)
col = database.db_engine_spec.get_timestamp_column(expression, column_name)
grain_expr = database.db_engine_spec.get_time_expr(col, pdf, time_grain, grain)
grain_expr_expected = grain.function.replace('{col}', column_name)
self.assertEqual(grain_expr, grain_expr_expected)
def test_postgres_mixedcase_col_time_grain(self):
uri = 'postgresql+psycopg2://uid:pwd@localhost:5432/superset'
database = Database(sqlalchemy_uri=uri)
pdf, time_grain = '', 'P1D'
expression, column_name = '', 'MixedCaseCol'
grain = database.grains_dict().get(time_grain)
col = database.db_engine_spec.get_timestamp_column(expression, column_name)
grain_expr = database.db_engine_spec.get_time_expr(col, pdf, time_grain, grain)
grain_expr_expected = grain.function.replace('{col}', f'"{column_name}"')
self.assertEqual(grain_expr, grain_expr_expected)
def test_single_statement(self):
main_db = get_main_database(db.session)
@ -217,24 +176,6 @@ class SqlaTableModelTestCase(SupersetTestCase):
self.assertEquals(compiled, 'DATE(from_unixtime(DATE_ADD(ds, 1)))')
ds_col.expression = prev_ds_expr
def test_get_timestamp_expression_backward(self):
tbl = self.get_table_by_name('birth_names')
ds_col = tbl.get_column('ds')
ds_col.expression = None
ds_col.python_date_format = None
sqla_literal = ds_col.get_timestamp_expression('day')
compiled = '{}'.format(sqla_literal.compile())
if tbl.database.backend == 'mysql':
self.assertEquals(compiled, 'DATE(ds)')
ds_col.expression = None
ds_col.python_date_format = None
sqla_literal = ds_col.get_timestamp_expression('Time Column')
compiled = '{}'.format(sqla_literal.compile())
if tbl.database.backend == 'mysql':
self.assertEquals(compiled, 'ds')
def query_with_expr_helper(self, is_timeseries, inner_join=True):
tbl = self.get_table_by_name('birth_names')
ds_col = tbl.get_column('ds')