fix: Ensure SQLAlchemy sessions are closed (#25031)

This commit is contained in:
John Bodley 2023-08-23 11:57:36 -07:00 committed by GitHub
parent 0dadf06245
commit adaab3550c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 152 additions and 126 deletions

View File

@ -74,6 +74,8 @@ def copy_dashboard(_mapper: Mapper, connection: Connection, target: Dashboard) -
session_class = sessionmaker(autoflush=False)
session = session_class(bind=connection)
try:
new_user = session.query(User).filter_by(id=target.id).first()
# copy template dashboard to user
@ -88,7 +90,6 @@ def copy_dashboard(_mapper: Mapper, connection: Connection, target: Dashboard) -
owners=[new_user],
)
session.add(dashboard)
session.commit()
# set dashboard as the welcome dashboard
extra_attributes = UserAttribute(
@ -96,6 +97,8 @@ def copy_dashboard(_mapper: Mapper, connection: Connection, target: Dashboard) -
)
session.add(extra_attributes)
session.commit()
finally:
session.close()
sqla.event.listen(User, "after_insert", copy_dashboard)
@ -414,13 +417,12 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin):
"native_filter_configuration", []
)
for native_filter in native_filter_configuration:
session = db.session()
for target in native_filter.get("targets", []):
id_ = target.get("datasetId")
if id_ is None:
continue
datasource = DatasourceDAO.get_datasource(
session, utils.DatasourceType.TABLE, id_
db.session, utils.DatasourceType.TABLE, id_
)
datasource_ids.add((datasource.id, datasource.type))

View File

@ -170,6 +170,7 @@ class ObjectUpdater:
) -> None:
session = Session(bind=connection)
try:
# add `owner:` tags
cls._add_owners(session, target)
@ -179,8 +180,9 @@ class ObjectUpdater:
tag_id=tag.id, object_id=target.id, object_type=cls.object_type
)
session.add(tagged_object)
session.commit()
finally:
session.close()
@classmethod
def after_update(
@ -191,6 +193,7 @@ class ObjectUpdater:
) -> None:
session = Session(bind=connection)
try:
# delete current `owner:` tags
query = (
session.query(TaggedObject.id)
@ -208,8 +211,9 @@ class ObjectUpdater:
# add `owner:` tags
cls._add_owners(session, target)
session.commit()
finally:
session.close()
@classmethod
def after_delete(
@ -220,6 +224,7 @@ class ObjectUpdater:
) -> None:
session = Session(bind=connection)
try:
# delete row from `tagged_objects`
session.query(TaggedObject).filter(
TaggedObject.object_type == cls.object_type,
@ -227,6 +232,8 @@ class ObjectUpdater:
).delete()
session.commit()
finally:
session.close()
class ChartUpdater(ObjectUpdater):
@ -267,6 +274,7 @@ class FavStarUpdater:
cls, _mapper: Mapper, connection: Connection, target: FavStar
) -> None:
session = Session(bind=connection)
try:
name = f"favorited_by:{target.user_id}"
tag = get_tag(name, session, TagTypes.favorited_by)
tagged_object = TaggedObject(
@ -275,14 +283,16 @@ class FavStarUpdater:
object_type=get_object_type(target.class_name),
)
session.add(tagged_object)
session.commit()
finally:
session.close()
@classmethod
def after_delete(
cls, _mapper: Mapper, connection: Connection, target: FavStar
) -> None:
session = Session(bind=connection)
try:
name = f"favorited_by:{target.user_id}"
query = (
session.query(TaggedObject.id)
@ -299,3 +309,5 @@ class FavStarUpdater:
)
session.commit()
finally:
session.close()

View File

@ -95,7 +95,11 @@ class DummyStrategy(Strategy): # pylint: disable=too-few-public-methods
def get_payloads(self) -> list[dict[str, int]]:
session = db.create_scoped_session()
try:
charts = session.query(Slice).all()
finally:
session.close()
return [get_payload(chart) for chart in charts]
@ -129,6 +133,7 @@ class TopNDashboardsStrategy(Strategy): # pylint: disable=too-few-public-method
payloads = []
session = db.create_scoped_session()
try:
records = (
session.query(Log.dashboard_id, func.count(Log.dashboard_id))
.filter(and_(Log.dashboard_id.isnot(None), Log.dttm >= self.since))
@ -138,11 +143,14 @@ class TopNDashboardsStrategy(Strategy): # pylint: disable=too-few-public-method
.all()
)
dash_ids = [record.dashboard_id for record in records]
dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
dashboards = (
session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
)
for dashboard in dashboards:
for chart in dashboard.slices:
payloads.append(get_payload(chart, dashboard))
finally:
session.close()
return payloads
@ -172,6 +180,7 @@ class DashboardTagsStrategy(Strategy): # pylint: disable=too-few-public-methods
payloads = []
session = db.create_scoped_session()
try:
tags = session.query(Tag).filter(Tag.name.in_(self.tags)).all()
tag_ids = [tag.id for tag in tags]
@ -187,7 +196,9 @@ class DashboardTagsStrategy(Strategy): # pylint: disable=too-few-public-methods
.all()
)
dash_ids = [tagged_object.object_id for tagged_object in tagged_objects]
tagged_dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids))
tagged_dashboards = session.query(Dashboard).filter(
Dashboard.id.in_(dash_ids)
)
for dashboard in tagged_dashboards:
for chart in dashboard.slices:
payloads.append(get_payload(chart))
@ -207,7 +218,8 @@ class DashboardTagsStrategy(Strategy): # pylint: disable=too-few-public-methods
tagged_charts = session.query(Slice).filter(Slice.id.in_(chart_ids))
for chart in tagged_charts:
payloads.append(get_payload(chart))
finally:
session.close()
return payloads