diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py index f17e164822..15dd600c12 100644 --- a/superset/models/dashboard.py +++ b/superset/models/dashboard.py @@ -74,28 +74,31 @@ def copy_dashboard(_mapper: Mapper, connection: Connection, target: Dashboard) - session_class = sessionmaker(autoflush=False) session = session_class(bind=connection) - new_user = session.query(User).filter_by(id=target.id).first() - # copy template dashboard to user - template = session.query(Dashboard).filter_by(id=int(dashboard_id)).first() - dashboard = Dashboard( - dashboard_title=template.dashboard_title, - position_json=template.position_json, - description=template.description, - css=template.css, - json_metadata=template.json_metadata, - slices=template.slices, - owners=[new_user], - ) - session.add(dashboard) - session.commit() + try: + new_user = session.query(User).filter_by(id=target.id).first() - # set dashboard as the welcome dashboard - extra_attributes = UserAttribute( - user_id=target.id, welcome_dashboard_id=dashboard.id - ) - session.add(extra_attributes) - session.commit() + # copy template dashboard to user + template = session.query(Dashboard).filter_by(id=int(dashboard_id)).first() + dashboard = Dashboard( + dashboard_title=template.dashboard_title, + position_json=template.position_json, + description=template.description, + css=template.css, + json_metadata=template.json_metadata, + slices=template.slices, + owners=[new_user], + ) + session.add(dashboard) + + # set dashboard as the welcome dashboard + extra_attributes = UserAttribute( + user_id=target.id, welcome_dashboard_id=dashboard.id + ) + session.add(extra_attributes) + session.commit() + finally: + session.close() sqla.event.listen(User, "after_insert", copy_dashboard) @@ -414,13 +417,12 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin): "native_filter_configuration", [] ) for native_filter in native_filter_configuration: - session = db.session() for target in native_filter.get("targets", []): id_ = target.get("datasetId") if id_ is None: continue datasource = DatasourceDAO.get_datasource( - session, utils.DatasourceType.TABLE, id_ + db.session, utils.DatasourceType.TABLE, id_ ) datasource_ids.add((datasource.id, datasource.type)) diff --git a/superset/tags/models.py b/superset/tags/models.py index 7e350061fa..7825f283bf 100644 --- a/superset/tags/models.py +++ b/superset/tags/models.py @@ -170,17 +170,19 @@ class ObjectUpdater: ) -> None: session = Session(bind=connection) - # add `owner:` tags - cls._add_owners(session, target) + try: + # add `owner:` tags + cls._add_owners(session, target) - # add `type:` tags - tag = get_tag(f"type:{cls.object_type}", session, TagTypes.type) - tagged_object = TaggedObject( - tag_id=tag.id, object_id=target.id, object_type=cls.object_type - ) - session.add(tagged_object) - - session.commit() + # add `type:` tags + tag = get_tag(f"type:{cls.object_type}", session, TagTypes.type) + tagged_object = TaggedObject( + tag_id=tag.id, object_id=target.id, object_type=cls.object_type + ) + session.add(tagged_object) + session.commit() + finally: + session.close() @classmethod def after_update( @@ -191,25 +193,27 @@ class ObjectUpdater: ) -> None: session = Session(bind=connection) - # delete current `owner:` tags - query = ( - session.query(TaggedObject.id) - .join(Tag) - .filter( - TaggedObject.object_type == cls.object_type, - TaggedObject.object_id == target.id, - Tag.type == TagTypes.owner, + try: + # delete current `owner:` tags + query = ( + session.query(TaggedObject.id) + .join(Tag) + .filter( + TaggedObject.object_type == cls.object_type, + TaggedObject.object_id == target.id, + Tag.type == TagTypes.owner, + ) + ) + ids = [row[0] for row in query] + session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete( + synchronize_session=False ) - ) - ids = [row[0] for row in query] - session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete( - synchronize_session=False - ) - # add `owner:` tags - cls._add_owners(session, target) - - session.commit() + # add `owner:` tags + cls._add_owners(session, target) + session.commit() + finally: + session.close() @classmethod def after_delete( @@ -220,13 +224,16 @@ class ObjectUpdater: ) -> None: session = Session(bind=connection) - # delete row from `tagged_objects` - session.query(TaggedObject).filter( - TaggedObject.object_type == cls.object_type, - TaggedObject.object_id == target.id, - ).delete() + try: + # delete row from `tagged_objects` + session.query(TaggedObject).filter( + TaggedObject.object_type == cls.object_type, + TaggedObject.object_id == target.id, + ).delete() - session.commit() + session.commit() + finally: + session.close() class ChartUpdater(ObjectUpdater): @@ -267,35 +274,40 @@ class FavStarUpdater: cls, _mapper: Mapper, connection: Connection, target: FavStar ) -> None: session = Session(bind=connection) - name = f"favorited_by:{target.user_id}" - tag = get_tag(name, session, TagTypes.favorited_by) - tagged_object = TaggedObject( - tag_id=tag.id, - object_id=target.obj_id, - object_type=get_object_type(target.class_name), - ) - session.add(tagged_object) - - session.commit() + try: + name = f"favorited_by:{target.user_id}" + tag = get_tag(name, session, TagTypes.favorited_by) + tagged_object = TaggedObject( + tag_id=tag.id, + object_id=target.obj_id, + object_type=get_object_type(target.class_name), + ) + session.add(tagged_object) + session.commit() + finally: + session.close() @classmethod def after_delete( cls, _mapper: Mapper, connection: Connection, target: FavStar ) -> None: session = Session(bind=connection) - name = f"favorited_by:{target.user_id}" - query = ( - session.query(TaggedObject.id) - .join(Tag) - .filter( - TaggedObject.object_id == target.obj_id, - Tag.type == TagTypes.favorited_by, - Tag.name == name, + try: + name = f"favorited_by:{target.user_id}" + query = ( + session.query(TaggedObject.id) + .join(Tag) + .filter( + TaggedObject.object_id == target.obj_id, + Tag.type == TagTypes.favorited_by, + Tag.name == name, + ) + ) + ids = [row[0] for row in query] + session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete( + synchronize_session=False ) - ) - ids = [row[0] for row in query] - session.query(TaggedObject).filter(TaggedObject.id.in_(ids)).delete( - synchronize_session=False - ) - session.commit() + session.commit() + finally: + session.close() diff --git a/superset/tasks/cache.py b/superset/tasks/cache.py index 534b00f94d..569797ba27 100644 --- a/superset/tasks/cache.py +++ b/superset/tasks/cache.py @@ -95,7 +95,11 @@ class DummyStrategy(Strategy): # pylint: disable=too-few-public-methods def get_payloads(self) -> list[dict[str, int]]: session = db.create_scoped_session() - charts = session.query(Slice).all() + + try: + charts = session.query(Slice).all() + finally: + session.close() return [get_payload(chart) for chart in charts] @@ -129,20 +133,24 @@ class TopNDashboardsStrategy(Strategy): # pylint: disable=too-few-public-method payloads = [] session = db.create_scoped_session() - records = ( - session.query(Log.dashboard_id, func.count(Log.dashboard_id)) - .filter(and_(Log.dashboard_id.isnot(None), Log.dttm >= self.since)) - .group_by(Log.dashboard_id) - .order_by(func.count(Log.dashboard_id).desc()) - .limit(self.top_n) - .all() - ) - dash_ids = [record.dashboard_id for record in records] - dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all() - for dashboard in dashboards: - for chart in dashboard.slices: - payloads.append(get_payload(chart, dashboard)) - + try: + records = ( + session.query(Log.dashboard_id, func.count(Log.dashboard_id)) + .filter(and_(Log.dashboard_id.isnot(None), Log.dttm >= self.since)) + .group_by(Log.dashboard_id) + .order_by(func.count(Log.dashboard_id).desc()) + .limit(self.top_n) + .all() + ) + dash_ids = [record.dashboard_id for record in records] + dashboards = ( + session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all() + ) + for dashboard in dashboards: + for chart in dashboard.slices: + payloads.append(get_payload(chart, dashboard)) + finally: + session.close() return payloads @@ -172,42 +180,46 @@ class DashboardTagsStrategy(Strategy): # pylint: disable=too-few-public-methods payloads = [] session = db.create_scoped_session() - tags = session.query(Tag).filter(Tag.name.in_(self.tags)).all() - tag_ids = [tag.id for tag in tags] + try: + tags = session.query(Tag).filter(Tag.name.in_(self.tags)).all() + tag_ids = [tag.id for tag in tags] - # add dashboards that are tagged - tagged_objects = ( - session.query(TaggedObject) - .filter( - and_( - TaggedObject.object_type == "dashboard", - TaggedObject.tag_id.in_(tag_ids), + # add dashboards that are tagged + tagged_objects = ( + session.query(TaggedObject) + .filter( + and_( + TaggedObject.object_type == "dashboard", + TaggedObject.tag_id.in_(tag_ids), + ) ) + .all() ) - .all() - ) - dash_ids = [tagged_object.object_id for tagged_object in tagged_objects] - tagged_dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)) - for dashboard in tagged_dashboards: - for chart in dashboard.slices: + dash_ids = [tagged_object.object_id for tagged_object in tagged_objects] + tagged_dashboards = session.query(Dashboard).filter( + Dashboard.id.in_(dash_ids) + ) + for dashboard in tagged_dashboards: + for chart in dashboard.slices: + payloads.append(get_payload(chart)) + + # add charts that are tagged + tagged_objects = ( + session.query(TaggedObject) + .filter( + and_( + TaggedObject.object_type == "chart", + TaggedObject.tag_id.in_(tag_ids), + ) + ) + .all() + ) + chart_ids = [tagged_object.object_id for tagged_object in tagged_objects] + tagged_charts = session.query(Slice).filter(Slice.id.in_(chart_ids)) + for chart in tagged_charts: payloads.append(get_payload(chart)) - - # add charts that are tagged - tagged_objects = ( - session.query(TaggedObject) - .filter( - and_( - TaggedObject.object_type == "chart", - TaggedObject.tag_id.in_(tag_ids), - ) - ) - .all() - ) - chart_ids = [tagged_object.object_id for tagged_object in tagged_objects] - tagged_charts = session.query(Slice).filter(Slice.id.in_(chart_ids)) - for chart in tagged_charts: - payloads.append(get_payload(chart)) - + finally: + session.close() return payloads