Add Table performance improvements (#3509)

* Improved performance of 'Add table' function

* got rid of pvt function call

* changes metric obj to key on metric_name
This commit is contained in:
Jeff Niu 2017-09-25 11:35:09 -07:00 committed by Maxime Beauchemin
parent 255ea69977
commit f3146ef6f9
3 changed files with 40 additions and 42 deletions

View File

@ -10,7 +10,7 @@ from sqlalchemy import (
DateTime,
)
import sqlalchemy as sa
from sqlalchemy import asc, and_, desc, select
from sqlalchemy import asc, and_, desc, select, or_
from sqlalchemy.sql.expression import TextAsFrom
from sqlalchemy.orm import backref, relationship
from sqlalchemy.sql import table, literal_column, text, column
@ -588,30 +588,29 @@ class SqlaTable(Model, BaseDatasource):
table = self.get_sqla_table_object()
except Exception:
raise Exception(_(
"Table doesn't seem to exist in the specified database, "
"couldn't fetch column information"))
"Table [{}] doesn't seem to exist in the specified database, "
"couldn't fetch column information").format(self.table_name))
TC = TableColumn # noqa shortcut to class
M = SqlMetric # noqa
metrics = []
any_date_col = None
db_dialect = self.database.get_sqla_engine().dialect
db_dialect = self.database.get_dialect()
dbcols = (
db.session.query(TableColumn)
.filter(TableColumn.table == self)
.filter(or_(TableColumn.column_name == col.name
for col in table.columns)))
dbcols = {dbcol.column_name: dbcol for dbcol in dbcols}
for col in table.columns:
try:
datatype = "{}".format(col.type.compile(dialect=db_dialect)).upper()
datatype = col.type.compile(dialect=db_dialect).upper()
except Exception as e:
datatype = "UNKNOWN"
logging.error(
"Unrecognized data type in {}.{}".format(table, col.name))
logging.exception(e)
dbcol = (
db.session
.query(TC)
.filter(TC.table == self)
.filter(TC.column_name == col.name)
.first()
)
db.session.flush()
dbcol = dbcols.get(col.name, None)
if not dbcol:
dbcol = TableColumn(column_name=col.name, type=datatype)
dbcol.groupby = dbcol.is_string
@ -619,14 +618,11 @@ class SqlaTable(Model, BaseDatasource):
dbcol.sum = dbcol.is_num
dbcol.avg = dbcol.is_num
dbcol.is_dttm = dbcol.is_time
db.session.merge(self)
self.columns.append(dbcol)
if not any_date_col and dbcol.is_time:
any_date_col = col.name
quoted = "{}".format(col.compile(dialect=db_dialect))
quoted = str(col.compile(dialect=db_dialect))
if dbcol.sum:
metrics.append(M(
metric_name='sum__' + dbcol.column_name,
@ -663,8 +659,6 @@ class SqlaTable(Model, BaseDatasource):
expression="COUNT(DISTINCT {})".format(quoted)
))
dbcol.type = datatype
db.session.merge(self)
db.session.commit()
metrics.append(M(
metric_name='count',
@ -672,19 +666,18 @@ class SqlaTable(Model, BaseDatasource):
metric_type='count',
expression="COUNT(*)"
))
dbmetrics = db.session.query(M).filter(M.table_id == self.id).filter(
or_(M.metric_name == metric.metric_name for metric in metrics))
dbmetrics = {metric.metric_name: metric for metric in dbmetrics}
for metric in metrics:
m = (
db.session.query(M)
.filter(M.metric_name == metric.metric_name)
.filter(M.table_id == self.id)
.first()
)
metric.table_id = self.id
if not m:
if not dbmetrics.get(metric.metric_name, None):
db.session.add(metric)
db.session.commit()
if not self.main_dttm_col:
self.main_dttm_col = any_date_col
db.session.merge(self)
db.session.commit()
@classmethod
def import_obj(cls, i_datasource, import_time=None):

View File

@ -1,6 +1,5 @@
"""Views used by the SqlAlchemy connector"""
import logging
from past.builtins import basestring
from flask import Markup, flash, redirect
@ -229,21 +228,17 @@ class TableModelView(DatasourceModelView, DeleteMixin): # noqa
}
def pre_add(self, table):
number_of_existing_tables = db.session.query(
sa.func.count('*')).filter(
models.SqlaTable.table_name == table.table_name,
models.SqlaTable.schema == table.schema,
models.SqlaTable.database_id == table.database.id
).scalar()
# table object is already added to the session
if number_of_existing_tables > 1:
raise Exception(get_datasource_exist_error_mgs(table.full_name))
with db.session.no_autoflush:
table_query = db.session.query(models.SqlaTable).filter(
models.SqlaTable.table_name == table.table_name,
models.SqlaTable.schema == table.schema,
models.SqlaTable.database_id == table.database.id)
if db.session.query(table_query.exists()).scalar():
raise Exception(
get_datasource_exist_error_mgs(table.full_name))
# Fail before adding if the table can't be found
try:
table.get_sqla_table_object()
except Exception as e:
logging.exception(e)
if not table.database.has_table(table):
raise Exception(_(
"Table [{}] could not be found, "
"please double check your "

View File

@ -33,6 +33,7 @@ from sqlalchemy.orm.session import make_transient
from sqlalchemy.pool import NullPool
from sqlalchemy.sql import text
from sqlalchemy.sql.expression import TextAsFrom
from sqlalchemy.engine import url
from sqlalchemy_utils import EncryptedType
from superset import app, db, db_engine_specs, utils, sm
@ -743,6 +744,15 @@ class Database(Model, AuditMixinNullable):
return (
"[{obj.database_name}].(id:{obj.id})").format(obj=self)
def has_table(self, table):
engine = self.get_sqla_engine()
return engine.dialect.has_table(
engine, table.table_name, table.schema or None)
def get_dialect(self):
sqla_url = url.make_url(self.sqlalchemy_uri_decrypted)
return sqla_url.get_dialect()()
sqla.event.listen(Database, 'after_insert', set_perm)
sqla.event.listen(Database, 'after_update', set_perm)