fix: Ensure SQLAlchemy sessions are closed (#25031)

This commit is contained in:
John Bodley 2023-08-23 11:57:36 -07:00 committed by GitHub
parent 0dadf06245
commit adaab3550c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 152 additions and 126 deletions

View File

@ -74,28 +74,31 @@ def copy_dashboard(_mapper: Mapper, connection: Connection, target: Dashboard) -
session_class = sessionmaker(autoflush=False)
session = session_class(bind=connection)
new_user = session.query(User).filter_by(id=target.id).first()
# copy template dashboard to user
template = session.query(Dashboard).filter_by(id=int(dashboard_id)).first()
dashboard = Dashboard(
dashboard_title=template.dashboard_title,
position_json=template.position_json,
description=template.description,
css=template.css,
json_metadata=template.json_metadata,
slices=template.slices,
owners=[new_user],
)
session.add(dashboard)
session.commit()
try:
new_user = session.query(User).filter_by(id=target.id).first()
# set dashboard as the welcome dashboard
extra_attributes = UserAttribute(
user_id=target.id, welcome_dashboard_id=dashboard.id
)
session.add(extra_attributes)
session.commit()
# copy template dashboard to user
template = session.query(Dashboard).filter_by(id=int(dashboard_id)).first()
dashboard = Dashboard(
dashboard_title=template.dashboard_title,
position_json=template.position_json,
description=template.description,
css=template.css,
json_metadata=template.json_metadata,
slices=template.slices,
owners=[new_user],
)
session.add(dashboard)
# set dashboard as the welcome dashboard
extra_attributes = UserAttribute(
user_id=target.id, welcome_dashboard_id=dashboard.id
)
session.add(extra_attributes)
session.commit()
finally:
session.close()
sqla.event.listen(User, "after_insert", copy_dashboard)
@ -414,13 +417,12 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin):
"native_filter_configuration", []
)
for native_filter in native_filter_configuration:
session = db.session()
for target in native_filter.get("targets", []):
id_ = target.get("datasetId")
if id_ is None:
continue
datasource = DatasourceDAO.get_datasource(
session, utils.DatasourceType.TABLE, id_
db.session, utils.DatasourceType.TABLE, id_
)
datasource_ids.add((datasource.id, datasource.type))

View File

@ -170,17 +170,19 @@ class ObjectUpdater:
) -> None:
session = Session(bind=connection)
# add `owner:` tags
cls._add_owners(session, target)
try:
# add `owner:` tags
cls._add_owners(session, target)
# add `type:` tags
tag = get_tag(f"type:{cls.object_type}", session, TagTypes.type)
tagged_object = TaggedObject(
tag_id=tag.id, object_id=target.id, object_type=cls.object_type
)
session.add(tagged_object)
session.commit()
# add `type:` tags
tag = get_tag(f"type:{cls.object_type}", session, TagTypes.type)
tagged_object = TaggedObject(
tag_id=tag.id, object_id=target.id, object_type=cls.object_type
)
session.add(tagged_object)
session.commit()
finally:
session.close()
@classmethod
def after_update(
@ -191,25 +193,27 @@ class ObjectUpdater:
) -> None:
session = Session(bind=connection)
# delete current `owner:` tags
query = (
session.query(TaggedObject.id)
.join(Tag)
.filter(
TaggedObject.object_type == cls.object_type,
TaggedObject.object_id == target.id,
Tag.type == TagTypes.owner,
try:
# delete current `owner:` tags
query = (
session.query(TaggedObject.id)
.join(Tag)
.filter(
TaggedObject.object_type == cls.object_type,
TaggedObject.object_id == target.id,
Tag.type == TagTypes.owner,
)
)
ids = [row[0] for row in query]
session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
synchronize_session=False
)
)
ids = [row[0] for row in query]
session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
synchronize_session=False
)
# add `owner:` tags
cls._add_owners(session, target)
session.commit()
# add `owner:` tags
cls._add_owners(session, target)
session.commit()
finally:
session.close()
@classmethod
def after_delete(
@ -220,13 +224,16 @@ class ObjectUpdater:
) -> None:
session = Session(bind=connection)
# delete row from `tagged_objects`
session.query(TaggedObject).filter(
TaggedObject.object_type == cls.object_type,
TaggedObject.object_id == target.id,
).delete()
try:
# delete row from `tagged_objects`
session.query(TaggedObject).filter(
TaggedObject.object_type == cls.object_type,
TaggedObject.object_id == target.id,
).delete()
session.commit()
session.commit()
finally:
session.close()
class ChartUpdater(ObjectUpdater):
@ -267,35 +274,40 @@ class FavStarUpdater:
cls, _mapper: Mapper, connection: Connection, target: FavStar
) -> None:
session = Session(bind=connection)
name = f"favorited_by:{target.user_id}"
tag = get_tag(name, session, TagTypes.favorited_by)
tagged_object = TaggedObject(
tag_id=tag.id,
object_id=target.obj_id,
object_type=get_object_type(target.class_name),
)
session.add(tagged_object)
session.commit()
try:
name = f"favorited_by:{target.user_id}"
tag = get_tag(name, session, TagTypes.favorited_by)
tagged_object = TaggedObject(
tag_id=tag.id,
object_id=target.obj_id,
object_type=get_object_type(target.class_name),
)
session.add(tagged_object)
session.commit()
finally:
session.close()
@classmethod
def after_delete(
cls, _mapper: Mapper, connection: Connection, target: FavStar
) -> None:
session = Session(bind=connection)
name = f"favorited_by:{target.user_id}"
query = (
session.query(TaggedObject.id)
.join(Tag)
.filter(
TaggedObject.object_id == target.obj_id,
Tag.type == TagTypes.favorited_by,
Tag.name == name,
try:
name = f"favorited_by:{target.user_id}"
query = (
session.query(TaggedObject.id)
.join(Tag)
.filter(
TaggedObject.object_id == target.obj_id,
Tag.type == TagTypes.favorited_by,
Tag.name == name,
)
)
ids = [row[0] for row in query]
session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
synchronize_session=False
)
)
ids = [row[0] for row in query]
session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete(
synchronize_session=False
)
session.commit()
session.commit()
finally:
session.close()

