fix: Ensure SQLAlchemy sessions are closed (#25031)

This commit is contained in:
John Bodley 2023-08-23 11:57:36 -07:00 committed by GitHub
parent 0dadf06245
commit adaab3550c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 152 additions and 126 deletions

View File

@ -74,28 +74,31 @@ def copy_dashboard(_mapper: Mapper, connection: Connection, target: Dashboard) -
session_class = sessionmaker(autoflush=False) session_class = sessionmaker(autoflush=False)
session = session_class(bind=connection) session = session_class(bind=connection)
new_user = session.query(User).filter_by(id=target.id).first()
# copy template dashboard to user try:
template = session.query(Dashboard).filter_by(id=int(dashboard_id)).first() new_user = session.query(User).filter_by(id=target.id).first()
dashboard = Dashboard(
dashboard_title=template.dashboard_title,
position_json=template.position_json,
description=template.description,
css=template.css,
json_metadata=template.json_metadata,
slices=template.slices,
owners=[new_user],
)
session.add(dashboard)
session.commit()
# set dashboard as the welcome dashboard # copy template dashboard to user
extra_attributes = UserAttribute( template = session.query(Dashboard).filter_by(id=int(dashboard_id)).first()
user_id=target.id, welcome_dashboard_id=dashboard.id dashboard = Dashboard(
) dashboard_title=template.dashboard_title,
session.add(extra_attributes) position_json=template.position_json,
session.commit() description=template.description,
css=template.css,
json_metadata=template.json_metadata,
slices=template.slices,
owners=[new_user],
)
session.add(dashboard)
# set dashboard as the welcome dashboard
extra_attributes = UserAttribute(
user_id=target.id, welcome_dashboard_id=dashboard.id
)
session.add(extra_attributes)
session.commit()
finally:
session.close()
sqla.event.listen(User, "after_insert", copy_dashboard) sqla.event.listen(User, "after_insert", copy_dashboard)
@ -414,13 +417,12 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin):
"native_filter_configuration", [] "native_filter_configuration", []
) )
for native_filter in native_filter_configuration: for native_filter in native_filter_configuration:
session = db.session()
for target in native_filter.get("targets", []): for target in native_filter.get("targets", []):
id_ = target.get("datasetId") id_ = target.get("datasetId")
if id_ is None: if id_ is None:
continue continue
datasource = DatasourceDAO.get_datasource( datasource = DatasourceDAO.get_datasource(
session, utils.DatasourceType.TABLE, id_ db.session, utils.DatasourceType.TABLE, id_
) )
datasource_ids.add((datasource.id, datasource.type)) datasource_ids.add((datasource.id, datasource.type))

View File

@ -170,17 +170,19 @@ class ObjectUpdater:
) -> None: ) -> None:
session = Session(bind=connection) session = Session(bind=connection)
# add `owner:` tags try:
cls._add_owners(session, target) # add `owner:` tags
cls._add_owners(session, target)
# add `type:` tags # add `type:` tags
tag = get_tag(f"type:{cls.object_type}", session, TagTypes.type) tag = get_tag(f"type:{cls.object_type}", session, TagTypes.type)
tagged_object = TaggedObject( tagged_object = TaggedObject(
tag_id=tag.id, object_id=target.id, object_type=cls.object_type tag_id=tag.id, object_id=target.id, object_type=cls.object_type
) )
session.add(tagged_object) session.add(tagged_object)
session.commit()
session.commit() finally:
session.close()
@classmethod @classmethod
def after_update( def after_update(
@ -191,25 +193,27 @@ class ObjectUpdater:
) -> None: ) -> None:
session = Session(bind=connection) session = Session(bind=connection)
# delete current `owner:` tags try:
query = ( # delete current `owner:` tags
session.query(TaggedObject.id) query = (
.join(Tag) session.query(TaggedObject.id)
.filter( .join(Tag)
TaggedObject.object_type == cls.object_type, .filter(
TaggedObject.object_id == target.id, TaggedObject.object_type == cls.object_type,
Tag.type == TagTypes.owner, TaggedObject.object_id == target.id,
Tag.type == TagTypes.owner,
)
)
ids = [row[0] for row in query]
session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
synchronize_session=False
) )
)
ids = [row[0] for row in query]
session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
synchronize_session=False
)
# add `owner:` tags # add `owner:` tags
cls._add_owners(session, target) cls._add_owners(session, target)
session.commit()
session.commit() finally:
session.close()
@classmethod @classmethod
def after_delete( def after_delete(
@ -220,13 +224,16 @@ class ObjectUpdater:
) -> None: ) -> None:
session = Session(bind=connection) session = Session(bind=connection)
# delete row from `tagged_objects` try:
session.query(TaggedObject).filter( # delete row from `tagged_objects`
TaggedObject.object_type == cls.object_type, session.query(TaggedObject).filter(
TaggedObject.object_id == target.id, TaggedObject.object_type == cls.object_type,
).delete() TaggedObject.object_id == target.id,
).delete()
session.commit() session.commit()
finally:
session.close()
class ChartUpdater(ObjectUpdater): class ChartUpdater(ObjectUpdater):
@ -267,35 +274,40 @@ class FavStarUpdater:
cls, _mapper: Mapper, connection: Connection, target: FavStar cls, _mapper: Mapper, connection: Connection, target: FavStar
) -> None: ) -> None:
session = Session(bind=connection) session = Session(bind=connection)
name = f"favorited_by:{target.user_id}" try:
tag = get_tag(name, session, TagTypes.favorited_by) name = f"favorited_by:{target.user_id}"
tagged_object = TaggedObject( tag = get_tag(name, session, TagTypes.favorited_by)
tag_id=tag.id, tagged_object = TaggedObject(
object_id=target.obj_id, tag_id=tag.id,
object_type=get_object_type(target.class_name), object_id=target.obj_id,
) object_type=get_object_type(target.class_name),
session.add(tagged_object) )
session.add(tagged_object)
session.commit() session.commit()
finally:
session.close()
@classmethod @classmethod
def after_delete( def after_delete(
cls, _mapper: Mapper, connection: Connection, target: FavStar cls, _mapper: Mapper, connection: Connection, target: FavStar
) -> None: ) -> None:
session = Session(bind=connection) session = Session(bind=connection)
name = f"favorited_by:{target.user_id}" try:
query = ( name = f"favorited_by:{target.user_id}"
session.query(TaggedObject.id) query = (
.join(Tag) session.query(TaggedObject.id)
.filter( .join(Tag)
TaggedObject.object_id == target.obj_id, .filter(
Tag.type == TagTypes.favorited_by, TaggedObject.object_id == target.obj_id,
Tag.name == name, Tag.type == TagTypes.favorited_by,
Tag.name == name,
)
)
ids = [row[0] for row in query]
session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
synchronize_session=False
) )
)
ids = [row[0] for row in query]
session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
synchronize_session=False
)
session.commit() session.commit()
finally:
session.close()

