diff --git a/superset/models/core.py b/superset/models/core.py index 8448c7ba54..e36c717299 100644 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -690,9 +690,13 @@ class Database(Model, AuditMixinNullable, ImportMixin): return self.get_dialect().identifier_preparer.quote def get_df(self, sql, schema): - sql = sql.strip().strip(';') + sqls = [x.strip() for x in sql.strip().strip(';').split(';')] eng = self.get_sqla_engine(schema=schema) - df = pd.read_sql_query(sql, eng) + + for i in range(len(sqls) - 1): + eng.execute(sqls[i]) + + df = pd.read_sql_query(sqls[-1], eng) def needs_conversion(df_series): if df_series.empty: diff --git a/tests/model_tests.py b/tests/model_tests.py index 8af104f57c..45ee61edd6 100644 --- a/tests/model_tests.py +++ b/tests/model_tests.py @@ -106,6 +106,20 @@ class DatabaseModelTestCase(SupersetTestCase): self.assertEquals(d.get('P1D').function, 'DATE({col})') self.assertEquals(d.get('Time Column').function, '{col}') + def test_single_statement(self): + main_db = self.get_main_database(db.session) + + if main_db.backend == 'mysql': + df = main_db.get_df('SELECT 1', None) + self.assertEquals(df.iat[0, 0], 1) + + def test_multi_statement(self): + main_db = self.get_main_database(db.session) + + if main_db.backend == 'mysql': + df = main_db.get_df('USE superset; SELECT 1', None) + self.assertEquals(df.iat[0, 0], 1) + class SqlaTableModelTestCase(SupersetTestCase):