View File

@ -95,7 +95,11 @@ class DummyStrategy(Strategy): # pylint: disable=too-few-public-methods
def get_payloads(self) -> list[dict[str, int]]:
session = db.create_scoped_session()
charts = session.query(Slice).all()
try:
charts = session.query(Slice).all()
finally:
session.close()
return [get_payload(chart) for chart in charts]
@ -129,20 +133,24 @@ class TopNDashboardsStrategy(Strategy): # pylint: disable=too-few-public-method
payloads = []
session = db.create_scoped_session()
records = (
session.query(Log.dashboard_id, func.count(Log.dashboard_id))
.filter(and_(Log.dashboard_id.isnot(None), Log.dttm >= self.since))
.group_by(Log.dashboard_id)
.order_by(func.count(Log.dashboard_id).desc())
.limit(self.top_n)
.all()
)
dash_ids = [record.dashboard_id for record in records]
dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
for dashboard in dashboards:
for chart in dashboard.slices:
payloads.append(get_payload(chart, dashboard))
try:
records = (
session.query(Log.dashboard_id, func.count(Log.dashboard_id))
.filter(and_(Log.dashboard_id.isnot(None), Log.dttm >= self.since))
.group_by(Log.dashboard_id)
.order_by(func.count(Log.dashboard_id).desc())
.limit(self.top_n)
.all()
)
dash_ids = [record.dashboard_id for record in records]
dashboards = (
session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
)
for dashboard in dashboards:
for chart in dashboard.slices:
payloads.append(get_payload(chart, dashboard))
finally:
session.close()
return payloads
@ -172,42 +180,46 @@ class DashboardTagsStrategy(Strategy): # pylint: disable=too-few-public-methods
payloads = []
session = db.create_scoped_session()
tags = session.query(Tag).filter(Tag.name.in_(self.tags)).all()
tag_ids = [tag.id for tag in tags]
try:
tags = session.query(Tag).filter(Tag.name.in_(self.tags)).all()
tag_ids = [tag.id for tag in tags]
# add dashboards that are tagged
tagged_objects = (
session.query(TaggedObject)
.filter(
and_(
TaggedObject.object_type == "dashboard",
TaggedObject.tag_id.in_(tag_ids),
# add dashboards that are tagged
tagged_objects = (
session.query(TaggedObject)
.filter(
and_(
TaggedObject.object_type == "dashboard",
TaggedObject.tag_id.in_(tag_ids),
)
)
.all()
)
.all()
)
dash_ids = [tagged_object.object_id for tagged_object in tagged_objects]
tagged_dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids))
for dashboard in tagged_dashboards:
for chart in dashboard.slices:
dash_ids = [tagged_object.object_id for tagged_object in tagged_objects]
tagged_dashboards = session.query(Dashboard).filter(
Dashboard.id.in_(dash_ids)
)
for dashboard in tagged_dashboards:
for chart in dashboard.slices:
payloads.append(get_payload(chart))
# add charts that are tagged
tagged_objects = (
session.query(TaggedObject)
.filter(
and_(
TaggedObject.object_type == "chart",
TaggedObject.tag_id.in_(tag_ids),
)
)
.all()
)
chart_ids = [tagged_object.object_id for tagged_object in tagged_objects]
tagged_charts = session.query(Slice).filter(Slice.id.in_(chart_ids))
for chart in tagged_charts:
payloads.append(get_payload(chart))
# add charts that are tagged
tagged_objects = (
session.query(TaggedObject)
.filter(
and_(
TaggedObject.object_type == "chart",
TaggedObject.tag_id.in_(tag_ids),
)
)
.all()
)
chart_ids = [tagged_object.object_id for tagged_object in tagged_objects]
tagged_charts = session.query(Slice).filter(Slice.id.in_(chart_ids))
for chart in tagged_charts:
payloads.append(get_payload(chart))
finally:
session.close()
return payloads