Add Table performance improvements (#3509)

* Improved performance of 'Add table' function

* got rid of pvt function call

* changes metric obj to key on metric_name
This commit is contained in:
Jeff Niu 2017-09-25 11:35:09 -07:00 committed by Maxime Beauchemin
parent 255ea69977
commit f3146ef6f9
3 changed files with 40 additions and 42 deletions

View File

@ -10,7 +10,7 @@ from sqlalchemy import (
DateTime, DateTime,
) )
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy import asc, and_, desc, select from sqlalchemy import asc, and_, desc, select, or_
from sqlalchemy.sql.expression import TextAsFrom from sqlalchemy.sql.expression import TextAsFrom
from sqlalchemy.orm import backref, relationship from sqlalchemy.orm import backref, relationship
from sqlalchemy.sql import table, literal_column, text, column from sqlalchemy.sql import table, literal_column, text, column
@ -588,30 +588,29 @@ class SqlaTable(Model, BaseDatasource):
table = self.get_sqla_table_object() table = self.get_sqla_table_object()
except Exception: except Exception:
raise Exception(_( raise Exception(_(
"Table doesn't seem to exist in the specified database, " "Table [{}] doesn't seem to exist in the specified database, "
"couldn't fetch column information")) "couldn't fetch column information").format(self.table_name))
TC = TableColumn # noqa shortcut to class
M = SqlMetric # noqa M = SqlMetric # noqa
metrics = [] metrics = []
any_date_col = None any_date_col = None
db_dialect = self.database.get_sqla_engine().dialect db_dialect = self.database.get_dialect()
dbcols = (
db.session.query(TableColumn)
.filter(TableColumn.table == self)
.filter(or_(TableColumn.column_name == col.name
for col in table.columns)))
dbcols = {dbcol.column_name: dbcol for dbcol in dbcols}
for col in table.columns: for col in table.columns:
try: try:
datatype = "{}".format(col.type.compile(dialect=db_dialect)).upper() datatype = col.type.compile(dialect=db_dialect).upper()
except Exception as e: except Exception as e:
datatype = "UNKNOWN" datatype = "UNKNOWN"
logging.error( logging.error(
"Unrecognized data type in {}.{}".format(table, col.name)) "Unrecognized data type in {}.{}".format(table, col.name))
logging.exception(e) logging.exception(e)
dbcol = ( dbcol = dbcols.get(col.name, None)
db.session
.query(TC)
.filter(TC.table == self)
.filter(TC.column_name == col.name)
.first()
)
db.session.flush()
if not dbcol: if not dbcol:
dbcol = TableColumn(column_name=col.name, type=datatype) dbcol = TableColumn(column_name=col.name, type=datatype)
dbcol.groupby = dbcol.is_string dbcol.groupby = dbcol.is_string
@ -619,14 +618,11 @@ class SqlaTable(Model, BaseDatasource):
dbcol.sum = dbcol.is_num dbcol.sum = dbcol.is_num
dbcol.avg = dbcol.is_num dbcol.avg = dbcol.is_num
dbcol.is_dttm = dbcol.is_time dbcol.is_dttm = dbcol.is_time
db.session.merge(self)
self.columns.append(dbcol) self.columns.append(dbcol)
if not any_date_col and dbcol.is_time: if not any_date_col and dbcol.is_time:
any_date_col = col.name any_date_col = col.name
quoted = "{}".format(col.compile(dialect=db_dialect)) quoted = str(col.compile(dialect=db_dialect))
if dbcol.sum: if dbcol.sum:
metrics.append(M( metrics.append(M(
metric_name='sum__' + dbcol.column_name, metric_name='sum__' + dbcol.column_name,
@ -663,8 +659,6 @@ class SqlaTable(Model, BaseDatasource):
expression="COUNT(DISTINCT {})".format(quoted) expression="COUNT(DISTINCT {})".format(quoted)
)) ))
dbcol.type = datatype dbcol.type = datatype
db.session.merge(self)
db.session.commit()
metrics.append(M( metrics.append(M(
metric_name='count', metric_name='count',
@ -672,19 +666,18 @@ class SqlaTable(Model, BaseDatasource):
metric_type='count', metric_type='count',
expression="COUNT(*)" expression="COUNT(*)"
)) ))
dbmetrics = db.session.query(M).filter(M.table_id == self.id).filter(
or_(M.metric_name == metric.metric_name for metric in metrics))
dbmetrics = {metric.metric_name: metric for metric in dbmetrics}
for metric in metrics: for metric in metrics:
m = (
db.session.query(M)
.filter(M.metric_name == metric.metric_name)
.filter(M.table_id == self.id)
.first()
)
metric.table_id = self.id metric.table_id = self.id
if not m: if not dbmetrics.get(metric.metric_name, None):
db.session.add(metric) db.session.add(metric)
db.session.commit()
if not self.main_dttm_col: if not self.main_dttm_col:
self.main_dttm_col = any_date_col self.main_dttm_col = any_date_col
db.session.merge(self)
db.session.commit()
@classmethod @classmethod
def import_obj(cls, i_datasource, import_time=None): def import_obj(cls, i_datasource, import_time=None):

