chore: Embrace the walrus operator (#24127)

This commit is contained in:
John Bodley 2023-05-19 00:37:13 -07:00 committed by GitHub
parent 6b5459121f
commit d583ca9ef5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
54 changed files with 100 additions and 185 deletions

View File

@ -19,6 +19,10 @@ repos:
rev: 5.12.0
hooks:
- id: isort
- repo: https://github.com/MarcoGorelli/auto-walrus
rev: v0.2.2
hooks:
- id: auto-walrus
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.3.0
hooks:

View File

@ -649,8 +649,7 @@ class ChartRestApi(BaseSupersetModelRestApi):
return self.response_404()
# fetch the chart screenshot using the current user and cache if set
img = ChartScreenshot.get_from_cache_key(thumbnail_cache, digest)
if img:
if img := ChartScreenshot.get_from_cache_key(thumbnail_cache, digest):
return Response(
FileWrapper(img), mimetype="image/png", direct_passthrough=True
)
@ -783,7 +782,6 @@ class ChartRestApi(BaseSupersetModelRestApi):
500:
$ref: '#/components/responses/500'
"""
token = request.args.get("token")
requested_ids = kwargs["rison"]
timestamp = datetime.now().strftime("%Y%m%dT%H%M%S")
root = f"chart_export_{timestamp}"
@ -805,7 +803,7 @@ class ChartRestApi(BaseSupersetModelRestApi):
as_attachment=True,
download_name=filename,
)
if token:
if token := request.args.get("token"):
response.set_cookie(token, "done", max_age=600)
return response

View File

@ -55,8 +55,7 @@ class BulkDeleteChartCommand(BaseCommand):
if not self._models or len(self._models) != len(self._model_ids):
raise ChartNotFoundError()
# Check there are no associated ReportSchedules
reports = ReportScheduleDAO.find_by_chart_ids(self._model_ids)
if reports:
if reports := ReportScheduleDAO.find_by_chart_ids(self._model_ids):
report_names = [report.name for report in reports]
raise ChartBulkDeleteFailedReportsExistError(
_("There are associated alerts or reports: %s" % ",".join(report_names))

View File

@ -64,8 +64,7 @@ class DeleteChartCommand(BaseCommand):
if not self._model:
raise ChartNotFoundError()
# Check there are no associated ReportSchedules
reports = ReportScheduleDAO.find_by_chart_id(self._model_id)
if reports:
if reports := ReportScheduleDAO.find_by_chart_id(self._model_id):
report_names = [report.name for report in reports]
raise ChartDeleteFailedReportsExistError(
_("There are associated alerts or reports: %s" % ",".join(report_names))

View File

@ -221,8 +221,7 @@ def get_query_results(
:raises QueryObjectValidationError: if an unsupported result type is requested
:return: JSON serializable result payload
"""
result_func = _result_type_functions.get(result_type)
if result_func:
if result_func := _result_type_functions.get(result_type):
return result_func(query_context, query_obj, force_cached)
raise QueryObjectValidationError(
_("Invalid result type: %(result_type)s", result_type=result_type)

View File

@ -125,10 +125,9 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
for column in datasource.columns
if (column["is_dttm"] if isinstance(column, dict) else column.is_dttm)
}
granularity = query_object.granularity
x_axis = form_data and form_data.get("x_axis")
if granularity:
if granularity := query_object.granularity:
filter_to_remove = None
if x_axis and x_axis in temporal_columns:
filter_to_remove = x_axis

View File

@ -500,8 +500,7 @@ class QueryContextProcessor:
return return_value
def get_cache_timeout(self) -> int:
cache_timeout_rv = self._query_context.get_cache_timeout()
if cache_timeout_rv:
if cache_timeout_rv := self._query_context.get_cache_timeout():
return cache_timeout_rv
if (
data_cache_timeout := config["DATA_CACHE_CONFIG"].get(

View File

@ -148,8 +148,7 @@ class QueryCacheManager:
if not key or not _cache[region] or force_query:
return query_cache
cache_value = _cache[region].get(key)
if cache_value:
if cache_value := _cache[region].get(key):
logger.debug("Cache key: %s", key)
stats_logger.incr("loading_from_cache")
try:

View File

@ -993,11 +993,10 @@ class SqlaTable(
schema=self.schema,
template_processor=template_processor,
)
col_in_metadata = self.get_column(expression)
time_grain = col.get("timeGrain")
has_timegrain = col.get("columnType") == "BASE_AXIS" and time_grain
is_dttm = False
if col_in_metadata:
if col_in_metadata := self.get_column(expression):
sqla_column = col_in_metadata.get_sqla_col(
template_processor=template_processor
)

View File

@ -56,8 +56,7 @@ class BulkDeleteDashboardCommand(BaseCommand):
if not self._models or len(self._models) != len(self._model_ids):
raise DashboardNotFoundError()
# Check there are no associated ReportSchedules
reports = ReportScheduleDAO.find_by_dashboard_ids(self._model_ids)
if reports:
if reports := ReportScheduleDAO.find_by_dashboard_ids(self._model_ids):
report_names = [report.name for report in reports]
raise DashboardBulkDeleteFailedReportsExistError(
_("There are associated alerts or reports: %s" % ",".join(report_names))

View File

@ -57,8 +57,7 @@ class DeleteDashboardCommand(BaseCommand):
if not self._model:
raise DashboardNotFoundError()
# Check there are no associated ReportSchedules
reports = ReportScheduleDAO.find_by_dashboard_id(self._model_id)
if reports:
if reports := ReportScheduleDAO.find_by_dashboard_id(self._model_id):
report_names = [report.name for report in reports]
raise DashboardDeleteFailedReportsExistError(
_("There are associated alerts or reports: %s" % ",".join(report_names))

View File

@ -200,11 +200,10 @@ class DashboardDAO(BaseDAO):
old_to_new_slice_ids: Optional[Dict[int, int]] = None,
commit: bool = False,
) -> Dashboard:
positions = data.get("positions")
new_filter_scopes = {}
md = dashboard.params_dict
if positions is not None:
if (positions := data.get("positions")) is not None:
# find slices in the position data
slice_ids = [
value.get("meta", {}).get("chartId")

View File

@ -1036,7 +1036,6 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
500:
$ref: '#/components/responses/500'
"""
token = request.args.get("token")
requested_ids = kwargs["rison"]
timestamp = datetime.now().strftime("%Y%m%dT%H%M%S")
root = f"database_export_{timestamp}"
@ -1060,7 +1059,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
as_attachment=True,
download_name=filename,
)
if token:
if token := request.args.get("token"):
response.set_cookie(token, "done", max_age=600)
return response

View File

@ -55,9 +55,8 @@ class DeleteDatabaseCommand(BaseCommand):
if not self._model:
raise DatabaseNotFoundError()
# Check there are no associated ReportSchedules
reports = ReportScheduleDAO.find_by_database_id(self._model_id)
if reports:
if reports := ReportScheduleDAO.find_by_database_id(self._model_id):
report_names = [report.name for report in reports]
raise DatabaseDeleteFailedReportsExistError(
_("There are associated alerts or reports: %s" % ",".join(report_names))

View File

@ -228,6 +228,5 @@ class TestConnectionDatabaseCommand(BaseCommand):
raise DatabaseTestConnectionUnexpectedError(errors) from ex
def validate(self) -> None:
database_name = self._properties.get("database_name")
if database_name is not None:
if (database_name := self._properties.get("database_name")) is not None:
self._model = DatabaseDAO.get_database_by_name(database_name)

View File

@ -128,6 +128,5 @@ class ValidateDatabaseParametersCommand(BaseCommand):
)
def validate(self) -> None:
database_id = self._properties.get("id")
if database_id is not None:
if (database_id := self._properties.get("id")) is not None:
self._model = DatabaseDAO.find_by_id(database_id)

View File

@ -977,8 +977,7 @@ class DatasetRestApi(BaseSupersetModelRestApi):
return self.response(400, message=ex.messages)
table_name = body["table_name"]
database_id = body["database_id"]
table = DatasetDAO.get_table_by_name(database_id, table_name)
if table:
if table := DatasetDAO.get_table_by_name(database_id, table_name):
return self.response(200, result={"table_id": table.id})
body["database"] = database_id

View File

@ -62,8 +62,7 @@ def get_sqla_type(native_type: str) -> VisitableType:
if native_type.upper() in type_map:
return type_map[native_type.upper()]
match = VARCHAR.match(native_type)
if match:
if match := VARCHAR.match(native_type):
size = int(match.group(1))
return String(size)

View File

@ -114,13 +114,11 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
exceptions.append(DatasetEndpointUnsafeValidationError())
# Validate columns
columns = self._properties.get("columns")
if columns:
if columns := self._properties.get("columns"):
self._validate_columns(columns, exceptions)
# Validate metrics
metrics = self._properties.get("metrics")
if metrics:
if metrics := self._properties.get("metrics"):
self._validate_metrics(metrics, exceptions)
if exceptions:

View File

@ -1704,8 +1704,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:param source: Type coming from the database table or cursor description
:return: ColumnSpec object
"""
col_types = cls.get_column_types(native_type)
if col_types:
if col_types := cls.get_column_types(native_type):
column_type, generic_type = col_types
is_dttm = generic_type == GenericDataType.TEMPORAL
return ColumnSpec(
@ -1996,9 +1995,8 @@ class BasicParametersMixin:
required = {"host", "port", "username", "database"}
parameters = properties.get("parameters", {})
present = {key for key in parameters if parameters.get(key, ())}
missing = sorted(required - present)
if missing:
if missing := sorted(required - present):
errors.append(
SupersetError(
message=f'One or more parameters are missing: {", ".join(missing)}',

View File

@ -384,9 +384,8 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
}
# Add credentials if they are set on the SQLAlchemy dialect.
creds = engine.dialect.credentials_info
if creds:
if creds := engine.dialect.credentials_info:
to_gbq_kwargs[
"credentials"
] = service_account.Credentials.from_service_account_info(creds)

View File

@ -285,9 +285,8 @@ class DatabricksNativeEngineSpec(DatabricksODBCEngineSpec, BasicParametersMixin)
parameters["http_path"] = connect_args.get("http_path")
present = {key for key in parameters if parameters.get(key, ())}
missing = sorted(required - present)
if missing:
if missing := sorted(required - present):
errors.append(
SupersetError(
message=f'One or more parameters are missing: {", ".join(missing)}',

View File

@ -1213,8 +1213,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
) -> Dict[str, Any]:
metadata = {}
indexes = database.get_indexes(table_name, schema_name)
if indexes:
if indexes := database.get_indexes(table_name, schema_name):
col_names, latest_parts = cls.latest_partition(
table_name, schema_name, database, show_first=True
)
@ -1278,8 +1277,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
@classmethod
def handle_cursor(cls, cursor: "Cursor", query: Query, session: Session) -> None:
"""Updates progress information"""
tracking_url = cls.get_tracking_url(cursor)
if tracking_url:
if tracking_url := cls.get_tracking_url(cursor):
query.tracking_url = tracking_url
session.commit()

View File

@ -312,9 +312,8 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
}
parameters = properties.get("parameters", {})
present = {key for key in parameters if parameters.get(key, ())}
missing = sorted(required - present)
if missing:
if missing := sorted(required - present):
errors.append(
SupersetError(
message=f'One or more parameters are missing: {", ".join(missing)}',

View File

@ -57,8 +57,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
) -> Dict[str, Any]:
metadata = {}
indexes = database.get_indexes(table_name, schema_name)
if indexes:
if indexes := database.get_indexes(table_name, schema_name):
col_names, latest_parts = cls.latest_partition(
table_name, schema_name, database, show_first=True
)
@ -150,8 +149,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
@classmethod
def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None:
tracking_url = cls.get_tracking_url(cursor)
if tracking_url:
if tracking_url := cls.get_tracking_url(cursor):
query.tracking_url = tracking_url
# Adds the executed query id to the extra payload so the query can be cancelled

View File

@ -211,8 +211,7 @@ class SupersetError:
Mutates the extra params with user facing error codes that map to backend
errors.
"""
issue_codes = ERROR_TYPES_TO_ISSUE_CODES_MAPPING.get(self.error_type)
if issue_codes:
if issue_codes := ERROR_TYPES_TO_ISSUE_CODES_MAPPING.get(self.error_type):
self.extra = self.extra or {}
self.extra.update(
{

View File

@ -453,8 +453,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
# Hook that provides administrators a handle on the Flask APP
# after initialization
flask_app_mutator = self.config["FLASK_APP_MUTATOR"]
if flask_app_mutator:
if flask_app_mutator := self.config["FLASK_APP_MUTATOR"]:
flask_app_mutator(self.superset_app)
if feature_flag_manager.is_feature_enabled("TAGGING_SYSTEM"):

View File

@ -103,8 +103,7 @@ def run_migrations_online() -> None:
kwargs = {}
if engine.name in ("sqlite", "mysql"):
kwargs = {"transaction_per_migration": True, "transactional_ddl": True}
configure_args = current_app.extensions["migrate"].configure_args
if configure_args:
if configure_args := current_app.extensions["migrate"].configure_args:
kwargs.update(configure_args)
context.configure(
@ -112,7 +111,7 @@ def run_migrations_online() -> None:
target_metadata=target_metadata,
# compare_type=True,
process_revision_directives=process_revision_directives,
**kwargs
**kwargs,
)
try:

View File

@ -130,8 +130,7 @@ class MigrateViz:
# only backup params
slc.params = json.dumps({**clz.data, FORM_DATA_BAK_FIELD_NAME: form_data_bak})
query_context = try_load_json(slc.query_context)
if "form_data" in query_context:
if "form_data" in (query_context := try_load_json(slc.query_context)):
query_context["form_data"] = clz.data
slc.query_context = json.dumps(query_context)
return slc
@ -139,8 +138,7 @@ class MigrateViz:
@classmethod
def downgrade_slice(cls, slc: Slice) -> Slice:
form_data = try_load_json(slc.params)
form_data_bak = form_data.get(FORM_DATA_BAK_FIELD_NAME, {})
if "viz_type" in form_data_bak:
if "viz_type" in (form_data_bak := form_data.get(FORM_DATA_BAK_FIELD_NAME, {})):
slc.params = json.dumps(form_data_bak)
slc.viz_type = form_data_bak.get("viz_type")
query_context = try_load_json(slc.query_context)

View File

@ -40,8 +40,7 @@ class MigrateAreaChart(MigrateViz):
if self.data.get("contribution"):
self.data["contributionMode"] = "row"
stacked = self.data.get("stacked_style")
if stacked:
if stacked := self.data.get("stacked_style"):
stacked_map = {
"expand": "Expand",
"stack": "Stack",
@ -49,7 +48,6 @@ class MigrateAreaChart(MigrateViz):
self.data["show_extra_controls"] = True
self.data["stack"] = stacked_map.get(stacked)
x_axis_label = self.data.get("x_axis_label")
if x_axis_label:
if x_axis_label := self.data.get("x_axis_label"):
self.data["x_axis_title"] = x_axis_label
self.data["x_axis_title_margin"] = 30

View File

@ -193,13 +193,12 @@ def get_chart_holder(position):
size_y = position["size_y"]
slice_id = position["slice_id"]
slice_name = position.get("slice_name")
code = position.get("code")
width = max(GRID_MIN_COLUMN_COUNT, int(round(size_x / GRID_RATIO)))
height = max(
GRID_MIN_ROW_UNITS, int(round(((size_y / GRID_RATIO) * 100) / ROW_HEIGHT))
)
if code is not None:
if (code := position.get("code")) is not None:
markdown_content = " " # white-space markdown
if len(code):
markdown_content = code

View File

@ -80,8 +80,7 @@ def upgrade_dashboard(dashboard: Dict[str, Any]) -> Tuple[int, int]:
changed_filters, changed_filter_sets = 0, 0
# upgrade native select filter metadata
# upgrade native select filter metadata
native_filters = dashboard.get("native_filter_configuration")
if native_filters:
if native_filters := dashboard.get("native_filter_configuration"):
changed_filters += upgrade_filters(native_filters)
# upgrade filter sets
@ -123,8 +122,7 @@ def upgrade():
def downgrade_dashboard(dashboard: Dict[str, Any]) -> Tuple[int, int]:
changed_filters, changed_filter_sets = 0, 0
# upgrade native select filter metadata
native_filters = dashboard.get("native_filter_configuration")
if native_filters:
if native_filters := dashboard.get("native_filter_configuration"):
changed_filters += downgrade_filters(native_filters)
# upgrade filter sets

View File

@ -40,9 +40,8 @@ def upgrade():
insp = engine.reflection.Inspector.from_engine(bind)
# Drop the uniqueness constraint if it exists.
constraint = generic_find_uq_constraint_name("tables", {"table_name"}, insp)
if constraint:
if constraint := generic_find_uq_constraint_name("tables", {"table_name"}, insp):
with op.batch_alter_table("tables", naming_convention=conv) as batch_op:
batch_op.drop_constraint(constraint, type_="unique")

View File

@ -277,9 +277,8 @@ class Database(
# When returning the parameters we should use the masked SQLAlchemy URI and the
# masked ``encrypted_extra`` to prevent exposing sensitive credentials.
masked_uri = make_url_safe(self.sqlalchemy_uri)
masked_encrypted_extra = self.masked_encrypted_extra
encrypted_config = {}
if masked_encrypted_extra is not None:
if (masked_encrypted_extra := self.masked_encrypted_extra) is not None:
try:
encrypted_config = json.loads(masked_encrypted_extra)
except (TypeError, json.JSONDecodeError):

View File

@ -880,8 +880,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
"""Apply config's SQL_QUERY_MUTATOR
Typically adds comments to the query with context"""
sql_query_mutator = config["SQL_QUERY_MUTATOR"]
if sql_query_mutator:
if sql_query_mutator := config["SQL_QUERY_MUTATOR"]:
sql = sql_query_mutator(
sql,
user_name=utils.get_username(), # TODO(john-bodley): Deprecate in 3.0.

View File

@ -378,8 +378,7 @@ class Slice( # pylint: disable=too-many-public-methods
def set_related_perm(_mapper: Mapper, _connection: Connection, target: Slice) -> None:
src_class = target.cls_model
id_ = target.datasource_id
if id_:
if id_ := target.datasource_id:
ds = db.session.query(src_class).filter_by(id=int(id_)).first()
if ds:
target.perm = ds.perm

View File

@ -262,7 +262,6 @@ class SavedQueryRestApi(BaseSupersetModelRestApi):
500:
$ref: '#/components/responses/500'
"""
token = request.args.get("token")
requested_ids = kwargs["rison"]
timestamp = datetime.now().strftime("%Y%m%dT%H%M%S")
root = f"saved_query_export_{timestamp}"
@ -286,7 +285,7 @@ class SavedQueryRestApi(BaseSupersetModelRestApi):
as_attachment=True,
download_name=filename,
)
if token:
if token := request.args.get("token"):
response.set_cookie(token, "done", max_age=600)
return response

View File

@ -176,8 +176,7 @@ class BaseReportState:
)
# If we need to render dashboard in a specific state, use stateful permalink
dashboard_state = self._report_schedule.extra.get("dashboard")
if dashboard_state:
if dashboard_state := self._report_schedule.extra.get("dashboard"):
permalink_key = CreateDashboardPermalinkCommand(
dashboard_id=str(self._report_schedule.dashboard.uuid),
state=dashboard_state,

View File

@ -225,12 +225,10 @@ class SupersetResultSet:
def data_type(self, col_name: str, pa_dtype: pa.DataType) -> Optional[str]:
"""Given a pyarrow data type, Returns a generic database type"""
set_type = self._type_dict.get(col_name)
if set_type:
if set_type := self._type_dict.get(col_name):
return set_type
mapped_type = self.convert_pa_dtype(pa_dtype)
if mapped_type:
if mapped_type := self.convert_pa_dtype(pa_dtype):
return mapped_type
return None

View File

@ -589,8 +589,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
return {s.name for s in view_menu_names}
# Properly treat anonymous user
public_role = self.get_public_role()
if public_role:
if public_role := self.get_public_role():
# filter by public role
view_menu_names = (
base_query.filter(self.role_model.id == public_role.id).filter(
@ -639,8 +638,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
}
# datasource_access
perms = self.user_view_menu_names("datasource_access")
if perms:
if perms := self.user_view_menu_names("datasource_access"):
tables = (
self.get_session.query(SqlaTable.schema)
.filter(SqlaTable.database_id == database.id)
@ -770,9 +768,8 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
== None,
)
)
deleted_count = pvms.delete()
sesh.commit()
if deleted_count:
if deleted_count := pvms.delete():
logger.info("Deleted %i faulty permissions", deleted_count)
def sync_role_definitions(self) -> None:
@ -1916,8 +1913,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
:param dataset: The dataset to check against
:return: A list of filters
"""
guest_user = self.get_current_guest_user_if_guest()
if guest_user:
if guest_user := self.get_current_guest_user_if_guest():
return [
rule
for rule in guest_user.rls

View File

@ -95,7 +95,6 @@ def handle_query_error(
"""Local method handling error while processing the SQL"""
payload = payload or {}
msg = f"{prefix_message} {str(ex)}".strip()
troubleshooting_link = config["TROUBLESHOOTING_LINK"]
query.error_message = msg
query.tmp_table_name = None
query.status = QueryStatus.FAILED
@ -119,7 +118,7 @@ def handle_query_error(
session.commit()
payload.update({"status": query.status, "error": msg, "errors": errors_payload})
if troubleshooting_link:
if troubleshooting_link := config["TROUBLESHOOTING_LINK"]:
payload["link"] = troubleshooting_link
return payload

View File

@ -54,8 +54,7 @@ class PrestoDBSQLValidator(BaseSQLValidator):
sql = parsed_query.stripped()
# Hook to allow environment-specific mutation (usually comments) to the SQL
sql_query_mutator = config["SQL_QUERY_MUTATOR"]
if sql_query_mutator:
if sql_query_mutator := config["SQL_QUERY_MUTATOR"]:
sql = sql_query_mutator(
sql,
user_name=get_username(), # TODO(john-bodley): Deprecate in 3.0.

View File

@ -1800,8 +1800,7 @@ def get_time_filter_status(
}
applied: List[Dict[str, str]] = []
rejected: List[Dict[str, str]] = []
time_column = applied_time_extras.get(ExtraFiltersTimeColumnType.TIME_COL)
if time_column:
if time_column := applied_time_extras.get(ExtraFiltersTimeColumnType.TIME_COL):
if time_column in temporal_columns:
applied.append({"column": ExtraFiltersTimeColumnType.TIME_COL})
else:

View File

@ -121,8 +121,7 @@ class BaseScreenshot:
@staticmethod
def get_from_cache_key(cache: Cache, cache_key: str) -> Optional[BytesIO]:
logger.info("Attempting to get from cache: %s", cache_key)
payload = cache.get(cache_key)
if payload:
if payload := cache.get(cache_key):
return BytesIO(payload)
logger.info("Failed at getting from cache: %s", cache_key)
return None

View File

@ -79,8 +79,7 @@ class Api(BaseSupersetView):
params: slice_id: integer
"""
form_data = {}
slice_id = request.args.get("slice_id")
if slice_id:
if slice_id := request.args.get("slice_id"):
slc = db.session.query(Slice).filter_by(id=slice_id).one_or_none()
if slc:
form_data = slc.form_data.copy()

View File

@ -380,8 +380,7 @@ class BaseSupersetModelRestApi(ModelRestApi, BaseSupersetApiMixin):
filter_field = cast(RelatedFieldFilter, filter_field)
search_columns = [filter_field.field_name] if filter_field else None
filters = datamodel.get_filters(search_columns)
base_filters = self.base_related_field_filters.get(column_name)
if base_filters:
if base_filters := self.base_related_field_filters.get(column_name):
filters.add_filter_list(base_filters)
if value and filter_field:
filters.add_filter(
@ -588,8 +587,7 @@ class BaseSupersetModelRestApi(ModelRestApi, BaseSupersetApiMixin):
return self.response_404()
page, page_size = self._sanitize_page_args(page, page_size)
# handle ordering
order_field = self.order_rel_fields.get(column_name)
if order_field:
if order_field := self.order_rel_fields.get(column_name):
order_column, order_direction = order_field
else:
order_column, order_direction = "", ""

View File

@ -765,8 +765,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
"""
redirect_url = request.url.replace("/superset/explore", "/explore")
form_data_key = None
request_form_data = request.args.get("form_data")
if request_form_data:
if request_form_data := request.args.get("form_data"):
parsed_form_data = loads_request_json(request_form_data)
slice_id = parsed_form_data.get(
"slice_id", int(request.args.get("slice_id", 0))
@ -1498,8 +1497,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
@deprecated(new_target="/api/v1/log/recent_activity/<user_id>/")
def recent_activity(self, user_id: int) -> FlaskResponse:
"""Recent activity (actions) for a given user"""
error_obj = self.get_user_activity_access_error(user_id)
if error_obj:
if error_obj := self.get_user_activity_access_error(user_id):
return error_obj
limit = request.args.get("limit")
@ -1543,8 +1541,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
@expose("/fave_dashboards/<int:user_id>/", methods=("GET",))
@deprecated(new_target="api/v1/dashboard/favorite_status/")
def fave_dashboards(self, user_id: int) -> FlaskResponse:
error_obj = self.get_user_activity_access_error(user_id)
if error_obj:
if error_obj := self.get_user_activity_access_error(user_id):
return error_obj
qry = (
db.session.query(Dashboard, FavStar.dttm)
@ -1580,8 +1577,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
@expose("/created_dashboards/<int:user_id>/", methods=("GET",))
@deprecated(new_target="api/v1/dashboard/")
def created_dashboards(self, user_id: int) -> FlaskResponse:
error_obj = self.get_user_activity_access_error(user_id)
if error_obj:
if error_obj := self.get_user_activity_access_error(user_id):
return error_obj
qry = (
db.session.query(Dashboard)
@ -1615,8 +1611,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
"""List of slices a user owns, created, modified or faved"""
if not user_id:
user_id = cast(int, get_user_id())
error_obj = self.get_user_activity_access_error(user_id)
if error_obj:
if error_obj := self.get_user_activity_access_error(user_id):
return error_obj
owner_ids_query = (
@ -1669,8 +1664,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
"""List of slices created by this user"""
if not user_id:
user_id = cast(int, get_user_id())
error_obj = self.get_user_activity_access_error(user_id)
if error_obj:
if error_obj := self.get_user_activity_access_error(user_id):
return error_obj
qry = (
db.session.query(Slice)
@ -1701,8 +1695,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
"""Favorite slices for a user"""
if user_id is None:
user_id = cast(int, get_user_id())
error_obj = self.get_user_activity_access_error(user_id)
if error_obj:
if error_obj := self.get_user_activity_access_error(user_id):
return error_obj
qry = (
db.session.query(Slice, FavStar.dttm)
@ -1965,8 +1958,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
return json_error_response(_("permalink state not found"), status=404)
dashboard_id, state = value["dashboardId"], value.get("state", {})
url = f"/superset/dashboard/{dashboard_id}?permalink_key={key}"
url_params = state.get("urlParams")
if url_params:
if url_params := state.get("urlParams"):
params = parse.urlencode(url_params)
url = f"{url}&{params}"
hash_ = state.get("anchor", state.get("hash"))
@ -2125,8 +2117,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
mydb = db.session.query(Database).get(database_id)
sql = json.loads(request.form.get("sql", '""'))
template_params = json.loads(request.form.get("templateParams") or "{}")
if template_params:
if template_params := json.loads(request.form.get("templateParams") or "{}"):
template_processor = get_template_processor(mydb)
sql = template_processor.process_template(sql, **template_params)
@ -2393,8 +2384,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
@expose("/sql_json/", methods=("POST",))
@deprecated(new_target="/api/v1/sqllab/execute/")
def sql_json(self) -> FlaskResponse:
errors = SqlJsonPayloadSchema().validate(request.json)
if errors:
if errors := SqlJsonPayloadSchema().validate(request.json):
return json_error_response(status=400, payload=errors)
try:
@ -2621,10 +2611,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
search_user_id = get_user_id()
database_id = request.args.get("database_id")
search_text = request.args.get("search_text")
status = request.args.get("status")
# From and To time stamp should be Epoch timestamp in seconds
from_time = request.args.get("from")
to_time = request.args.get("to")
query = db.session.query(Query)
if search_user_id:
@ -2635,7 +2622,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
# Filter on db Id
query = query.filter(Query.database_id == database_id)
if status:
if status := request.args.get("status"):
# Filter on status
query = query.filter(Query.status == status)
@ -2643,10 +2630,10 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
# Filter on search text
query = query.filter(Query.sql.like(f"%{search_text}%"))
if from_time:
if from_time := request.args.get("from"):
query = query.filter(Query.start_time > int(from_time))
if to_time:
if to_time := request.args.get("to"):
query = query.filter(Query.start_time < int(to_time))
query_limit = config["QUERY_SEARCH_LIMIT"]
@ -2709,8 +2696,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
user_id = -1 if not user else user.id
# Prevent unauthorized access to other user's profiles,
# unless configured to do so on with ENABLE_BROAD_ACTIVITY_ACCESS
error_obj = self.get_user_activity_access_error(user_id)
if error_obj:
if error_obj := self.get_user_activity_access_error(user_id):
return error_obj
payload = {
@ -2789,8 +2775,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
**self._get_sqllab_tabs(get_user_id()),
}
form_data = request.form.get("form_data")
if form_data:
if form_data := request.form.get("form_data"):
try:
payload["requested_query"] = json.loads(form_data)
except json.JSONDecodeError:

View File

@ -51,7 +51,6 @@ def sqlalchemy_uri_validator(
def schema_allows_file_upload(database: Database, schema: Optional[str]) -> bool:
if not database.allow_file_upload:
return False
schemas = database.get_schema_access_for_file_upload()
if schemas:
if schemas := database.get_schema_access_for_file_upload():
return schema in schemas
return security_manager.can_access_database(database)

View File

@ -125,8 +125,7 @@ class LogRestApi(LogMixin, BaseSupersetModelRestApi):
500:
$ref: '#/components/responses/500'
"""
error_obj = self.get_user_activity_access_error(user_id)
if error_obj:
if error_obj := self.get_user_activity_access_error(user_id):
return error_obj
args = kwargs["rison"]

View File

@ -261,9 +261,8 @@ def get_datasource_info(
:raises SupersetException: If the datasource no longer exists
"""
datasource = form_data.get("datasource", "")
if "__" in datasource:
# pylint: disable=superfluous-parens
if "__" in (datasource := form_data.get("datasource", "")):
datasource_id, datasource_type = datasource.split("__")
# The case where the datasource has been deleted
if datasource_id == "None":
@ -462,7 +461,7 @@ def check_datasource_perms(
_self: Any,
datasource_type: Optional[str] = None,
datasource_id: Optional[int] = None,
**kwargs: Any
**kwargs: Any,
) -> None:
"""
Check if user can access a cached response from explore_json.

View File

@ -365,10 +365,11 @@ class BaseViz: # pylint: disable=too-many-public-methods
metrics = self.all_metrics or []
groupby = self.dedup_columns(self.groupby, self.form_data.get("columns"))
groupby_labels = get_column_names(groupby)
is_timeseries = self.is_timeseries
if DTTM_ALIAS in groupby_labels:
# pylint: disable=superfluous-parens
if DTTM_ALIAS in (groupby_labels := get_column_names(groupby)):
del groupby[groupby_labels.index(DTTM_ALIAS)]
is_timeseries = True
@ -959,8 +960,7 @@ class PivotTableViz(BaseViz):
if len(deduped_cols) < (len(groupby) + len(columns)):
raise QueryObjectValidationError(_("Group By' and 'Columns' can't overlap"))
sort_by = self.form_data.get("timeseries_limit_metric")
if sort_by:
if sort_by := self.form_data.get("timeseries_limit_metric"):
sort_by_label = utils.get_metric_name(sort_by)
if sort_by_label not in utils.get_metric_names(query_obj["metrics"]):
query_obj["metrics"].append(sort_by)
@ -1077,8 +1077,7 @@ class TreemapViz(BaseViz):
@deprecated(deprecated_in="3.0")
def query_obj(self) -> QueryObjectDict:
query_obj = super().query_obj()
sort_by = self.form_data.get("timeseries_limit_metric")
if sort_by:
if sort_by := self.form_data.get("timeseries_limit_metric"):
sort_by_label = utils.get_metric_name(sort_by)
if sort_by_label not in utils.get_metric_names(query_obj["metrics"]):
query_obj["metrics"].append(sort_by)
@ -1880,8 +1879,7 @@ class DistributionBarViz(BaseViz):
if not self.form_data.get("groupby"):
raise QueryObjectValidationError(_("Pick at least one field for [Series]"))
sort_by = self.form_data.get("timeseries_limit_metric")
if sort_by:
if sort_by := self.form_data.get("timeseries_limit_metric"):
sort_by_label = utils.get_metric_name(sort_by)
if sort_by_label not in utils.get_metric_names(query_obj["metrics"]):
query_obj["metrics"].append(sort_by)
@ -2310,8 +2308,7 @@ class ParallelCoordinatesViz(BaseViz):
def query_obj(self) -> QueryObjectDict:
query_obj = super().query_obj()
query_obj["groupby"] = [self.form_data.get("series")]
sort_by = self.form_data.get("timeseries_limit_metric")
if sort_by:
if sort_by := self.form_data.get("timeseries_limit_metric"):
sort_by_label = utils.get_metric_name(sort_by)
if sort_by_label not in utils.get_metric_names(query_obj["metrics"]):
query_obj["metrics"].append(sort_by)
@ -2679,8 +2676,7 @@ class BaseDeckGLViz(BaseViz):
if self.form_data.get("adhoc_filters") is None:
self.form_data["adhoc_filters"] = []
line_column = self.form_data.get("line_column")
if line_column:
if line_column := self.form_data.get("line_column"):
spatial_columns.add(line_column)
for column in sorted(spatial_columns):
@ -2706,13 +2702,12 @@ class BaseDeckGLViz(BaseViz):
if self.form_data.get("js_columns"):
group_by += self.form_data.get("js_columns") or []
metrics = self.get_metrics()
# Ensure this value is sorted so that it does not
# cause the cache key generation (which hashes the
# query object) to generate different keys for values
# that should be considered the same.
group_by = sorted(set(group_by))
if metrics:
if metrics := self.get_metrics():
query_obj["groupby"] = group_by
query_obj["metrics"] = metrics
query_obj["columns"] = []
@ -3097,8 +3092,7 @@ class PairedTTestViz(BaseViz):
@deprecated(deprecated_in="3.0")
def query_obj(self) -> QueryObjectDict:
query_obj = super().query_obj()
sort_by = self.form_data.get("timeseries_limit_metric")
if sort_by:
if sort_by := self.form_data.get("timeseries_limit_metric"):
sort_by_label = utils.get_metric_name(sort_by)
if sort_by_label not in utils.get_metric_names(query_obj["metrics"]):
query_obj["metrics"].append(sort_by)

View File

@ -67,8 +67,7 @@ class PandasDataLoader(DataLoader):
return inspect(self._db_engine).default_schema_name
def _take_data_types(self, table: Table) -> Optional[Dict[str, str]]:
metadata_table = table.table_metadata
if metadata_table:
if metadata_table := table.table_metadata:
types = metadata_table.types
if types:
return types

View File

@ -136,7 +136,6 @@ def upload_csv(
dtype: Union[str, None] = None,
):
csv_upload_db_id = get_upload_db().id
schema = utils.get_example_default_schema()
form_data = {
"csv_file": open(filename, "rb"),
"delimiter": ",",
@ -146,7 +145,7 @@ def upload_csv(
"index_label": "test_label",
"overwrite_duplicate": False,
}
if schema:
if schema := utils.get_example_default_schema():
form_data["schema"] = schema
if extra:
form_data.update(extra)
@ -159,7 +158,6 @@ def upload_excel(
filename: str, table_name: str, extra: Optional[Dict[str, str]] = None
):
excel_upload_db_id = get_upload_db().id
schema = utils.get_example_default_schema()
form_data = {
"excel_file": open(filename, "rb"),
"name": table_name,
@ -169,7 +167,7 @@ def upload_excel(
"index_label": "test_label",
"mangle_dupe_cols": False,
}
if schema:
if schema := utils.get_example_default_schema():
form_data["schema"] = schema
if extra:
form_data.update(extra)
@ -180,7 +178,6 @@ def upload_columnar(
filename: str, table_name: str, extra: Optional[Dict[str, str]] = None
):
columnar_upload_db_id = get_upload_db().id
schema = utils.get_example_default_schema()
form_data = {
"columnar_file": open(filename, "rb"),
"name": table_name,
@ -188,7 +185,7 @@ def upload_columnar(
"if_exists": "fail",
"index_label": "test_label",
}
if schema:
if schema := utils.get_example_default_schema():
form_data["schema"] = schema
if extra:
form_data.update(extra)

View File

@ -641,14 +641,13 @@ class TestDatasetApi(SupersetTestCase):
if backend() == "sqlite":
return
schema = get_example_default_schema()
energy_usage_ds = self.get_energy_usage_dataset()
self.login(username="admin")
table_data = {
"database": energy_usage_ds.database_id,
"table_name": energy_usage_ds.table_name,
}
if schema:
if schema := get_example_default_schema():
table_data["schema"] = schema
rv = self.post_assert_metric("/api/v1/dataset/", table_data, "post")
assert rv.status_code == 422
@ -665,7 +664,6 @@ class TestDatasetApi(SupersetTestCase):
if backend() == "sqlite":
return
schema = get_example_default_schema()
energy_usage_ds = self.get_energy_usage_dataset()
self.login(username="admin")
table_data = {
@ -673,7 +671,7 @@ class TestDatasetApi(SupersetTestCase):
"table_name": energy_usage_ds.table_name,
"sql": "select * from energy_usage",
}
if schema:
if schema := get_example_default_schema():
table_data["schema"] = schema
rv = self.post_assert_metric("/api/v1/dataset/", table_data, "post")
assert rv.status_code == 422
@ -690,7 +688,6 @@ class TestDatasetApi(SupersetTestCase):
if backend() == "sqlite":
return
schema = get_example_default_schema()
energy_usage_ds = self.get_energy_usage_dataset()
self.login(username="alpha")
admin = self.get_user("admin")
@ -701,7 +698,7 @@ class TestDatasetApi(SupersetTestCase):
"sql": "select * from energy_usage",
"owners": [admin.id],
}
if schema:
if schema := get_example_default_schema():
table_data["schema"] = schema
rv = self.post_assert_metric("/api/v1/dataset/", table_data, "post")
assert rv.status_code == 201