View File

@ -95,7 +95,11 @@ class DummyStrategy(Strategy): # pylint: disable=too-few-public-methods
def get_payloads(self) -> list[dict[str, int]]: def get_payloads(self) -> list[dict[str, int]]:
session = db.create_scoped_session() session = db.create_scoped_session()
charts = session.query(Slice).all()
try:
charts = session.query(Slice).all()
finally:
session.close()
return [get_payload(chart) for chart in charts] return [get_payload(chart) for chart in charts]
@ -129,20 +133,24 @@ class TopNDashboardsStrategy(Strategy): # pylint: disable=too-few-public-method
payloads = [] payloads = []
session = db.create_scoped_session() session = db.create_scoped_session()
records = ( try:
session.query(Log.dashboard_id, func.count(Log.dashboard_id)) records = (
.filter(and_(Log.dashboard_id.isnot(None), Log.dttm >= self.since)) session.query(Log.dashboard_id, func.count(Log.dashboard_id))
.group_by(Log.dashboard_id) .filter(and_(Log.dashboard_id.isnot(None), Log.dttm >= self.since))
.order_by(func.count(Log.dashboard_id).desc()) .group_by(Log.dashboard_id)
.limit(self.top_n) .order_by(func.count(Log.dashboard_id).desc())
.all() .limit(self.top_n)
) .all()
dash_ids = [record.dashboard_id for record in records] )
dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all() dash_ids = [record.dashboard_id for record in records]
for dashboard in dashboards: dashboards = (
for chart in dashboard.slices: session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
payloads.append(get_payload(chart, dashboard)) )
for dashboard in dashboards:
for chart in dashboard.slices:
payloads.append(get_payload(chart, dashboard))
finally:
session.close()
return payloads return payloads
@ -172,42 +180,46 @@ class DashboardTagsStrategy(Strategy): # pylint: disable=too-few-public-methods
payloads = [] payloads = []
session = db.create_scoped_session() session = db.create_scoped_session()
tags = session.query(Tag).filter(Tag.name.in_(self.tags)).all() try:
tag_ids = [tag.id for tag in tags] tags = session.query(Tag).filter(Tag.name.in_(self.tags)).all()
tag_ids = [tag.id for tag in tags]
# add dashboards that are tagged # add dashboards that are tagged
tagged_objects = ( tagged_objects = (
session.query(TaggedObject) session.query(TaggedObject)
.filter( .filter(
and_( and_(
TaggedObject.object_type == "dashboard", TaggedObject.object_type == "dashboard",
TaggedObject.tag_id.in_(tag_ids), TaggedObject.tag_id.in_(tag_ids),
)
) )
.all()
) )
.all() dash_ids = [tagged_object.object_id for tagged_object in tagged_objects]
) tagged_dashboards = session.query(Dashboard).filter(
dash_ids = [tagged_object.object_id for tagged_object in tagged_objects] Dashboard.id.in_(dash_ids)
tagged_dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)) )
for dashboard in tagged_dashboards: for dashboard in tagged_dashboards:
for chart in dashboard.slices: for chart in dashboard.slices:
payloads.append(get_payload(chart))
# add charts that are tagged
tagged_objects = (
session.query(TaggedObject)
.filter(
and_(
TaggedObject.object_type == "chart",
TaggedObject.tag_id.in_(tag_ids),
)
)
.all()
)
chart_ids = [tagged_object.object_id for tagged_object in tagged_objects]
tagged_charts = session.query(Slice).filter(Slice.id.in_(chart_ids))
for chart in tagged_charts:
payloads.append(get_payload(chart)) payloads.append(get_payload(chart))
finally:
# add charts that are tagged session.close()
tagged_objects = (
session.query(TaggedObject)
.filter(
and_(
TaggedObject.object_type == "chart",
TaggedObject.tag_id.in_(tag_ids),
)
)
.all()
)
chart_ids = [tagged_object.object_id for tagged_object in tagged_objects]
tagged_charts = session.query(Slice).filter(Slice.id.in_(chart_ids))
for chart in tagged_charts:
payloads.append(get_payload(chart))
return payloads return payloads