mirror of https://github.com/apache/superset.git
Make timestamp expression native SQLAlchemy element (#7131)
* Add native sqla component for time expressions * Add unit tests and remove old tests * Remove redundant _grains_dict method * Clarify time_grain logic * Add docstrings and typing * Fix flake8 errors * Add missing typings * Rename to TimestampExpression * Remove redundant tests * Fix broken reference to db.database_name due to refactor
This commit is contained in:
parent
fc3b043462
commit
34407e8962
|
@ -18,6 +18,7 @@
|
|||
from collections import namedtuple, OrderedDict
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from typing import Optional, Union
|
||||
|
||||
from flask import escape, Markup
|
||||
from flask_appbuilder import Model
|
||||
|
@ -32,11 +33,12 @@ from sqlalchemy.exc import CompileError
|
|||
from sqlalchemy.orm import backref, relationship
|
||||
from sqlalchemy.schema import UniqueConstraint
|
||||
from sqlalchemy.sql import column, literal_column, table, text
|
||||
from sqlalchemy.sql.expression import TextAsFrom
|
||||
from sqlalchemy.sql.expression import Label, TextAsFrom
|
||||
import sqlparse
|
||||
|
||||
from superset import app, db, security_manager
|
||||
from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric
|
||||
from superset.db_engine_specs import TimestampExpression
|
||||
from superset.jinja_context import get_template_processor
|
||||
from superset.models.annotations import Annotation
|
||||
from superset.models.core import Database
|
||||
|
@ -140,8 +142,14 @@ class TableColumn(Model, BaseColumn):
|
|||
l.append(col <= text(self.dttm_sql_literal(end_dttm, is_epoch_in_utc)))
|
||||
return and_(*l)
|
||||
|
||||
def get_timestamp_expression(self, time_grain):
|
||||
"""Getting the time component of the query"""
|
||||
def get_timestamp_expression(self, time_grain: Optional[str]) \
|
||||
-> Union[TimestampExpression, Label]:
|
||||
"""
|
||||
Return a SQLAlchemy Core element representation of self to be used in a query.
|
||||
|
||||
:param time_grain: Optional time grain, e.g. P1Y
|
||||
:return: A TimeExpression object wrapped in a Label if supported by db
|
||||
"""
|
||||
label = utils.DTTM_ALIAS
|
||||
|
||||
db = self.table.database
|
||||
|
@ -150,16 +158,12 @@ class TableColumn(Model, BaseColumn):
|
|||
if not self.expression and not time_grain and not is_epoch:
|
||||
sqla_col = column(self.column_name, type_=DateTime)
|
||||
return self.table.make_sqla_column_compatible(sqla_col, label)
|
||||
grain = None
|
||||
if time_grain:
|
||||
grain = db.grains_dict().get(time_grain)
|
||||
if not grain:
|
||||
raise NotImplementedError(
|
||||
f'No grain spec for {time_grain} for database {db.database_name}')
|
||||
col = db.db_engine_spec.get_timestamp_column(self.expression, self.column_name)
|
||||
expr = db.db_engine_spec.get_time_expr(col, pdf, time_grain, grain)
|
||||
sqla_col = literal_column(expr, type_=DateTime)
|
||||
return self.table.make_sqla_column_compatible(sqla_col, label)
|
||||
if self.expression:
|
||||
col = literal_column(self.expression)
|
||||
else:
|
||||
col = column(self.column_name)
|
||||
time_expr = db.db_engine_spec.get_timestamp_expr(col, pdf, time_grain)
|
||||
return self.table.make_sqla_column_compatible(time_expr, label)
|
||||
|
||||
@classmethod
|
||||
def import_obj(cls, i_column):
|
||||
|
|
|
@ -36,19 +36,20 @@ import os
|
|||
import re
|
||||
import textwrap
|
||||
import time
|
||||
from typing import List, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from urllib import parse
|
||||
|
||||
from flask import g
|
||||
from flask_babel import lazy_gettext as _
|
||||
import pandas
|
||||
import sqlalchemy as sqla
|
||||
from sqlalchemy import Column, select, types
|
||||
from sqlalchemy import Column, DateTime, select, types
|
||||
from sqlalchemy.engine import create_engine
|
||||
from sqlalchemy.engine.base import Engine
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.engine.result import RowProxy
|
||||
from sqlalchemy.engine.url import make_url
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.sql import quoted_name, text
|
||||
from sqlalchemy.sql.expression import ColumnClause
|
||||
from sqlalchemy.sql.expression import TextAsFrom
|
||||
|
@ -90,6 +91,24 @@ builtin_time_grains = {
|
|||
}
|
||||
|
||||
|
||||
class TimestampExpression(ColumnClause):
|
||||
def __init__(self, expr: str, col: ColumnClause, **kwargs):
|
||||
"""Sqlalchemy class that can be can be used to render native column elements
|
||||
respeting engine-specific quoting rules as part of a string-based expression.
|
||||
|
||||
:param expr: Sql expression with '{col}' denoting the locations where the col
|
||||
object will be rendered.
|
||||
:param col: the target column
|
||||
"""
|
||||
super().__init__(expr, **kwargs)
|
||||
self.col = col
|
||||
|
||||
|
||||
@compiles(TimestampExpression)
|
||||
def compile_timegrain_expression(element: TimestampExpression, compiler, **kw):
|
||||
return element.name.replace('{col}', compiler.process(element.col, **kw))
|
||||
|
||||
|
||||
def _create_time_grains_tuple(time_grains, time_grain_functions, blacklist):
|
||||
ret_list = []
|
||||
blacklist = blacklist if blacklist else []
|
||||
|
@ -112,7 +131,7 @@ class BaseEngineSpec(object):
|
|||
"""Abstract class for database engine specific configurations"""
|
||||
|
||||
engine = 'base' # str as defined in sqlalchemy.engine.engine
|
||||
time_grain_functions: dict = {}
|
||||
time_grain_functions: Dict[Optional[str], str] = {}
|
||||
time_groupby_inline = False
|
||||
limit_method = LimitMethod.FORCE_LIMIT
|
||||
time_secondary_columns = False
|
||||
|
@ -125,16 +144,31 @@ class BaseEngineSpec(object):
|
|||
try_remove_schema_from_table_name = True
|
||||
|
||||
@classmethod
|
||||
def get_time_expr(cls, expr, pdf, time_grain, grain):
|
||||
def get_timestamp_expr(cls, col: ColumnClause, pdf: Optional[str],
|
||||
time_grain: Optional[str]) -> TimestampExpression:
|
||||
"""
|
||||
Construct a TimeExpression to be used in a SQLAlchemy query.
|
||||
|
||||
:param col: Target column for the TimeExpression
|
||||
:param pdf: date format (seconds or milliseconds)
|
||||
:param time_grain: time grain, e.g. P1Y for 1 year
|
||||
:return: TimestampExpression object
|
||||
"""
|
||||
if time_grain:
|
||||
time_expr = cls.time_grain_functions.get(time_grain)
|
||||
if not time_expr:
|
||||
raise NotImplementedError(
|
||||
f'No grain spec for {time_grain} for database {cls.engine}')
|
||||
else:
|
||||
time_expr = '{col}'
|
||||
|
||||
# if epoch, translate to DATE using db specific conf
|
||||
if pdf == 'epoch_s':
|
||||
expr = cls.epoch_to_dttm().format(col=expr)
|
||||
time_expr = time_expr.replace('{col}', cls.epoch_to_dttm())
|
||||
elif pdf == 'epoch_ms':
|
||||
expr = cls.epoch_ms_to_dttm().format(col=expr)
|
||||
time_expr = time_expr.replace('{col}', cls.epoch_ms_to_dttm())
|
||||
|
||||
if grain:
|
||||
expr = grain.function.format(col=expr)
|
||||
return expr
|
||||
return TimestampExpression(time_expr, col, type_=DateTime)
|
||||
|
||||
@classmethod
|
||||
def get_time_grains(cls):
|
||||
|
@ -489,13 +523,6 @@ class BaseEngineSpec(object):
|
|||
label = label[:cls.max_column_name_length]
|
||||
return label
|
||||
|
||||
@staticmethod
|
||||
def get_timestamp_column(expression, column_name):
|
||||
"""Return the expression if defined, otherwise return column_name. Some
|
||||
engines require forcing quotes around column name, in which case this method
|
||||
can be overridden."""
|
||||
return expression or column_name
|
||||
|
||||
|
||||
class PostgresBaseEngineSpec(BaseEngineSpec):
|
||||
""" Abstract class for Postgres 'like' databases """
|
||||
|
@ -543,16 +570,6 @@ class PostgresEngineSpec(PostgresBaseEngineSpec):
|
|||
tables.extend(inspector.get_foreign_table_names(schema))
|
||||
return sorted(tables)
|
||||
|
||||
@staticmethod
|
||||
def get_timestamp_column(expression, column_name):
|
||||
"""Postgres is unable to identify mixed case column names unless they
|
||||
are quoted."""
|
||||
if expression:
|
||||
return expression
|
||||
elif column_name.lower() != column_name:
|
||||
return f'"{column_name}"'
|
||||
return column_name
|
||||
|
||||
|
||||
class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
||||
engine = 'snowflake'
|
||||
|
@ -794,7 +811,7 @@ class MySQLEngineSpec(BaseEngineSpec):
|
|||
'INTERVAL DAYOFWEEK(DATE_SUB({col}, INTERVAL 1 DAY)) - 1 DAY))',
|
||||
}
|
||||
|
||||
type_code_map: dict = {} # loaded from get_datatype only if needed
|
||||
type_code_map: Dict[int, str] = {} # loaded from get_datatype only if needed
|
||||
|
||||
@classmethod
|
||||
def convert_dttm(cls, target_type, dttm):
|
||||
|
@ -1812,20 +1829,21 @@ class PinotEngineSpec(BaseEngineSpec):
|
|||
inner_joins = False
|
||||
supports_column_aliases = False
|
||||
|
||||
_time_grain_to_datetimeconvert = {
|
||||
# Pinot does its own conversion below
|
||||
time_grain_functions: Dict[Optional[str], str] = {
|
||||
'PT1S': '1:SECONDS',
|
||||
'PT1M': '1:MINUTES',
|
||||
'PT1H': '1:HOURS',
|
||||
'P1D': '1:DAYS',
|
||||
'P1Y': '1:YEARS',
|
||||
'P1W': '1:WEEKS',
|
||||
'P1M': '1:MONTHS',
|
||||
'P0.25Y': '3:MONTHS',
|
||||
'P1Y': '1:YEARS',
|
||||
}
|
||||
|
||||
# Pinot does its own conversion below
|
||||
time_grain_functions = {k: None for k in _time_grain_to_datetimeconvert.keys()}
|
||||
|
||||
@classmethod
|
||||
def get_time_expr(cls, expr, pdf, time_grain, grain):
|
||||
def get_timestamp_expr(cls, col: ColumnClause, pdf: Optional[str],
|
||||
time_grain: Optional[str]) -> TimestampExpression:
|
||||
is_epoch = pdf in ('epoch_s', 'epoch_ms')
|
||||
if not is_epoch:
|
||||
raise NotImplementedError('Pinot currently only supports epochs')
|
||||
|
@ -1834,11 +1852,12 @@ class PinotEngineSpec(BaseEngineSpec):
|
|||
# We are not really converting any time units, just bucketing them.
|
||||
seconds_or_ms = 'MILLISECONDS' if pdf == 'epoch_ms' else 'SECONDS'
|
||||
tf = f'1:{seconds_or_ms}:EPOCH'
|
||||
granularity = cls._time_grain_to_datetimeconvert.get(time_grain)
|
||||
granularity = cls.time_grain_functions.get(time_grain)
|
||||
if not granularity:
|
||||
raise NotImplementedError('No pinot grain spec for ' + str(time_grain))
|
||||
# In pinot the output is a string since there is no timestamp column like pg
|
||||
return f'DATETIMECONVERT({expr}, "{tf}", "{tf}", "{granularity}")'
|
||||
time_expr = f'DATETIMECONVERT({{col}}, "{tf}", "{tf}", "{granularity}")'
|
||||
return TimestampExpression(time_expr, col)
|
||||
|
||||
@classmethod
|
||||
def make_select_compatible(cls, groupby_exprs, select_exprs):
|
||||
|
|
|
@ -1029,21 +1029,13 @@ class Database(Model, AuditMixinNullable, ImportMixin):
|
|||
"""Defines time granularity database-specific expressions.
|
||||
|
||||
The idea here is to make it easy for users to change the time grain
|
||||
form a datetime (maybe the source grain is arbitrary timestamps, daily
|
||||
from a datetime (maybe the source grain is arbitrary timestamps, daily
|
||||
or 5 minutes increments) to another, "truncated" datetime. Since
|
||||
each database has slightly different but similar datetime functions,
|
||||
this allows a mapping between database engines and actual functions.
|
||||
"""
|
||||
return self.db_engine_spec.get_time_grains()
|
||||
|
||||
def grains_dict(self):
|
||||
"""Allowing to lookup grain by either label or duration
|
||||
|
||||
For backward compatibility"""
|
||||
d = {grain.duration: grain for grain in self.grains()}
|
||||
d.update({grain.label: grain for grain in self.grains()})
|
||||
return d
|
||||
|
||||
def get_extra(self):
|
||||
extra = {}
|
||||
if self.extra:
|
||||
|
|
|
@ -17,15 +17,16 @@
|
|||
import inspect
|
||||
from unittest import mock
|
||||
|
||||
from sqlalchemy import column, select, table
|
||||
from sqlalchemy.dialects.mssql import pymssql
|
||||
from sqlalchemy import column, literal_column, select, table
|
||||
from sqlalchemy.dialects import mssql, oracle, postgresql
|
||||
from sqlalchemy.engine.result import RowProxy
|
||||
from sqlalchemy.types import String, UnicodeText
|
||||
|
||||
from superset import db_engine_specs
|
||||
from superset.db_engine_specs import (
|
||||
BaseEngineSpec, BQEngineSpec, HiveEngineSpec, MssqlEngineSpec,
|
||||
MySQLEngineSpec, OracleEngineSpec, PrestoEngineSpec,
|
||||
MySQLEngineSpec, OracleEngineSpec, PinotEngineSpec, PostgresEngineSpec,
|
||||
PrestoEngineSpec,
|
||||
)
|
||||
from superset.models.core import Database
|
||||
from .base_tests import SupersetTestCase
|
||||
|
@ -451,7 +452,7 @@ class DbEngineSpecsTestCase(SupersetTestCase):
|
|||
assert_type('NTEXT', UnicodeText)
|
||||
|
||||
def test_mssql_where_clause_n_prefix(self):
|
||||
dialect = pymssql.dialect()
|
||||
dialect = mssql.dialect()
|
||||
spec = MssqlEngineSpec
|
||||
str_col = column('col', type_=spec.get_sqla_column_type('VARCHAR(10)'))
|
||||
unicode_col = column('unicode_col', type_=spec.get_sqla_column_type('NTEXT'))
|
||||
|
@ -462,7 +463,9 @@ class DbEngineSpecsTestCase(SupersetTestCase):
|
|||
where(unicode_col == 'abc')
|
||||
|
||||
query = str(sel.compile(dialect=dialect, compile_kwargs={'literal_binds': True}))
|
||||
query_expected = "SELECT col, unicode_col \nFROM tbl \nWHERE col = 'abc' AND unicode_col = N'abc'" # noqa
|
||||
query_expected = 'SELECT col, unicode_col \n' \
|
||||
'FROM tbl \n' \
|
||||
"WHERE col = 'abc' AND unicode_col = N'abc'"
|
||||
self.assertEqual(query, query_expected)
|
||||
|
||||
def test_get_table_names(self):
|
||||
|
@ -483,3 +486,51 @@ class DbEngineSpecsTestCase(SupersetTestCase):
|
|||
pg_result = db_engine_specs.PostgresEngineSpec.get_table_names(
|
||||
schema='schema', inspector=inspector)
|
||||
self.assertListEqual(pg_result_expected, pg_result)
|
||||
|
||||
def test_pg_time_expression_literal_no_grain(self):
|
||||
col = literal_column('COALESCE(a, b)')
|
||||
expr = PostgresEngineSpec.get_timestamp_expr(col, None, None)
|
||||
result = str(expr.compile(dialect=postgresql.dialect()))
|
||||
self.assertEqual(result, 'COALESCE(a, b)')
|
||||
|
||||
def test_pg_time_expression_literal_1y_grain(self):
|
||||
col = literal_column('COALESCE(a, b)')
|
||||
expr = PostgresEngineSpec.get_timestamp_expr(col, None, 'P1Y')
|
||||
result = str(expr.compile(dialect=postgresql.dialect()))
|
||||
self.assertEqual(result, "DATE_TRUNC('year', COALESCE(a, b))")
|
||||
|
||||
def test_pg_time_expression_lower_column_no_grain(self):
|
||||
col = column('lower_case')
|
||||
expr = PostgresEngineSpec.get_timestamp_expr(col, None, None)
|
||||
result = str(expr.compile(dialect=postgresql.dialect()))
|
||||
self.assertEqual(result, 'lower_case')
|
||||
|
||||
def test_pg_time_expression_lower_case_column_sec_1y_grain(self):
|
||||
col = column('lower_case')
|
||||
expr = PostgresEngineSpec.get_timestamp_expr(col, 'epoch_s', 'P1Y')
|
||||
result = str(expr.compile(dialect=postgresql.dialect()))
|
||||
self.assertEqual(result, "DATE_TRUNC('year', (timestamp 'epoch' + lower_case * interval '1 second'))") # noqa
|
||||
|
||||
def test_pg_time_expression_mixed_case_column_1y_grain(self):
|
||||
col = column('MixedCase')
|
||||
expr = PostgresEngineSpec.get_timestamp_expr(col, None, 'P1Y')
|
||||
result = str(expr.compile(dialect=postgresql.dialect()))
|
||||
self.assertEqual(result, "DATE_TRUNC('year', \"MixedCase\")")
|
||||
|
||||
def test_mssql_time_expression_mixed_case_column_1y_grain(self):
|
||||
col = column('MixedCase')
|
||||
expr = MssqlEngineSpec.get_timestamp_expr(col, None, 'P1Y')
|
||||
result = str(expr.compile(dialect=mssql.dialect()))
|
||||
self.assertEqual(result, 'DATEADD(year, DATEDIFF(year, 0, [MixedCase]), 0)')
|
||||
|
||||
def test_oracle_time_expression_reserved_keyword_1m_grain(self):
|
||||
col = column('decimal')
|
||||
expr = OracleEngineSpec.get_timestamp_expr(col, None, 'P1M')
|
||||
result = str(expr.compile(dialect=oracle.dialect()))
|
||||
self.assertEqual(result, "TRUNC(CAST(\"decimal\" as DATE), 'MONTH')")
|
||||
|
||||
def test_pinot_time_expression_sec_1m_grain(self):
|
||||
col = column('tstamp')
|
||||
expr = PinotEngineSpec.get_timestamp_expr(col, 'epoch_s', 'P1M')
|
||||
result = str(expr.compile())
|
||||
self.assertEqual(result, 'DATETIMECONVERT(tstamp, "1:SECONDS:EPOCH", "1:SECONDS:EPOCH", "1:MONTHS")') # noqa
|
||||
|
|
|
@ -109,47 +109,6 @@ class DatabaseModelTestCase(SupersetTestCase):
|
|||
LIMIT 100""")
|
||||
assert sql.startswith(expected)
|
||||
|
||||
def test_grains_dict(self):
|
||||
uri = 'mysql://root@localhost'
|
||||
database = Database(sqlalchemy_uri=uri)
|
||||
d = database.grains_dict()
|
||||
self.assertEquals(d.get('day').function, 'DATE({col})')
|
||||
self.assertEquals(d.get('P1D').function, 'DATE({col})')
|
||||
self.assertEquals(d.get('Time Column').function, '{col}')
|
||||
|
||||
def test_postgres_expression_time_grain(self):
|
||||
uri = 'postgresql+psycopg2://uid:pwd@localhost:5432/superset'
|
||||
database = Database(sqlalchemy_uri=uri)
|
||||
pdf, time_grain = '', 'P1D'
|
||||
expression, column_name = 'COALESCE(lowercase_col, "MixedCaseCol")', ''
|
||||
grain = database.grains_dict().get(time_grain)
|
||||
col = database.db_engine_spec.get_timestamp_column(expression, column_name)
|
||||
grain_expr = database.db_engine_spec.get_time_expr(col, pdf, time_grain, grain)
|
||||
grain_expr_expected = grain.function.replace('{col}', expression)
|
||||
self.assertEqual(grain_expr, grain_expr_expected)
|
||||
|
||||
def test_postgres_lowercase_col_time_grain(self):
|
||||
uri = 'postgresql+psycopg2://uid:pwd@localhost:5432/superset'
|
||||
database = Database(sqlalchemy_uri=uri)
|
||||
pdf, time_grain = '', 'P1D'
|
||||
expression, column_name = '', 'lowercase_col'
|
||||
grain = database.grains_dict().get(time_grain)
|
||||
col = database.db_engine_spec.get_timestamp_column(expression, column_name)
|
||||
grain_expr = database.db_engine_spec.get_time_expr(col, pdf, time_grain, grain)
|
||||
grain_expr_expected = grain.function.replace('{col}', column_name)
|
||||
self.assertEqual(grain_expr, grain_expr_expected)
|
||||
|
||||
def test_postgres_mixedcase_col_time_grain(self):
|
||||
uri = 'postgresql+psycopg2://uid:pwd@localhost:5432/superset'
|
||||
database = Database(sqlalchemy_uri=uri)
|
||||
pdf, time_grain = '', 'P1D'
|
||||
expression, column_name = '', 'MixedCaseCol'
|
||||
grain = database.grains_dict().get(time_grain)
|
||||
col = database.db_engine_spec.get_timestamp_column(expression, column_name)
|
||||
grain_expr = database.db_engine_spec.get_time_expr(col, pdf, time_grain, grain)
|
||||
grain_expr_expected = grain.function.replace('{col}', f'"{column_name}"')
|
||||
self.assertEqual(grain_expr, grain_expr_expected)
|
||||
|
||||
def test_single_statement(self):
|
||||
main_db = get_main_database(db.session)
|
||||
|
||||
|
@ -217,24 +176,6 @@ class SqlaTableModelTestCase(SupersetTestCase):
|
|||
self.assertEquals(compiled, 'DATE(from_unixtime(DATE_ADD(ds, 1)))')
|
||||
ds_col.expression = prev_ds_expr
|
||||
|
||||
def test_get_timestamp_expression_backward(self):
|
||||
tbl = self.get_table_by_name('birth_names')
|
||||
ds_col = tbl.get_column('ds')
|
||||
|
||||
ds_col.expression = None
|
||||
ds_col.python_date_format = None
|
||||
sqla_literal = ds_col.get_timestamp_expression('day')
|
||||
compiled = '{}'.format(sqla_literal.compile())
|
||||
if tbl.database.backend == 'mysql':
|
||||
self.assertEquals(compiled, 'DATE(ds)')
|
||||
|
||||
ds_col.expression = None
|
||||
ds_col.python_date_format = None
|
||||
sqla_literal = ds_col.get_timestamp_expression('Time Column')
|
||||
compiled = '{}'.format(sqla_literal.compile())
|
||||
if tbl.database.backend == 'mysql':
|
||||
self.assertEquals(compiled, 'ds')
|
||||
|
||||
def query_with_expr_helper(self, is_timeseries, inner_join=True):
|
||||
tbl = self.get_table_by_name('birth_names')
|
||||
ds_col = tbl.get_column('ds')
|
||||
|
|
Loading…
Reference in New Issue