[sql lab] Use context manager for sqllab sessions (#4927)

* use session context manager

* contextlib2 added to requirements.txt

* Fixing error: Import statements are in the wrong order. from contextlib2 import contextmanager should be before import sqlalchemy

* Fixing return inside generator

* fixed C812 missing trailing comma

* E501 line too long

* fixed E127 continuation line over-indented for visual indent

* E722 do not use bare except

* reorganized imports

* added context manager contextlib2.contextmanager

* fixed import ordering
This commit is contained in:
grafke 2018-05-10 19:32:31 +02:00 committed by Maxime Beauchemin
parent af4dd59661
commit 8591319bde
3 changed files with 38 additions and 26 deletions

View File

@ -292,7 +292,7 @@ generated-members=
# List of decorators that produce context managers, such as # List of decorators that produce context managers, such as
# contextlib.contextmanager. Add to this list to register other decorators that # contextlib.contextmanager. Add to this list to register other decorators that
# produce valid context managers. # produce valid context managers.
contextmanager-decorators=contextlib.contextmanager contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager
[VARIABLES] [VARIABLES]

View File

@ -37,3 +37,4 @@ thrift==0.11.0
thrift-sasl==0.3.0 thrift-sasl==0.3.0
unicodecsv==0.14.1 unicodecsv==0.14.1
unidecode==1.0.22 unidecode==1.0.22
contextlib2==0.5.5

View File

@ -1,9 +1,6 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
# pylint: disable=C,R,W # pylint: disable=C,R,W
from __future__ import absolute_import from __future__ import absolute_import, division, print_function, unicode_literals
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from datetime import datetime from datetime import datetime
import json import json
@ -12,6 +9,7 @@ from time import sleep
import uuid import uuid
from celery.exceptions import SoftTimeLimitExceeded from celery.exceptions import SoftTimeLimitExceeded
from contextlib2 import contextmanager
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import sqlalchemy import sqlalchemy
@ -75,16 +73,28 @@ def get_query(query_id, session, retry_count=5):
return query return query
def get_session(nullpool): @contextmanager
def session_scope(nullpool):
"""Provide a transactional scope around a series of operations."""
if nullpool: if nullpool:
engine = sqlalchemy.create_engine( engine = sqlalchemy.create_engine(
app.config.get('SQLALCHEMY_DATABASE_URI'), poolclass=NullPool) app.config.get('SQLALCHEMY_DATABASE_URI'), poolclass=NullPool)
session_class = sessionmaker() session_class = sessionmaker()
session_class.configure(bind=engine) session_class.configure(bind=engine)
return session_class() session = session_class()
session = db.session() else:
session.commit() # HACK session = db.session()
return session session.commit() # HACK
try:
yield session
session.commit()
except Exception as e:
session.rollback()
logging.exception(e)
raise
finally:
session.close()
def convert_results_to_df(cursor_description, data): def convert_results_to_df(cursor_description, data):
@ -109,30 +119,31 @@ def convert_results_to_df(cursor_description, data):
@celery_app.task(bind=True, soft_time_limit=SQLLAB_TIMEOUT) @celery_app.task(bind=True, soft_time_limit=SQLLAB_TIMEOUT)
def get_sql_results( def get_sql_results(
ctask, query_id, rendered_query, return_results=True, store_results=False, ctask, query_id, rendered_query, return_results=True, store_results=False,
user_name=None): user_name=None):
"""Executes the sql query returns the results.""" """Executes the sql query returns the results."""
try: with session_scope(not ctask.request.called_directly) as session:
return execute_sql(
ctask, query_id, rendered_query, return_results, store_results, user_name) try:
except Exception as e: return execute_sql(
logging.exception(e) ctask, query_id, rendered_query, return_results, store_results, user_name,
stats_logger.incr('error_sqllab_unhandled') session=session)
sesh = get_session(not ctask.request.called_directly) except Exception as e:
query = get_query(query_id, sesh) logging.exception(e)
query.error_message = str(e) stats_logger.incr('error_sqllab_unhandled')
query.status = QueryStatus.FAILED query = get_query(query_id, session)
query.tmp_table_name = None query.error_message = str(e)
sesh.commit() query.status = QueryStatus.FAILED
raise query.tmp_table_name = None
session.commit()
raise
def execute_sql( def execute_sql(
ctask, query_id, rendered_query, return_results=True, store_results=False, ctask, query_id, rendered_query, return_results=True, store_results=False,
user_name=None, user_name=None, session=None,
): ):
"""Executes the sql query returns the results.""" """Executes the sql query returns the results."""
session = get_session(not ctask.request.called_directly)
query = get_query(query_id, session) query = get_query(query_id, session)
payload = dict(query_id=query_id) payload = dict(query_id=query_id)