chore(sqlalchemy): Remove erroneous SQLAlchemy ORM session.merge operations (#24776)

This commit is contained in:
John Bodley 2023-11-20 17:25:41 -08:00 committed by GitHub
parent e7797b65d1
commit dd58b31cc4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
49 changed files with 34 additions and 82 deletions

View File

@ -60,9 +60,9 @@ def load_bart_lines(only_metadata: bool = False, force: bool = False) -> None:
tbl = db.session.query(table).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = table(table_name=tbl_name, schema=schema)
db.session.add(tbl)
tbl.description = "BART lines"
tbl.database = database
tbl.filter_select_enabled = True
db.session.merge(tbl)
db.session.commit()
tbl.fetch_metadata()

View File

@ -80,13 +80,13 @@ def load_country_map_data(only_metadata: bool = False, force: bool = False) -> N
obj = db.session.query(table).filter_by(table_name=tbl_name).first()
if not obj:
obj = table(table_name=tbl_name, schema=schema)
db.session.add(obj)
obj.main_dttm_col = "dttm"
obj.database = database
obj.filter_select_enabled = True
if not any(col.metric_name == "avg__2004" for col in obj.metrics):
col = str(column("2004").compile(db.engine))
obj.metrics.append(SqlMetric(metric_name="avg__2004", expression=f"AVG({col})"))
db.session.merge(obj)
db.session.commit()
obj.fetch_metadata()
tbl = obj

View File

@ -27,6 +27,7 @@ def load_css_templates() -> None:
obj = db.session.query(CssTemplate).filter_by(template_name="Flat").first()
if not obj:
obj = CssTemplate(template_name="Flat")
db.session.add(obj)
css = textwrap.dedent(
"""\
.navbar {
@ -51,12 +52,12 @@ def load_css_templates() -> None:
"""
)
obj.css = css
db.session.merge(obj)
db.session.commit()
obj = db.session.query(CssTemplate).filter_by(template_name="Courier Black").first()
if not obj:
obj = CssTemplate(template_name="Courier Black")
db.session.add(obj)
css = textwrap.dedent(
"""\
h2 {
@ -96,5 +97,4 @@ def load_css_templates() -> None:
"""
)
obj.css = css
db.session.merge(obj)
db.session.commit()

View File

@ -532,6 +532,7 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements
if not dash:
dash = Dashboard()
db.session.add(dash)
dash.published = True
js = POSITION_JSON
pos = json.loads(js)
@ -540,5 +541,4 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements
dash.dashboard_title = title
dash.slug = slug
dash.slices = slices
db.session.merge(dash)
db.session.commit()

View File

@ -66,6 +66,7 @@ def load_energy(
tbl = db.session.query(table).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = table(table_name=tbl_name, schema=schema)
db.session.add(tbl)
tbl.description = "Energy consumption"
tbl.database = database
tbl.filter_select_enabled = True
@ -76,7 +77,6 @@ def load_energy(
SqlMetric(metric_name="sum__value", expression=f"SUM({col})")
)
db.session.merge(tbl)
db.session.commit()
tbl.fetch_metadata()

View File

@ -63,10 +63,10 @@ def load_flights(only_metadata: bool = False, force: bool = False) -> None:
tbl = db.session.query(table).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = table(table_name=tbl_name, schema=schema)
db.session.add(tbl)
tbl.description = "Random set of flights in the US"
tbl.database = database
tbl.filter_select_enabled = True
db.session.merge(tbl)
db.session.commit()
tbl.fetch_metadata()
print("Done loading table!")

View File

@ -92,10 +92,10 @@ def load_long_lat_data(only_metadata: bool = False, force: bool = False) -> None
obj = db.session.query(table).filter_by(table_name=tbl_name).first()
if not obj:
obj = table(table_name=tbl_name, schema=schema)
db.session.add(obj)
obj.main_dttm_col = "datetime"
obj.database = database
obj.filter_select_enabled = True
db.session.merge(obj)
db.session.commit()
obj.fetch_metadata()
tbl = obj

View File

@ -34,6 +34,7 @@ def load_misc_dashboard() -> None:
if not dash:
dash = Dashboard()
db.session.add(dash)
js = textwrap.dedent(
"""\
{
@ -215,5 +216,4 @@ def load_misc_dashboard() -> None:
dash.position_json = json.dumps(pos, indent=4)
dash.slug = DASH_SLUG
dash.slices = slices
db.session.merge(dash)
db.session.commit()

View File

@ -82,6 +82,7 @@ def load_multiformat_time_series( # pylint: disable=too-many-locals
obj = db.session.query(table).filter_by(table_name=tbl_name).first()
if not obj:
obj = table(table_name=tbl_name, schema=schema)
db.session.add(obj)
obj.main_dttm_col = "ds"
obj.database = database
obj.filter_select_enabled = True
@ -100,7 +101,6 @@ def load_multiformat_time_series( # pylint: disable=too-many-locals
col.python_date_format = dttm_and_expr[0]
col.database_expression = dttm_and_expr[1]
col.is_dttm = True
db.session.merge(obj)
db.session.commit()
obj.fetch_metadata()
tbl = obj

View File

@ -57,9 +57,9 @@ def load_paris_iris_geojson(only_metadata: bool = False, force: bool = False) ->
tbl = db.session.query(table).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = table(table_name=tbl_name, schema=schema)
db.session.add(tbl)
tbl.description = "Map of Paris"
tbl.database = database
tbl.filter_select_enabled = True
db.session.merge(tbl)
db.session.commit()
tbl.fetch_metadata()

View File

@ -67,10 +67,10 @@ def load_random_time_series_data(
obj = db.session.query(table).filter_by(table_name=tbl_name).first()
if not obj:
obj = table(table_name=tbl_name, schema=schema)
db.session.add(obj)
obj.main_dttm_col = "ds"
obj.database = database
obj.filter_select_enabled = True
db.session.merge(obj)
db.session.commit()
obj.fetch_metadata()
tbl = obj

View File

@ -59,9 +59,9 @@ def load_sf_population_polygons(
tbl = db.session.query(table).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = table(table_name=tbl_name, schema=schema)
db.session.add(tbl)
tbl.description = "Population density of San Francisco"
tbl.database = database
tbl.filter_select_enabled = True
db.session.merge(tbl)
db.session.commit()
tbl.fetch_metadata()

View File

@ -33,6 +33,7 @@ def load_tabbed_dashboard(_: bool = False) -> None:
if not dash:
dash = Dashboard()
db.session.add(dash)
js = textwrap.dedent(
"""
@ -556,6 +557,4 @@ def load_tabbed_dashboard(_: bool = False) -> None:
dash.slices = slices
dash.dashboard_title = "Tabbed Dashboard"
dash.slug = slug
db.session.merge(dash)
db.session.commit()

View File

@ -87,6 +87,7 @@ def load_world_bank_health_n_pop( # pylint: disable=too-many-locals, too-many-s
tbl = db.session.query(table).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = table(table_name=tbl_name, schema=schema)
db.session.add(tbl)
tbl.description = utils.readfile(
os.path.join(get_examples_folder(), "countries.md")
)
@ -110,7 +111,6 @@ def load_world_bank_health_n_pop( # pylint: disable=too-many-locals, too-many-s
SqlMetric(metric_name=metric, expression=f"{aggr_func}({col})")
)
db.session.merge(tbl)
db.session.commit()
tbl.fetch_metadata()
@ -126,6 +126,7 @@ def load_world_bank_health_n_pop( # pylint: disable=too-many-locals, too-many-s
if not dash:
dash = Dashboard()
db.session.add(dash)
dash.published = True
pos = dashboard_positions
slices = update_slice_ids(pos)
@ -134,7 +135,6 @@ def load_world_bank_health_n_pop( # pylint: disable=too-many-locals, too-many-s
dash.position_json = json.dumps(pos, indent=4)
dash.slug = slug
dash.slices = slices
db.session.merge(dash)
db.session.commit()

View File

@ -84,7 +84,6 @@ class UpdateKeyValueCommand(BaseCommand):
entry.expires_on = self.expires_on
entry.changed_on = datetime.now()
entry.changed_by_fk = get_user_id()
db.session.merge(entry)
db.session.commit()
return Key(id=entry.id, uuid=entry.uuid)

View File

@ -88,7 +88,6 @@ class UpsertKeyValueCommand(BaseCommand):
entry.expires_on = self.expires_on
entry.changed_on = datetime.now()
entry.changed_by_fk = get_user_id()
db.session.merge(entry)
db.session.commit()
return Key(entry.id, entry.uuid)

View File

@ -123,7 +123,7 @@ class MigrateViz:
]
@classmethod
def upgrade_slice(cls, slc: Slice) -> Slice:
def upgrade_slice(cls, slc: Slice) -> None:
clz = cls(slc.params)
form_data_bak = copy.deepcopy(clz.data)
@ -141,10 +141,9 @@ class MigrateViz:
if "form_data" in (query_context := try_load_json(slc.query_context)):
query_context["form_data"] = clz.data
slc.query_context = json.dumps(query_context)
return slc
@classmethod
def downgrade_slice(cls, slc: Slice) -> Slice:
def downgrade_slice(cls, slc: Slice) -> None:
form_data = try_load_json(slc.params)
if "viz_type" in (form_data_bak := form_data.get(FORM_DATA_BAK_FIELD_NAME, {})):
slc.params = json.dumps(form_data_bak)
@ -153,7 +152,6 @@ class MigrateViz:
if "form_data" in query_context:
query_context["form_data"] = form_data_bak
slc.query_context = json.dumps(query_context)
return slc
@classmethod
def upgrade(cls, session: Session) -> None:
@ -162,8 +160,7 @@ class MigrateViz:
slices,
lambda current, total: print(f"Upgraded {current}/{total} charts"),
):
new_viz = cls.upgrade_slice(slc)
session.merge(new_viz)
cls.upgrade_slice(slc)
@classmethod
def downgrade(cls, session: Session) -> None:
@ -177,5 +174,4 @@ class MigrateViz:
slices,
lambda current, total: print(f"Downgraded {current}/{total} charts"),
):
new_viz = cls.downgrade_slice(slc)
session.merge(new_viz)
cls.downgrade_slice(slc)

View File

@ -243,7 +243,6 @@ def migrate_roles(
if new_pvm not in role.permissions:
logger.info(f"Add {new_pvm} to {role}")
role.permissions.append(new_pvm)
session.merge(role)
# Delete old permissions
_delete_old_permissions(session, pvm_map)

View File

@ -56,7 +56,6 @@ def upgrade():
for slc in session.query(Slice).all():
if slc.datasource:
slc.perm = slc.datasource.perm
session.merge(slc)
session.commit()
db.session.close()

View File

@ -56,7 +56,6 @@ def upgrade():
slc.datasource_id = slc.druid_datasource_id
if slc.table_id:
slc.datasource_id = slc.table_id
session.merge(slc)
session.commit()
session.close()
@ -69,7 +68,6 @@ def downgrade():
slc.druid_datasource_id = slc.datasource_id
if slc.datasource_type == "table":
slc.table_id = slc.datasource_id
session.merge(slc)
session.commit()
session.close()
op.drop_column("slices", "datasource_id")

View File

@ -57,7 +57,6 @@ def upgrade():
try:
d = json.loads(slc.params or "{}")
slc.params = json.dumps(d, indent=2, sort_keys=True)
session.merge(slc)
session.commit()
print(f"Upgraded ({i}/{slice_len}): {slc.slice_name}")
except Exception as ex:

View File

@ -80,7 +80,6 @@ def upgrade():
"/".join(split[:-1]) + "/?form_data=" + parse.quote_plus(json.dumps(d))
)
url.url = newurl
session.merge(url)
session.commit()
print(f"Updating url ({i}/{urls_len})")
session.close()

View File

@ -58,7 +58,6 @@ def upgrade():
del params["latitude"]
del params["longitude"]
slc.params = json.dumps(params)
session.merge(slc)
session.commit()
session.close()

View File

@ -69,7 +69,6 @@ def upgrade():
)
params["annotation_layers"] = new_layers
slc.params = json.dumps(params)
session.merge(slc)
session.commit()
session.close()
@ -86,6 +85,5 @@ def downgrade():
if layers:
params["annotation_layers"] = [layer["value"] for layer in layers]
slc.params = json.dumps(params)
session.merge(slc)
session.commit()
session.close()

View File

@ -62,7 +62,6 @@ def upgrade():
pos["v"] = 1
dashboard.position_json = json.dumps(positions, indent=2)
session.merge(dashboard)
session.commit()
session.close()
@ -85,6 +84,5 @@ def downgrade():
pos["v"] = 0
dashboard.position_json = json.dumps(positions, indent=2)
session.merge(dashboard)
session.commit()
pass

View File

@ -59,7 +59,6 @@ def upgrade():
params["metrics"] = [params.get("metric")]
del params["metric"]
slc.params = json.dumps(params, indent=2, sort_keys=True)
session.merge(slc)
session.commit()
print(f"Upgraded ({i}/{slice_len}): {slc.slice_name}")
except Exception as ex:

View File

@ -647,7 +647,6 @@ def upgrade():
sorted_by_key = collections.OrderedDict(sorted(v2_layout.items()))
dashboard.position_json = json.dumps(sorted_by_key, indent=2)
session.merge(dashboard)
session.commit()
else:
print(f"Skip converted dash_id: {dashboard.id}")

View File

@ -76,7 +76,6 @@ def upgrade():
dashboard.id, len(original_text), len(text)
)
)
session.merge(dashboard)
session.commit()

