From 8591319bde7db14ac358eae979dfefcbf450db91 Mon Sep 17 00:00:00 2001 From: grafke Date: Thu, 10 May 2018 19:32:31 +0200 Subject: [PATCH] [sql lab] Use context manager for sqllab sessions (#4927) * use session context manager * contextlib2 added to requirements.txt * Fixing error: Import statements are in the wrong order. from contextlib2 import contextmanager should be before import sqlalchemy * Fixing return inside generator * fixed C812 missing trailing comma * E501 line too long * fixed E127 continuation line over-indented for visual indent * E722 do not use bare except * reorganized imports * added context manager contextlib2.contextmanager * fixed import ordering --- .pylintrc | 2 +- requirements.txt | 1 + superset/sql_lab.py | 61 ++++++++++++++++++++++++++------------------- 3 files changed, 38 insertions(+), 26 deletions(-) diff --git a/.pylintrc b/.pylintrc index 6e213c4831..0f9710688d 100644 --- a/.pylintrc +++ b/.pylintrc @@ -292,7 +292,7 @@ generated-members= # List of decorators that produce context managers, such as # contextlib.contextmanager. Add to this list to register other decorators that # produce valid context managers. -contextmanager-decorators=contextlib.contextmanager +contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager [VARIABLES] diff --git a/requirements.txt b/requirements.txt index e3bceee7cb..05dca9daf4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -37,3 +37,4 @@ thrift==0.11.0 thrift-sasl==0.3.0 unicodecsv==0.14.1 unidecode==1.0.22 +contextlib2==0.5.5 \ No newline at end of file diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 750da1f376..856ea4880f 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -1,9 +1,6 @@ # -*- coding: utf-8 -*- # pylint: disable=C,R,W -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals +from __future__ import absolute_import, division, print_function, unicode_literals from datetime import datetime import json @@ -12,6 +9,7 @@ from time import sleep import uuid from celery.exceptions import SoftTimeLimitExceeded +from contextlib2 import contextmanager import numpy as np import pandas as pd import sqlalchemy @@ -75,16 +73,28 @@ def get_query(query_id, session, retry_count=5): return query -def get_session(nullpool): +@contextmanager +def session_scope(nullpool): + """Provide a transactional scope around a series of operations.""" if nullpool: engine = sqlalchemy.create_engine( app.config.get('SQLALCHEMY_DATABASE_URI'), poolclass=NullPool) session_class = sessionmaker() session_class.configure(bind=engine) - return session_class() - session = db.session() - session.commit() # HACK - return session + session = session_class() + else: + session = db.session() + session.commit() # HACK + + try: + yield session + session.commit() + except Exception as e: + session.rollback() + logging.exception(e) + raise + finally: + session.close() def convert_results_to_df(cursor_description, data): @@ -109,30 +119,31 @@ def convert_results_to_df(cursor_description, data): @celery_app.task(bind=True, soft_time_limit=SQLLAB_TIMEOUT) def get_sql_results( - ctask, query_id, rendered_query, return_results=True, store_results=False, + ctask, query_id, rendered_query, return_results=True, store_results=False, user_name=None): """Executes the sql query returns the results.""" - try: - return execute_sql( - ctask, query_id, rendered_query, return_results, store_results, user_name) - except Exception as e: - logging.exception(e) - stats_logger.incr('error_sqllab_unhandled') - sesh = get_session(not ctask.request.called_directly) - query = get_query(query_id, sesh) - query.error_message = str(e) - query.status = QueryStatus.FAILED - query.tmp_table_name = None - sesh.commit() - raise + with session_scope(not ctask.request.called_directly) as session: + + try: + return execute_sql( + ctask, query_id, rendered_query, return_results, store_results, user_name, + session=session) + except Exception as e: + logging.exception(e) + stats_logger.incr('error_sqllab_unhandled') + query = get_query(query_id, session) + query.error_message = str(e) + query.status = QueryStatus.FAILED + query.tmp_table_name = None + session.commit() + raise def execute_sql( ctask, query_id, rendered_query, return_results=True, store_results=False, - user_name=None, + user_name=None, session=None, ): """Executes the sql query returns the results.""" - session = get_session(not ctask.request.called_directly) query = get_query(query_id, session) payload = dict(query_id=query_id)