View File

@ -1,6 +1,5 @@
"""Views used by the SqlAlchemy connector""" """Views used by the SqlAlchemy connector"""
import logging import logging
from past.builtins import basestring from past.builtins import basestring
from flask import Markup, flash, redirect from flask import Markup, flash, redirect
@ -229,21 +228,17 @@ class TableModelView(DatasourceModelView, DeleteMixin): # noqa
} }
def pre_add(self, table): def pre_add(self, table):
number_of_existing_tables = db.session.query( with db.session.no_autoflush:
sa.func.count('*')).filter( table_query = db.session.query(models.SqlaTable).filter(
models.SqlaTable.table_name == table.table_name, models.SqlaTable.table_name == table.table_name,
models.SqlaTable.schema == table.schema, models.SqlaTable.schema == table.schema,
models.SqlaTable.database_id == table.database.id models.SqlaTable.database_id == table.database.id)
).scalar() if db.session.query(table_query.exists()).scalar():
# table object is already added to the session raise Exception(
if number_of_existing_tables > 1: get_datasource_exist_error_mgs(table.full_name))
raise Exception(get_datasource_exist_error_mgs(table.full_name))
# Fail before adding if the table can't be found # Fail before adding if the table can't be found
try: if not table.database.has_table(table):
table.get_sqla_table_object()
except Exception as e:
logging.exception(e)
raise Exception(_( raise Exception(_(
"Table [{}] could not be found, " "Table [{}] could not be found, "
"please double check your " "please double check your "

View File

@ -33,6 +33,7 @@ from sqlalchemy.orm.session import make_transient
from sqlalchemy.pool import NullPool from sqlalchemy.pool import NullPool
from sqlalchemy.sql import text from sqlalchemy.sql import text
from sqlalchemy.sql.expression import TextAsFrom from sqlalchemy.sql.expression import TextAsFrom
from sqlalchemy.engine import url
from sqlalchemy_utils import EncryptedType from sqlalchemy_utils import EncryptedType
from superset import app, db, db_engine_specs, utils, sm from superset import app, db, db_engine_specs, utils, sm
@ -743,6 +744,15 @@ class Database(Model, AuditMixinNullable):
return ( return (
"[{obj.database_name}].(id:{obj.id})").format(obj=self) "[{obj.database_name}].(id:{obj.id})").format(obj=self)
def has_table(self, table):
engine = self.get_sqla_engine()
return engine.dialect.has_table(
engine, table.table_name, table.schema or None)
def get_dialect(self):
sqla_url = url.make_url(self.sqlalchemy_uri_decrypted)
return sqla_url.get_dialect()()
sqla.event.listen(Database, 'after_insert', set_perm) sqla.event.listen(Database, 'after_insert', set_perm)
sqla.event.listen(Database, 'after_update', set_perm) sqla.event.listen(Database, 'after_update', set_perm)