From d583ca9ef57d1c49ae84cda8cc888ee01dcf5601 Mon Sep 17 00:00:00 2001 From: John Bodley <4567245+john-bodley@users.noreply.github.com> Date: Fri, 19 May 2023 00:37:13 -0700 Subject: [PATCH] chore: Embrace the walrus operator (#24127) --- .pre-commit-config.yaml | 4 ++ superset/charts/api.py | 6 +-- superset/charts/commands/bulk_delete.py | 3 +- superset/charts/commands/delete.py | 3 +- superset/common/query_actions.py | 3 +- superset/common/query_context_factory.py | 3 +- superset/common/query_context_processor.py | 3 +- superset/common/utils/query_cache_manager.py | 3 +- superset/connectors/sqla/models.py | 3 +- superset/dashboards/commands/bulk_delete.py | 3 +- superset/dashboards/commands/delete.py | 3 +- superset/dashboards/dao.py | 3 +- superset/databases/api.py | 3 +- superset/databases/commands/delete.py | 3 +- .../databases/commands/test_connection.py | 3 +- superset/databases/commands/validate.py | 3 +- superset/datasets/api.py | 3 +- .../datasets/commands/importers/v1/utils.py | 3 +- superset/datasets/commands/update.py | 6 +-- superset/db_engine_specs/base.py | 6 +-- superset/db_engine_specs/bigquery.py | 3 +- superset/db_engine_specs/databricks.py | 3 +- superset/db_engine_specs/presto.py | 6 +-- superset/db_engine_specs/snowflake.py | 3 +- superset/db_engine_specs/trino.py | 6 +-- superset/errors.py | 3 +- superset/initialization/__init__.py | 3 +- superset/migrations/env.py | 5 +-- .../migrations/shared/migrate_viz/base.py | 6 +-- .../shared/migrate_viz/processors.py | 6 +-- ...f3fed1fe_convert_dashboard_v1_positions.py | 3 +- ...95_migrate_native_filters_to_new_schema.py | 6 +-- ...-25_31b2a1039d4a_drop_tables_constraint.py | 3 +- superset/models/core.py | 3 +- superset/models/helpers.py | 3 +- superset/models/slice.py | 3 +- superset/queries/saved_queries/api.py | 3 +- superset/reports/commands/execute.py | 3 +- superset/result_set.py | 6 +-- superset/security/manager.py | 12 ++--- superset/sql_lab.py | 3 +- superset/sql_validators/presto_db.py | 3 +- superset/utils/core.py | 3 +- superset/utils/screenshots.py | 3 +- superset/views/api.py | 3 +- superset/views/base_api.py | 6 +-- superset/views/core.py | 45 +++++++------------ superset/views/database/validators.py | 3 +- superset/views/log/api.py | 3 +- superset/views/utils.py | 7 ++- superset/viz.py | 26 +++++------ .../data_loading/pandas/pandas_data_loader.py | 3 +- tests/integration_tests/csv_upload_tests.py | 9 ++-- tests/integration_tests/datasets/api_tests.py | 9 ++-- 54 files changed, 100 insertions(+), 185 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2ff79da9e9..3f524b3658 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,6 +19,10 @@ repos: rev: 5.12.0 hooks: - id: isort + - repo: https://github.com/MarcoGorelli/auto-walrus + rev: v0.2.2 + hooks: + - id: auto-walrus - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.3.0 hooks: diff --git a/superset/charts/api.py b/superset/charts/api.py index 6a4bf04aa1..2c50a8d163 100644 --- a/superset/charts/api.py +++ b/superset/charts/api.py @@ -649,8 +649,7 @@ class ChartRestApi(BaseSupersetModelRestApi): return self.response_404() # fetch the chart screenshot using the current user and cache if set - img = ChartScreenshot.get_from_cache_key(thumbnail_cache, digest) - if img: + if img := ChartScreenshot.get_from_cache_key(thumbnail_cache, digest): return Response( FileWrapper(img), mimetype="image/png", direct_passthrough=True ) @@ -783,7 +782,6 @@ class ChartRestApi(BaseSupersetModelRestApi): 500: $ref: '#/components/responses/500' """ - token = request.args.get("token") requested_ids = kwargs["rison"] timestamp = datetime.now().strftime("%Y%m%dT%H%M%S") root = f"chart_export_{timestamp}" @@ -805,7 +803,7 @@ class ChartRestApi(BaseSupersetModelRestApi): as_attachment=True, download_name=filename, ) - if token: + if token := request.args.get("token"): response.set_cookie(token, "done", max_age=600) return response diff --git a/superset/charts/commands/bulk_delete.py b/superset/charts/commands/bulk_delete.py index caf8fe0399..c252f0be4c 100644 --- a/superset/charts/commands/bulk_delete.py +++ b/superset/charts/commands/bulk_delete.py @@ -55,8 +55,7 @@ class BulkDeleteChartCommand(BaseCommand): if not self._models or len(self._models) != len(self._model_ids): raise ChartNotFoundError() # Check there are no associated ReportSchedules - reports = ReportScheduleDAO.find_by_chart_ids(self._model_ids) - if reports: + if reports := ReportScheduleDAO.find_by_chart_ids(self._model_ids): report_names = [report.name for report in reports] raise ChartBulkDeleteFailedReportsExistError( _("There are associated alerts or reports: %s" % ",".join(report_names)) diff --git a/superset/charts/commands/delete.py b/superset/charts/commands/delete.py index 4c636f0433..11f6e59257 100644 --- a/superset/charts/commands/delete.py +++ b/superset/charts/commands/delete.py @@ -64,8 +64,7 @@ class DeleteChartCommand(BaseCommand): if not self._model: raise ChartNotFoundError() # Check there are no associated ReportSchedules - reports = ReportScheduleDAO.find_by_chart_id(self._model_id) - if reports: + if reports := ReportScheduleDAO.find_by_chart_id(self._model_id): report_names = [report.name for report in reports] raise ChartDeleteFailedReportsExistError( _("There are associated alerts or reports: %s" % ",".join(report_names)) diff --git a/superset/common/query_actions.py b/superset/common/query_actions.py index 38526475b9..f6f5a5cd62 100644 --- a/superset/common/query_actions.py +++ b/superset/common/query_actions.py @@ -221,8 +221,7 @@ def get_query_results( :raises QueryObjectValidationError: if an unsupported result type is requested :return: JSON serializable result payload """ - result_func = _result_type_functions.get(result_type) - if result_func: + if result_func := _result_type_functions.get(result_type): return result_func(query_context, query_obj, force_cached) raise QueryObjectValidationError( _("Invalid result type: %(result_type)s", result_type=result_type) diff --git a/superset/common/query_context_factory.py b/superset/common/query_context_factory.py index a42d1d4ba7..84c0415722 100644 --- a/superset/common/query_context_factory.py +++ b/superset/common/query_context_factory.py @@ -125,10 +125,9 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods for column in datasource.columns if (column["is_dttm"] if isinstance(column, dict) else column.is_dttm) } - granularity = query_object.granularity x_axis = form_data and form_data.get("x_axis") - if granularity: + if granularity := query_object.granularity: filter_to_remove = None if x_axis and x_axis in temporal_columns: filter_to_remove = x_axis diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py index 56f07dcb64..85a2b5d97a 100644 --- a/superset/common/query_context_processor.py +++ b/superset/common/query_context_processor.py @@ -500,8 +500,7 @@ class QueryContextProcessor: return return_value def get_cache_timeout(self) -> int: - cache_timeout_rv = self._query_context.get_cache_timeout() - if cache_timeout_rv: + if cache_timeout_rv := self._query_context.get_cache_timeout(): return cache_timeout_rv if ( data_cache_timeout := config["DATA_CACHE_CONFIG"].get( diff --git a/superset/common/utils/query_cache_manager.py b/superset/common/utils/query_cache_manager.py index 7143fcc201..6c1b268f46 100644 --- a/superset/common/utils/query_cache_manager.py +++ b/superset/common/utils/query_cache_manager.py @@ -148,8 +148,7 @@ class QueryCacheManager: if not key or not _cache[region] or force_query: return query_cache - cache_value = _cache[region].get(key) - if cache_value: + if cache_value := _cache[region].get(key): logger.debug("Cache key: %s", key) stats_logger.incr("loading_from_cache") try: diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index e339f7b1f4..5f487f60f6 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -993,11 +993,10 @@ class SqlaTable( schema=self.schema, template_processor=template_processor, ) - col_in_metadata = self.get_column(expression) time_grain = col.get("timeGrain") has_timegrain = col.get("columnType") == "BASE_AXIS" and time_grain is_dttm = False - if col_in_metadata: + if col_in_metadata := self.get_column(expression): sqla_column = col_in_metadata.get_sqla_col( template_processor=template_processor ) diff --git a/superset/dashboards/commands/bulk_delete.py b/superset/dashboards/commands/bulk_delete.py index 52f5998438..13541cd946 100644 --- a/superset/dashboards/commands/bulk_delete.py +++ b/superset/dashboards/commands/bulk_delete.py @@ -56,8 +56,7 @@ class BulkDeleteDashboardCommand(BaseCommand): if not self._models or len(self._models) != len(self._model_ids): raise DashboardNotFoundError() # Check there are no associated ReportSchedules - reports = ReportScheduleDAO.find_by_dashboard_ids(self._model_ids) - if reports: + if reports := ReportScheduleDAO.find_by_dashboard_ids(self._model_ids): report_names = [report.name for report in reports] raise DashboardBulkDeleteFailedReportsExistError( _("There are associated alerts or reports: %s" % ",".join(report_names)) diff --git a/superset/dashboards/commands/delete.py b/superset/dashboards/commands/delete.py index 7af2fdf4ce..8ce7cb0cbf 100644 --- a/superset/dashboards/commands/delete.py +++ b/superset/dashboards/commands/delete.py @@ -57,8 +57,7 @@ class DeleteDashboardCommand(BaseCommand): if not self._model: raise DashboardNotFoundError() # Check there are no associated ReportSchedules - reports = ReportScheduleDAO.find_by_dashboard_id(self._model_id) - if reports: + if reports := ReportScheduleDAO.find_by_dashboard_id(self._model_id): report_names = [report.name for report in reports] raise DashboardDeleteFailedReportsExistError( _("There are associated alerts or reports: %s" % ",".join(report_names)) diff --git a/superset/dashboards/dao.py b/superset/dashboards/dao.py index 89fca4619a..5355d602be 100644 --- a/superset/dashboards/dao.py +++ b/superset/dashboards/dao.py @@ -200,11 +200,10 @@ class DashboardDAO(BaseDAO): old_to_new_slice_ids: Optional[Dict[int, int]] = None, commit: bool = False, ) -> Dashboard: - positions = data.get("positions") new_filter_scopes = {} md = dashboard.params_dict - if positions is not None: + if (positions := data.get("positions")) is not None: # find slices in the position data slice_ids = [ value.get("meta", {}).get("chartId") diff --git a/superset/databases/api.py b/superset/databases/api.py index 4997edc073..8e444a84d8 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -1036,7 +1036,6 @@ class DatabaseRestApi(BaseSupersetModelRestApi): 500: $ref: '#/components/responses/500' """ - token = request.args.get("token") requested_ids = kwargs["rison"] timestamp = datetime.now().strftime("%Y%m%dT%H%M%S") root = f"database_export_{timestamp}" @@ -1060,7 +1059,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi): as_attachment=True, download_name=filename, ) - if token: + if token := request.args.get("token"): response.set_cookie(token, "done", max_age=600) return response diff --git a/superset/databases/commands/delete.py b/superset/databases/commands/delete.py index ebdd543570..825b126218 100644 --- a/superset/databases/commands/delete.py +++ b/superset/databases/commands/delete.py @@ -55,9 +55,8 @@ class DeleteDatabaseCommand(BaseCommand): if not self._model: raise DatabaseNotFoundError() # Check there are no associated ReportSchedules - reports = ReportScheduleDAO.find_by_database_id(self._model_id) - if reports: + if reports := ReportScheduleDAO.find_by_database_id(self._model_id): report_names = [report.name for report in reports] raise DatabaseDeleteFailedReportsExistError( _("There are associated alerts or reports: %s" % ",".join(report_names)) diff --git a/superset/databases/commands/test_connection.py b/superset/databases/commands/test_connection.py index cbc1240905..9809641d5c 100644 --- a/superset/databases/commands/test_connection.py +++ b/superset/databases/commands/test_connection.py @@ -228,6 +228,5 @@ class TestConnectionDatabaseCommand(BaseCommand): raise DatabaseTestConnectionUnexpectedError(errors) from ex def validate(self) -> None: - database_name = self._properties.get("database_name") - if database_name is not None: + if (database_name := self._properties.get("database_name")) is not None: self._model = DatabaseDAO.get_database_by_name(database_name) diff --git a/superset/databases/commands/validate.py b/superset/databases/commands/validate.py index 8c58ef5de0..2a624e32c7 100644 --- a/superset/databases/commands/validate.py +++ b/superset/databases/commands/validate.py @@ -128,6 +128,5 @@ class ValidateDatabaseParametersCommand(BaseCommand): ) def validate(self) -> None: - database_id = self._properties.get("id") - if database_id is not None: + if (database_id := self._properties.get("id")) is not None: self._model = DatabaseDAO.find_by_id(database_id) diff --git a/superset/datasets/api.py b/superset/datasets/api.py index d52e622793..6568ba3793 100644 --- a/superset/datasets/api.py +++ b/superset/datasets/api.py @@ -977,8 +977,7 @@ class DatasetRestApi(BaseSupersetModelRestApi): return self.response(400, message=ex.messages) table_name = body["table_name"] database_id = body["database_id"] - table = DatasetDAO.get_table_by_name(database_id, table_name) - if table: + if table := DatasetDAO.get_table_by_name(database_id, table_name): return self.response(200, result={"table_id": table.id}) body["database"] = database_id diff --git a/superset/datasets/commands/importers/v1/utils.py b/superset/datasets/commands/importers/v1/utils.py index 2df85cdfa2..52f46829b5 100644 --- a/superset/datasets/commands/importers/v1/utils.py +++ b/superset/datasets/commands/importers/v1/utils.py @@ -62,8 +62,7 @@ def get_sqla_type(native_type: str) -> VisitableType: if native_type.upper() in type_map: return type_map[native_type.upper()] - match = VARCHAR.match(native_type) - if match: + if match := VARCHAR.match(native_type): size = int(match.group(1)) return String(size) diff --git a/superset/datasets/commands/update.py b/superset/datasets/commands/update.py index a2e483ba93..b6bf1256d1 100644 --- a/superset/datasets/commands/update.py +++ b/superset/datasets/commands/update.py @@ -114,13 +114,11 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand): exceptions.append(DatasetEndpointUnsafeValidationError()) # Validate columns - columns = self._properties.get("columns") - if columns: + if columns := self._properties.get("columns"): self._validate_columns(columns, exceptions) # Validate metrics - metrics = self._properties.get("metrics") - if metrics: + if metrics := self._properties.get("metrics"): self._validate_metrics(metrics, exceptions) if exceptions: diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 98fb60c275..221872f544 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -1704,8 +1704,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods :param source: Type coming from the database table or cursor description :return: ColumnSpec object """ - col_types = cls.get_column_types(native_type) - if col_types: + if col_types := cls.get_column_types(native_type): column_type, generic_type = col_types is_dttm = generic_type == GenericDataType.TEMPORAL return ColumnSpec( @@ -1996,9 +1995,8 @@ class BasicParametersMixin: required = {"host", "port", "username", "database"} parameters = properties.get("parameters", {}) present = {key for key in parameters if parameters.get(key, ())} - missing = sorted(required - present) - if missing: + if missing := sorted(required - present): errors.append( SupersetError( message=f'One or more parameters are missing: {", ".join(missing)}', diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py index 1f2ee51068..1f5068ad04 100644 --- a/superset/db_engine_specs/bigquery.py +++ b/superset/db_engine_specs/bigquery.py @@ -384,9 +384,8 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met } # Add credentials if they are set on the SQLAlchemy dialect. - creds = engine.dialect.credentials_info - if creds: + if creds := engine.dialect.credentials_info: to_gbq_kwargs[ "credentials" ] = service_account.Credentials.from_service_account_info(creds) diff --git a/superset/db_engine_specs/databricks.py b/superset/db_engine_specs/databricks.py index f39e43aa60..5f12f3174d 100644 --- a/superset/db_engine_specs/databricks.py +++ b/superset/db_engine_specs/databricks.py @@ -285,9 +285,8 @@ class DatabricksNativeEngineSpec(DatabricksODBCEngineSpec, BasicParametersMixin) parameters["http_path"] = connect_args.get("http_path") present = {key for key in parameters if parameters.get(key, ())} - missing = sorted(required - present) - if missing: + if missing := sorted(required - present): errors.append( SupersetError( message=f'One or more parameters are missing: {", ".join(missing)}', diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index e047923f92..42f6ed9af6 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -1213,8 +1213,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): ) -> Dict[str, Any]: metadata = {} - indexes = database.get_indexes(table_name, schema_name) - if indexes: + if indexes := database.get_indexes(table_name, schema_name): col_names, latest_parts = cls.latest_partition( table_name, schema_name, database, show_first=True ) @@ -1278,8 +1277,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): @classmethod def handle_cursor(cls, cursor: "Cursor", query: Query, session: Session) -> None: """Updates progress information""" - tracking_url = cls.get_tracking_url(cursor) - if tracking_url: + if tracking_url := cls.get_tracking_url(cursor): query.tracking_url = tracking_url session.commit() diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py index c7049ae71d..69ccf55931 100644 --- a/superset/db_engine_specs/snowflake.py +++ b/superset/db_engine_specs/snowflake.py @@ -312,9 +312,8 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec): } parameters = properties.get("parameters", {}) present = {key for key in parameters if parameters.get(key, ())} - missing = sorted(required - present) - if missing: + if missing := sorted(required - present): errors.append( SupersetError( message=f'One or more parameters are missing: {", ".join(missing)}', diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 6cca83be06..0fa4d05cbc 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -57,8 +57,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec): ) -> Dict[str, Any]: metadata = {} - indexes = database.get_indexes(table_name, schema_name) - if indexes: + if indexes := database.get_indexes(table_name, schema_name): col_names, latest_parts = cls.latest_partition( table_name, schema_name, database, show_first=True ) @@ -150,8 +149,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec): @classmethod def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None: - tracking_url = cls.get_tracking_url(cursor) - if tracking_url: + if tracking_url := cls.get_tracking_url(cursor): query.tracking_url = tracking_url # Adds the executed query id to the extra payload so the query can be cancelled diff --git a/superset/errors.py b/superset/errors.py index 2df0eb82b2..5261848687 100644 --- a/superset/errors.py +++ b/superset/errors.py @@ -211,8 +211,7 @@ class SupersetError: Mutates the extra params with user facing error codes that map to backend errors. """ - issue_codes = ERROR_TYPES_TO_ISSUE_CODES_MAPPING.get(self.error_type) - if issue_codes: + if issue_codes := ERROR_TYPES_TO_ISSUE_CODES_MAPPING.get(self.error_type): self.extra = self.extra or {} self.extra.update( { diff --git a/superset/initialization/__init__.py b/superset/initialization/__init__.py index 90c1653d0b..c489cc323c 100644 --- a/superset/initialization/__init__.py +++ b/superset/initialization/__init__.py @@ -453,8 +453,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods # Hook that provides administrators a handle on the Flask APP # after initialization - flask_app_mutator = self.config["FLASK_APP_MUTATOR"] - if flask_app_mutator: + if flask_app_mutator := self.config["FLASK_APP_MUTATOR"]: flask_app_mutator(self.superset_app) if feature_flag_manager.is_feature_enabled("TAGGING_SYSTEM"): diff --git a/superset/migrations/env.py b/superset/migrations/env.py index 90561beea4..e3779bb65b 100755 --- a/superset/migrations/env.py +++ b/superset/migrations/env.py @@ -103,8 +103,7 @@ def run_migrations_online() -> None: kwargs = {} if engine.name in ("sqlite", "mysql"): kwargs = {"transaction_per_migration": True, "transactional_ddl": True} - configure_args = current_app.extensions["migrate"].configure_args - if configure_args: + if configure_args := current_app.extensions["migrate"].configure_args: kwargs.update(configure_args) context.configure( @@ -112,7 +111,7 @@ def run_migrations_online() -> None: target_metadata=target_metadata, # compare_type=True, process_revision_directives=process_revision_directives, - **kwargs + **kwargs, ) try: diff --git a/superset/migrations/shared/migrate_viz/base.py b/superset/migrations/shared/migrate_viz/base.py index 19bb7cc2a9..5ea23551ea 100644 --- a/superset/migrations/shared/migrate_viz/base.py +++ b/superset/migrations/shared/migrate_viz/base.py @@ -130,8 +130,7 @@ class MigrateViz: # only backup params slc.params = json.dumps({**clz.data, FORM_DATA_BAK_FIELD_NAME: form_data_bak}) - query_context = try_load_json(slc.query_context) - if "form_data" in query_context: + if "form_data" in (query_context := try_load_json(slc.query_context)): query_context["form_data"] = clz.data slc.query_context = json.dumps(query_context) return slc @@ -139,8 +138,7 @@ class MigrateViz: @classmethod def downgrade_slice(cls, slc: Slice) -> Slice: form_data = try_load_json(slc.params) - form_data_bak = form_data.get(FORM_DATA_BAK_FIELD_NAME, {}) - if "viz_type" in form_data_bak: + if "viz_type" in (form_data_bak := form_data.get(FORM_DATA_BAK_FIELD_NAME, {})): slc.params = json.dumps(form_data_bak) slc.viz_type = form_data_bak.get("viz_type") query_context = try_load_json(slc.query_context) diff --git a/superset/migrations/shared/migrate_viz/processors.py b/superset/migrations/shared/migrate_viz/processors.py index 3584856beb..6d35a974db 100644 --- a/superset/migrations/shared/migrate_viz/processors.py +++ b/superset/migrations/shared/migrate_viz/processors.py @@ -40,8 +40,7 @@ class MigrateAreaChart(MigrateViz): if self.data.get("contribution"): self.data["contributionMode"] = "row" - stacked = self.data.get("stacked_style") - if stacked: + if stacked := self.data.get("stacked_style"): stacked_map = { "expand": "Expand", "stack": "Stack", @@ -49,7 +48,6 @@ class MigrateAreaChart(MigrateViz): self.data["show_extra_controls"] = True self.data["stack"] = stacked_map.get(stacked) - x_axis_label = self.data.get("x_axis_label") - if x_axis_label: + if x_axis_label := self.data.get("x_axis_label"): self.data["x_axis_title"] = x_axis_label self.data["x_axis_title_margin"] = 30 diff --git a/superset/migrations/versions/2018-07-22_11-59_bebcf3fed1fe_convert_dashboard_v1_positions.py b/superset/migrations/versions/2018-07-22_11-59_bebcf3fed1fe_convert_dashboard_v1_positions.py index 865a8e59a0..13c4e61718 100644 --- a/superset/migrations/versions/2018-07-22_11-59_bebcf3fed1fe_convert_dashboard_v1_positions.py +++ b/superset/migrations/versions/2018-07-22_11-59_bebcf3fed1fe_convert_dashboard_v1_positions.py @@ -193,13 +193,12 @@ def get_chart_holder(position): size_y = position["size_y"] slice_id = position["slice_id"] slice_name = position.get("slice_name") - code = position.get("code") width = max(GRID_MIN_COLUMN_COUNT, int(round(size_x / GRID_RATIO))) height = max( GRID_MIN_ROW_UNITS, int(round(((size_y / GRID_RATIO) * 100) / ROW_HEIGHT)) ) - if code is not None: + if (code := position.get("code")) is not None: markdown_content = " " # white-space markdown if len(code): markdown_content = code diff --git a/superset/migrations/versions/2021-04-29_15-32_f1410ed7ec95_migrate_native_filters_to_new_schema.py b/superset/migrations/versions/2021-04-29_15-32_f1410ed7ec95_migrate_native_filters_to_new_schema.py index 46b8e5f958..ec8f8e1cc0 100644 --- a/superset/migrations/versions/2021-04-29_15-32_f1410ed7ec95_migrate_native_filters_to_new_schema.py +++ b/superset/migrations/versions/2021-04-29_15-32_f1410ed7ec95_migrate_native_filters_to_new_schema.py @@ -80,8 +80,7 @@ def upgrade_dashboard(dashboard: Dict[str, Any]) -> Tuple[int, int]: changed_filters, changed_filter_sets = 0, 0 # upgrade native select filter metadata # upgrade native select filter metadata - native_filters = dashboard.get("native_filter_configuration") - if native_filters: + if native_filters := dashboard.get("native_filter_configuration"): changed_filters += upgrade_filters(native_filters) # upgrade filter sets @@ -123,8 +122,7 @@ def upgrade(): def downgrade_dashboard(dashboard: Dict[str, Any]) -> Tuple[int, int]: changed_filters, changed_filter_sets = 0, 0 # upgrade native select filter metadata - native_filters = dashboard.get("native_filter_configuration") - if native_filters: + if native_filters := dashboard.get("native_filter_configuration"): changed_filters += downgrade_filters(native_filters) # upgrade filter sets diff --git a/superset/migrations/versions/2021-07-27_08-25_31b2a1039d4a_drop_tables_constraint.py b/superset/migrations/versions/2021-07-27_08-25_31b2a1039d4a_drop_tables_constraint.py index 8f07ba1ae3..9773851ae9 100644 --- a/superset/migrations/versions/2021-07-27_08-25_31b2a1039d4a_drop_tables_constraint.py +++ b/superset/migrations/versions/2021-07-27_08-25_31b2a1039d4a_drop_tables_constraint.py @@ -40,9 +40,8 @@ def upgrade(): insp = engine.reflection.Inspector.from_engine(bind) # Drop the uniqueness constraint if it exists. - constraint = generic_find_uq_constraint_name("tables", {"table_name"}, insp) - if constraint: + if constraint := generic_find_uq_constraint_name("tables", {"table_name"}, insp): with op.batch_alter_table("tables", naming_convention=conv) as batch_op: batch_op.drop_constraint(constraint, type_="unique") diff --git a/superset/models/core.py b/superset/models/core.py index 43d12900e6..592207faba 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -277,9 +277,8 @@ class Database( # When returning the parameters we should use the masked SQLAlchemy URI and the # masked ``encrypted_extra`` to prevent exposing sensitive credentials. masked_uri = make_url_safe(self.sqlalchemy_uri) - masked_encrypted_extra = self.masked_encrypted_extra encrypted_config = {} - if masked_encrypted_extra is not None: + if (masked_encrypted_extra := self.masked_encrypted_extra) is not None: try: encrypted_config = json.loads(masked_encrypted_extra) except (TypeError, json.JSONDecodeError): diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 4022bcbc13..558ad15fc9 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -880,8 +880,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods """Apply config's SQL_QUERY_MUTATOR Typically adds comments to the query with context""" - sql_query_mutator = config["SQL_QUERY_MUTATOR"] - if sql_query_mutator: + if sql_query_mutator := config["SQL_QUERY_MUTATOR"]: sql = sql_query_mutator( sql, user_name=utils.get_username(), # TODO(john-bodley): Deprecate in 3.0. diff --git a/superset/models/slice.py b/superset/models/slice.py index d08e345d82..6835215338 100644 --- a/superset/models/slice.py +++ b/superset/models/slice.py @@ -378,8 +378,7 @@ class Slice( # pylint: disable=too-many-public-methods def set_related_perm(_mapper: Mapper, _connection: Connection, target: Slice) -> None: src_class = target.cls_model - id_ = target.datasource_id - if id_: + if id_ := target.datasource_id: ds = db.session.query(src_class).filter_by(id=int(id_)).first() if ds: target.perm = ds.perm diff --git a/superset/queries/saved_queries/api.py b/superset/queries/saved_queries/api.py index 0c28e31a52..c6e980c5de 100644 --- a/superset/queries/saved_queries/api.py +++ b/superset/queries/saved_queries/api.py @@ -262,7 +262,6 @@ class SavedQueryRestApi(BaseSupersetModelRestApi): 500: $ref: '#/components/responses/500' """ - token = request.args.get("token") requested_ids = kwargs["rison"] timestamp = datetime.now().strftime("%Y%m%dT%H%M%S") root = f"saved_query_export_{timestamp}" @@ -286,7 +285,7 @@ class SavedQueryRestApi(BaseSupersetModelRestApi): as_attachment=True, download_name=filename, ) - if token: + if token := request.args.get("token"): response.set_cookie(token, "done", max_age=600) return response diff --git a/superset/reports/commands/execute.py b/superset/reports/commands/execute.py index 78ac8aa3a8..61f72d4790 100644 --- a/superset/reports/commands/execute.py +++ b/superset/reports/commands/execute.py @@ -176,8 +176,7 @@ class BaseReportState: ) # If we need to render dashboard in a specific state, use stateful permalink - dashboard_state = self._report_schedule.extra.get("dashboard") - if dashboard_state: + if dashboard_state := self._report_schedule.extra.get("dashboard"): permalink_key = CreateDashboardPermalinkCommand( dashboard_id=str(self._report_schedule.dashboard.uuid), state=dashboard_state, diff --git a/superset/result_set.py b/superset/result_set.py index 170de1869c..9aa06bba09 100644 --- a/superset/result_set.py +++ b/superset/result_set.py @@ -225,12 +225,10 @@ class SupersetResultSet: def data_type(self, col_name: str, pa_dtype: pa.DataType) -> Optional[str]: """Given a pyarrow data type, Returns a generic database type""" - set_type = self._type_dict.get(col_name) - if set_type: + if set_type := self._type_dict.get(col_name): return set_type - mapped_type = self.convert_pa_dtype(pa_dtype) - if mapped_type: + if mapped_type := self.convert_pa_dtype(pa_dtype): return mapped_type return None diff --git a/superset/security/manager.py b/superset/security/manager.py index c54fdac87a..db6e631d91 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -589,8 +589,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods return {s.name for s in view_menu_names} # Properly treat anonymous user - public_role = self.get_public_role() - if public_role: + if public_role := self.get_public_role(): # filter by public role view_menu_names = ( base_query.filter(self.role_model.id == public_role.id).filter( @@ -639,8 +638,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods } # datasource_access - perms = self.user_view_menu_names("datasource_access") - if perms: + if perms := self.user_view_menu_names("datasource_access"): tables = ( self.get_session.query(SqlaTable.schema) .filter(SqlaTable.database_id == database.id) @@ -770,9 +768,8 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods == None, ) ) - deleted_count = pvms.delete() sesh.commit() - if deleted_count: + if deleted_count := pvms.delete(): logger.info("Deleted %i faulty permissions", deleted_count) def sync_role_definitions(self) -> None: @@ -1916,8 +1913,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods :param dataset: The dataset to check against :return: A list of filters """ - guest_user = self.get_current_guest_user_if_guest() - if guest_user: + if guest_user := self.get_current_guest_user_if_guest(): return [ rule for rule in guest_user.rls diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 149feb1639..0f373a3514 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -95,7 +95,6 @@ def handle_query_error( """Local method handling error while processing the SQL""" payload = payload or {} msg = f"{prefix_message} {str(ex)}".strip() - troubleshooting_link = config["TROUBLESHOOTING_LINK"] query.error_message = msg query.tmp_table_name = None query.status = QueryStatus.FAILED @@ -119,7 +118,7 @@ def handle_query_error( session.commit() payload.update({"status": query.status, "error": msg, "errors": errors_payload}) - if troubleshooting_link: + if troubleshooting_link := config["TROUBLESHOOTING_LINK"]: payload["link"] = troubleshooting_link return payload diff --git a/superset/sql_validators/presto_db.py b/superset/sql_validators/presto_db.py index 5bc844751b..10ef1fc1e1 100644 --- a/superset/sql_validators/presto_db.py +++ b/superset/sql_validators/presto_db.py @@ -54,8 +54,7 @@ class PrestoDBSQLValidator(BaseSQLValidator): sql = parsed_query.stripped() # Hook to allow environment-specific mutation (usually comments) to the SQL - sql_query_mutator = config["SQL_QUERY_MUTATOR"] - if sql_query_mutator: + if sql_query_mutator := config["SQL_QUERY_MUTATOR"]: sql = sql_query_mutator( sql, user_name=get_username(), # TODO(john-bodley): Deprecate in 3.0. diff --git a/superset/utils/core.py b/superset/utils/core.py index 8451eaaa6f..c537abf459 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -1800,8 +1800,7 @@ def get_time_filter_status( } applied: List[Dict[str, str]] = [] rejected: List[Dict[str, str]] = [] - time_column = applied_time_extras.get(ExtraFiltersTimeColumnType.TIME_COL) - if time_column: + if time_column := applied_time_extras.get(ExtraFiltersTimeColumnType.TIME_COL): if time_column in temporal_columns: applied.append({"column": ExtraFiltersTimeColumnType.TIME_COL}) else: diff --git a/superset/utils/screenshots.py b/superset/utils/screenshots.py index a904b7dc43..88b97901b2 100644 --- a/superset/utils/screenshots.py +++ b/superset/utils/screenshots.py @@ -121,8 +121,7 @@ class BaseScreenshot: @staticmethod def get_from_cache_key(cache: Cache, cache_key: str) -> Optional[BytesIO]: logger.info("Attempting to get from cache: %s", cache_key) - payload = cache.get(cache_key) - if payload: + if payload := cache.get(cache_key): return BytesIO(payload) logger.info("Failed at getting from cache: %s", cache_key) return None diff --git a/superset/views/api.py b/superset/views/api.py index 2884ac997f..84c27d2fac 100644 --- a/superset/views/api.py +++ b/superset/views/api.py @@ -79,8 +79,7 @@ class Api(BaseSupersetView): params: slice_id: integer """ form_data = {} - slice_id = request.args.get("slice_id") - if slice_id: + if slice_id := request.args.get("slice_id"): slc = db.session.query(Slice).filter_by(id=slice_id).one_or_none() if slc: form_data = slc.form_data.copy() diff --git a/superset/views/base_api.py b/superset/views/base_api.py index 2e069c196d..30d25382f3 100644 --- a/superset/views/base_api.py +++ b/superset/views/base_api.py @@ -380,8 +380,7 @@ class BaseSupersetModelRestApi(ModelRestApi, BaseSupersetApiMixin): filter_field = cast(RelatedFieldFilter, filter_field) search_columns = [filter_field.field_name] if filter_field else None filters = datamodel.get_filters(search_columns) - base_filters = self.base_related_field_filters.get(column_name) - if base_filters: + if base_filters := self.base_related_field_filters.get(column_name): filters.add_filter_list(base_filters) if value and filter_field: filters.add_filter( @@ -588,8 +587,7 @@ class BaseSupersetModelRestApi(ModelRestApi, BaseSupersetApiMixin): return self.response_404() page, page_size = self._sanitize_page_args(page, page_size) # handle ordering - order_field = self.order_rel_fields.get(column_name) - if order_field: + if order_field := self.order_rel_fields.get(column_name): order_column, order_direction = order_field else: order_column, order_direction = "", "" diff --git a/superset/views/core.py b/superset/views/core.py index b473172399..24bc16c310 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -765,8 +765,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods """ redirect_url = request.url.replace("/superset/explore", "/explore") form_data_key = None - request_form_data = request.args.get("form_data") - if request_form_data: + if request_form_data := request.args.get("form_data"): parsed_form_data = loads_request_json(request_form_data) slice_id = parsed_form_data.get( "slice_id", int(request.args.get("slice_id", 0)) @@ -1498,8 +1497,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods @deprecated(new_target="/api/v1/log/recent_activity//") def recent_activity(self, user_id: int) -> FlaskResponse: """Recent activity (actions) for a given user""" - error_obj = self.get_user_activity_access_error(user_id) - if error_obj: + if error_obj := self.get_user_activity_access_error(user_id): return error_obj limit = request.args.get("limit") @@ -1543,8 +1541,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods @expose("/fave_dashboards//", methods=("GET",)) @deprecated(new_target="api/v1/dashboard/favorite_status/") def fave_dashboards(self, user_id: int) -> FlaskResponse: - error_obj = self.get_user_activity_access_error(user_id) - if error_obj: + if error_obj := self.get_user_activity_access_error(user_id): return error_obj qry = ( db.session.query(Dashboard, FavStar.dttm) @@ -1580,8 +1577,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods @expose("/created_dashboards//", methods=("GET",)) @deprecated(new_target="api/v1/dashboard/") def created_dashboards(self, user_id: int) -> FlaskResponse: - error_obj = self.get_user_activity_access_error(user_id) - if error_obj: + if error_obj := self.get_user_activity_access_error(user_id): return error_obj qry = ( db.session.query(Dashboard) @@ -1615,8 +1611,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods """List of slices a user owns, created, modified or faved""" if not user_id: user_id = cast(int, get_user_id()) - error_obj = self.get_user_activity_access_error(user_id) - if error_obj: + if error_obj := self.get_user_activity_access_error(user_id): return error_obj owner_ids_query = ( @@ -1669,8 +1664,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods """List of slices created by this user""" if not user_id: user_id = cast(int, get_user_id()) - error_obj = self.get_user_activity_access_error(user_id) - if error_obj: + if error_obj := self.get_user_activity_access_error(user_id): return error_obj qry = ( db.session.query(Slice) @@ -1701,8 +1695,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods """Favorite slices for a user""" if user_id is None: user_id = cast(int, get_user_id()) - error_obj = self.get_user_activity_access_error(user_id) - if error_obj: + if error_obj := self.get_user_activity_access_error(user_id): return error_obj qry = ( db.session.query(Slice, FavStar.dttm) @@ -1965,8 +1958,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods return json_error_response(_("permalink state not found"), status=404) dashboard_id, state = value["dashboardId"], value.get("state", {}) url = f"/superset/dashboard/{dashboard_id}?permalink_key={key}" - url_params = state.get("urlParams") - if url_params: + if url_params := state.get("urlParams"): params = parse.urlencode(url_params) url = f"{url}&{params}" hash_ = state.get("anchor", state.get("hash")) @@ -2125,8 +2117,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods mydb = db.session.query(Database).get(database_id) sql = json.loads(request.form.get("sql", '""')) - template_params = json.loads(request.form.get("templateParams") or "{}") - if template_params: + if template_params := json.loads(request.form.get("templateParams") or "{}"): template_processor = get_template_processor(mydb) sql = template_processor.process_template(sql, **template_params) @@ -2393,8 +2384,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods @expose("/sql_json/", methods=("POST",)) @deprecated(new_target="/api/v1/sqllab/execute/") def sql_json(self) -> FlaskResponse: - errors = SqlJsonPayloadSchema().validate(request.json) - if errors: + if errors := SqlJsonPayloadSchema().validate(request.json): return json_error_response(status=400, payload=errors) try: @@ -2621,10 +2611,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods search_user_id = get_user_id() database_id = request.args.get("database_id") search_text = request.args.get("search_text") - status = request.args.get("status") # From and To time stamp should be Epoch timestamp in seconds - from_time = request.args.get("from") - to_time = request.args.get("to") query = db.session.query(Query) if search_user_id: @@ -2635,7 +2622,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods # Filter on db Id query = query.filter(Query.database_id == database_id) - if status: + if status := request.args.get("status"): # Filter on status query = query.filter(Query.status == status) @@ -2643,10 +2630,10 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods # Filter on search text query = query.filter(Query.sql.like(f"%{search_text}%")) - if from_time: + if from_time := request.args.get("from"): query = query.filter(Query.start_time > int(from_time)) - if to_time: + if to_time := request.args.get("to"): query = query.filter(Query.start_time < int(to_time)) query_limit = config["QUERY_SEARCH_LIMIT"] @@ -2709,8 +2696,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods user_id = -1 if not user else user.id # Prevent unauthorized access to other user's profiles, # unless configured to do so on with ENABLE_BROAD_ACTIVITY_ACCESS - error_obj = self.get_user_activity_access_error(user_id) - if error_obj: + if error_obj := self.get_user_activity_access_error(user_id): return error_obj payload = { @@ -2789,8 +2775,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods **self._get_sqllab_tabs(get_user_id()), } - form_data = request.form.get("form_data") - if form_data: + if form_data := request.form.get("form_data"): try: payload["requested_query"] = json.loads(form_data) except json.JSONDecodeError: diff --git a/superset/views/database/validators.py b/superset/views/database/validators.py index 93723ac38b..29d80611a2 100644 --- a/superset/views/database/validators.py +++ b/superset/views/database/validators.py @@ -51,7 +51,6 @@ def sqlalchemy_uri_validator( def schema_allows_file_upload(database: Database, schema: Optional[str]) -> bool: if not database.allow_file_upload: return False - schemas = database.get_schema_access_for_file_upload() - if schemas: + if schemas := database.get_schema_access_for_file_upload(): return schema in schemas return security_manager.can_access_database(database) diff --git a/superset/views/log/api.py b/superset/views/log/api.py index b94af731c4..e218792c25 100644 --- a/superset/views/log/api.py +++ b/superset/views/log/api.py @@ -125,8 +125,7 @@ class LogRestApi(LogMixin, BaseSupersetModelRestApi): 500: $ref: '#/components/responses/500' """ - error_obj = self.get_user_activity_access_error(user_id) - if error_obj: + if error_obj := self.get_user_activity_access_error(user_id): return error_obj args = kwargs["rison"] diff --git a/superset/views/utils.py b/superset/views/utils.py index 35a39fdc9c..a53e750040 100644 --- a/superset/views/utils.py +++ b/superset/views/utils.py @@ -261,9 +261,8 @@ def get_datasource_info( :raises SupersetException: If the datasource no longer exists """ - datasource = form_data.get("datasource", "") - - if "__" in datasource: + # pylint: disable=superfluous-parens + if "__" in (datasource := form_data.get("datasource", "")): datasource_id, datasource_type = datasource.split("__") # The case where the datasource has been deleted if datasource_id == "None": @@ -462,7 +461,7 @@ def check_datasource_perms( _self: Any, datasource_type: Optional[str] = None, datasource_id: Optional[int] = None, - **kwargs: Any + **kwargs: Any, ) -> None: """ Check if user can access a cached response from explore_json. diff --git a/superset/viz.py b/superset/viz.py index d605b8b006..8abb6038e8 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -365,10 +365,11 @@ class BaseViz: # pylint: disable=too-many-public-methods metrics = self.all_metrics or [] groupby = self.dedup_columns(self.groupby, self.form_data.get("columns")) - groupby_labels = get_column_names(groupby) is_timeseries = self.is_timeseries - if DTTM_ALIAS in groupby_labels: + + # pylint: disable=superfluous-parens + if DTTM_ALIAS in (groupby_labels := get_column_names(groupby)): del groupby[groupby_labels.index(DTTM_ALIAS)] is_timeseries = True @@ -959,8 +960,7 @@ class PivotTableViz(BaseViz): if len(deduped_cols) < (len(groupby) + len(columns)): raise QueryObjectValidationError(_("Group By' and 'Columns' can't overlap")) - sort_by = self.form_data.get("timeseries_limit_metric") - if sort_by: + if sort_by := self.form_data.get("timeseries_limit_metric"): sort_by_label = utils.get_metric_name(sort_by) if sort_by_label not in utils.get_metric_names(query_obj["metrics"]): query_obj["metrics"].append(sort_by) @@ -1077,8 +1077,7 @@ class TreemapViz(BaseViz): @deprecated(deprecated_in="3.0") def query_obj(self) -> QueryObjectDict: query_obj = super().query_obj() - sort_by = self.form_data.get("timeseries_limit_metric") - if sort_by: + if sort_by := self.form_data.get("timeseries_limit_metric"): sort_by_label = utils.get_metric_name(sort_by) if sort_by_label not in utils.get_metric_names(query_obj["metrics"]): query_obj["metrics"].append(sort_by) @@ -1880,8 +1879,7 @@ class DistributionBarViz(BaseViz): if not self.form_data.get("groupby"): raise QueryObjectValidationError(_("Pick at least one field for [Series]")) - sort_by = self.form_data.get("timeseries_limit_metric") - if sort_by: + if sort_by := self.form_data.get("timeseries_limit_metric"): sort_by_label = utils.get_metric_name(sort_by) if sort_by_label not in utils.get_metric_names(query_obj["metrics"]): query_obj["metrics"].append(sort_by) @@ -2310,8 +2308,7 @@ class ParallelCoordinatesViz(BaseViz): def query_obj(self) -> QueryObjectDict: query_obj = super().query_obj() query_obj["groupby"] = [self.form_data.get("series")] - sort_by = self.form_data.get("timeseries_limit_metric") - if sort_by: + if sort_by := self.form_data.get("timeseries_limit_metric"): sort_by_label = utils.get_metric_name(sort_by) if sort_by_label not in utils.get_metric_names(query_obj["metrics"]): query_obj["metrics"].append(sort_by) @@ -2679,8 +2676,7 @@ class BaseDeckGLViz(BaseViz): if self.form_data.get("adhoc_filters") is None: self.form_data["adhoc_filters"] = [] - line_column = self.form_data.get("line_column") - if line_column: + if line_column := self.form_data.get("line_column"): spatial_columns.add(line_column) for column in sorted(spatial_columns): @@ -2706,13 +2702,12 @@ class BaseDeckGLViz(BaseViz): if self.form_data.get("js_columns"): group_by += self.form_data.get("js_columns") or [] - metrics = self.get_metrics() # Ensure this value is sorted so that it does not # cause the cache key generation (which hashes the # query object) to generate different keys for values # that should be considered the same. group_by = sorted(set(group_by)) - if metrics: + if metrics := self.get_metrics(): query_obj["groupby"] = group_by query_obj["metrics"] = metrics query_obj["columns"] = [] @@ -3097,8 +3092,7 @@ class PairedTTestViz(BaseViz): @deprecated(deprecated_in="3.0") def query_obj(self) -> QueryObjectDict: query_obj = super().query_obj() - sort_by = self.form_data.get("timeseries_limit_metric") - if sort_by: + if sort_by := self.form_data.get("timeseries_limit_metric"): sort_by_label = utils.get_metric_name(sort_by) if sort_by_label not in utils.get_metric_names(query_obj["metrics"]): query_obj["metrics"].append(sort_by) diff --git a/tests/example_data/data_loading/pandas/pandas_data_loader.py b/tests/example_data/data_loading/pandas/pandas_data_loader.py index 00f3f775ca..7f41602054 100644 --- a/tests/example_data/data_loading/pandas/pandas_data_loader.py +++ b/tests/example_data/data_loading/pandas/pandas_data_loader.py @@ -67,8 +67,7 @@ class PandasDataLoader(DataLoader): return inspect(self._db_engine).default_schema_name def _take_data_types(self, table: Table) -> Optional[Dict[str, str]]: - metadata_table = table.table_metadata - if metadata_table: + if metadata_table := table.table_metadata: types = metadata_table.types if types: return types diff --git a/tests/integration_tests/csv_upload_tests.py b/tests/integration_tests/csv_upload_tests.py index 70f984775d..91a76f97cf 100644 --- a/tests/integration_tests/csv_upload_tests.py +++ b/tests/integration_tests/csv_upload_tests.py @@ -136,7 +136,6 @@ def upload_csv( dtype: Union[str, None] = None, ): csv_upload_db_id = get_upload_db().id - schema = utils.get_example_default_schema() form_data = { "csv_file": open(filename, "rb"), "delimiter": ",", @@ -146,7 +145,7 @@ def upload_csv( "index_label": "test_label", "overwrite_duplicate": False, } - if schema: + if schema := utils.get_example_default_schema(): form_data["schema"] = schema if extra: form_data.update(extra) @@ -159,7 +158,6 @@ def upload_excel( filename: str, table_name: str, extra: Optional[Dict[str, str]] = None ): excel_upload_db_id = get_upload_db().id - schema = utils.get_example_default_schema() form_data = { "excel_file": open(filename, "rb"), "name": table_name, @@ -169,7 +167,7 @@ def upload_excel( "index_label": "test_label", "mangle_dupe_cols": False, } - if schema: + if schema := utils.get_example_default_schema(): form_data["schema"] = schema if extra: form_data.update(extra) @@ -180,7 +178,6 @@ def upload_columnar( filename: str, table_name: str, extra: Optional[Dict[str, str]] = None ): columnar_upload_db_id = get_upload_db().id - schema = utils.get_example_default_schema() form_data = { "columnar_file": open(filename, "rb"), "name": table_name, @@ -188,7 +185,7 @@ def upload_columnar( "if_exists": "fail", "index_label": "test_label", } - if schema: + if schema := utils.get_example_default_schema(): form_data["schema"] = schema if extra: form_data.update(extra) diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 055cf4779e..f87bafcd37 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -641,14 +641,13 @@ class TestDatasetApi(SupersetTestCase): if backend() == "sqlite": return - schema = get_example_default_schema() energy_usage_ds = self.get_energy_usage_dataset() self.login(username="admin") table_data = { "database": energy_usage_ds.database_id, "table_name": energy_usage_ds.table_name, } - if schema: + if schema := get_example_default_schema(): table_data["schema"] = schema rv = self.post_assert_metric("/api/v1/dataset/", table_data, "post") assert rv.status_code == 422 @@ -665,7 +664,6 @@ class TestDatasetApi(SupersetTestCase): if backend() == "sqlite": return - schema = get_example_default_schema() energy_usage_ds = self.get_energy_usage_dataset() self.login(username="admin") table_data = { @@ -673,7 +671,7 @@ class TestDatasetApi(SupersetTestCase): "table_name": energy_usage_ds.table_name, "sql": "select * from energy_usage", } - if schema: + if schema := get_example_default_schema(): table_data["schema"] = schema rv = self.post_assert_metric("/api/v1/dataset/", table_data, "post") assert rv.status_code == 422 @@ -690,7 +688,6 @@ class TestDatasetApi(SupersetTestCase): if backend() == "sqlite": return - schema = get_example_default_schema() energy_usage_ds = self.get_energy_usage_dataset() self.login(username="alpha") admin = self.get_user("admin") @@ -701,7 +698,7 @@ class TestDatasetApi(SupersetTestCase): "sql": "select * from energy_usage", "owners": [admin.id], } - if schema: + if schema := get_example_default_schema(): table_data["schema"] = schema rv = self.post_assert_metric("/api/v1/dataset/", table_data, "post") assert rv.status_code == 201