pass source to db api mutator (#6497)

This commit is contained in:
timifasubaa 2019-01-10 17:30:32 -08:00 committed by GitHub
parent a2ce9974cd
commit 9d70c348d3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 20 additions and 4 deletions

View File

@ -757,7 +757,7 @@ class Database(Model, AuditMixinNullable, ImportMixin):
@utils.memoized( @utils.memoized(
watch=('impersonate_user', 'sqlalchemy_uri_decrypted', 'extra')) watch=('impersonate_user', 'sqlalchemy_uri_decrypted', 'extra'))
def get_sqla_engine(self, schema=None, nullpool=True, user_name=None): def get_sqla_engine(self, schema=None, nullpool=True, user_name=None, source=None):
extra = self.get_extra() extra = self.get_extra()
url = make_url(self.sqlalchemy_uri_decrypted) url = make_url(self.sqlalchemy_uri_decrypted)
url = self.db_engine_spec.adjust_database_uri(url, schema) url = self.db_engine_spec.adjust_database_uri(url, schema)
@ -790,7 +790,7 @@ class Database(Model, AuditMixinNullable, ImportMixin):
DB_CONNECTION_MUTATOR = config.get('DB_CONNECTION_MUTATOR') DB_CONNECTION_MUTATOR = config.get('DB_CONNECTION_MUTATOR')
if DB_CONNECTION_MUTATOR: if DB_CONNECTION_MUTATOR:
url, params = DB_CONNECTION_MUTATOR( url, params = DB_CONNECTION_MUTATOR(
url, params, effective_username, security_manager) url, params, effective_username, security_manager, source)
return create_engine(url, **params) return create_engine(url, **params)
def get_reserved_words(self): def get_reserved_words(self):
@ -801,7 +801,14 @@ class Database(Model, AuditMixinNullable, ImportMixin):
def get_df(self, sql, schema): def get_df(self, sql, schema):
sqls = [str(s).strip().strip(';') for s in sqlparse.parse(sql)] sqls = [str(s).strip().strip(';') for s in sqlparse.parse(sql)]
engine = self.get_sqla_engine(schema=schema) source_key = None
if request and request.referrer:
if '/superset/dashboard/' in request.referrer:
source_key = 'dashboard'
elif '/superset/explore/' in request.referrer:
source_key = 'chart'
engine = self.get_sqla_engine(
schema=schema, source=utils.sources.get(source_key, None))
username = utils.get_username() username = utils.get_username()
def needs_conversion(df_series): def needs_conversion(df_series):
@ -860,7 +867,8 @@ class Database(Model, AuditMixinNullable, ImportMixin):
self, table_name, schema=None, limit=100, show_cols=False, self, table_name, schema=None, limit=100, show_cols=False,
indent=True, latest_partition=False, cols=None): indent=True, latest_partition=False, cols=None):
"""Generates a ``select *`` statement in the proper dialect""" """Generates a ``select *`` statement in the proper dialect"""
eng = self.get_sqla_engine(schema=schema) eng = self.get_sqla_engine(
schema=schema, source=utils.sources.get('sql_lab', None))
return self.db_engine_spec.select_star( return self.db_engine_spec.select_star(
self, table_name, schema=schema, engine=eng, self, table_name, schema=schema, engine=eng,
limit=limit, show_cols=show_cols, limit=limit, show_cols=show_cols,

View File

@ -20,6 +20,7 @@ from superset.tasks.celery_app import app as celery_app
from superset.utils.core import ( from superset.utils.core import (
json_iso_dttm_ser, json_iso_dttm_ser,
QueryStatus, QueryStatus,
sources,
zlib_compress, zlib_compress,
) )
from superset.utils.dates import now_as_float from superset.utils.dates import now_as_float
@ -226,6 +227,7 @@ def execute_sql_statements(
schema=query.schema, schema=query.schema,
nullpool=True, nullpool=True,
user_name=user_name, user_name=user_name,
source=sources.get('sql_lab', None),
) )
# Sharing a single connection and cursor across the # Sharing a single connection and cursor across the
# execution of all statements (if many) # execution of all statements (if many)

View File

@ -54,6 +54,12 @@ ADHOC_METRIC_EXPRESSION_TYPES = {
JS_MAX_INTEGER = 9007199254740991 # Largest int Java Script can handle 2^53-1 JS_MAX_INTEGER = 9007199254740991 # Largest int Java Script can handle 2^53-1
sources = {
'chart': 0,
'dashboard': 1,
'sql_lab': 2,
}
def flasher(msg, severity=None): def flasher(msg, severity=None):
"""Flask's flash if available, logging call if not""" """Flask's flash if available, logging call if not"""