From b08e21efd906d13994414b39bfa7f6e98466d4cb Mon Sep 17 00:00:00 2001 From: John Bodley <4567245+john-bodley@users.noreply.github.com> Date: Fri, 3 Jun 2022 15:27:35 -0700 Subject: [PATCH] [requirements] Resolve rebase conflicts (#20152) Co-authored-by: John Bodley --- requirements/development.in | 1 - requirements/development.txt | 8 +- setup.py | 5 +- superset/db_engine_specs/trino.py | 136 +----------------- superset/jinja_context.py | 19 ++- .../db_engine_specs/trino_tests.py | 19 --- tests/integration_tests/test_jinja_context.py | 17 +++ .../unit_tests/db_engine_specs/test_trino.py | 56 -------- 8 files changed, 44 insertions(+), 217 deletions(-) delete mode 100644 tests/unit_tests/db_engine_specs/test_trino.py diff --git a/requirements/development.in b/requirements/development.in index 477fff3376..2baae3faeb 100644 --- a/requirements/development.in +++ b/requirements/development.in @@ -18,7 +18,6 @@ # -r base.in -e .[cors,druid,hive,mysql,postgres,thumbnails] -flask-cors>=2.0.0 ipython progress>=1.5,<2 pyinstrument>=4.0.2,<5 diff --git a/requirements/development.txt b/requirements/development.txt index 1beebbc9a4..75af963cf8 100644 --- a/requirements/development.txt +++ b/requirements/development.txt @@ -1,4 +1,4 @@ -# SHA1:2bd0d7307aeb633b7d97b510eb467285210e783a +# SHA1:80db36131ba9a8df7c34810cee7788f03cfb68b8 # # This file is autogenerated by pip-compile-multi # To update, run: @@ -35,9 +35,7 @@ et-xmlfile==1.1.0 executing==0.8.3 # via stack-data flask-cors==3.0.10 - # via - # -r requirements/development.in - # apache-superset + # via apache-superset future==0.18.2 # via pyhive ijson==3.1.4 @@ -82,7 +80,7 @@ pydruid==0.6.2 # via apache-superset pygments==2.12.0 # via ipython -pyhive[hive]==0.6.4 +pyhive[hive]==0.6.5 # via apache-superset pyinstrument==4.0.2 # via -r requirements/development.in diff --git a/setup.py b/setup.py index 4d1fd3ab43..53b322b200 100644 --- a/setup.py +++ b/setup.py @@ -64,7 +64,6 @@ setup( zip_safe=False, entry_points={ "console_scripts": ["superset=superset.cli.main:superset"], - "sqlalchemy.dialects": ["trinonative = trino.sqlalchemy.dialect:TrinoDialect"], }, install_requires=[ "backoff>=1.8.0", @@ -142,7 +141,7 @@ setup( "firebolt": ["firebolt-sqlalchemy>=0.0.1"], "gsheets": ["shillelagh[gsheetsapi]>=1.0.14, <2"], "hana": ["hdbcli==2.4.162", "sqlalchemy_hana==0.4.0"], - "hive": ["pyhive[hive]>=0.6.1", "tableschema", "thrift>=0.11.0, <1.0.0"], + "hive": ["pyhive[hive]>=0.6.5", "tableschema", "thrift>=0.11.0, <1.0.0"], "impala": ["impyla>0.16.2, <0.17"], "kusto": ["sqlalchemy-kusto>=1.0.1, <2"], "kylin": ["kylinpy>=2.8.1, <2.9"], @@ -151,7 +150,7 @@ setup( "oracle": ["cx-Oracle>8.0.0, <8.1"], "pinot": ["pinotdb>=0.3.3, <0.4"], "postgres": ["psycopg2-binary==2.9.1"], - "presto": ["pyhive[presto]>=0.4.0"], + "presto": ["pyhive[presto]>=0.6.5"], "trino": ["trino>=0.313.0"], "prophet": ["prophet>=1.0.1, <1.1", "pystan<3.0"], "redshift": ["sqlalchemy-redshift>=0.8.1, < 0.9"], diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 9941c57d98..46e3ed55de 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -15,9 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from datetime import datetime -from typing import Any, Dict, List, Optional, TYPE_CHECKING -from urllib import parse +from typing import Any, Dict, Optional, TYPE_CHECKING import simplejson as json from flask import current_app @@ -25,6 +23,7 @@ from sqlalchemy.engine.url import URL from superset.databases.utils import make_url_safe from superset.db_engine_specs.base import BaseEngineSpec +from superset.db_engine_specs.presto import PrestoEngineSpec from superset.utils import core as utils if TYPE_CHECKING: @@ -33,66 +32,11 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class TrinoEngineSpec(BaseEngineSpec): +class TrinoEngineSpec(PrestoEngineSpec): engine = "trino" - engine_aliases = {"trinonative"} + engine_aliases = {"trinonative"} # Required for backwards compatibility. engine_name = "Trino" - _time_grain_expressions = { - None: "{col}", - "PT1S": "date_trunc('second', CAST({col} AS TIMESTAMP))", - "PT1M": "date_trunc('minute', CAST({col} AS TIMESTAMP))", - "PT1H": "date_trunc('hour', CAST({col} AS TIMESTAMP))", - "P1D": "date_trunc('day', CAST({col} AS TIMESTAMP))", - "P1W": "date_trunc('week', CAST({col} AS TIMESTAMP))", - "P1M": "date_trunc('month', CAST({col} AS TIMESTAMP))", - "P3M": "date_trunc('quarter', CAST({col} AS TIMESTAMP))", - "P1Y": "date_trunc('year', CAST({col} AS TIMESTAMP))", - # "1969-12-28T00:00:00Z/P1W", # Week starting Sunday - # "1969-12-29T00:00:00Z/P1W", # Week starting Monday - # "P1W/1970-01-03T00:00:00Z", # Week ending Saturday - # "P1W/1970-01-04T00:00:00Z", # Week ending Sunday - } - - @classmethod - def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None - ) -> Optional[str]: - """ - Convert a Python `datetime` object to a SQL expression. - - :param target_type: The target type of expression - :param dttm: The datetime object - :param db_extra: The database extra object - :return: The SQL expression - - Superset only defines time zone naive `datetime` objects, though this method - handles both time zone naive and aware conversions. - """ - tt = target_type.upper() - if tt == utils.TemporalType.DATE: - return f"DATE '{dttm.date().isoformat()}'" - if tt in ( - utils.TemporalType.TIMESTAMP, - utils.TemporalType.TIMESTAMP_WITH_TIME_ZONE, - ): - return f"""TIMESTAMP '{dttm.isoformat(timespec="microseconds", sep=" ")}'""" - return None - - @classmethod - def epoch_to_dttm(cls) -> str: - return "from_unixtime({col})" - - @classmethod - def adjust_database_uri( - cls, uri: URL, selected_schema: Optional[str] = None - ) -> None: - database = uri.database - if selected_schema and database: - selected_schema = parse.quote(selected_schema, safe="") - database = database.split("/")[0] + "/" + selected_schema - uri.database = database - @classmethod def update_impersonation_config( cls, @@ -133,78 +77,6 @@ class TrinoEngineSpec(BaseEngineSpec): def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool: return True - @classmethod - def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]: - """ - Run a SQL query that estimates the cost of a given statement. - - :param statement: A single SQL statement - :param cursor: Cursor instance - :return: JSON response from Trino - """ - sql = f"EXPLAIN (TYPE IO, FORMAT JSON) {statement}" - cursor.execute(sql) - - # the output from Trino is a single column and a single row containing - # JSON: - # - # { - # ... - # "estimate" : { - # "outputRowCount" : 8.73265878E8, - # "outputSizeInBytes" : 3.41425774958E11, - # "cpuCost" : 3.41425774958E11, - # "maxMemory" : 0.0, - # "networkCost" : 3.41425774958E11 - # } - # } - result = json.loads(cursor.fetchone()[0]) - return result - - @classmethod - def query_cost_formatter( - cls, raw_cost: List[Dict[str, Any]] - ) -> List[Dict[str, str]]: - """ - Format cost estimate. - - :param raw_cost: JSON estimate from Trino - :return: Human readable cost estimate - """ - - def humanize(value: Any, suffix: str) -> str: - try: - value = int(value) - except ValueError: - return str(value) - - prefixes = ["K", "M", "G", "T", "P", "E", "Z", "Y"] - prefix = "" - to_next_prefix = 1000 - while value > to_next_prefix and prefixes: - prefix = prefixes.pop(0) - value //= to_next_prefix - - return f"{value} {prefix}{suffix}" - - cost = [] - columns = [ - ("outputRowCount", "Output count", " rows"), - ("outputSizeInBytes", "Output size", "B"), - ("cpuCost", "CPU cost", ""), - ("maxMemory", "Max memory", "B"), - ("networkCost", "Network cost", ""), - ] - for row in raw_cost: - estimate: Dict[str, float] = row.get("estimate", {}) - statement_cost = {} - for key, label, suffix in columns: - if key in estimate: - statement_cost[label] = humanize(estimate[key], suffix).strip() - cost.append(statement_cost) - - return cost - @staticmethod def get_extra_params(database: "Database") -> Dict[str, Any]: """ diff --git a/superset/jinja_context.py b/superset/jinja_context.py index 42f6809c74..4ee250673f 100644 --- a/superset/jinja_context.py +++ b/superset/jinja_context.py @@ -577,7 +577,24 @@ class HiveTemplateProcessor(PrestoTemplateProcessor): engine = "hive" -DEFAULT_PROCESSORS = {"presto": PrestoTemplateProcessor, "hive": HiveTemplateProcessor} +class TrinoTemplateProcessor(PrestoTemplateProcessor): + engine = "trino" + + def process_template(self, sql: str, **kwargs: Any) -> str: + template = self._env.from_string(sql) + kwargs.update(self._context) + + # Backwards compatibility if migrating from Presto. + context = validate_template_context(self.engine, kwargs) + context["presto"] = context["trino"] + return template.render(context) + + +DEFAULT_PROCESSORS = { + "presto": PrestoTemplateProcessor, + "hive": HiveTemplateProcessor, + "trino": TrinoTemplateProcessor, +} @memoized diff --git a/tests/integration_tests/db_engine_specs/trino_tests.py b/tests/integration_tests/db_engine_specs/trino_tests.py index b2235ae0c7..fc83b8c64c 100644 --- a/tests/integration_tests/db_engine_specs/trino_tests.py +++ b/tests/integration_tests/db_engine_specs/trino_tests.py @@ -27,25 +27,6 @@ from tests.integration_tests.db_engine_specs.base_tests import TestDbEngineSpec class TestTrinoDbEngineSpec(TestDbEngineSpec): - def test_adjust_database_uri(self): - url = URL(drivername="trino", database="hive") - TrinoEngineSpec.adjust_database_uri(url, selected_schema="foobar") - self.assertEqual(url.database, "hive/foobar") - - def test_adjust_database_uri_when_database_contain_schema(self): - url = URL(drivername="trino", database="hive/default") - TrinoEngineSpec.adjust_database_uri(url, selected_schema="foobar") - self.assertEqual(url.database, "hive/foobar") - - def test_adjust_database_uri_when_selected_schema_is_none(self): - url = URL(drivername="trino", database="hive") - TrinoEngineSpec.adjust_database_uri(url, selected_schema=None) - self.assertEqual(url.database, "hive") - - url.database = "hive/default" - TrinoEngineSpec.adjust_database_uri(url, selected_schema=None) - self.assertEqual(url.database, "hive/default") - def test_get_extra_params(self): database = Mock() diff --git a/tests/integration_tests/test_jinja_context.py b/tests/integration_tests/test_jinja_context.py index 879881a299..8c2db6920d 100644 --- a/tests/integration_tests/test_jinja_context.py +++ b/tests/integration_tests/test_jinja_context.py @@ -121,6 +121,23 @@ def test_template_hive(app_context: AppContext, mocker: MockFixture) -> None: assert tp.process_template(template) == "the_latest" +def test_template_trino(app_context: AppContext, mocker: MockFixture) -> None: + lp_mock = mocker.patch( + "superset.jinja_context.TrinoTemplateProcessor.latest_partition" + ) + lp_mock.return_value = "the_latest" + db = mock.Mock() + db.backend = "trino" + template = "{{ trino.latest_partition('my_table') }}" + tp = get_template_processor(database=db) + assert tp.process_template(template) == "the_latest" + + # Backwards compatibility if migrating from Presto. + template = "{{ presto.latest_partition('my_table') }}" + tp = get_template_processor(database=db) + assert tp.process_template(template) == "the_latest" + + def test_template_context_addons(app_context: AppContext, mocker: MockFixture) -> None: addons_mock = mocker.patch("superset.jinja_context.context_addons") addons_mock.return_value = {"datetime": datetime} diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py deleted file mode 100644 index 692fe875da..0000000000 --- a/tests/unit_tests/db_engine_specs/test_trino.py +++ /dev/null @@ -1,56 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from datetime import datetime -from typing import Optional - -import pytest -import pytz -from flask.ctx import AppContext - - -@pytest.mark.parametrize( - "target_type,dttm,result", - [ - ("VARCHAR", datetime(2022, 1, 1), None), - ("DATE", datetime(2022, 1, 1), "DATE '2022-01-01'"), - ( - "TIMESTAMP", - datetime(2022, 1, 1, 1, 23, 45, 600000), - "TIMESTAMP '2022-01-01 01:23:45.600000'", - ), - ( - "TIMESTAMP WITH TIME ZONE", - datetime(2022, 1, 1, 1, 23, 45, 600000), - "TIMESTAMP '2022-01-01 01:23:45.600000'", - ), - ( - "TIMESTAMP WITH TIME ZONE", - datetime(2022, 1, 1, 1, 23, 45, 600000, tzinfo=pytz.UTC), - "TIMESTAMP '2022-01-01 01:23:45.600000+00:00'", - ), - ], -) -def test_convert_dttm( - app_context: AppContext, - target_type: str, - dttm: datetime, - result: Optional[str], -) -> None: - from superset.db_engine_specs.trino import TrinoEngineSpec - - for case in (str.lower, str.upper): - assert TrinoEngineSpec.convert_dttm(case(target_type), dttm) == result