mirror of https://github.com/apache/superset.git
chore: Embrace the walrus operator (#24127)
This commit is contained in:
parent
6b5459121f
commit
d583ca9ef5
|
@ -19,6 +19,10 @@ repos:
|
|||
rev: 5.12.0
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/MarcoGorelli/auto-walrus
|
||||
rev: v0.2.2
|
||||
hooks:
|
||||
- id: auto-walrus
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.3.0
|
||||
hooks:
|
||||
|
|
|
@ -649,8 +649,7 @@ class ChartRestApi(BaseSupersetModelRestApi):
|
|||
return self.response_404()
|
||||
|
||||
# fetch the chart screenshot using the current user and cache if set
|
||||
img = ChartScreenshot.get_from_cache_key(thumbnail_cache, digest)
|
||||
if img:
|
||||
if img := ChartScreenshot.get_from_cache_key(thumbnail_cache, digest):
|
||||
return Response(
|
||||
FileWrapper(img), mimetype="image/png", direct_passthrough=True
|
||||
)
|
||||
|
@ -783,7 +782,6 @@ class ChartRestApi(BaseSupersetModelRestApi):
|
|||
500:
|
||||
$ref: '#/components/responses/500'
|
||||
"""
|
||||
token = request.args.get("token")
|
||||
requested_ids = kwargs["rison"]
|
||||
timestamp = datetime.now().strftime("%Y%m%dT%H%M%S")
|
||||
root = f"chart_export_{timestamp}"
|
||||
|
@ -805,7 +803,7 @@ class ChartRestApi(BaseSupersetModelRestApi):
|
|||
as_attachment=True,
|
||||
download_name=filename,
|
||||
)
|
||||
if token:
|
||||
if token := request.args.get("token"):
|
||||
response.set_cookie(token, "done", max_age=600)
|
||||
return response
|
||||
|
||||
|
|
|
@ -55,8 +55,7 @@ class BulkDeleteChartCommand(BaseCommand):
|
|||
if not self._models or len(self._models) != len(self._model_ids):
|
||||
raise ChartNotFoundError()
|
||||
# Check there are no associated ReportSchedules
|
||||
reports = ReportScheduleDAO.find_by_chart_ids(self._model_ids)
|
||||
if reports:
|
||||
if reports := ReportScheduleDAO.find_by_chart_ids(self._model_ids):
|
||||
report_names = [report.name for report in reports]
|
||||
raise ChartBulkDeleteFailedReportsExistError(
|
||||
_("There are associated alerts or reports: %s" % ",".join(report_names))
|
||||
|
|
|
@ -64,8 +64,7 @@ class DeleteChartCommand(BaseCommand):
|
|||
if not self._model:
|
||||
raise ChartNotFoundError()
|
||||
# Check there are no associated ReportSchedules
|
||||
reports = ReportScheduleDAO.find_by_chart_id(self._model_id)
|
||||
if reports:
|
||||
if reports := ReportScheduleDAO.find_by_chart_id(self._model_id):
|
||||
report_names = [report.name for report in reports]
|
||||
raise ChartDeleteFailedReportsExistError(
|
||||
_("There are associated alerts or reports: %s" % ",".join(report_names))
|
||||
|
|
|
@ -221,8 +221,7 @@ def get_query_results(
|
|||
:raises QueryObjectValidationError: if an unsupported result type is requested
|
||||
:return: JSON serializable result payload
|
||||
"""
|
||||
result_func = _result_type_functions.get(result_type)
|
||||
if result_func:
|
||||
if result_func := _result_type_functions.get(result_type):
|
||||
return result_func(query_context, query_obj, force_cached)
|
||||
raise QueryObjectValidationError(
|
||||
_("Invalid result type: %(result_type)s", result_type=result_type)
|
||||
|
|
|
@ -125,10 +125,9 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
|
|||
for column in datasource.columns
|
||||
if (column["is_dttm"] if isinstance(column, dict) else column.is_dttm)
|
||||
}
|
||||
granularity = query_object.granularity
|
||||
x_axis = form_data and form_data.get("x_axis")
|
||||
|
||||
if granularity:
|
||||
if granularity := query_object.granularity:
|
||||
filter_to_remove = None
|
||||
if x_axis and x_axis in temporal_columns:
|
||||
filter_to_remove = x_axis
|
||||
|
|
|
@ -500,8 +500,7 @@ class QueryContextProcessor:
|
|||
return return_value
|
||||
|
||||
def get_cache_timeout(self) -> int:
|
||||
cache_timeout_rv = self._query_context.get_cache_timeout()
|
||||
if cache_timeout_rv:
|
||||
if cache_timeout_rv := self._query_context.get_cache_timeout():
|
||||
return cache_timeout_rv
|
||||
if (
|
||||
data_cache_timeout := config["DATA_CACHE_CONFIG"].get(
|
||||
|
|
|
@ -148,8 +148,7 @@ class QueryCacheManager:
|
|||
if not key or not _cache[region] or force_query:
|
||||
return query_cache
|
||||
|
||||
cache_value = _cache[region].get(key)
|
||||
if cache_value:
|
||||
if cache_value := _cache[region].get(key):
|
||||
logger.debug("Cache key: %s", key)
|
||||
stats_logger.incr("loading_from_cache")
|
||||
try:
|
||||
|
|
|
@ -993,11 +993,10 @@ class SqlaTable(
|
|||
schema=self.schema,
|
||||
template_processor=template_processor,
|
||||
)
|
||||
col_in_metadata = self.get_column(expression)
|
||||
time_grain = col.get("timeGrain")
|
||||
has_timegrain = col.get("columnType") == "BASE_AXIS" and time_grain
|
||||
is_dttm = False
|
||||
if col_in_metadata:
|
||||
if col_in_metadata := self.get_column(expression):
|
||||
sqla_column = col_in_metadata.get_sqla_col(
|
||||
template_processor=template_processor
|
||||
)
|
||||
|
|
|
@ -56,8 +56,7 @@ class BulkDeleteDashboardCommand(BaseCommand):
|
|||
if not self._models or len(self._models) != len(self._model_ids):
|
||||
raise DashboardNotFoundError()
|
||||
# Check there are no associated ReportSchedules
|
||||
reports = ReportScheduleDAO.find_by_dashboard_ids(self._model_ids)
|
||||
if reports:
|
||||
if reports := ReportScheduleDAO.find_by_dashboard_ids(self._model_ids):
|
||||
report_names = [report.name for report in reports]
|
||||
raise DashboardBulkDeleteFailedReportsExistError(
|
||||
_("There are associated alerts or reports: %s" % ",".join(report_names))
|
||||
|
|
|
@ -57,8 +57,7 @@ class DeleteDashboardCommand(BaseCommand):
|
|||
if not self._model:
|
||||
raise DashboardNotFoundError()
|
||||
# Check there are no associated ReportSchedules
|
||||
reports = ReportScheduleDAO.find_by_dashboard_id(self._model_id)
|
||||
if reports:
|
||||
if reports := ReportScheduleDAO.find_by_dashboard_id(self._model_id):
|
||||
report_names = [report.name for report in reports]
|
||||
raise DashboardDeleteFailedReportsExistError(
|
||||
_("There are associated alerts or reports: %s" % ",".join(report_names))
|
||||
|
|
|
@ -200,11 +200,10 @@ class DashboardDAO(BaseDAO):
|
|||
old_to_new_slice_ids: Optional[Dict[int, int]] = None,
|
||||
commit: bool = False,
|
||||
) -> Dashboard:
|
||||
positions = data.get("positions")
|
||||
new_filter_scopes = {}
|
||||
md = dashboard.params_dict
|
||||
|
||||
if positions is not None:
|
||||
if (positions := data.get("positions")) is not None:
|
||||
# find slices in the position data
|
||||
slice_ids = [
|
||||
value.get("meta", {}).get("chartId")
|
||||
|
|
|
@ -1036,7 +1036,6 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||
500:
|
||||
$ref: '#/components/responses/500'
|
||||
"""
|
||||
token = request.args.get("token")
|
||||
requested_ids = kwargs["rison"]
|
||||
timestamp = datetime.now().strftime("%Y%m%dT%H%M%S")
|
||||
root = f"database_export_{timestamp}"
|
||||
|
@ -1060,7 +1059,7 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
|||
as_attachment=True,
|
||||
download_name=filename,
|
||||
)
|
||||
if token:
|
||||
if token := request.args.get("token"):
|
||||
response.set_cookie(token, "done", max_age=600)
|
||||
return response
|
||||
|
||||
|
|
|
@ -55,9 +55,8 @@ class DeleteDatabaseCommand(BaseCommand):
|
|||
if not self._model:
|
||||
raise DatabaseNotFoundError()
|
||||
# Check there are no associated ReportSchedules
|
||||
reports = ReportScheduleDAO.find_by_database_id(self._model_id)
|
||||
|
||||
if reports:
|
||||
if reports := ReportScheduleDAO.find_by_database_id(self._model_id):
|
||||
report_names = [report.name for report in reports]
|
||||
raise DatabaseDeleteFailedReportsExistError(
|
||||
_("There are associated alerts or reports: %s" % ",".join(report_names))
|
||||
|
|
|
@ -228,6 +228,5 @@ class TestConnectionDatabaseCommand(BaseCommand):
|
|||
raise DatabaseTestConnectionUnexpectedError(errors) from ex
|
||||
|
||||
def validate(self) -> None:
|
||||
database_name = self._properties.get("database_name")
|
||||
if database_name is not None:
|
||||
if (database_name := self._properties.get("database_name")) is not None:
|
||||
self._model = DatabaseDAO.get_database_by_name(database_name)
|
||||
|
|
|
@ -128,6 +128,5 @@ class ValidateDatabaseParametersCommand(BaseCommand):
|
|||
)
|
||||
|
||||
def validate(self) -> None:
|
||||
database_id = self._properties.get("id")
|
||||
if database_id is not None:
|
||||
if (database_id := self._properties.get("id")) is not None:
|
||||
self._model = DatabaseDAO.find_by_id(database_id)
|
||||
|
|
|
@ -977,8 +977,7 @@ class DatasetRestApi(BaseSupersetModelRestApi):
|
|||
return self.response(400, message=ex.messages)
|
||||
table_name = body["table_name"]
|
||||
database_id = body["database_id"]
|
||||
table = DatasetDAO.get_table_by_name(database_id, table_name)
|
||||
if table:
|
||||
if table := DatasetDAO.get_table_by_name(database_id, table_name):
|
||||
return self.response(200, result={"table_id": table.id})
|
||||
|
||||
body["database"] = database_id
|
||||
|
|
|
@ -62,8 +62,7 @@ def get_sqla_type(native_type: str) -> VisitableType:
|
|||
if native_type.upper() in type_map:
|
||||
return type_map[native_type.upper()]
|
||||
|
||||
match = VARCHAR.match(native_type)
|
||||
if match:
|
||||
if match := VARCHAR.match(native_type):
|
||||
size = int(match.group(1))
|
||||
return String(size)
|
||||
|
||||
|
|
|
@ -114,13 +114,11 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
|
|||
exceptions.append(DatasetEndpointUnsafeValidationError())
|
||||
|
||||
# Validate columns
|
||||
columns = self._properties.get("columns")
|
||||
if columns:
|
||||
if columns := self._properties.get("columns"):
|
||||
self._validate_columns(columns, exceptions)
|
||||
|
||||
# Validate metrics
|
||||
metrics = self._properties.get("metrics")
|
||||
if metrics:
|
||||
if metrics := self._properties.get("metrics"):
|
||||
self._validate_metrics(metrics, exceptions)
|
||||
|
||||
if exceptions:
|
||||
|
|
|
@ -1704,8 +1704,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
|||
:param source: Type coming from the database table or cursor description
|
||||
:return: ColumnSpec object
|
||||
"""
|
||||
col_types = cls.get_column_types(native_type)
|
||||
if col_types:
|
||||
if col_types := cls.get_column_types(native_type):
|
||||
column_type, generic_type = col_types
|
||||
is_dttm = generic_type == GenericDataType.TEMPORAL
|
||||
return ColumnSpec(
|
||||
|
@ -1996,9 +1995,8 @@ class BasicParametersMixin:
|
|||
required = {"host", "port", "username", "database"}
|
||||
parameters = properties.get("parameters", {})
|
||||
present = {key for key in parameters if parameters.get(key, ())}
|
||||
missing = sorted(required - present)
|
||||
|
||||
if missing:
|
||||
if missing := sorted(required - present):
|
||||
errors.append(
|
||||
SupersetError(
|
||||
message=f'One or more parameters are missing: {", ".join(missing)}',
|
||||
|
|
|
@ -384,9 +384,8 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
|
|||
}
|
||||
|
||||
# Add credentials if they are set on the SQLAlchemy dialect.
|
||||
creds = engine.dialect.credentials_info
|
||||
|
||||
if creds:
|
||||
if creds := engine.dialect.credentials_info:
|
||||
to_gbq_kwargs[
|
||||
"credentials"
|
||||
] = service_account.Credentials.from_service_account_info(creds)
|
||||
|
|
|
@ -285,9 +285,8 @@ class DatabricksNativeEngineSpec(DatabricksODBCEngineSpec, BasicParametersMixin)
|
|||
parameters["http_path"] = connect_args.get("http_path")
|
||||
|
||||
present = {key for key in parameters if parameters.get(key, ())}
|
||||
missing = sorted(required - present)
|
||||
|
||||
if missing:
|
||||
if missing := sorted(required - present):
|
||||
errors.append(
|
||||
SupersetError(
|
||||
message=f'One or more parameters are missing: {", ".join(missing)}',
|
||||
|
|
|
@ -1213,8 +1213,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
|||
) -> Dict[str, Any]:
|
||||
metadata = {}
|
||||
|
||||
indexes = database.get_indexes(table_name, schema_name)
|
||||
if indexes:
|
||||
if indexes := database.get_indexes(table_name, schema_name):
|
||||
col_names, latest_parts = cls.latest_partition(
|
||||
table_name, schema_name, database, show_first=True
|
||||
)
|
||||
|
@ -1278,8 +1277,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
|
|||
@classmethod
|
||||
def handle_cursor(cls, cursor: "Cursor", query: Query, session: Session) -> None:
|
||||
"""Updates progress information"""
|
||||
tracking_url = cls.get_tracking_url(cursor)
|
||||
if tracking_url:
|
||||
if tracking_url := cls.get_tracking_url(cursor):
|
||||
query.tracking_url = tracking_url
|
||||
session.commit()
|
||||
|
||||
|
|
|
@ -312,9 +312,8 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
|||
}
|
||||
parameters = properties.get("parameters", {})
|
||||
present = {key for key in parameters if parameters.get(key, ())}
|
||||
missing = sorted(required - present)
|
||||
|
||||
if missing:
|
||||
if missing := sorted(required - present):
|
||||
errors.append(
|
||||
SupersetError(
|
||||
message=f'One or more parameters are missing: {", ".join(missing)}',
|
||||
|
|
|
@ -57,8 +57,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
|
|||
) -> Dict[str, Any]:
|
||||
metadata = {}
|
||||
|
||||
indexes = database.get_indexes(table_name, schema_name)
|
||||
if indexes:
|
||||
if indexes := database.get_indexes(table_name, schema_name):
|
||||
col_names, latest_parts = cls.latest_partition(
|
||||
table_name, schema_name, database, show_first=True
|
||||
)
|
||||
|
@ -150,8 +149,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
|
|||
|
||||
@classmethod
|
||||
def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None:
|
||||
tracking_url = cls.get_tracking_url(cursor)
|
||||
if tracking_url:
|
||||
if tracking_url := cls.get_tracking_url(cursor):
|
||||
query.tracking_url = tracking_url
|
||||
|
||||
# Adds the executed query id to the extra payload so the query can be cancelled
|
||||
|
|
|
@ -211,8 +211,7 @@ class SupersetError:
|
|||
Mutates the extra params with user facing error codes that map to backend
|
||||
errors.
|
||||
"""
|
||||
issue_codes = ERROR_TYPES_TO_ISSUE_CODES_MAPPING.get(self.error_type)
|
||||
if issue_codes:
|
||||
if issue_codes := ERROR_TYPES_TO_ISSUE_CODES_MAPPING.get(self.error_type):
|
||||
self.extra = self.extra or {}
|
||||
self.extra.update(
|
||||
{
|
||||
|
|
|
@ -453,8 +453,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
|
|||
|
||||
# Hook that provides administrators a handle on the Flask APP
|
||||
# after initialization
|
||||
flask_app_mutator = self.config["FLASK_APP_MUTATOR"]
|
||||
if flask_app_mutator:
|
||||
if flask_app_mutator := self.config["FLASK_APP_MUTATOR"]:
|
||||
flask_app_mutator(self.superset_app)
|
||||
|
||||
if feature_flag_manager.is_feature_enabled("TAGGING_SYSTEM"):
|
||||
|
|
|
@ -103,8 +103,7 @@ def run_migrations_online() -> None:
|
|||
kwargs = {}
|
||||
if engine.name in ("sqlite", "mysql"):
|
||||
kwargs = {"transaction_per_migration": True, "transactional_ddl": True}
|
||||
configure_args = current_app.extensions["migrate"].configure_args
|
||||
if configure_args:
|
||||
if configure_args := current_app.extensions["migrate"].configure_args:
|
||||
kwargs.update(configure_args)
|
||||
|
||||
context.configure(
|
||||
|
@ -112,7 +111,7 @@ def run_migrations_online() -> None:
|
|||
target_metadata=target_metadata,
|
||||
# compare_type=True,
|
||||
process_revision_directives=process_revision_directives,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
|
@ -130,8 +130,7 @@ class MigrateViz:
|
|||
# only backup params
|
||||
slc.params = json.dumps({**clz.data, FORM_DATA_BAK_FIELD_NAME: form_data_bak})
|
||||
|
||||
query_context = try_load_json(slc.query_context)
|
||||
if "form_data" in query_context:
|
||||
if "form_data" in (query_context := try_load_json(slc.query_context)):
|
||||
query_context["form_data"] = clz.data
|
||||
slc.query_context = json.dumps(query_context)
|
||||
return slc
|
||||
|
@ -139,8 +138,7 @@ class MigrateViz:
|
|||
@classmethod
|
||||
def downgrade_slice(cls, slc: Slice) -> Slice:
|
||||
form_data = try_load_json(slc.params)
|
||||
form_data_bak = form_data.get(FORM_DATA_BAK_FIELD_NAME, {})
|
||||
if "viz_type" in form_data_bak:
|
||||
if "viz_type" in (form_data_bak := form_data.get(FORM_DATA_BAK_FIELD_NAME, {})):
|
||||
slc.params = json.dumps(form_data_bak)
|
||||
slc.viz_type = form_data_bak.get("viz_type")
|
||||
query_context = try_load_json(slc.query_context)
|
||||
|
|
|
@ -40,8 +40,7 @@ class MigrateAreaChart(MigrateViz):
|
|||
if self.data.get("contribution"):
|
||||
self.data["contributionMode"] = "row"
|
||||
|
||||
stacked = self.data.get("stacked_style")
|
||||
if stacked:
|
||||
if stacked := self.data.get("stacked_style"):
|
||||
stacked_map = {
|
||||
"expand": "Expand",
|
||||
"stack": "Stack",
|
||||
|
@ -49,7 +48,6 @@ class MigrateAreaChart(MigrateViz):
|
|||
self.data["show_extra_controls"] = True
|
||||
self.data["stack"] = stacked_map.get(stacked)
|
||||
|
||||
x_axis_label = self.data.get("x_axis_label")
|
||||
if x_axis_label:
|
||||
if x_axis_label := self.data.get("x_axis_label"):
|
||||
self.data["x_axis_title"] = x_axis_label
|
||||
self.data["x_axis_title_margin"] = 30
|
||||
|
|
|
@ -193,13 +193,12 @@ def get_chart_holder(position):
|
|||
size_y = position["size_y"]
|
||||
slice_id = position["slice_id"]
|
||||
slice_name = position.get("slice_name")
|
||||
code = position.get("code")
|
||||
|
||||
width = max(GRID_MIN_COLUMN_COUNT, int(round(size_x / GRID_RATIO)))
|
||||
height = max(
|
||||
GRID_MIN_ROW_UNITS, int(round(((size_y / GRID_RATIO) * 100) / ROW_HEIGHT))
|
||||
)
|
||||
if code is not None:
|
||||
if (code := position.get("code")) is not None:
|
||||
markdown_content = " " # white-space markdown
|
||||
if len(code):
|
||||
markdown_content = code
|
||||
|
|
|
@ -80,8 +80,7 @@ def upgrade_dashboard(dashboard: Dict[str, Any]) -> Tuple[int, int]:
|
|||
changed_filters, changed_filter_sets = 0, 0
|
||||
# upgrade native select filter metadata
|
||||
# upgrade native select filter metadata
|
||||
native_filters = dashboard.get("native_filter_configuration")
|
||||
if native_filters:
|
||||
if native_filters := dashboard.get("native_filter_configuration"):
|
||||
changed_filters += upgrade_filters(native_filters)
|
||||
|
||||
# upgrade filter sets
|
||||
|
@ -123,8 +122,7 @@ def upgrade():
|
|||
def downgrade_dashboard(dashboard: Dict[str, Any]) -> Tuple[int, int]:
|
||||
changed_filters, changed_filter_sets = 0, 0
|
||||
# upgrade native select filter metadata
|
||||
native_filters = dashboard.get("native_filter_configuration")
|
||||
if native_filters:
|
||||
if native_filters := dashboard.get("native_filter_configuration"):
|
||||
changed_filters += downgrade_filters(native_filters)
|
||||
|
||||
# upgrade filter sets
|
||||
|
|
|
@ -40,9 +40,8 @@ def upgrade():
|
|||
insp = engine.reflection.Inspector.from_engine(bind)
|
||||
|
||||
# Drop the uniqueness constraint if it exists.
|
||||
constraint = generic_find_uq_constraint_name("tables", {"table_name"}, insp)
|
||||
|
||||
if constraint:
|
||||
if constraint := generic_find_uq_constraint_name("tables", {"table_name"}, insp):
|
||||
with op.batch_alter_table("tables", naming_convention=conv) as batch_op:
|
||||
batch_op.drop_constraint(constraint, type_="unique")
|
||||
|
||||
|
|
|
@ -277,9 +277,8 @@ class Database(
|
|||
# When returning the parameters we should use the masked SQLAlchemy URI and the
|
||||
# masked ``encrypted_extra`` to prevent exposing sensitive credentials.
|
||||
masked_uri = make_url_safe(self.sqlalchemy_uri)
|
||||
masked_encrypted_extra = self.masked_encrypted_extra
|
||||
encrypted_config = {}
|
||||
if masked_encrypted_extra is not None:
|
||||
if (masked_encrypted_extra := self.masked_encrypted_extra) is not None:
|
||||
try:
|
||||
encrypted_config = json.loads(masked_encrypted_extra)
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
|
|
|
@ -880,8 +880,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
|
|||
"""Apply config's SQL_QUERY_MUTATOR
|
||||
|
||||
Typically adds comments to the query with context"""
|
||||
sql_query_mutator = config["SQL_QUERY_MUTATOR"]
|
||||
if sql_query_mutator:
|
||||
if sql_query_mutator := config["SQL_QUERY_MUTATOR"]:
|
||||
sql = sql_query_mutator(
|
||||
sql,
|
||||
user_name=utils.get_username(), # TODO(john-bodley): Deprecate in 3.0.
|
||||
|
|
|
@ -378,8 +378,7 @@ class Slice( # pylint: disable=too-many-public-methods
|
|||
|
||||
def set_related_perm(_mapper: Mapper, _connection: Connection, target: Slice) -> None:
|
||||
src_class = target.cls_model
|
||||
id_ = target.datasource_id
|
||||
if id_:
|
||||
if id_ := target.datasource_id:
|
||||
ds = db.session.query(src_class).filter_by(id=int(id_)).first()
|
||||
if ds:
|
||||
target.perm = ds.perm
|
||||
|
|
|
@ -262,7 +262,6 @@ class SavedQueryRestApi(BaseSupersetModelRestApi):
|
|||
500:
|
||||
$ref: '#/components/responses/500'
|
||||
"""
|
||||
token = request.args.get("token")
|
||||
requested_ids = kwargs["rison"]
|
||||
timestamp = datetime.now().strftime("%Y%m%dT%H%M%S")
|
||||
root = f"saved_query_export_{timestamp}"
|
||||
|
@ -286,7 +285,7 @@ class SavedQueryRestApi(BaseSupersetModelRestApi):
|
|||
as_attachment=True,
|
||||
download_name=filename,
|
||||
)
|
||||
if token:
|
||||
if token := request.args.get("token"):
|
||||
response.set_cookie(token, "done", max_age=600)
|
||||
return response
|
||||
|
||||
|
|
|
@ -176,8 +176,7 @@ class BaseReportState:
|
|||
)
|
||||
|
||||
# If we need to render dashboard in a specific state, use stateful permalink
|
||||
dashboard_state = self._report_schedule.extra.get("dashboard")
|
||||
if dashboard_state:
|
||||
if dashboard_state := self._report_schedule.extra.get("dashboard"):
|
||||
permalink_key = CreateDashboardPermalinkCommand(
|
||||
dashboard_id=str(self._report_schedule.dashboard.uuid),
|
||||
state=dashboard_state,
|
||||
|
|
|
@ -225,12 +225,10 @@ class SupersetResultSet:
|
|||
|
||||
def data_type(self, col_name: str, pa_dtype: pa.DataType) -> Optional[str]:
|
||||
"""Given a pyarrow data type, Returns a generic database type"""
|
||||
set_type = self._type_dict.get(col_name)
|
||||
if set_type:
|
||||
if set_type := self._type_dict.get(col_name):
|
||||
return set_type
|
||||
|
||||
mapped_type = self.convert_pa_dtype(pa_dtype)
|
||||
if mapped_type:
|
||||
if mapped_type := self.convert_pa_dtype(pa_dtype):
|
||||
return mapped_type
|
||||
|
||||
return None
|
||||
|
|
|
@ -589,8 +589,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
return {s.name for s in view_menu_names}
|
||||
|
||||
# Properly treat anonymous user
|
||||
public_role = self.get_public_role()
|
||||
if public_role:
|
||||
if public_role := self.get_public_role():
|
||||
# filter by public role
|
||||
view_menu_names = (
|
||||
base_query.filter(self.role_model.id == public_role.id).filter(
|
||||
|
@ -639,8 +638,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
}
|
||||
|
||||
# datasource_access
|
||||
perms = self.user_view_menu_names("datasource_access")
|
||||
if perms:
|
||||
if perms := self.user_view_menu_names("datasource_access"):
|
||||
tables = (
|
||||
self.get_session.query(SqlaTable.schema)
|
||||
.filter(SqlaTable.database_id == database.id)
|
||||
|
@ -770,9 +768,8 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
== None,
|
||||
)
|
||||
)
|
||||
deleted_count = pvms.delete()
|
||||
sesh.commit()
|
||||
if deleted_count:
|
||||
if deleted_count := pvms.delete():
|
||||
logger.info("Deleted %i faulty permissions", deleted_count)
|
||||
|
||||
def sync_role_definitions(self) -> None:
|
||||
|
@ -1916,8 +1913,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
|||
:param dataset: The dataset to check against
|
||||
:return: A list of filters
|
||||
"""
|
||||
guest_user = self.get_current_guest_user_if_guest()
|
||||
if guest_user:
|
||||
if guest_user := self.get_current_guest_user_if_guest():
|
||||
return [
|
||||
rule
|
||||
for rule in guest_user.rls
|
||||
|
|
|
@ -95,7 +95,6 @@ def handle_query_error(
|
|||
"""Local method handling error while processing the SQL"""
|
||||
payload = payload or {}
|
||||
msg = f"{prefix_message} {str(ex)}".strip()
|
||||
troubleshooting_link = config["TROUBLESHOOTING_LINK"]
|
||||
query.error_message = msg
|
||||
query.tmp_table_name = None
|
||||
query.status = QueryStatus.FAILED
|
||||
|
@ -119,7 +118,7 @@ def handle_query_error(
|
|||
|
||||
session.commit()
|
||||
payload.update({"status": query.status, "error": msg, "errors": errors_payload})
|
||||
if troubleshooting_link:
|
||||
if troubleshooting_link := config["TROUBLESHOOTING_LINK"]:
|
||||
payload["link"] = troubleshooting_link
|
||||
return payload
|
||||
|
||||
|
|
|
@ -54,8 +54,7 @@ class PrestoDBSQLValidator(BaseSQLValidator):
|
|||
sql = parsed_query.stripped()
|
||||
|
||||
# Hook to allow environment-specific mutation (usually comments) to the SQL
|
||||
sql_query_mutator = config["SQL_QUERY_MUTATOR"]
|
||||
if sql_query_mutator:
|
||||
if sql_query_mutator := config["SQL_QUERY_MUTATOR"]:
|
||||
sql = sql_query_mutator(
|
||||
sql,
|
||||
user_name=get_username(), # TODO(john-bodley): Deprecate in 3.0.
|
||||
|
|
|
@ -1800,8 +1800,7 @@ def get_time_filter_status(
|
|||
}
|
||||
applied: List[Dict[str, str]] = []
|
||||
rejected: List[Dict[str, str]] = []
|
||||
time_column = applied_time_extras.get(ExtraFiltersTimeColumnType.TIME_COL)
|
||||
if time_column:
|
||||
if time_column := applied_time_extras.get(ExtraFiltersTimeColumnType.TIME_COL):
|
||||
if time_column in temporal_columns:
|
||||
applied.append({"column": ExtraFiltersTimeColumnType.TIME_COL})
|
||||
else:
|
||||
|
|
|
@ -121,8 +121,7 @@ class BaseScreenshot:
|
|||
@staticmethod
|
||||
def get_from_cache_key(cache: Cache, cache_key: str) -> Optional[BytesIO]:
|
||||
logger.info("Attempting to get from cache: %s", cache_key)
|
||||
payload = cache.get(cache_key)
|
||||
if payload:
|
||||
if payload := cache.get(cache_key):
|
||||
return BytesIO(payload)
|
||||
logger.info("Failed at getting from cache: %s", cache_key)
|
||||
return None
|
||||
|
|
|
@ -79,8 +79,7 @@ class Api(BaseSupersetView):
|
|||
params: slice_id: integer
|
||||
"""
|
||||
form_data = {}
|
||||
slice_id = request.args.get("slice_id")
|
||||
if slice_id:
|
||||
if slice_id := request.args.get("slice_id"):
|
||||
slc = db.session.query(Slice).filter_by(id=slice_id).one_or_none()
|
||||
if slc:
|
||||
form_data = slc.form_data.copy()
|
||||
|
|
|
@ -380,8 +380,7 @@ class BaseSupersetModelRestApi(ModelRestApi, BaseSupersetApiMixin):
|
|||
filter_field = cast(RelatedFieldFilter, filter_field)
|
||||
search_columns = [filter_field.field_name] if filter_field else None
|
||||
filters = datamodel.get_filters(search_columns)
|
||||
base_filters = self.base_related_field_filters.get(column_name)
|
||||
if base_filters:
|
||||
if base_filters := self.base_related_field_filters.get(column_name):
|
||||
filters.add_filter_list(base_filters)
|
||||
if value and filter_field:
|
||||
filters.add_filter(
|
||||
|
@ -588,8 +587,7 @@ class BaseSupersetModelRestApi(ModelRestApi, BaseSupersetApiMixin):
|
|||
return self.response_404()
|
||||
page, page_size = self._sanitize_page_args(page, page_size)
|
||||
# handle ordering
|
||||
order_field = self.order_rel_fields.get(column_name)
|
||||
if order_field:
|
||||
if order_field := self.order_rel_fields.get(column_name):
|
||||
order_column, order_direction = order_field
|
||||
else:
|
||||
order_column, order_direction = "", ""
|
||||
|
|
|
@ -765,8 +765,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
"""
|
||||
redirect_url = request.url.replace("/superset/explore", "/explore")
|
||||
form_data_key = None
|
||||
request_form_data = request.args.get("form_data")
|
||||
if request_form_data:
|
||||
if request_form_data := request.args.get("form_data"):
|
||||
parsed_form_data = loads_request_json(request_form_data)
|
||||
slice_id = parsed_form_data.get(
|
||||
"slice_id", int(request.args.get("slice_id", 0))
|
||||
|
@ -1498,8 +1497,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
@deprecated(new_target="/api/v1/log/recent_activity/<user_id>/")
|
||||
def recent_activity(self, user_id: int) -> FlaskResponse:
|
||||
"""Recent activity (actions) for a given user"""
|
||||
error_obj = self.get_user_activity_access_error(user_id)
|
||||
if error_obj:
|
||||
if error_obj := self.get_user_activity_access_error(user_id):
|
||||
return error_obj
|
||||
|
||||
limit = request.args.get("limit")
|
||||
|
@ -1543,8 +1541,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
@expose("/fave_dashboards/<int:user_id>/", methods=("GET",))
|
||||
@deprecated(new_target="api/v1/dashboard/favorite_status/")
|
||||
def fave_dashboards(self, user_id: int) -> FlaskResponse:
|
||||
error_obj = self.get_user_activity_access_error(user_id)
|
||||
if error_obj:
|
||||
if error_obj := self.get_user_activity_access_error(user_id):
|
||||
return error_obj
|
||||
qry = (
|
||||
db.session.query(Dashboard, FavStar.dttm)
|
||||
|
@ -1580,8 +1577,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
@expose("/created_dashboards/<int:user_id>/", methods=("GET",))
|
||||
@deprecated(new_target="api/v1/dashboard/")
|
||||
def created_dashboards(self, user_id: int) -> FlaskResponse:
|
||||
error_obj = self.get_user_activity_access_error(user_id)
|
||||
if error_obj:
|
||||
if error_obj := self.get_user_activity_access_error(user_id):
|
||||
return error_obj
|
||||
qry = (
|
||||
db.session.query(Dashboard)
|
||||
|
@ -1615,8 +1611,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
"""List of slices a user owns, created, modified or faved"""
|
||||
if not user_id:
|
||||
user_id = cast(int, get_user_id())
|
||||
error_obj = self.get_user_activity_access_error(user_id)
|
||||
if error_obj:
|
||||
if error_obj := self.get_user_activity_access_error(user_id):
|
||||
return error_obj
|
||||
|
||||
owner_ids_query = (
|
||||
|
@ -1669,8 +1664,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
"""List of slices created by this user"""
|
||||
if not user_id:
|
||||
user_id = cast(int, get_user_id())
|
||||
error_obj = self.get_user_activity_access_error(user_id)
|
||||
if error_obj:
|
||||
if error_obj := self.get_user_activity_access_error(user_id):
|
||||
return error_obj
|
||||
qry = (
|
||||
db.session.query(Slice)
|
||||
|
@ -1701,8 +1695,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
"""Favorite slices for a user"""
|
||||
if user_id is None:
|
||||
user_id = cast(int, get_user_id())
|
||||
error_obj = self.get_user_activity_access_error(user_id)
|
||||
if error_obj:
|
||||
if error_obj := self.get_user_activity_access_error(user_id):
|
||||
return error_obj
|
||||
qry = (
|
||||
db.session.query(Slice, FavStar.dttm)
|
||||
|
@ -1965,8 +1958,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
return json_error_response(_("permalink state not found"), status=404)
|
||||
dashboard_id, state = value["dashboardId"], value.get("state", {})
|
||||
url = f"/superset/dashboard/{dashboard_id}?permalink_key={key}"
|
||||
url_params = state.get("urlParams")
|
||||
if url_params:
|
||||
if url_params := state.get("urlParams"):
|
||||
params = parse.urlencode(url_params)
|
||||
url = f"{url}&{params}"
|
||||
hash_ = state.get("anchor", state.get("hash"))
|
||||
|
@ -2125,8 +2117,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
mydb = db.session.query(Database).get(database_id)
|
||||
|
||||
sql = json.loads(request.form.get("sql", '""'))
|
||||
template_params = json.loads(request.form.get("templateParams") or "{}")
|
||||
if template_params:
|
||||
if template_params := json.loads(request.form.get("templateParams") or "{}"):
|
||||
template_processor = get_template_processor(mydb)
|
||||
sql = template_processor.process_template(sql, **template_params)
|
||||
|
||||
|
@ -2393,8 +2384,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
@expose("/sql_json/", methods=("POST",))
|
||||
@deprecated(new_target="/api/v1/sqllab/execute/")
|
||||
def sql_json(self) -> FlaskResponse:
|
||||
errors = SqlJsonPayloadSchema().validate(request.json)
|
||||
if errors:
|
||||
if errors := SqlJsonPayloadSchema().validate(request.json):
|
||||
return json_error_response(status=400, payload=errors)
|
||||
|
||||
try:
|
||||
|
@ -2621,10 +2611,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
search_user_id = get_user_id()
|
||||
database_id = request.args.get("database_id")
|
||||
search_text = request.args.get("search_text")
|
||||
status = request.args.get("status")
|
||||
# From and To time stamp should be Epoch timestamp in seconds
|
||||
from_time = request.args.get("from")
|
||||
to_time = request.args.get("to")
|
||||
|
||||
query = db.session.query(Query)
|
||||
if search_user_id:
|
||||
|
@ -2635,7 +2622,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
# Filter on db Id
|
||||
query = query.filter(Query.database_id == database_id)
|
||||
|
||||
if status:
|
||||
if status := request.args.get("status"):
|
||||
# Filter on status
|
||||
query = query.filter(Query.status == status)
|
||||
|
||||
|
@ -2643,10 +2630,10 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
# Filter on search text
|
||||
query = query.filter(Query.sql.like(f"%{search_text}%"))
|
||||
|
||||
if from_time:
|
||||
if from_time := request.args.get("from"):
|
||||
query = query.filter(Query.start_time > int(from_time))
|
||||
|
||||
if to_time:
|
||||
if to_time := request.args.get("to"):
|
||||
query = query.filter(Query.start_time < int(to_time))
|
||||
|
||||
query_limit = config["QUERY_SEARCH_LIMIT"]
|
||||
|
@ -2709,8 +2696,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
user_id = -1 if not user else user.id
|
||||
# Prevent unauthorized access to other user's profiles,
|
||||
# unless configured to do so on with ENABLE_BROAD_ACTIVITY_ACCESS
|
||||
error_obj = self.get_user_activity_access_error(user_id)
|
||||
if error_obj:
|
||||
if error_obj := self.get_user_activity_access_error(user_id):
|
||||
return error_obj
|
||||
|
||||
payload = {
|
||||
|
@ -2789,8 +2775,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
|
|||
**self._get_sqllab_tabs(get_user_id()),
|
||||
}
|
||||
|
||||
form_data = request.form.get("form_data")
|
||||
if form_data:
|
||||
if form_data := request.form.get("form_data"):
|
||||
try:
|
||||
payload["requested_query"] = json.loads(form_data)
|
||||
except json.JSONDecodeError:
|
||||
|
|
|
@ -51,7 +51,6 @@ def sqlalchemy_uri_validator(
|
|||
def schema_allows_file_upload(database: Database, schema: Optional[str]) -> bool:
|
||||
if not database.allow_file_upload:
|
||||
return False
|
||||
schemas = database.get_schema_access_for_file_upload()
|
||||
if schemas:
|
||||
if schemas := database.get_schema_access_for_file_upload():
|
||||
return schema in schemas
|
||||
return security_manager.can_access_database(database)
|
||||
|
|
|
@ -125,8 +125,7 @@ class LogRestApi(LogMixin, BaseSupersetModelRestApi):
|
|||
500:
|
||||
$ref: '#/components/responses/500'
|
||||
"""
|
||||
error_obj = self.get_user_activity_access_error(user_id)
|
||||
if error_obj:
|
||||
if error_obj := self.get_user_activity_access_error(user_id):
|
||||
return error_obj
|
||||
|
||||
args = kwargs["rison"]
|
||||
|
|
|
@ -261,9 +261,8 @@ def get_datasource_info(
|
|||
:raises SupersetException: If the datasource no longer exists
|
||||
"""
|
||||
|
||||
datasource = form_data.get("datasource", "")
|
||||
|
||||
if "__" in datasource:
|
||||
# pylint: disable=superfluous-parens
|
||||
if "__" in (datasource := form_data.get("datasource", "")):
|
||||
datasource_id, datasource_type = datasource.split("__")
|
||||
# The case where the datasource has been deleted
|
||||
if datasource_id == "None":
|
||||
|
@ -462,7 +461,7 @@ def check_datasource_perms(
|
|||
_self: Any,
|
||||
datasource_type: Optional[str] = None,
|
||||
datasource_id: Optional[int] = None,
|
||||
**kwargs: Any
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Check if user can access a cached response from explore_json.
|
||||
|
|
|
@ -365,10 +365,11 @@ class BaseViz: # pylint: disable=too-many-public-methods
|
|||
metrics = self.all_metrics or []
|
||||
|
||||
groupby = self.dedup_columns(self.groupby, self.form_data.get("columns"))
|
||||
groupby_labels = get_column_names(groupby)
|
||||
|
||||
is_timeseries = self.is_timeseries
|
||||
if DTTM_ALIAS in groupby_labels:
|
||||
|
||||
# pylint: disable=superfluous-parens
|
||||
if DTTM_ALIAS in (groupby_labels := get_column_names(groupby)):
|
||||
del groupby[groupby_labels.index(DTTM_ALIAS)]
|
||||
is_timeseries = True
|
||||
|
||||
|
@ -959,8 +960,7 @@ class PivotTableViz(BaseViz):
|
|||
|
||||
if len(deduped_cols) < (len(groupby) + len(columns)):
|
||||
raise QueryObjectValidationError(_("Group By' and 'Columns' can't overlap"))
|
||||
sort_by = self.form_data.get("timeseries_limit_metric")
|
||||
if sort_by:
|
||||
if sort_by := self.form_data.get("timeseries_limit_metric"):
|
||||
sort_by_label = utils.get_metric_name(sort_by)
|
||||
if sort_by_label not in utils.get_metric_names(query_obj["metrics"]):
|
||||
query_obj["metrics"].append(sort_by)
|
||||
|
@ -1077,8 +1077,7 @@ class TreemapViz(BaseViz):
|
|||
@deprecated(deprecated_in="3.0")
|
||||
def query_obj(self) -> QueryObjectDict:
|
||||
query_obj = super().query_obj()
|
||||
sort_by = self.form_data.get("timeseries_limit_metric")
|
||||
if sort_by:
|
||||
if sort_by := self.form_data.get("timeseries_limit_metric"):
|
||||
sort_by_label = utils.get_metric_name(sort_by)
|
||||
if sort_by_label not in utils.get_metric_names(query_obj["metrics"]):
|
||||
query_obj["metrics"].append(sort_by)
|
||||
|
@ -1880,8 +1879,7 @@ class DistributionBarViz(BaseViz):
|
|||
if not self.form_data.get("groupby"):
|
||||
raise QueryObjectValidationError(_("Pick at least one field for [Series]"))
|
||||
|
||||
sort_by = self.form_data.get("timeseries_limit_metric")
|
||||
if sort_by:
|
||||
if sort_by := self.form_data.get("timeseries_limit_metric"):
|
||||
sort_by_label = utils.get_metric_name(sort_by)
|
||||
if sort_by_label not in utils.get_metric_names(query_obj["metrics"]):
|
||||
query_obj["metrics"].append(sort_by)
|
||||
|
@ -2310,8 +2308,7 @@ class ParallelCoordinatesViz(BaseViz):
|
|||
def query_obj(self) -> QueryObjectDict:
|
||||
query_obj = super().query_obj()
|
||||
query_obj["groupby"] = [self.form_data.get("series")]
|
||||
sort_by = self.form_data.get("timeseries_limit_metric")
|
||||
if sort_by:
|
||||
if sort_by := self.form_data.get("timeseries_limit_metric"):
|
||||
sort_by_label = utils.get_metric_name(sort_by)
|
||||
if sort_by_label not in utils.get_metric_names(query_obj["metrics"]):
|
||||
query_obj["metrics"].append(sort_by)
|
||||
|
@ -2679,8 +2676,7 @@ class BaseDeckGLViz(BaseViz):
|
|||
if self.form_data.get("adhoc_filters") is None:
|
||||
self.form_data["adhoc_filters"] = []
|
||||
|
||||
line_column = self.form_data.get("line_column")
|
||||
if line_column:
|
||||
if line_column := self.form_data.get("line_column"):
|
||||
spatial_columns.add(line_column)
|
||||
|
||||
for column in sorted(spatial_columns):
|
||||
|
@ -2706,13 +2702,12 @@ class BaseDeckGLViz(BaseViz):
|
|||
|
||||
if self.form_data.get("js_columns"):
|
||||
group_by += self.form_data.get("js_columns") or []
|
||||
metrics = self.get_metrics()
|
||||
# Ensure this value is sorted so that it does not
|
||||
# cause the cache key generation (which hashes the
|
||||
# query object) to generate different keys for values
|
||||
# that should be considered the same.
|
||||
group_by = sorted(set(group_by))
|
||||
if metrics:
|
||||
if metrics := self.get_metrics():
|
||||
query_obj["groupby"] = group_by
|
||||
query_obj["metrics"] = metrics
|
||||
query_obj["columns"] = []
|
||||
|
@ -3097,8 +3092,7 @@ class PairedTTestViz(BaseViz):
|
|||
@deprecated(deprecated_in="3.0")
|
||||
def query_obj(self) -> QueryObjectDict:
|
||||
query_obj = super().query_obj()
|
||||
sort_by = self.form_data.get("timeseries_limit_metric")
|
||||
if sort_by:
|
||||
if sort_by := self.form_data.get("timeseries_limit_metric"):
|
||||
sort_by_label = utils.get_metric_name(sort_by)
|
||||
if sort_by_label not in utils.get_metric_names(query_obj["metrics"]):
|
||||
query_obj["metrics"].append(sort_by)
|
||||
|
|
|
@ -67,8 +67,7 @@ class PandasDataLoader(DataLoader):
|
|||
return inspect(self._db_engine).default_schema_name
|
||||
|
||||
def _take_data_types(self, table: Table) -> Optional[Dict[str, str]]:
|
||||
metadata_table = table.table_metadata
|
||||
if metadata_table:
|
||||
if metadata_table := table.table_metadata:
|
||||
types = metadata_table.types
|
||||
if types:
|
||||
return types
|
||||
|
|
|
@ -136,7 +136,6 @@ def upload_csv(
|
|||
dtype: Union[str, None] = None,
|
||||
):
|
||||
csv_upload_db_id = get_upload_db().id
|
||||
schema = utils.get_example_default_schema()
|
||||
form_data = {
|
||||
"csv_file": open(filename, "rb"),
|
||||
"delimiter": ",",
|
||||
|
@ -146,7 +145,7 @@ def upload_csv(
|
|||
"index_label": "test_label",
|
||||
"overwrite_duplicate": False,
|
||||
}
|
||||
if schema:
|
||||
if schema := utils.get_example_default_schema():
|
||||
form_data["schema"] = schema
|
||||
if extra:
|
||||
form_data.update(extra)
|
||||
|
@ -159,7 +158,6 @@ def upload_excel(
|
|||
filename: str, table_name: str, extra: Optional[Dict[str, str]] = None
|
||||
):
|
||||
excel_upload_db_id = get_upload_db().id
|
||||
schema = utils.get_example_default_schema()
|
||||
form_data = {
|
||||
"excel_file": open(filename, "rb"),
|
||||
"name": table_name,
|
||||
|
@ -169,7 +167,7 @@ def upload_excel(
|
|||
"index_label": "test_label",
|
||||
"mangle_dupe_cols": False,
|
||||
}
|
||||
if schema:
|
||||
if schema := utils.get_example_default_schema():
|
||||
form_data["schema"] = schema
|
||||
if extra:
|
||||
form_data.update(extra)
|
||||
|
@ -180,7 +178,6 @@ def upload_columnar(
|
|||
filename: str, table_name: str, extra: Optional[Dict[str, str]] = None
|
||||
):
|
||||
columnar_upload_db_id = get_upload_db().id
|
||||
schema = utils.get_example_default_schema()
|
||||
form_data = {
|
||||
"columnar_file": open(filename, "rb"),
|
||||
"name": table_name,
|
||||
|
@ -188,7 +185,7 @@ def upload_columnar(
|
|||
"if_exists": "fail",
|
||||
"index_label": "test_label",
|
||||
}
|
||||
if schema:
|
||||
if schema := utils.get_example_default_schema():
|
||||
form_data["schema"] = schema
|
||||
if extra:
|
||||
form_data.update(extra)
|
||||
|
|
|
@ -641,14 +641,13 @@ class TestDatasetApi(SupersetTestCase):
|
|||
if backend() == "sqlite":
|
||||
return
|
||||
|
||||
schema = get_example_default_schema()
|
||||
energy_usage_ds = self.get_energy_usage_dataset()
|
||||
self.login(username="admin")
|
||||
table_data = {
|
||||
"database": energy_usage_ds.database_id,
|
||||
"table_name": energy_usage_ds.table_name,
|
||||
}
|
||||
if schema:
|
||||
if schema := get_example_default_schema():
|
||||
table_data["schema"] = schema
|
||||
rv = self.post_assert_metric("/api/v1/dataset/", table_data, "post")
|
||||
assert rv.status_code == 422
|
||||
|
@ -665,7 +664,6 @@ class TestDatasetApi(SupersetTestCase):
|
|||
if backend() == "sqlite":
|
||||
return
|
||||
|
||||
schema = get_example_default_schema()
|
||||
energy_usage_ds = self.get_energy_usage_dataset()
|
||||
self.login(username="admin")
|
||||
table_data = {
|
||||
|
@ -673,7 +671,7 @@ class TestDatasetApi(SupersetTestCase):
|
|||
"table_name": energy_usage_ds.table_name,
|
||||
"sql": "select * from energy_usage",
|
||||
}
|
||||
if schema:
|
||||
if schema := get_example_default_schema():
|
||||
table_data["schema"] = schema
|
||||
rv = self.post_assert_metric("/api/v1/dataset/", table_data, "post")
|
||||
assert rv.status_code == 422
|
||||
|
@ -690,7 +688,6 @@ class TestDatasetApi(SupersetTestCase):
|
|||
if backend() == "sqlite":
|
||||
return
|
||||
|
||||
schema = get_example_default_schema()
|
||||
energy_usage_ds = self.get_energy_usage_dataset()
|
||||
self.login(username="alpha")
|
||||
admin = self.get_user("admin")
|
||||
|
@ -701,7 +698,7 @@ class TestDatasetApi(SupersetTestCase):
|
|||
"sql": "select * from energy_usage",
|
||||
"owners": [admin.id],
|
||||
}
|
||||
if schema:
|
||||
if schema := get_example_default_schema():
|
||||
table_data["schema"] = schema
|
||||
rv = self.post_assert_metric("/api/v1/dataset/", table_data, "post")
|
||||
assert rv.status_code == 201
|
||||
|
|
Loading…
Reference in New Issue