fix: always denorm column value before querying values (#25919)

This commit is contained in:
Hugh A. Miles II 2023-11-13 13:18:28 -05:00 committed by GitHub
parent 943696a87f
commit 8d8e1bb637
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 31 additions and 65 deletions

View File

@ -496,13 +496,6 @@ class BaseDatasource(
"""
raise NotImplementedError()
def values_for_column(self, column_name: str, limit: int = 10000) -> list[Any]:
"""Given a column, returns an iterable of distinct values
This is used to populate the dropdown showing a list of
values in filters in the explore view"""
raise NotImplementedError()
@staticmethod
def default_query(qry: Query) -> Query:
return qry

View File

@ -46,7 +46,6 @@ from sqlalchemy import (
inspect,
Integer,
or_,
select,
String,
Table,
Text,
@ -793,34 +792,6 @@ class SqlaTable(
)
) from ex
def values_for_column(self, column_name: str, limit: int = 10000) -> list[Any]:
"""Runs query against sqla to retrieve some
sample values for the given column.
"""
cols = {col.column_name: col for col in self.columns}
target_col = cols[column_name]
tp = self.get_template_processor()
tbl, cte = self.get_from_clause(tp)
qry = (
select([target_col.get_sqla_col(template_processor=tp)])
.select_from(tbl)
.distinct()
)
if limit:
qry = qry.limit(limit)
if self.fetch_values_predicate:
qry = qry.where(self.get_fetch_values_predicate(template_processor=tp))
with self.database.get_sqla_engine_with_context() as engine:
sql = qry.compile(engine, compile_kwargs={"literal_binds": True})
sql = self._apply_cte(sql, cte)
sql = self.mutate_query_from_config(sql)
df = pd.read_sql_query(sql=sql, con=engine)
return df[column_name].to_list()
def mutate_query_from_config(self, sql: str) -> str:
"""Apply config's SQL_QUERY_MUTATOR

View File

@ -120,6 +120,10 @@ class DatasourceRestApi(BaseSupersetApi):
column_name=column_name, limit=row_limit
)
return self.response(200, result=payload)
except KeyError:
return self.response(
400, message=f"Column name {column_name} does not exist"
)
except NotImplementedError:
return self.response(
400,

View File

@ -705,10 +705,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
"MIN": sa.func.MIN,
"MAX": sa.func.MAX,
}
@property
def fetch_value_predicate(self) -> str:
return "fix this!"
fetch_values_predicate = None
@property
def type(self) -> str:
@ -785,17 +782,20 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
def columns(self) -> list[Any]:
raise NotImplementedError()
def get_fetch_values_predicate(
self, template_processor: Optional[BaseTemplateProcessor] = None
) -> TextClause:
raise NotImplementedError()
def get_extra_cache_keys(self, query_obj: dict[str, Any]) -> list[Hashable]:
raise NotImplementedError()
def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor:
raise NotImplementedError()
def get_fetch_values_predicate(
self,
template_processor: Optional[ # pylint: disable=unused-argument
BaseTemplateProcessor
] = None, # pylint: disable=unused-argument
) -> TextClause:
return self.fetch_values_predicate
def get_sqla_row_level_filters(
self,
template_processor: BaseTemplateProcessor,
@ -1341,36 +1341,34 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
return and_(*l)
def values_for_column(self, column_name: str, limit: int = 10000) -> list[Any]:
"""Runs query against sqla to retrieve some
sample values for the given column.
"""
cols = {}
for col in self.columns:
if isinstance(col, dict):
cols[col.get("column_name")] = col
else:
cols[col.column_name] = col
target_col = cols[column_name]
tp = None # todo(hughhhh): add back self.get_template_processor()
# always denormalize column name before querying for values
db_dialect = self.database.get_dialect()
denomalized_col_name = self.database.db_engine_spec.denormalize_name(
db_dialect, column_name
)
cols = {col.column_name: col for col in self.columns}
target_col = cols[denomalized_col_name]
tp = self.get_template_processor()
tbl, cte = self.get_from_clause(tp)
if isinstance(target_col, dict):
sql_column = sa.column(target_col.get("name"))
else:
sql_column = target_col
qry = sa.select([sql_column]).select_from(tbl).distinct()
qry = (
sa.select([target_col.get_sqla_col(template_processor=tp)])
.select_from(tbl)
.distinct()
)
if limit:
qry = qry.limit(limit)
if self.fetch_values_predicate:
qry = qry.where(self.get_fetch_values_predicate(template_processor=tp))
with self.database.get_sqla_engine_with_context() as engine:
sql = qry.compile(engine, compile_kwargs={"literal_binds": True})
sql = self._apply_cte(sql, cte)
sql = self.mutate_query_from_config(sql)
df = pd.read_sql_query(sql=sql, con=engine)
return df[column_name].to_list()
return df[denomalized_col_name].to_list()
def get_timestamp_expression(
self,
@ -1942,7 +1940,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
)
having_clause_and += [self.text(having)]
if apply_fetch_values_predicate and self.fetch_values_predicate: # type: ignore
if apply_fetch_values_predicate and self.fetch_values_predicate:
qry = qry.where(
self.get_fetch_values_predicate(template_processor=template_processor)
)