View File

@ -80,7 +80,6 @@ def upgrade():
dashboard.position_json = json.dumps(
layout, indent=None, separators=(",", ":"), sort_keys=True
)
session.merge(dashboard)
except Exception as ex:
logging.exception(ex)
@ -110,7 +109,6 @@ def downgrade():
dashboard.position_json = json.dumps(
layout, indent=None, separators=(",", ":"), sort_keys=True
)
session.merge(dashboard)
except Exception as ex:
logging.exception(ex)

View File

@ -99,8 +99,6 @@ def upgrade():
)
else:
dashboard.json_metadata = None
session.merge(dashboard)
except Exception as ex:
logging.exception(f"dashboard {dashboard.id} has error: {ex}")

View File

@ -163,7 +163,6 @@ def upgrade():
separators=(",", ":"),
sort_keys=True,
)
session.merge(dashboard)
# remove iframe, separator and markup charts
slices_to_remove = (

View File

@ -96,7 +96,6 @@ def update_position_json(dashboard, session, uuid_map):
del object_["meta"]["uuid"]
dashboard.position_json = json.dumps(layout, indent=4)
session.merge(dashboard)
def update_dashboards(session, uuid_map):

View File

@ -70,7 +70,6 @@ def upgrade():
slc.params = json.dumps(params)
slc.viz_type = "graph_chart"
session.merge(slc)
session.commit()
session.close()
@ -100,6 +99,5 @@ def downgrade():
slc.params = json.dumps(params)
slc.viz_type = "directed_force"
session.merge(slc)
session.commit()
session.close()

View File

@ -72,7 +72,6 @@ def upgrade():
new_conditional_formatting.append(formatter)
params["conditional_formatting"] = new_conditional_formatting
slc.params = json.dumps(params)
session.merge(slc)
session.commit()
session.close()

View File

@ -123,8 +123,6 @@ class BaseReportState:
self._report_schedule.last_state = state
self._report_schedule.last_eval_dttm = datetime.utcnow()
self._session.merge(self._report_schedule)
self._session.commit()
def create_log(self, error_message: Optional[str] = None) -> None:

View File

@ -876,7 +876,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
):
role_from_permissions.append(permission_view)
role_to.permissions = role_from_permissions
self.get_session.merge(role_to)
self.get_session.commit()
def set_role(
@ -898,7 +897,6 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
permission_view for permission_view in pvms if pvm_check(permission_view)
]
role.permissions = role_pvms
self.get_session.merge(role)
self.get_session.commit()
def _is_admin_only(self, pvm: PermissionView) -> bool:

View File

@ -1293,7 +1293,6 @@ def test_chart_cache_timeout(
slice_with_cache_timeout = load_energy_table_with_slice[0]
slice_with_cache_timeout.cache_timeout = 20
db.session.merge(slice_with_cache_timeout)
datasource: SqlaTable = (
db.session.query(SqlaTable)
@ -1301,7 +1300,6 @@ def test_chart_cache_timeout(
.first()
)
datasource.cache_timeout = 1254
db.session.merge(datasource)
db.session.commit()
@ -1331,7 +1329,6 @@ def test_chart_cache_timeout_not_present(
.first()
)
datasource.cache_timeout = 1980
db.session.merge(datasource)
db.session.commit()
rv = test_client.post(CHART_DATA_URI, json=physical_query_context)

View File

@ -326,7 +326,8 @@ def virtual_dataset():
TableColumn(column_name="col5", type="VARCHAR(255)", table=dataset)
SqlMetric(metric_name="count", expression="count(*)", table=dataset)
db.session.merge(dataset)
db.session.add(dataset)
db.session.commit()
yield dataset
@ -390,7 +391,7 @@ def physical_dataset():
table=dataset,
)
SqlMetric(metric_name="count", expression="count(*)", table=dataset)
db.session.merge(dataset)
db.session.add(dataset)
db.session.commit()
yield dataset
@ -425,7 +426,8 @@ def virtual_dataset_comma_in_column_value():
TableColumn(column_name="col2", type="VARCHAR(255)", table=dataset)
SqlMetric(metric_name="count", expression="count(*)", table=dataset)
db.session.merge(dataset)
db.session.add(dataset)
db.session.commit()
yield dataset

View File

@ -78,8 +78,8 @@ class TestDashboard(SupersetTestCase):
hidden_dash.slices = [slice]
hidden_dash.published = False
db.session.merge(published_dash)
db.session.merge(hidden_dash)
db.session.add(published_dash)
db.session.add(hidden_dash)
yield db.session.commit()
self.revoke_public_access_to_table(table)
@ -137,8 +137,6 @@ class TestDashboard(SupersetTestCase):
# Make the births dash published so it can be seen
births_dash = db.session.query(Dashboard).filter_by(slug="births").one()
births_dash.published = True
db.session.merge(births_dash)
db.session.commit()
# Try access before adding appropriate permissions.
@ -180,7 +178,6 @@ class TestDashboard(SupersetTestCase):
dash = db.session.query(Dashboard).filter_by(slug="births").first()
dash.owners = [security_manager.find_user("admin")]
dash.created_by = security_manager.find_user("admin")
db.session.merge(dash)
db.session.commit()
res: Response = self.client.get("/superset/dashboard/births/")

View File

@ -59,11 +59,11 @@ def create_table_metadata(
normalize_columns=False,
always_filter_main_dttm=False,
)
db.session.add(table)
if fetch_values_predicate:
table.fetch_values_predicate = fetch_values_predicate
table.database = database
table.description = table_description
db.session.merge(table)
db.session.commit()
return table

View File

@ -113,7 +113,6 @@ class TestDashboardDAO(SupersetTestCase):
data.update({"foo": "bar"})
DashboardDAO.set_dash_metadata(dashboard, data)
db.session.merge(dashboard)
db.session.commit()
new_changed_on = DashboardDAO.get_dashboard_changed_on(dashboard)
assert old_changed_on.replace(microsecond=0) < new_changed_on
@ -125,7 +124,6 @@ class TestDashboardDAO(SupersetTestCase):
)
DashboardDAO.set_dash_metadata(dashboard, original_data)
db.session.merge(dashboard)
db.session.commit()
@pytest.mark.usefixtures("load_world_bank_dashboard_with_slices")

View File

@ -110,12 +110,10 @@ def random_str():
def grant_access_to_dashboard(dashboard, role_name):
role = security_manager.find_role(role_name)
dashboard.roles.append(role)
db.session.merge(dashboard)
db.session.commit()
def revoke_access_to_dashboard(dashboard, role_name):
role = security_manager.find_role(role_name)
dashboard.roles.remove(role)
db.session.merge(dashboard)
db.session.commit()

View File

@ -61,8 +61,8 @@ class TestDashboardDatasetSecurity(DashboardTestCase):
hidden_dash.slices = [slice]
hidden_dash.published = False
db.session.merge(published_dash)
db.session.merge(hidden_dash)
db.session.add(published_dash)
db.session.add(hidden_dash)
yield db.session.commit()
self.revoke_public_access_to_table(table)

View File

@ -550,7 +550,6 @@ def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_data
table=virtual_dataset,
expression="INCORRECT SQL",
)
db.session.merge(virtual_dataset)
uri = (
f"/datasource/samples?datasource_id={virtual_dataset.id}&datasource_type=table"

View File

@ -82,8 +82,6 @@ def _create_energy_table() -> list[Slice]:
table.metrics.append(
SqlMetric(metric_name="sum__value", expression=f"SUM({col})")
)
db.session.merge(table)
db.session.commit()
table.fetch_metadata()
slices = []

View File

@ -68,7 +68,7 @@ def test_treemap_migrate(app_context: SupersetApp) -> None:
query_context=f'{{"form_data": {treemap_form_data}}}',
)
slc = MigrateTreeMap.upgrade_slice(slc)
MigrateTreeMap.upgrade_slice(slc)
assert slc.viz_type == MigrateTreeMap.target_viz_type
# verify form_data
new_form_data = json.loads(slc.params)
@ -84,7 +84,7 @@ def test_treemap_migrate(app_context: SupersetApp) -> None:
assert new_query_context["form_data"]["viz_type"] == "treemap_v2"
# downgrade
slc = MigrateTreeMap.downgrade_slice(slc)
MigrateTreeMap.downgrade_slice(slc)
assert slc.viz_type == MigrateTreeMap.source_viz_type
assert json.dumps(json.loads(slc.params), sort_keys=True) == json.dumps(
json.loads(treemap_form_data), sort_keys=True

View File

@ -1919,7 +1919,6 @@ def test_grace_period_error_flap(
# Change report_schedule to valid
create_invalid_sql_alert_email_chart.sql = "SELECT 1 AS metric"
create_invalid_sql_alert_email_chart.grace_period = 0
db.session.merge(create_invalid_sql_alert_email_chart)
db.session.commit()
with freeze_time("2020-01-01T00:31:00Z"):
@ -1936,7 +1935,6 @@ def test_grace_period_error_flap(
create_invalid_sql_alert_email_chart.sql = "SELECT 'first'"
create_invalid_sql_alert_email_chart.grace_period = 10
db.session.merge(create_invalid_sql_alert_email_chart)
db.session.commit()
# assert that after a success, when back to error we send the error notification

View File

@ -62,7 +62,6 @@ def create_old_role(pvm_map: PvmMigrationMapType, external_pvms):
db.session.query(Role).filter(Role.name == "Dummy Role").one_or_none()
)
new_role.permissions = []
db.session.merge(new_role)
for old_pvm, new_pvms in pvm_map.items():
security_manager.del_permission_view_menu(old_pvm.permission, old_pvm.view)
for new_pvm in new_pvms:

View File

@ -79,7 +79,7 @@ def migrate_and_assert(
)
# upgrade
slc = cls.upgrade_slice(slc)
cls.upgrade_slice(slc)
# verify form_data
new_form_data = json.loads(slc.params)
@ -91,6 +91,6 @@ def migrate_and_assert(
assert new_query_context["form_data"]["viz_type"] == cls.target_viz_type
# downgrade
slc = cls.downgrade_slice(slc)
cls.downgrade_slice(slc)
assert slc.viz_type == cls.source_viz_type
assert json.loads(slc.params) == source