diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 3f524b3658..07544d66d2 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -15,14 +15,28 @@
# limitations under the License.
#
repos:
+ - repo: https://github.com/MarcoGorelli/auto-walrus
+ rev: v0.2.2
+ hooks:
+ - id: auto-walrus
+ - repo: https://github.com/asottile/pyupgrade
+ rev: v3.4.0
+ hooks:
+ - id: pyupgrade
+ args:
+ - --py39-plus
+ - repo: https://github.com/hadialqattan/pycln
+ rev: v2.1.2
+ hooks:
+ - id: pycln
+ args:
+ - --disable-all-dunder-policy
+ - --exclude=superset/config.py
+ - --extend-exclude=tests/integration_tests/superset_test_config.*.py
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort
- - repo: https://github.com/MarcoGorelli/auto-walrus
- rev: v0.2.2
- hooks:
- - id: auto-walrus
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.3.0
hooks:
diff --git a/RELEASING/changelog.py b/RELEASING/changelog.py
index 68a54e10be..d1ba06a620 100644
--- a/RELEASING/changelog.py
+++ b/RELEASING/changelog.py
@@ -17,8 +17,9 @@ import csv as lib_csv
import os
import re
import sys
+from collections.abc import Iterator
from dataclasses import dataclass
-from typing import Any, Dict, Iterator, List, Optional, Union
+from typing import Any, Optional, Union
import click
from click.core import Context
@@ -67,15 +68,15 @@ class GitChangeLog:
def __init__(
self,
version: str,
- logs: List[GitLog],
+ logs: list[GitLog],
access_token: Optional[str] = None,
risk: Optional[bool] = False,
) -> None:
self._version = version
self._logs = logs
- self._pr_logs_with_details: Dict[int, Dict[str, Any]] = {}
- self._github_login_cache: Dict[str, Optional[str]] = {}
- self._github_prs: Dict[int, Any] = {}
+ self._pr_logs_with_details: dict[int, dict[str, Any]] = {}
+ self._github_login_cache: dict[str, Optional[str]] = {}
+ self._github_prs: dict[int, Any] = {}
self._wait = 10
github_token = access_token or os.environ.get("GITHUB_TOKEN")
self._github = Github(github_token)
@@ -126,7 +127,7 @@ class GitChangeLog:
"superset/migrations/versions/" in file.filename for file in commit.files
)
- def _get_pull_request_details(self, git_log: GitLog) -> Dict[str, Any]:
+ def _get_pull_request_details(self, git_log: GitLog) -> dict[str, Any]:
pr_number = git_log.pr_number
if pr_number:
detail = self._pr_logs_with_details.get(pr_number)
@@ -156,7 +157,7 @@ class GitChangeLog:
return detail
- def _is_risk_pull_request(self, labels: List[Any]) -> bool:
+ def _is_risk_pull_request(self, labels: list[Any]) -> bool:
for label in labels:
risk_label = re.match(SUPERSET_RISKY_LABELS, label.name)
if risk_label is not None:
@@ -174,8 +175,8 @@ class GitChangeLog:
def _parse_change_log(
self,
- changelog: Dict[str, str],
- pr_info: Dict[str, str],
+ changelog: dict[str, str],
+ pr_info: dict[str, str],
github_login: str,
) -> None:
formatted_pr = (
@@ -227,7 +228,7 @@ class GitChangeLog:
result += f"**{key}** {changelog[key]}\n"
return result
- def __iter__(self) -> Iterator[Dict[str, Any]]:
+ def __iter__(self) -> Iterator[dict[str, Any]]:
for log in self._logs:
yield {
"pr_number": log.pr_number,
@@ -250,20 +251,20 @@ class GitLogs:
def __init__(self, git_ref: str) -> None:
self._git_ref = git_ref
- self._logs: List[GitLog] = []
+ self._logs: list[GitLog] = []
@property
def git_ref(self) -> str:
return self._git_ref
@property
- def logs(self) -> List[GitLog]:
+ def logs(self) -> list[GitLog]:
return self._logs
def fetch(self) -> None:
self._logs = list(map(self._parse_log, self._git_logs()))[::-1]
- def diff(self, git_logs: "GitLogs") -> List[GitLog]:
+ def diff(self, git_logs: "GitLogs") -> list[GitLog]:
return [log for log in git_logs.logs if log not in self._logs]
def __repr__(self) -> str:
@@ -284,7 +285,7 @@ class GitLogs:
print(f"Could not checkout {git_ref}")
sys.exit(1)
- def _git_logs(self) -> List[str]:
+ def _git_logs(self) -> list[str]:
# let's get current git ref so we can revert it back
current_git_ref = self._git_get_current_head()
self._git_checkout(self._git_ref)
diff --git a/RELEASING/generate_email.py b/RELEASING/generate_email.py
index 92536670cd..29142557c0 100755
--- a/RELEASING/generate_email.py
+++ b/RELEASING/generate_email.py
@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-from typing import Any, Dict, List
+from typing import Any
from click.core import Context
@@ -34,7 +34,7 @@ PROJECT_MODULE = "superset"
PROJECT_DESCRIPTION = "Apache Superset is a modern, enterprise-ready business intelligence web application"
-def string_comma_to_list(message: str) -> List[str]:
+def string_comma_to_list(message: str) -> list[str]:
if not message:
return []
return [element.strip() for element in message.split(",")]
@@ -52,7 +52,7 @@ def render_template(template_file: str, **kwargs: Any) -> str:
return template.render(kwargs)
-class BaseParameters(object):
+class BaseParameters:
def __init__(
self,
version: str,
@@ -60,7 +60,7 @@ class BaseParameters(object):
) -> None:
self.version = version
self.version_rc = version_rc
- self.template_arguments: Dict[str, Any] = {}
+ self.template_arguments: dict[str, Any] = {}
def __repr__(self) -> str:
return f"Apache Credentials: {self.version}/{self.version_rc}"
diff --git a/docker/pythonpath_dev/superset_config.py b/docker/pythonpath_dev/superset_config.py
index 6ea9abf63c..199e79f66e 100644
--- a/docker/pythonpath_dev/superset_config.py
+++ b/docker/pythonpath_dev/superset_config.py
@@ -22,7 +22,6 @@
#
import logging
import os
-from datetime import timedelta
from typing import Optional
from cachelib.file import FileSystemCache
@@ -42,7 +41,7 @@ def get_env_variable(var_name: str, default: Optional[str] = None) -> str:
error_msg = "The environment variable {} was missing, abort...".format(
var_name
)
- raise EnvironmentError(error_msg)
+ raise OSError(error_msg)
DATABASE_DIALECT = get_env_variable("DATABASE_DIALECT")
@@ -53,7 +52,7 @@ DATABASE_PORT = get_env_variable("DATABASE_PORT")
DATABASE_DB = get_env_variable("DATABASE_DB")
# The SQLAlchemy connection string.
-SQLALCHEMY_DATABASE_URI = "%s://%s:%s@%s:%s/%s" % (
+SQLALCHEMY_DATABASE_URI = "{}://{}:{}@{}:{}/{}".format(
DATABASE_DIALECT,
DATABASE_USER,
DATABASE_PASSWORD,
@@ -80,7 +79,7 @@ CACHE_CONFIG = {
DATA_CACHE_CONFIG = CACHE_CONFIG
-class CeleryConfig(object):
+class CeleryConfig:
broker_url = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_CELERY_DB}"
imports = ("superset.sql_lab",)
result_backend = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_RESULTS_DB}"
diff --git a/scripts/benchmark_migration.py b/scripts/benchmark_migration.py
index 83c06456a1..466fab6f13 100644
--- a/scripts/benchmark_migration.py
+++ b/scripts/benchmark_migration.py
@@ -23,7 +23,7 @@ from graphlib import TopologicalSorter
from inspect import getsource
from pathlib import Path
from types import ModuleType
-from typing import Any, Dict, List, Set, Type
+from typing import Any
import click
from flask import current_app
@@ -48,12 +48,10 @@ def import_migration_script(filepath: Path) -> ModuleType:
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore
return module
- raise Exception(
- "No module spec found in location: `{path}`".format(path=str(filepath))
- )
+ raise Exception(f"No module spec found in location: `{str(filepath)}`")
-def extract_modified_tables(module: ModuleType) -> Set[str]:
+def extract_modified_tables(module: ModuleType) -> set[str]:
"""
Extract the tables being modified by a migration script.
@@ -62,7 +60,7 @@ def extract_modified_tables(module: ModuleType) -> Set[str]:
actually traversing the AST.
"""
- tables: Set[str] = set()
+ tables: set[str] = set()
for function in {"upgrade", "downgrade"}:
source = getsource(getattr(module, function))
tables.update(re.findall(r'alter_table\(\s*"(\w+?)"\s*\)', source, re.DOTALL))
@@ -72,11 +70,11 @@ def extract_modified_tables(module: ModuleType) -> Set[str]:
return tables
-def find_models(module: ModuleType) -> List[Type[Model]]:
+def find_models(module: ModuleType) -> list[type[Model]]:
"""
Find all models in a migration script.
"""
- models: List[Type[Model]] = []
+ models: list[type[Model]] = []
tables = extract_modified_tables(module)
# add models defined explicitly in the migration script
@@ -123,7 +121,7 @@ def find_models(module: ModuleType) -> List[Type[Model]]:
sorter: TopologicalSorter[Any] = TopologicalSorter()
for model in models:
inspector = inspect(model)
- dependent_tables: List[str] = []
+ dependent_tables: list[str] = []
for column in inspector.columns.values():
for foreign_key in column.foreign_keys:
if foreign_key.column.table.name != model.__tablename__:
@@ -174,7 +172,7 @@ def main(
print("\nIdentifying models used in the migration:")
models = find_models(module)
- model_rows: Dict[Type[Model], int] = {}
+ model_rows: dict[type[Model], int] = {}
for model in models:
rows = session.query(model).count()
print(f"- {model.__name__} ({rows} rows in table {model.__tablename__})")
@@ -182,7 +180,7 @@ def main(
session.close()
print("Benchmarking migration")
- results: Dict[str, float] = {}
+ results: dict[str, float] = {}
start = time.time()
upgrade(revision=revision)
duration = time.time() - start
@@ -190,14 +188,14 @@ def main(
print(f"Migration on current DB took: {duration:.2f} seconds")
min_entities = 10
- new_models: Dict[Type[Model], List[Model]] = defaultdict(list)
+ new_models: dict[type[Model], list[Model]] = defaultdict(list)
while min_entities <= limit:
downgrade(revision=down_revision)
print(f"Running with at least {min_entities} entities of each model")
for model in models:
missing = min_entities - model_rows[model]
if missing > 0:
- entities: List[Model] = []
+ entities: list[Model] = []
print(f"- Adding {missing} entities to the {model.__name__} model")
bar = ChargingBar("Processing", max=missing)
try:
diff --git a/scripts/cancel_github_workflows.py b/scripts/cancel_github_workflows.py
index 4d30d34adf..70744c2954 100755
--- a/scripts/cancel_github_workflows.py
+++ b/scripts/cancel_github_workflows.py
@@ -33,13 +33,13 @@ Example:
./cancel_github_workflows.py 1024 --include-last
"""
import os
-from typing import Any, Dict, Iterable, Iterator, List, Optional, Union
+from collections.abc import Iterable, Iterator
+from typing import Any, Literal, Optional, Union
import click
import requests
from click.exceptions import ClickException
from dateutil import parser
-from typing_extensions import Literal
github_token = os.environ.get("GITHUB_TOKEN")
github_repo = os.environ.get("GITHUB_REPOSITORY", "apache/superset")
@@ -47,7 +47,7 @@ github_repo = os.environ.get("GITHUB_REPOSITORY", "apache/superset")
def request(
method: Literal["GET", "POST", "DELETE", "PUT"], endpoint: str, **kwargs: Any
-) -> Dict[str, Any]:
+) -> dict[str, Any]:
resp = requests.request(
method,
f"https://api.github.com/{endpoint.lstrip('/')}",
@@ -61,8 +61,8 @@ def request(
def list_runs(
repo: str,
- params: Optional[Dict[str, str]] = None,
-) -> Iterator[Dict[str, Any]]:
+ params: Optional[dict[str, str]] = None,
+) -> Iterator[dict[str, Any]]:
"""List all github workflow runs.
Returns:
An iterator that will iterate through all pages of matching runs."""
@@ -77,16 +77,15 @@ def list_runs(
params={**params, "per_page": 100, "page": page},
)
total_count = result["total_count"]
- for item in result["workflow_runs"]:
- yield item
+ yield from result["workflow_runs"]
page += 1
-def cancel_run(repo: str, run_id: Union[str, int]) -> Dict[str, Any]:
+def cancel_run(repo: str, run_id: Union[str, int]) -> dict[str, Any]:
return request("POST", f"/repos/{repo}/actions/runs/{run_id}/cancel")
-def get_pull_request(repo: str, pull_number: Union[str, int]) -> Dict[str, Any]:
+def get_pull_request(repo: str, pull_number: Union[str, int]) -> dict[str, Any]:
return request("GET", f"/repos/{repo}/pulls/{pull_number}")
@@ -96,7 +95,7 @@ def get_runs(
user: Optional[str] = None,
statuses: Iterable[str] = ("queued", "in_progress"),
events: Iterable[str] = ("pull_request", "push"),
-) -> List[Dict[str, Any]]:
+) -> list[dict[str, Any]]:
"""Get workflow runs associated with the given branch"""
return [
item
@@ -108,7 +107,7 @@ def get_runs(
]
-def print_commit(commit: Dict[str, Any], branch: str) -> None:
+def print_commit(commit: dict[str, Any], branch: str) -> None:
"""Print out commit message for verification"""
indented_message = " \n".join(commit["message"].split("\n"))
date_str = (
@@ -155,7 +154,7 @@ Date: {date_str}
def cancel_github_workflows(
branch_or_pull: Optional[str],
repo: str,
- event: List[str],
+ event: list[str],
include_last: bool,
include_running: bool,
) -> None:
diff --git a/scripts/permissions_cleanup.py b/scripts/permissions_cleanup.py
index 5ca75e394c..0416f55806 100644
--- a/scripts/permissions_cleanup.py
+++ b/scripts/permissions_cleanup.py
@@ -24,7 +24,7 @@ def cleanup_permissions() -> None:
pvms = security_manager.get_session.query(
security_manager.permissionview_model
).all()
- print("# of permission view menus is: {}".format(len(pvms)))
+ print(f"# of permission view menus is: {len(pvms)}")
pvms_dict = defaultdict(list)
for pvm in pvms:
pvms_dict[(pvm.permission, pvm.view_menu)].append(pvm)
@@ -43,7 +43,7 @@ def cleanup_permissions() -> None:
pvms = security_manager.get_session.query(
security_manager.permissionview_model
).all()
- print("Stage 1: # of permission view menus is: {}".format(len(pvms)))
+ print(f"Stage 1: # of permission view menus is: {len(pvms)}")
# 2. Clean up None permissions or view menus
pvms = security_manager.get_session.query(
@@ -57,7 +57,7 @@ def cleanup_permissions() -> None:
pvms = security_manager.get_session.query(
security_manager.permissionview_model
).all()
- print("Stage 2: # of permission view menus is: {}".format(len(pvms)))
+ print(f"Stage 2: # of permission view menus is: {len(pvms)}")
# 3. Delete empty permission view menus from roles
roles = security_manager.get_session.query(security_manager.role_model).all()
diff --git a/setup.py b/setup.py
index 41f7e11e38..d8adea3285 100644
--- a/setup.py
+++ b/setup.py
@@ -14,21 +14,19 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import io
import json
import os
import subprocess
-import sys
from setuptools import find_packages, setup
BASE_DIR = os.path.abspath(os.path.dirname(__file__))
PACKAGE_JSON = os.path.join(BASE_DIR, "superset-frontend", "package.json")
-with open(PACKAGE_JSON, "r") as package_file:
+with open(PACKAGE_JSON) as package_file:
version_string = json.load(package_file)["version"]
-with io.open("README.md", "r", encoding="utf-8") as f:
+with open("README.md", encoding="utf-8") as f:
long_description = f.read()
diff --git a/superset/advanced_data_type/plugins/internet_address.py b/superset/advanced_data_type/plugins/internet_address.py
index 08a0925846..8ab20fe2d0 100644
--- a/superset/advanced_data_type/plugins/internet_address.py
+++ b/superset/advanced_data_type/plugins/internet_address.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import ipaddress
-from typing import Any, List
+from typing import Any
from sqlalchemy import Column
@@ -77,7 +77,7 @@ def cidr_func(req: AdvancedDataTypeRequest) -> AdvancedDataTypeResponse:
# Make this return a single clause
def cidr_translate_filter_func(
- col: Column, operator: FilterOperator, values: List[Any]
+ col: Column, operator: FilterOperator, values: list[Any]
) -> Any:
"""
Convert a passed in column, FilterOperator and
diff --git a/superset/advanced_data_type/plugins/internet_port.py b/superset/advanced_data_type/plugins/internet_port.py
index 60a594bfd9..8983e41422 100644
--- a/superset/advanced_data_type/plugins/internet_port.py
+++ b/superset/advanced_data_type/plugins/internet_port.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import itertools
-from typing import Any, Dict, List
+from typing import Any
from sqlalchemy import Column
@@ -26,7 +26,7 @@ from superset.advanced_data_type.types import (
)
from superset.utils.core import FilterOperator, FilterStringOperators
-port_conversion_dict: Dict[str, List[int]] = {
+port_conversion_dict: dict[str, list[int]] = {
"http": [80],
"ssh": [22],
"https": [443],
@@ -100,7 +100,7 @@ def port_translation_func(req: AdvancedDataTypeRequest) -> AdvancedDataTypeRespo
def port_translate_filter_func(
- col: Column, operator: FilterOperator, values: List[Any]
+ col: Column, operator: FilterOperator, values: list[Any]
) -> Any:
"""
Convert a passed in column, FilterOperator
diff --git a/superset/advanced_data_type/types.py b/superset/advanced_data_type/types.py
index 316922f339..e8d5de9143 100644
--- a/superset/advanced_data_type/types.py
+++ b/superset/advanced_data_type/types.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from dataclasses import dataclass
-from typing import Any, Callable, List, Optional, TypedDict, Union
+from typing import Any, Callable, Optional, TypedDict, Union
from sqlalchemy import Column
from sqlalchemy.sql.expression import BinaryExpression
@@ -30,7 +30,7 @@ class AdvancedDataTypeRequest(TypedDict):
"""
advanced_data_type: str
- values: List[
+ values: list[
Union[FilterValues, None]
] # unparsed value (usually text when passed from text box)
@@ -41,9 +41,9 @@ class AdvancedDataTypeResponse(TypedDict, total=False):
"""
error_message: Optional[str]
- values: List[Any] # parsed value (can be any value)
+ values: list[Any] # parsed value (can be any value)
display_value: str # The string representation of the parsed values
- valid_filter_operators: List[FilterStringOperators]
+ valid_filter_operators: list[FilterStringOperators]
@dataclass
@@ -54,6 +54,6 @@ class AdvancedDataType:
verbose_name: str
description: str
- valid_data_types: List[str]
+ valid_data_types: list[str]
translate_type: Callable[[AdvancedDataTypeRequest], AdvancedDataTypeResponse]
translate_filter: Callable[[Column, FilterOperator, Any], BinaryExpression]
diff --git a/superset/annotation_layers/annotations/api.py b/superset/annotation_layers/annotations/api.py
index 0a6a2767f0..70e0a1ad02 100644
--- a/superset/annotation_layers/annotations/api.py
+++ b/superset/annotation_layers/annotations/api.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict
+from typing import Any
from flask import request, Response
from flask_appbuilder.api import expose, permission_name, protect, rison, safe
@@ -127,7 +127,7 @@ class AnnotationRestApi(BaseSupersetModelRestApi):
@staticmethod
def _apply_layered_relation_to_rison( # pylint: disable=invalid-name
- layer_id: int, rison_parameters: Dict[str, Any]
+ layer_id: int, rison_parameters: dict[str, Any]
) -> None:
if "filters" not in rison_parameters:
rison_parameters["filters"] = []
diff --git a/superset/annotation_layers/annotations/commands/bulk_delete.py b/superset/annotation_layers/annotations/commands/bulk_delete.py
index 113725050f..dd47047788 100644
--- a/superset/annotation_layers/annotations/commands/bulk_delete.py
+++ b/superset/annotation_layers/annotations/commands/bulk_delete.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import List, Optional
+from typing import Optional
from superset.annotation_layers.annotations.commands.exceptions import (
AnnotationBulkDeleteFailedError,
@@ -30,9 +30,9 @@ logger = logging.getLogger(__name__)
class BulkDeleteAnnotationCommand(BaseCommand):
- def __init__(self, model_ids: List[int]):
+ def __init__(self, model_ids: list[int]):
self._model_ids = model_ids
- self._models: Optional[List[Annotation]] = None
+ self._models: Optional[list[Annotation]] = None
def run(self) -> None:
self.validate()
diff --git a/superset/annotation_layers/annotations/commands/create.py b/superset/annotation_layers/annotations/commands/create.py
index 0974624561..986b564291 100644
--- a/superset/annotation_layers/annotations/commands/create.py
+++ b/superset/annotation_layers/annotations/commands/create.py
@@ -16,7 +16,7 @@
# under the License.
import logging
from datetime import datetime
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
@@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
class CreateAnnotationCommand(BaseCommand):
- def __init__(self, data: Dict[str, Any]):
+ def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
def run(self) -> Model:
@@ -50,7 +50,7 @@ class CreateAnnotationCommand(BaseCommand):
return annotation
def validate(self) -> None:
- exceptions: List[ValidationError] = []
+ exceptions: list[ValidationError] = []
layer_id: Optional[int] = self._properties.get("layer")
start_dttm: Optional[datetime] = self._properties.get("start_dttm")
end_dttm: Optional[datetime] = self._properties.get("end_dttm")
diff --git a/superset/annotation_layers/annotations/commands/update.py b/superset/annotation_layers/annotations/commands/update.py
index b644ddc362..99ab209165 100644
--- a/superset/annotation_layers/annotations/commands/update.py
+++ b/superset/annotation_layers/annotations/commands/update.py
@@ -16,7 +16,7 @@
# under the License.
import logging
from datetime import datetime
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
@@ -39,7 +39,7 @@ logger = logging.getLogger(__name__)
class UpdateAnnotationCommand(BaseCommand):
- def __init__(self, model_id: int, data: Dict[str, Any]):
+ def __init__(self, model_id: int, data: dict[str, Any]):
self._model_id = model_id
self._properties = data.copy()
self._model: Optional[Annotation] = None
@@ -54,7 +54,7 @@ class UpdateAnnotationCommand(BaseCommand):
return annotation
def validate(self) -> None:
- exceptions: List[ValidationError] = []
+ exceptions: list[ValidationError] = []
layer_id: Optional[int] = self._properties.get("layer")
short_descr: str = self._properties.get("short_descr", "")
diff --git a/superset/annotation_layers/annotations/dao.py b/superset/annotation_layers/annotations/dao.py
index 0c8a9e47c5..da69e576e5 100644
--- a/superset/annotation_layers/annotations/dao.py
+++ b/superset/annotation_layers/annotations/dao.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import List, Optional
+from typing import Optional
from sqlalchemy.exc import SQLAlchemyError
@@ -31,7 +31,7 @@ class AnnotationDAO(BaseDAO):
model_cls = Annotation
@staticmethod
- def bulk_delete(models: Optional[List[Annotation]], commit: bool = True) -> None:
+ def bulk_delete(models: Optional[list[Annotation]], commit: bool = True) -> None:
item_ids = [model.id for model in models] if models else []
try:
db.session.query(Annotation).filter(Annotation.id.in_(item_ids)).delete(
diff --git a/superset/annotation_layers/commands/bulk_delete.py b/superset/annotation_layers/commands/bulk_delete.py
index b9bc17e82f..4910dc4275 100644
--- a/superset/annotation_layers/commands/bulk_delete.py
+++ b/superset/annotation_layers/commands/bulk_delete.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import List, Optional
+from typing import Optional
from superset.annotation_layers.commands.exceptions import (
AnnotationLayerBulkDeleteFailedError,
@@ -31,9 +31,9 @@ logger = logging.getLogger(__name__)
class BulkDeleteAnnotationLayerCommand(BaseCommand):
- def __init__(self, model_ids: List[int]):
+ def __init__(self, model_ids: list[int]):
self._model_ids = model_ids
- self._models: Optional[List[AnnotationLayer]] = None
+ self._models: Optional[list[AnnotationLayer]] = None
def run(self) -> None:
self.validate()
diff --git a/superset/annotation_layers/commands/create.py b/superset/annotation_layers/commands/create.py
index 97431568a9..86b0cb3b85 100644
--- a/superset/annotation_layers/commands/create.py
+++ b/superset/annotation_layers/commands/create.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict, List
+from typing import Any
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
@@ -33,7 +33,7 @@ logger = logging.getLogger(__name__)
class CreateAnnotationLayerCommand(BaseCommand):
- def __init__(self, data: Dict[str, Any]):
+ def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
def run(self) -> Model:
@@ -46,7 +46,7 @@ class CreateAnnotationLayerCommand(BaseCommand):
return annotation_layer
def validate(self) -> None:
- exceptions: List[ValidationError] = []
+ exceptions: list[ValidationError] = []
name = self._properties.get("name", "")
diff --git a/superset/annotation_layers/commands/update.py b/superset/annotation_layers/commands/update.py
index 4a9cc31be5..67d869c005 100644
--- a/superset/annotation_layers/commands/update.py
+++ b/superset/annotation_layers/commands/update.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
@@ -35,7 +35,7 @@ logger = logging.getLogger(__name__)
class UpdateAnnotationLayerCommand(BaseCommand):
- def __init__(self, model_id: int, data: Dict[str, Any]):
+ def __init__(self, model_id: int, data: dict[str, Any]):
self._model_id = model_id
self._properties = data.copy()
self._model: Optional[AnnotationLayer] = None
@@ -50,7 +50,7 @@ class UpdateAnnotationLayerCommand(BaseCommand):
return annotation_layer
def validate(self) -> None:
- exceptions: List[ValidationError] = []
+ exceptions: list[ValidationError] = []
name = self._properties.get("name", "")
self._model = AnnotationLayerDAO.find_by_id(self._model_id)
diff --git a/superset/annotation_layers/dao.py b/superset/annotation_layers/dao.py
index d9db4b582d..67efc19f88 100644
--- a/superset/annotation_layers/dao.py
+++ b/superset/annotation_layers/dao.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import List, Optional, Union
+from typing import Optional, Union
from sqlalchemy.exc import SQLAlchemyError
@@ -32,7 +32,7 @@ class AnnotationLayerDAO(BaseDAO):
@staticmethod
def bulk_delete(
- models: Optional[List[AnnotationLayer]], commit: bool = True
+ models: Optional[list[AnnotationLayer]], commit: bool = True
) -> None:
item_ids = [model.id for model in models] if models else []
try:
@@ -46,7 +46,7 @@ class AnnotationLayerDAO(BaseDAO):
raise DAODeleteFailedError() from ex
@staticmethod
- def has_annotations(model_id: Union[int, List[int]]) -> bool:
+ def has_annotations(model_id: Union[int, list[int]]) -> bool:
if isinstance(model_id, list):
return (
db.session.query(AnnotationLayer)
diff --git a/superset/charts/commands/bulk_delete.py b/superset/charts/commands/bulk_delete.py
index c252f0be4c..ac801b7421 100644
--- a/superset/charts/commands/bulk_delete.py
+++ b/superset/charts/commands/bulk_delete.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import List, Optional
+from typing import Optional
from flask_babel import lazy_gettext as _
@@ -37,9 +37,9 @@ logger = logging.getLogger(__name__)
class BulkDeleteChartCommand(BaseCommand):
- def __init__(self, model_ids: List[int]):
+ def __init__(self, model_ids: list[int]):
self._model_ids = model_ids
- self._models: Optional[List[Slice]] = None
+ self._models: Optional[list[Slice]] = None
def run(self) -> None:
self.validate()
diff --git a/superset/charts/commands/create.py b/superset/charts/commands/create.py
index 38076fb9cd..78706b3a66 100644
--- a/superset/charts/commands/create.py
+++ b/superset/charts/commands/create.py
@@ -16,7 +16,7 @@
# under the License.
import logging
from datetime import datetime
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from flask import g
from flask_appbuilder.models.sqla import Model
@@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
class CreateChartCommand(CreateMixin, BaseCommand):
- def __init__(self, data: Dict[str, Any]):
+ def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
def run(self) -> Model:
@@ -56,7 +56,7 @@ class CreateChartCommand(CreateMixin, BaseCommand):
datasource_type = self._properties["datasource_type"]
datasource_id = self._properties["datasource_id"]
dashboard_ids = self._properties.get("dashboards", [])
- owner_ids: Optional[List[int]] = self._properties.get("owners")
+ owner_ids: Optional[list[int]] = self._properties.get("owners")
# Validate/Populate datasource
try:
diff --git a/superset/charts/commands/export.py b/superset/charts/commands/export.py
index 9d445cb54e..22310ade99 100644
--- a/superset/charts/commands/export.py
+++ b/superset/charts/commands/export.py
@@ -18,7 +18,7 @@
import json
import logging
-from typing import Iterator, Tuple
+from collections.abc import Iterator
import yaml
@@ -42,7 +42,7 @@ class ExportChartsCommand(ExportModelsCommand):
not_found = ChartNotFoundError
@staticmethod
- def _export(model: Slice, export_related: bool = True) -> Iterator[Tuple[str, str]]:
+ def _export(model: Slice, export_related: bool = True) -> Iterator[tuple[str, str]]:
file_name = get_filename(model.slice_name, model.id)
file_path = f"charts/{file_name}.yaml"
diff --git a/superset/charts/commands/importers/dispatcher.py b/superset/charts/commands/importers/dispatcher.py
index afeb9c2820..fb5007a50c 100644
--- a/superset/charts/commands/importers/dispatcher.py
+++ b/superset/charts/commands/importers/dispatcher.py
@@ -16,7 +16,7 @@
# under the License.
import logging
-from typing import Any, Dict
+from typing import Any
from marshmallow.exceptions import ValidationError
@@ -40,7 +40,7 @@ class ImportChartsCommand(BaseCommand):
until it finds one that matches.
"""
- def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
+ def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any):
self.contents = contents
self.args = args
self.kwargs = kwargs
diff --git a/superset/charts/commands/importers/v1/__init__.py b/superset/charts/commands/importers/v1/__init__.py
index ab88038aaa..132df21b08 100644
--- a/superset/charts/commands/importers/v1/__init__.py
+++ b/superset/charts/commands/importers/v1/__init__.py
@@ -15,8 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-import json
-from typing import Any, Dict, Set
+from typing import Any
from marshmallow import Schema
from sqlalchemy.orm import Session
@@ -40,7 +39,7 @@ class ImportChartsCommand(ImportModelsCommand):
dao = ChartDAO
model_name = "chart"
prefix = "charts/"
- schemas: Dict[str, Schema] = {
+ schemas: dict[str, Schema] = {
"charts/": ImportV1ChartSchema(),
"datasets/": ImportV1DatasetSchema(),
"databases/": ImportV1DatabaseSchema(),
@@ -49,29 +48,29 @@ class ImportChartsCommand(ImportModelsCommand):
@staticmethod
def _import(
- session: Session, configs: Dict[str, Any], overwrite: bool = False
+ session: Session, configs: dict[str, Any], overwrite: bool = False
) -> None:
# discover datasets associated with charts
- dataset_uuids: Set[str] = set()
+ dataset_uuids: set[str] = set()
for file_name, config in configs.items():
if file_name.startswith("charts/"):
dataset_uuids.add(config["dataset_uuid"])
# discover databases associated with datasets
- database_uuids: Set[str] = set()
+ database_uuids: set[str] = set()
for file_name, config in configs.items():
if file_name.startswith("datasets/") and config["uuid"] in dataset_uuids:
database_uuids.add(config["database_uuid"])
# import related databases
- database_ids: Dict[str, int] = {}
+ database_ids: dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/") and config["uuid"] in database_uuids:
database = import_database(session, config, overwrite=False)
database_ids[str(database.uuid)] = database.id
# import datasets with the correct parent ref
- datasets: Dict[str, SqlaTable] = {}
+ datasets: dict[str, SqlaTable] = {}
for file_name, config in configs.items():
if (
file_name.startswith("datasets/")
diff --git a/superset/charts/commands/importers/v1/utils.py b/superset/charts/commands/importers/v1/utils.py
index d4aeb17a1e..399e6c2243 100644
--- a/superset/charts/commands/importers/v1/utils.py
+++ b/superset/charts/commands/importers/v1/utils.py
@@ -16,7 +16,7 @@
# under the License.
import json
-from typing import Any, Dict
+from typing import Any
from flask import g
from sqlalchemy.orm import Session
@@ -28,7 +28,7 @@ from superset.models.slice import Slice
def import_chart(
session: Session,
- config: Dict[str, Any],
+ config: dict[str, Any],
overwrite: bool = False,
ignore_permissions: bool = False,
) -> Slice:
diff --git a/superset/charts/commands/update.py b/superset/charts/commands/update.py
index f5fc2616a5..a4265d0835 100644
--- a/superset/charts/commands/update.py
+++ b/superset/charts/commands/update.py
@@ -16,7 +16,7 @@
# under the License.
import logging
from datetime import datetime
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from flask import g
from flask_appbuilder.models.sqla import Model
@@ -42,14 +42,14 @@ from superset.models.slice import Slice
logger = logging.getLogger(__name__)
-def is_query_context_update(properties: Dict[str, Any]) -> bool:
+def is_query_context_update(properties: dict[str, Any]) -> bool:
return set(properties) == {"query_context", "query_context_generation"} and bool(
properties.get("query_context_generation")
)
class UpdateChartCommand(UpdateMixin, BaseCommand):
- def __init__(self, model_id: int, data: Dict[str, Any]):
+ def __init__(self, model_id: int, data: dict[str, Any]):
self._model_id = model_id
self._properties = data.copy()
self._model: Optional[Slice] = None
@@ -67,9 +67,9 @@ class UpdateChartCommand(UpdateMixin, BaseCommand):
return chart
def validate(self) -> None:
- exceptions: List[ValidationError] = []
+ exceptions: list[ValidationError] = []
dashboard_ids = self._properties.get("dashboards")
- owner_ids: Optional[List[int]] = self._properties.get("owners")
+ owner_ids: Optional[list[int]] = self._properties.get("owners")
# Validate if datasource_id is provided datasource_type is required
datasource_id = self._properties.get("datasource_id")
diff --git a/superset/charts/dao.py b/superset/charts/dao.py
index 7102e6ad23..9c6b2c26ef 100644
--- a/superset/charts/dao.py
+++ b/superset/charts/dao.py
@@ -17,7 +17,7 @@
# pylint: disable=arguments-renamed
import logging
from datetime import datetime
-from typing import List, Optional, TYPE_CHECKING
+from typing import Optional, TYPE_CHECKING
from sqlalchemy.exc import SQLAlchemyError
@@ -39,7 +39,7 @@ class ChartDAO(BaseDAO):
base_filter = ChartFilter
@staticmethod
- def bulk_delete(models: Optional[List[Slice]], commit: bool = True) -> None:
+ def bulk_delete(models: Optional[list[Slice]], commit: bool = True) -> None:
item_ids = [model.id for model in models] if models else []
# bulk delete, first delete related data
if models:
@@ -71,7 +71,7 @@ class ChartDAO(BaseDAO):
db.session.commit()
@staticmethod
- def favorited_ids(charts: List[Slice]) -> List[FavStar]:
+ def favorited_ids(charts: list[Slice]) -> list[FavStar]:
ids = [chart.id for chart in charts]
return [
star.obj_id
diff --git a/superset/charts/data/api.py b/superset/charts/data/api.py
index 9c620dcf5d..552044ebfa 100644
--- a/superset/charts/data/api.py
+++ b/superset/charts/data/api.py
@@ -18,7 +18,7 @@ from __future__ import annotations
import json
import logging
-from typing import Any, Dict, Optional, TYPE_CHECKING, Union
+from typing import Any, TYPE_CHECKING
import simplejson
from flask import current_app, g, make_response, request, Response
@@ -315,7 +315,7 @@ class ChartDataRestApi(ChartRestApi):
return self._get_data_response(command, True)
def _run_async(
- self, form_data: Dict[str, Any], command: ChartDataCommand
+ self, form_data: dict[str, Any], command: ChartDataCommand
) -> Response:
"""
Execute command as an async query.
@@ -344,9 +344,9 @@ class ChartDataRestApi(ChartRestApi):
def _send_chart_response(
self,
- result: Dict[Any, Any],
- form_data: Optional[Dict[str, Any]] = None,
- datasource: Optional[Union[BaseDatasource, Query]] = None,
+ result: dict[Any, Any],
+ form_data: dict[str, Any] | None = None,
+ datasource: BaseDatasource | Query | None = None,
) -> Response:
result_type = result["query_context"].result_type
result_format = result["query_context"].result_format
@@ -408,8 +408,8 @@ class ChartDataRestApi(ChartRestApi):
self,
command: ChartDataCommand,
force_cached: bool = False,
- form_data: Optional[Dict[str, Any]] = None,
- datasource: Optional[Union[BaseDatasource, Query]] = None,
+ form_data: dict[str, Any] | None = None,
+ datasource: BaseDatasource | Query | None = None,
) -> Response:
try:
result = command.run(force_cached=force_cached)
@@ -421,12 +421,12 @@ class ChartDataRestApi(ChartRestApi):
return self._send_chart_response(result, form_data, datasource)
# pylint: disable=invalid-name, no-self-use
- def _load_query_context_form_from_cache(self, cache_key: str) -> Dict[str, Any]:
+ def _load_query_context_form_from_cache(self, cache_key: str) -> dict[str, Any]:
return QueryContextCacheLoader.load(cache_key)
# pylint: disable=no-self-use
def _create_query_context_from_form(
- self, form_data: Dict[str, Any]
+ self, form_data: dict[str, Any]
) -> QueryContext:
try:
return ChartDataQueryContextSchema().load(form_data)
diff --git a/superset/charts/data/commands/create_async_job_command.py b/superset/charts/data/commands/create_async_job_command.py
index c4e25f742b..fb6e3f3dbf 100644
--- a/superset/charts/data/commands/create_async_job_command.py
+++ b/superset/charts/data/commands/create_async_job_command.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict, Optional
+from typing import Any, Optional
from flask import Request
@@ -32,7 +32,7 @@ class CreateAsyncChartDataJobCommand:
jwt_data = async_query_manager.parse_jwt_from_request(request)
self._async_channel_id = jwt_data["channel"]
- def run(self, form_data: Dict[str, Any], user_id: Optional[int]) -> Dict[str, Any]:
+ def run(self, form_data: dict[str, Any], user_id: Optional[int]) -> dict[str, Any]:
job_metadata = async_query_manager.init_job(self._async_channel_id, user_id)
load_chart_data_into_cache.delay(job_metadata, form_data)
return job_metadata
diff --git a/superset/charts/data/commands/get_data_command.py b/superset/charts/data/commands/get_data_command.py
index 819693607b..a84870a1dd 100644
--- a/superset/charts/data/commands/get_data_command.py
+++ b/superset/charts/data/commands/get_data_command.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict
+from typing import Any
from flask_babel import lazy_gettext as _
@@ -36,7 +36,7 @@ class ChartDataCommand(BaseCommand):
def __init__(self, query_context: QueryContext):
self._query_context = query_context
- def run(self, **kwargs: Any) -> Dict[str, Any]:
+ def run(self, **kwargs: Any) -> dict[str, Any]:
# caching is handled in query_context.get_df_payload
# (also evals `force` property)
cache_query_context = kwargs.get("cache", False)
diff --git a/superset/charts/data/query_context_cache_loader.py b/superset/charts/data/query_context_cache_loader.py
index b5ff3bdae8..97fa733a3e 100644
--- a/superset/charts/data/query_context_cache_loader.py
+++ b/superset/charts/data/query_context_cache_loader.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict
+from typing import Any
from superset import cache
from superset.charts.commands.exceptions import ChartDataCacheLoadError
@@ -22,7 +22,7 @@ from superset.charts.commands.exceptions import ChartDataCacheLoadError
class QueryContextCacheLoader: # pylint: disable=too-few-public-methods
@staticmethod
- def load(cache_key: str) -> Dict[str, Any]:
+ def load(cache_key: str) -> dict[str, Any]:
cache_value = cache.get(cache_key)
if not cache_value:
raise ChartDataCacheLoadError("Cached data not found")
diff --git a/superset/charts/post_processing.py b/superset/charts/post_processing.py
index 1165769fc8..a6b64c08d6 100644
--- a/superset/charts/post_processing.py
+++ b/superset/charts/post_processing.py
@@ -27,7 +27,7 @@ for these chart types.
"""
from io import StringIO
-from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
+from typing import Any, Optional, TYPE_CHECKING, Union
import pandas as pd
from flask_babel import gettext as __
@@ -45,14 +45,14 @@ if TYPE_CHECKING:
from superset.models.sql_lab import Query
-def get_column_key(label: Tuple[str, ...], metrics: List[str]) -> Tuple[Any, ...]:
+def get_column_key(label: tuple[str, ...], metrics: list[str]) -> tuple[Any, ...]:
"""
Sort columns when combining metrics.
MultiIndex labels have the metric name as the last element in the
tuple. We want to sort these according to the list of passed metrics.
"""
- parts: List[Any] = list(label)
+ parts: list[Any] = list(label)
metric = parts[-1]
parts[-1] = metrics.index(metric)
return tuple(parts)
@@ -60,9 +60,9 @@ def get_column_key(label: Tuple[str, ...], metrics: List[str]) -> Tuple[Any, ...
def pivot_df( # pylint: disable=too-many-locals, too-many-arguments, too-many-statements, too-many-branches
df: pd.DataFrame,
- rows: List[str],
- columns: List[str],
- metrics: List[str],
+ rows: list[str],
+ columns: list[str],
+ metrics: list[str],
aggfunc: str = "Sum",
transpose_pivot: bool = False,
combine_metrics: bool = False,
@@ -194,7 +194,7 @@ def list_unique_values(series: pd.Series) -> str:
"""
List unique values in a series.
"""
- return ", ".join(set(str(v) for v in pd.Series.unique(series)))
+ return ", ".join({str(v) for v in pd.Series.unique(series)})
pivot_v2_aggfunc_map = {
@@ -223,7 +223,7 @@ pivot_v2_aggfunc_map = {
def pivot_table_v2(
df: pd.DataFrame,
- form_data: Dict[str, Any],
+ form_data: dict[str, Any],
datasource: Optional[Union["BaseDatasource", "Query"]] = None,
) -> pd.DataFrame:
"""
@@ -249,7 +249,7 @@ def pivot_table_v2(
def pivot_table(
df: pd.DataFrame,
- form_data: Dict[str, Any],
+ form_data: dict[str, Any],
datasource: Optional[Union["BaseDatasource", "Query"]] = None,
) -> pd.DataFrame:
"""
@@ -285,7 +285,7 @@ def pivot_table(
def table(
df: pd.DataFrame,
- form_data: Dict[str, Any],
+ form_data: dict[str, Any],
datasource: Optional[ # pylint: disable=unused-argument
Union["BaseDatasource", "Query"]
] = None,
@@ -315,10 +315,10 @@ post_processors = {
def apply_post_process(
- result: Dict[Any, Any],
- form_data: Optional[Dict[str, Any]] = None,
+ result: dict[Any, Any],
+ form_data: Optional[dict[str, Any]] = None,
datasource: Optional[Union["BaseDatasource", "Query"]] = None,
-) -> Dict[Any, Any]:
+) -> dict[Any, Any]:
form_data = form_data or {}
viz_type = form_data.get("viz_type")
diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py
index 44252ef06f..373600cd08 100644
--- a/superset/charts/schemas.py
+++ b/superset/charts/schemas.py
@@ -18,7 +18,7 @@
from __future__ import annotations
import inspect
-from typing import Any, Dict, Optional, TYPE_CHECKING
+from typing import Any, TYPE_CHECKING
from flask_babel import gettext as _
from marshmallow import EXCLUDE, fields, post_load, Schema, validate
@@ -1383,7 +1383,7 @@ class ChartDataQueryObjectSchema(Schema):
class ChartDataQueryContextSchema(Schema):
- query_context_factory: Optional[QueryContextFactory] = None
+ query_context_factory: QueryContextFactory | None = None
datasource = fields.Nested(ChartDataDatasourceSchema)
queries = fields.List(fields.Nested(ChartDataQueryObjectSchema))
custom_cache_timeout = fields.Integer(
@@ -1407,7 +1407,7 @@ class ChartDataQueryContextSchema(Schema):
# pylint: disable=unused-argument
@post_load
- def make_query_context(self, data: Dict[str, Any], **kwargs: Any) -> QueryContext:
+ def make_query_context(self, data: dict[str, Any], **kwargs: Any) -> QueryContext:
query_context = self.get_query_context_factory().create(**data)
return query_context
diff --git a/superset/cli/importexport.py b/superset/cli/importexport.py
index c7689569c2..86f6fe9b67 100755
--- a/superset/cli/importexport.py
+++ b/superset/cli/importexport.py
@@ -18,7 +18,7 @@ import logging
import sys
from datetime import datetime
from pathlib import Path
-from typing import List, Optional
+from typing import Optional
from zipfile import is_zipfile, ZipFile
import click
@@ -309,7 +309,7 @@ else:
from superset.dashboards.commands.importers.v0 import ImportDashboardsCommand
path_object = Path(path)
- files: List[Path] = []
+ files: list[Path] = []
if path_object.is_file():
files.append(path_object)
elif path_object.exists() and not recursive:
@@ -363,7 +363,7 @@ else:
sync_metrics = "metrics" in sync_array
path_object = Path(path)
- files: List[Path] = []
+ files: list[Path] = []
if path_object.is_file():
files.append(path_object)
elif path_object.exists() and not recursive:
diff --git a/superset/cli/main.py b/superset/cli/main.py
index 006f8eb5c9..536617cadd 100755
--- a/superset/cli/main.py
+++ b/superset/cli/main.py
@@ -18,7 +18,7 @@
import importlib
import logging
import pkgutil
-from typing import Any, Dict
+from typing import Any
import click
from colorama import Fore, Style
@@ -40,7 +40,7 @@ def superset() -> None:
"""This is a management script for the Superset application."""
@app.shell_context_processor
- def make_shell_context() -> Dict[str, Any]:
+ def make_shell_context() -> dict[str, Any]:
return dict(app=app, db=db)
@@ -79,5 +79,5 @@ def version(verbose: bool) -> None:
)
print(Fore.BLUE + "-=" * 15)
if verbose:
- print("[DB] : " + "{}".format(db.engine))
+ print("[DB] : " + f"{db.engine}")
print(Style.RESET_ALL)
diff --git a/superset/cli/native_filters.py b/superset/cli/native_filters.py
index 63cc185e8e..a25724d38d 100644
--- a/superset/cli/native_filters.py
+++ b/superset/cli/native_filters.py
@@ -17,7 +17,6 @@
import json
from copy import deepcopy
from textwrap import dedent
-from typing import Set, Tuple
import click
from click_option_group import optgroup, RequiredMutuallyExclusiveOptionGroup
@@ -102,7 +101,7 @@ def native_filters() -> None:
)
def upgrade(
all_: bool, # pylint: disable=unused-argument
- dashboard_ids: Tuple[int, ...],
+ dashboard_ids: tuple[int, ...],
) -> None:
"""
Upgrade legacy filter-box charts to native dashboard filters.
@@ -251,7 +250,7 @@ def upgrade(
)
def downgrade(
all_: bool, # pylint: disable=unused-argument
- dashboard_ids: Tuple[int, ...],
+ dashboard_ids: tuple[int, ...],
) -> None:
"""
Downgrade native dashboard filters to legacy filter-box charts (where applicable).
@@ -347,7 +346,7 @@ def downgrade(
)
def cleanup(
all_: bool, # pylint: disable=unused-argument
- dashboard_ids: Tuple[int, ...],
+ dashboard_ids: tuple[int, ...],
) -> None:
"""
Cleanup obsolete legacy filter-box charts and interim metadata.
@@ -355,7 +354,7 @@ def cleanup(
Note this operation is irreversible.
"""
- slice_ids: Set[int] = set()
+ slice_ids: set[int] = set()
# Cleanup the dashboard which contains legacy fields used for downgrading.
for dashboard in (
diff --git a/superset/cli/thumbnails.py b/superset/cli/thumbnails.py
index 276d9981c1..325fab6853 100755
--- a/superset/cli/thumbnails.py
+++ b/superset/cli/thumbnails.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Type, Union
+from typing import Union
import click
from celery.utils.abstract import CallableTask
@@ -75,7 +75,7 @@ def compute_thumbnails(
def compute_generic_thumbnail(
friendly_type: str,
- model_cls: Union[Type[Dashboard], Type[Slice]],
+ model_cls: Union[type[Dashboard], type[Slice]],
model_id: int,
compute_func: CallableTask,
) -> None:
diff --git a/superset/commands/base.py b/superset/commands/base.py
index 42d5956312..caca50755d 100644
--- a/superset/commands/base.py
+++ b/superset/commands/base.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from abc import ABC, abstractmethod
-from typing import Any, List, Optional
+from typing import Any, Optional
from flask_appbuilder.security.sqla.models import User
@@ -45,7 +45,7 @@ class BaseCommand(ABC):
class CreateMixin: # pylint: disable=too-few-public-methods
@staticmethod
- def populate_owners(owner_ids: Optional[List[int]] = None) -> List[User]:
+ def populate_owners(owner_ids: Optional[list[int]] = None) -> list[User]:
"""
Populate list of owners, defaulting to the current user if `owner_ids` is
undefined or empty. If current user is missing in `owner_ids`, current user
@@ -60,7 +60,7 @@ class CreateMixin: # pylint: disable=too-few-public-methods
class UpdateMixin: # pylint: disable=too-few-public-methods
@staticmethod
- def populate_owners(owner_ids: Optional[List[int]] = None) -> List[User]:
+ def populate_owners(owner_ids: Optional[list[int]] = None) -> list[User]:
"""
Populate list of owners. If current user is missing in `owner_ids`, current user
is added unless belonging to the Admin role.
diff --git a/superset/commands/exceptions.py b/superset/commands/exceptions.py
index db9d1b6c63..4398d740c5 100644
--- a/superset/commands/exceptions.py
+++ b/superset/commands/exceptions.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from flask_babel import lazy_gettext as _
from marshmallow import ValidationError
@@ -59,7 +59,7 @@ class CommandInvalidError(CommandException):
def __init__(
self,
message: str = "",
- exceptions: Optional[List[ValidationError]] = None,
+ exceptions: Optional[list[ValidationError]] = None,
) -> None:
self._exceptions = exceptions or []
super().__init__(message)
@@ -67,14 +67,14 @@ class CommandInvalidError(CommandException):
def append(self, exception: ValidationError) -> None:
self._exceptions.append(exception)
- def extend(self, exceptions: List[ValidationError]) -> None:
+ def extend(self, exceptions: list[ValidationError]) -> None:
self._exceptions.extend(exceptions)
- def get_list_classnames(self) -> List[str]:
+ def get_list_classnames(self) -> list[str]:
return list(sorted({ex.__class__.__name__ for ex in self._exceptions}))
- def normalized_messages(self) -> Dict[Any, Any]:
- errors: Dict[Any, Any] = {}
+ def normalized_messages(self) -> dict[Any, Any]:
+ errors: dict[Any, Any] = {}
for exception in self._exceptions:
errors.update(exception.normalized_messages())
return errors
diff --git a/superset/commands/export/assets.py b/superset/commands/export/assets.py
index 9f088af428..1bd2cf6d61 100644
--- a/superset/commands/export/assets.py
+++ b/superset/commands/export/assets.py
@@ -15,8 +15,8 @@
# specific language governing permissions and limitations
# under the License.
+from collections.abc import Iterator
from datetime import datetime, timezone
-from typing import Iterator, Tuple
import yaml
@@ -36,7 +36,7 @@ class ExportAssetsCommand(BaseCommand):
Command that exports all databases, datasets, charts, dashboards and saved queries.
"""
- def run(self) -> Iterator[Tuple[str, str]]:
+ def run(self) -> Iterator[tuple[str, str]]:
metadata = {
"version": EXPORT_VERSION,
"type": "assets",
diff --git a/superset/commands/export/models.py b/superset/commands/export/models.py
index 4edafaa746..3f21f29281 100644
--- a/superset/commands/export/models.py
+++ b/superset/commands/export/models.py
@@ -15,8 +15,8 @@
# specific language governing permissions and limitations
# under the License.
+from collections.abc import Iterator
from datetime import datetime, timezone
-from typing import Iterator, List, Tuple, Type
import yaml
from flask_appbuilder import Model
@@ -30,21 +30,21 @@ METADATA_FILE_NAME = "metadata.yaml"
class ExportModelsCommand(BaseCommand):
- dao: Type[BaseDAO] = BaseDAO
- not_found: Type[CommandException] = CommandException
+ dao: type[BaseDAO] = BaseDAO
+ not_found: type[CommandException] = CommandException
- def __init__(self, model_ids: List[int], export_related: bool = True):
+ def __init__(self, model_ids: list[int], export_related: bool = True):
self.model_ids = model_ids
self.export_related = export_related
# this will be set when calling validate()
- self._models: List[Model] = []
+ self._models: list[Model] = []
@staticmethod
- def _export(model: Model, export_related: bool = True) -> Iterator[Tuple[str, str]]:
+ def _export(model: Model, export_related: bool = True) -> Iterator[tuple[str, str]]:
raise NotImplementedError("Subclasses MUST implement _export")
- def run(self) -> Iterator[Tuple[str, str]]:
+ def run(self) -> Iterator[tuple[str, str]]:
self.validate()
metadata = {
diff --git a/superset/commands/importers/v1/__init__.py b/superset/commands/importers/v1/__init__.py
index a67828bdb2..09830bf3cf 100644
--- a/superset/commands/importers/v1/__init__.py
+++ b/superset/commands/importers/v1/__init__.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, List, Optional, Set
+from typing import Any, Optional
from marshmallow import Schema, validate
from marshmallow.exceptions import ValidationError
@@ -40,33 +40,33 @@ class ImportModelsCommand(BaseCommand):
dao = BaseDAO
model_name = "model"
prefix = ""
- schemas: Dict[str, Schema] = {}
+ schemas: dict[str, Schema] = {}
import_error = CommandException
# pylint: disable=unused-argument
- def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
+ def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any):
self.contents = contents
- self.passwords: Dict[str, str] = kwargs.get("passwords") or {}
- self.ssh_tunnel_passwords: Dict[str, str] = (
+ self.passwords: dict[str, str] = kwargs.get("passwords") or {}
+ self.ssh_tunnel_passwords: dict[str, str] = (
kwargs.get("ssh_tunnel_passwords") or {}
)
- self.ssh_tunnel_private_keys: Dict[str, str] = (
+ self.ssh_tunnel_private_keys: dict[str, str] = (
kwargs.get("ssh_tunnel_private_keys") or {}
)
- self.ssh_tunnel_priv_key_passwords: Dict[str, str] = (
+ self.ssh_tunnel_priv_key_passwords: dict[str, str] = (
kwargs.get("ssh_tunnel_priv_key_passwords") or {}
)
self.overwrite: bool = kwargs.get("overwrite", False)
- self._configs: Dict[str, Any] = {}
+ self._configs: dict[str, Any] = {}
@staticmethod
def _import(
- session: Session, configs: Dict[str, Any], overwrite: bool = False
+ session: Session, configs: dict[str, Any], overwrite: bool = False
) -> None:
raise NotImplementedError("Subclasses MUST implement _import")
@classmethod
- def _get_uuids(cls) -> Set[str]:
+ def _get_uuids(cls) -> set[str]:
return {str(model.uuid) for model in db.session.query(cls.dao.model_cls).all()}
def run(self) -> None:
@@ -84,11 +84,11 @@ class ImportModelsCommand(BaseCommand):
raise self.import_error() from ex
def validate(self) -> None:
- exceptions: List[ValidationError] = []
+ exceptions: list[ValidationError] = []
# verify that the metadata file is present and valid
try:
- metadata: Optional[Dict[str, str]] = load_metadata(self.contents)
+ metadata: Optional[dict[str, str]] = load_metadata(self.contents)
except ValidationError as exc:
exceptions.append(exc)
metadata = None
@@ -114,7 +114,7 @@ class ImportModelsCommand(BaseCommand):
)
def _prevent_overwrite_existing_model( # pylint: disable=invalid-name
- self, exceptions: List[ValidationError]
+ self, exceptions: list[ValidationError]
) -> None:
"""check if the object exists and shouldn't be overwritten"""
if not self.overwrite:
diff --git a/superset/commands/importers/v1/assets.py b/superset/commands/importers/v1/assets.py
index ce8b46c2a0..1ab2e486cf 100644
--- a/superset/commands/importers/v1/assets.py
+++ b/superset/commands/importers/v1/assets.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from marshmallow import Schema
from marshmallow.exceptions import ValidationError
@@ -56,7 +56,7 @@ class ImportAssetsCommand(BaseCommand):
and will overwrite everything.
"""
- schemas: Dict[str, Schema] = {
+ schemas: dict[str, Schema] = {
"charts/": ImportV1ChartSchema(),
"dashboards/": ImportV1DashboardSchema(),
"datasets/": ImportV1DatasetSchema(),
@@ -65,24 +65,24 @@ class ImportAssetsCommand(BaseCommand):
}
# pylint: disable=unused-argument
- def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
+ def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any):
self.contents = contents
- self.passwords: Dict[str, str] = kwargs.get("passwords") or {}
- self.ssh_tunnel_passwords: Dict[str, str] = (
+ self.passwords: dict[str, str] = kwargs.get("passwords") or {}
+ self.ssh_tunnel_passwords: dict[str, str] = (
kwargs.get("ssh_tunnel_passwords") or {}
)
- self.ssh_tunnel_private_keys: Dict[str, str] = (
+ self.ssh_tunnel_private_keys: dict[str, str] = (
kwargs.get("ssh_tunnel_private_keys") or {}
)
- self.ssh_tunnel_priv_key_passwords: Dict[str, str] = (
+ self.ssh_tunnel_priv_key_passwords: dict[str, str] = (
kwargs.get("ssh_tunnel_priv_key_passwords") or {}
)
- self._configs: Dict[str, Any] = {}
+ self._configs: dict[str, Any] = {}
@staticmethod
- def _import(session: Session, configs: Dict[str, Any]) -> None:
+ def _import(session: Session, configs: dict[str, Any]) -> None:
# import databases first
- database_ids: Dict[str, int] = {}
+ database_ids: dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/"):
database = import_database(session, config, overwrite=True)
@@ -95,7 +95,7 @@ class ImportAssetsCommand(BaseCommand):
import_saved_query(session, config, overwrite=True)
# import datasets
- dataset_info: Dict[str, Dict[str, Any]] = {}
+ dataset_info: dict[str, dict[str, Any]] = {}
for file_name, config in configs.items():
if file_name.startswith("datasets/"):
config["database_id"] = database_ids[config["database_uuid"]]
@@ -107,7 +107,7 @@ class ImportAssetsCommand(BaseCommand):
}
# import charts
- chart_ids: Dict[str, int] = {}
+ chart_ids: dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("charts/"):
config.update(dataset_info[config["dataset_uuid"]])
@@ -121,7 +121,7 @@ class ImportAssetsCommand(BaseCommand):
dashboard = import_dashboard(session, config, overwrite=True)
# set ref in the dashboard_slices table
- dashboard_chart_ids: List[Dict[str, int]] = []
+ dashboard_chart_ids: list[dict[str, int]] = []
for uuid in find_chart_uuids(config["position"]):
if uuid not in chart_ids:
break
@@ -151,11 +151,11 @@ class ImportAssetsCommand(BaseCommand):
raise ImportFailedError() from ex
def validate(self) -> None:
- exceptions: List[ValidationError] = []
+ exceptions: list[ValidationError] = []
# verify that the metadata file is present and valid
try:
- metadata: Optional[Dict[str, str]] = load_metadata(self.contents)
+ metadata: Optional[dict[str, str]] = load_metadata(self.contents)
except ValidationError as exc:
exceptions.append(exc)
metadata = None
diff --git a/superset/commands/importers/v1/examples.py b/superset/commands/importers/v1/examples.py
index 35efdb1393..4c20e93ff7 100644
--- a/superset/commands/importers/v1/examples.py
+++ b/superset/commands/importers/v1/examples.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, List, Set, Tuple
+from typing import Any
from marshmallow import Schema
from sqlalchemy.orm import Session
@@ -52,7 +52,7 @@ class ImportExamplesCommand(ImportModelsCommand):
dao = BaseDAO
model_name = "model"
- schemas: Dict[str, Schema] = {
+ schemas: dict[str, Schema] = {
"charts/": ImportV1ChartSchema(),
"dashboards/": ImportV1DashboardSchema(),
"datasets/": ImportV1DatasetSchema(),
@@ -60,7 +60,7 @@ class ImportExamplesCommand(ImportModelsCommand):
}
import_error = CommandException
- def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
+ def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any):
super().__init__(contents, *args, **kwargs)
self.force_data = kwargs.get("force_data", False)
@@ -81,7 +81,7 @@ class ImportExamplesCommand(ImportModelsCommand):
raise self.import_error() from ex
@classmethod
- def _get_uuids(cls) -> Set[str]:
+ def _get_uuids(cls) -> set[str]:
# pylint: disable=protected-access
return (
ImportDatabasesCommand._get_uuids()
@@ -93,12 +93,12 @@ class ImportExamplesCommand(ImportModelsCommand):
@staticmethod
def _import( # pylint: disable=arguments-differ, too-many-locals, too-many-branches
session: Session,
- configs: Dict[str, Any],
+ configs: dict[str, Any],
overwrite: bool = False,
force_data: bool = False,
) -> None:
# import databases
- database_ids: Dict[str, int] = {}
+ database_ids: dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/"):
database = import_database(
@@ -114,7 +114,7 @@ class ImportExamplesCommand(ImportModelsCommand):
# database was created before its UUID was frozen, so it has a random UUID.
# We need to determine its ID so we can point the dataset to it.
examples_db = get_example_database()
- dataset_info: Dict[str, Dict[str, Any]] = {}
+ dataset_info: dict[str, dict[str, Any]] = {}
for file_name, config in configs.items():
if file_name.startswith("datasets/"):
# find the ID of the corresponding database
@@ -153,7 +153,7 @@ class ImportExamplesCommand(ImportModelsCommand):
}
# import charts
- chart_ids: Dict[str, int] = {}
+ chart_ids: dict[str, int] = {}
for file_name, config in configs.items():
if (
file_name.startswith("charts/")
@@ -175,7 +175,7 @@ class ImportExamplesCommand(ImportModelsCommand):
).fetchall()
# import dashboards
- dashboard_chart_ids: List[Tuple[int, int]] = []
+ dashboard_chart_ids: list[tuple[int, int]] = []
for file_name, config in configs.items():
if file_name.startswith("dashboards/"):
try:
diff --git a/superset/commands/importers/v1/utils.py b/superset/commands/importers/v1/utils.py
index c8fb97c53d..8ca008b3e2 100644
--- a/superset/commands/importers/v1/utils.py
+++ b/superset/commands/importers/v1/utils.py
@@ -15,7 +15,7 @@
import logging
from pathlib import Path, PurePosixPath
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from zipfile import ZipFile
import yaml
@@ -46,7 +46,7 @@ class MetadataSchema(Schema):
timestamp = fields.DateTime()
-def load_yaml(file_name: str, content: str) -> Dict[str, Any]:
+def load_yaml(file_name: str, content: str) -> dict[str, Any]:
"""Try to load a YAML file"""
try:
return yaml.safe_load(content)
@@ -55,7 +55,7 @@ def load_yaml(file_name: str, content: str) -> Dict[str, Any]:
raise ValidationError({file_name: "Not a valid YAML file"}) from ex
-def load_metadata(contents: Dict[str, str]) -> Dict[str, str]:
+def load_metadata(contents: dict[str, str]) -> dict[str, str]:
"""Apply validation and load a metadata file"""
if METADATA_FILE_NAME not in contents:
# if the contents have no METADATA_FILE_NAME this is probably
@@ -80,9 +80,9 @@ def load_metadata(contents: Dict[str, str]) -> Dict[str, str]:
def validate_metadata_type(
- metadata: Optional[Dict[str, str]],
+ metadata: Optional[dict[str, str]],
type_: str,
- exceptions: List[ValidationError],
+ exceptions: list[ValidationError],
) -> None:
"""Validate that the type declared in METADATA_FILE_NAME is correct"""
if metadata and "type" in metadata:
@@ -96,35 +96,35 @@ def validate_metadata_type(
# pylint: disable=too-many-locals,too-many-arguments
def load_configs(
- contents: Dict[str, str],
- schemas: Dict[str, Schema],
- passwords: Dict[str, str],
- exceptions: List[ValidationError],
- ssh_tunnel_passwords: Dict[str, str],
- ssh_tunnel_private_keys: Dict[str, str],
- ssh_tunnel_priv_key_passwords: Dict[str, str],
-) -> Dict[str, Any]:
- configs: Dict[str, Any] = {}
+ contents: dict[str, str],
+ schemas: dict[str, Schema],
+ passwords: dict[str, str],
+ exceptions: list[ValidationError],
+ ssh_tunnel_passwords: dict[str, str],
+ ssh_tunnel_private_keys: dict[str, str],
+ ssh_tunnel_priv_key_passwords: dict[str, str],
+) -> dict[str, Any]:
+ configs: dict[str, Any] = {}
# load existing databases so we can apply the password validation
- db_passwords: Dict[str, str] = {
+ db_passwords: dict[str, str] = {
str(uuid): password
for uuid, password in db.session.query(Database.uuid, Database.password).all()
}
# load existing ssh_tunnels so we can apply the password validation
- db_ssh_tunnel_passwords: Dict[str, str] = {
+ db_ssh_tunnel_passwords: dict[str, str] = {
str(uuid): password
for uuid, password in db.session.query(SSHTunnel.uuid, SSHTunnel.password).all()
}
# load existing ssh_tunnels so we can apply the private_key validation
- db_ssh_tunnel_private_keys: Dict[str, str] = {
+ db_ssh_tunnel_private_keys: dict[str, str] = {
str(uuid): private_key
for uuid, private_key in db.session.query(
SSHTunnel.uuid, SSHTunnel.private_key
).all()
}
# load existing ssh_tunnels so we can apply the private_key_password validation
- db_ssh_tunnel_priv_key_passws: Dict[str, str] = {
+ db_ssh_tunnel_priv_key_passws: dict[str, str] = {
str(uuid): private_key_password
for uuid, private_key_password in db.session.query(
SSHTunnel.uuid, SSHTunnel.private_key_password
@@ -206,7 +206,7 @@ def is_valid_config(file_name: str) -> bool:
return True
-def get_contents_from_bundle(bundle: ZipFile) -> Dict[str, str]:
+def get_contents_from_bundle(bundle: ZipFile) -> dict[str, str]:
return {
remove_root(file_name): bundle.read(file_name).decode()
for file_name in bundle.namelist()
diff --git a/superset/commands/utils.py b/superset/commands/utils.py
index ad58bb4050..7bb13984f8 100644
--- a/superset/commands/utils.py
+++ b/superset/commands/utils.py
@@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations
-from typing import List, Optional, TYPE_CHECKING
+from typing import TYPE_CHECKING
from flask import g
from flask_appbuilder.security.sqla.models import Role, User
@@ -37,9 +37,9 @@ if TYPE_CHECKING:
def populate_owners(
- owner_ids: Optional[List[int]],
+ owner_ids: list[int] | None,
default_to_user: bool,
-) -> List[User]:
+) -> list[User]:
"""
Helper function for commands, will fetch all users from owners id's
@@ -63,13 +63,13 @@ def populate_owners(
return owners
-def populate_roles(role_ids: Optional[List[int]] = None) -> List[Role]:
+def populate_roles(role_ids: list[int] | None = None) -> list[Role]:
"""
Helper function for commands, will fetch all roles from roles id's
:raises RolesNotFoundValidationError: If a role in the input list is not found
:param role_ids: A List of roles by id's
"""
- roles: List[Role] = []
+ roles: list[Role] = []
if role_ids:
roles = security_manager.find_roles_by_id(role_ids)
if len(roles) != len(role_ids):
diff --git a/superset/common/chart_data.py b/superset/common/chart_data.py
index 659a640159..65c0c43c11 100644
--- a/superset/common/chart_data.py
+++ b/superset/common/chart_data.py
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
from enum import Enum
-from typing import Set
class ChartDataResultFormat(str, Enum):
@@ -28,7 +27,7 @@ class ChartDataResultFormat(str, Enum):
XLSX = "xlsx"
@classmethod
- def table_like(cls) -> Set["ChartDataResultFormat"]:
+ def table_like(cls) -> set["ChartDataResultFormat"]:
return {cls.CSV} | {cls.XLSX}
diff --git a/superset/common/query_actions.py b/superset/common/query_actions.py
index f6f5a5cd62..22c778b77b 100644
--- a/superset/common/query_actions.py
+++ b/superset/common/query_actions.py
@@ -17,7 +17,7 @@
from __future__ import annotations
import copy
-from typing import Any, Callable, Dict, Optional, TYPE_CHECKING
+from typing import Any, Callable, TYPE_CHECKING
from flask_babel import _
@@ -49,7 +49,7 @@ def _get_datasource(
def _get_columns(
query_context: QueryContext, query_obj: QueryObject, _: bool
-) -> Dict[str, Any]:
+) -> dict[str, Any]:
datasource = _get_datasource(query_context, query_obj)
return {
"data": [
@@ -65,7 +65,7 @@ def _get_columns(
def _get_timegrains(
query_context: QueryContext, query_obj: QueryObject, _: bool
-) -> Dict[str, Any]:
+) -> dict[str, Any]:
datasource = _get_datasource(query_context, query_obj)
return {
"data": [
@@ -83,7 +83,7 @@ def _get_query(
query_context: QueryContext,
query_obj: QueryObject,
_: bool,
-) -> Dict[str, Any]:
+) -> dict[str, Any]:
datasource = _get_datasource(query_context, query_obj)
result = {"language": datasource.query_language}
try:
@@ -96,8 +96,8 @@ def _get_query(
def _get_full(
query_context: QueryContext,
query_obj: QueryObject,
- force_cached: Optional[bool] = False,
-) -> Dict[str, Any]:
+ force_cached: bool | None = False,
+) -> dict[str, Any]:
datasource = _get_datasource(query_context, query_obj)
result_type = query_obj.result_type or query_context.result_type
payload = query_context.get_df_payload(query_obj, force_cached=force_cached)
@@ -141,7 +141,7 @@ def _get_full(
def _get_samples(
query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False
-) -> Dict[str, Any]:
+) -> dict[str, Any]:
datasource = _get_datasource(query_context, query_obj)
query_obj = copy.copy(query_obj)
query_obj.is_timeseries = False
@@ -162,7 +162,7 @@ def _get_samples(
def _get_drill_detail(
query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False
-) -> Dict[str, Any]:
+) -> dict[str, Any]:
# todo(yongjie): Remove this function,
# when determining whether samples should be applied to the time filter.
datasource = _get_datasource(query_context, query_obj)
@@ -183,13 +183,13 @@ def _get_drill_detail(
def _get_results(
query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False
-) -> Dict[str, Any]:
+) -> dict[str, Any]:
payload = _get_full(query_context, query_obj, force_cached)
return payload
-_result_type_functions: Dict[
- ChartDataResultType, Callable[[QueryContext, QueryObject, bool], Dict[str, Any]]
+_result_type_functions: dict[
+ ChartDataResultType, Callable[[QueryContext, QueryObject, bool], dict[str, Any]]
] = {
ChartDataResultType.COLUMNS: _get_columns,
ChartDataResultType.TIMEGRAINS: _get_timegrains,
@@ -210,7 +210,7 @@ def get_query_results(
query_context: QueryContext,
query_obj: QueryObject,
force_cached: bool,
-) -> Dict[str, Any]:
+) -> dict[str, Any]:
"""
Return result payload for a chart data request.
diff --git a/superset/common/query_context.py b/superset/common/query_context.py
index 78eb8800c4..1a8d3c518b 100644
--- a/superset/common/query_context.py
+++ b/superset/common/query_context.py
@@ -17,7 +17,7 @@
from __future__ import annotations
import logging
-from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Union
+from typing import Any, ClassVar, TYPE_CHECKING
import pandas as pd
@@ -47,15 +47,15 @@ class QueryContext:
enforce_numerical_metrics: ClassVar[bool] = True
datasource: BaseDatasource
- slice_: Optional[Slice] = None
- queries: List[QueryObject]
- form_data: Optional[Dict[str, Any]]
+ slice_: Slice | None = None
+ queries: list[QueryObject]
+ form_data: dict[str, Any] | None
result_type: ChartDataResultType
result_format: ChartDataResultFormat
force: bool
- custom_cache_timeout: Optional[int]
+ custom_cache_timeout: int | None
- cache_values: Dict[str, Any]
+ cache_values: dict[str, Any]
_processor: QueryContextProcessor
@@ -65,14 +65,14 @@ class QueryContext:
self,
*,
datasource: BaseDatasource,
- queries: List[QueryObject],
- slice_: Optional[Slice],
- form_data: Optional[Dict[str, Any]],
+ queries: list[QueryObject],
+ slice_: Slice | None,
+ form_data: dict[str, Any] | None,
result_type: ChartDataResultType,
result_format: ChartDataResultFormat,
force: bool = False,
- custom_cache_timeout: Optional[int] = None,
- cache_values: Dict[str, Any],
+ custom_cache_timeout: int | None = None,
+ cache_values: dict[str, Any],
) -> None:
self.datasource = datasource
self.slice_ = slice_
@@ -88,18 +88,18 @@ class QueryContext:
def get_data(
self,
df: pd.DataFrame,
- ) -> Union[str, List[Dict[str, Any]]]:
+ ) -> str | list[dict[str, Any]]:
return self._processor.get_data(df)
def get_payload(
self,
- cache_query_context: Optional[bool] = False,
+ cache_query_context: bool | None = False,
force_cached: bool = False,
- ) -> Dict[str, Any]:
+ ) -> dict[str, Any]:
"""Returns the query results with both metadata and data"""
return self._processor.get_payload(cache_query_context, force_cached)
- def get_cache_timeout(self) -> Optional[int]:
+ def get_cache_timeout(self) -> int | None:
if self.custom_cache_timeout is not None:
return self.custom_cache_timeout
if self.slice_ and self.slice_.cache_timeout is not None:
@@ -110,14 +110,14 @@ class QueryContext:
return self.datasource.database.cache_timeout
return None
- def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> Optional[str]:
+ def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> str | None:
return self._processor.query_cache_key(query_obj, **kwargs)
def get_df_payload(
self,
query_obj: QueryObject,
- force_cached: Optional[bool] = False,
- ) -> Dict[str, Any]:
+ force_cached: bool | None = False,
+ ) -> dict[str, Any]:
return self._processor.get_df_payload(
query_obj=query_obj,
force_cached=force_cached,
diff --git a/superset/common/query_context_factory.py b/superset/common/query_context_factory.py
index 84c0415722..62018def8d 100644
--- a/superset/common/query_context_factory.py
+++ b/superset/common/query_context_factory.py
@@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations
-from typing import Any, Dict, List, Optional, TYPE_CHECKING
+from typing import Any, TYPE_CHECKING
from superset import app, db
from superset.charts.dao import ChartDAO
@@ -48,12 +48,12 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
self,
*,
datasource: DatasourceDict,
- queries: List[Dict[str, Any]],
- form_data: Optional[Dict[str, Any]] = None,
- result_type: Optional[ChartDataResultType] = None,
- result_format: Optional[ChartDataResultFormat] = None,
+ queries: list[dict[str, Any]],
+ form_data: dict[str, Any] | None = None,
+ result_type: ChartDataResultType | None = None,
+ result_format: ChartDataResultFormat | None = None,
force: bool = False,
- custom_cache_timeout: Optional[int] = None,
+ custom_cache_timeout: int | None = None,
) -> QueryContext:
datasource_model_instance = None
if datasource:
@@ -101,13 +101,13 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
datasource_id=int(datasource["id"]),
)
- def _get_slice(self, slice_id: Any) -> Optional[Slice]:
+ def _get_slice(self, slice_id: Any) -> Slice | None:
return ChartDAO.find_by_id(slice_id)
def _process_query_object(
self,
datasource: BaseDatasource,
- form_data: Optional[Dict[str, Any]],
+ form_data: dict[str, Any] | None,
query_object: QueryObject,
) -> QueryObject:
self._apply_granularity(query_object, form_data, datasource)
@@ -117,7 +117,7 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
def _apply_granularity(
self,
query_object: QueryObject,
- form_data: Optional[Dict[str, Any]],
+ form_data: dict[str, Any] | None,
datasource: BaseDatasource,
) -> None:
temporal_columns = {
diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py
index 85a2b5d97a..ecb8db4246 100644
--- a/superset/common/query_context_processor.py
+++ b/superset/common/query_context_processor.py
@@ -19,7 +19,7 @@ from __future__ import annotations
import copy
import logging
import re
-from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Union
+from typing import Any, ClassVar, TYPE_CHECKING
import numpy as np
import pandas as pd
@@ -77,8 +77,8 @@ logger = logging.getLogger(__name__)
class CachedTimeOffset(TypedDict):
df: pd.DataFrame
- queries: List[str]
- cache_keys: List[Optional[str]]
+ queries: list[str]
+ cache_keys: list[str | None]
class QueryContextProcessor:
@@ -102,8 +102,8 @@ class QueryContextProcessor:
enforce_numerical_metrics: ClassVar[bool] = True
def get_df_payload(
- self, query_obj: QueryObject, force_cached: Optional[bool] = False
- ) -> Dict[str, Any]:
+ self, query_obj: QueryObject, force_cached: bool | None = False
+ ) -> dict[str, Any]:
"""Handles caching around the df payload retrieval"""
cache_key = self.query_cache_key(query_obj)
timeout = self.get_cache_timeout()
@@ -181,7 +181,7 @@ class QueryContextProcessor:
"label_map": label_map,
}
- def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> Optional[str]:
+ def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> str | None:
"""
Returns a QueryObject cache key for objects in self.queries
"""
@@ -248,8 +248,8 @@ class QueryContextProcessor:
def normalize_df(self, df: pd.DataFrame, query_object: QueryObject) -> pd.DataFrame:
# todo: should support "python_date_format" and "get_column" in each datasource
def _get_timestamp_format(
- source: BaseDatasource, column: Optional[str]
- ) -> Optional[str]:
+ source: BaseDatasource, column: str | None
+ ) -> str | None:
column_obj = source.get_column(column)
if (
column_obj
@@ -315,9 +315,9 @@ class QueryContextProcessor:
query_context = self._query_context
# ensure query_object is immutable
query_object_clone = copy.copy(query_object)
- queries: List[str] = []
- cache_keys: List[Optional[str]] = []
- rv_dfs: List[pd.DataFrame] = [df]
+ queries: list[str] = []
+ cache_keys: list[str | None] = []
+ rv_dfs: list[pd.DataFrame] = [df]
time_offsets = query_object.time_offsets
outer_from_dttm, outer_to_dttm = get_since_until_from_query_object(query_object)
@@ -449,7 +449,7 @@ class QueryContextProcessor:
rv_df = pd.concat(rv_dfs, axis=1, copy=False) if time_offsets else df
return CachedTimeOffset(df=rv_df, queries=queries, cache_keys=cache_keys)
- def get_data(self, df: pd.DataFrame) -> Union[str, List[Dict[str, Any]]]:
+ def get_data(self, df: pd.DataFrame) -> str | list[dict[str, Any]]:
if self._query_context.result_format in ChartDataResultFormat.table_like():
include_index = not isinstance(df.index, pd.RangeIndex)
columns = list(df.columns)
@@ -470,9 +470,9 @@ class QueryContextProcessor:
def get_payload(
self,
- cache_query_context: Optional[bool] = False,
+ cache_query_context: bool | None = False,
force_cached: bool = False,
- ) -> Dict[str, Any]:
+ ) -> dict[str, Any]:
"""Returns the query results with both metadata and data"""
# Get all the payloads from the QueryObjects
@@ -522,13 +522,13 @@ class QueryContextProcessor:
return generate_cache_key(cache_dict, key_prefix)
- def get_annotation_data(self, query_obj: QueryObject) -> Dict[str, Any]:
+ def get_annotation_data(self, query_obj: QueryObject) -> dict[str, Any]:
"""
:param query_context:
:param query_obj:
:return:
"""
- annotation_data: Dict[str, Any] = self.get_native_annotation_data(query_obj)
+ annotation_data: dict[str, Any] = self.get_native_annotation_data(query_obj)
for annotation_layer in [
layer
for layer in query_obj.annotation_layers
@@ -541,7 +541,7 @@ class QueryContextProcessor:
return annotation_data
@staticmethod
- def get_native_annotation_data(query_obj: QueryObject) -> Dict[str, Any]:
+ def get_native_annotation_data(query_obj: QueryObject) -> dict[str, Any]:
annotation_data = {}
annotation_layers = [
layer
@@ -576,8 +576,8 @@ class QueryContextProcessor:
@staticmethod
def get_viz_annotation_data(
- annotation_layer: Dict[str, Any], force: bool
- ) -> Dict[str, Any]:
+ annotation_layer: dict[str, Any], force: bool
+ ) -> dict[str, Any]:
chart = ChartDAO.find_by_id(annotation_layer["value"])
if not chart:
raise QueryObjectValidationError(_("The chart does not exist"))
diff --git a/superset/common/query_object.py b/superset/common/query_object.py
index 802a1eed5b..dc02b774e5 100644
--- a/superset/common/query_object.py
+++ b/superset/common/query_object.py
@@ -21,7 +21,7 @@ import json
import logging
from datetime import datetime
from pprint import pformat
-from typing import Any, Dict, List, NamedTuple, Optional, TYPE_CHECKING
+from typing import Any, NamedTuple, TYPE_CHECKING
from flask import g
from flask_babel import gettext as _
@@ -81,58 +81,58 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
and druid. The query objects are constructed on the client.
"""
- annotation_layers: List[Dict[str, Any]]
- applied_time_extras: Dict[str, str]
+ annotation_layers: list[dict[str, Any]]
+ applied_time_extras: dict[str, str]
apply_fetch_values_predicate: bool
- columns: List[Column]
- datasource: Optional[BaseDatasource]
- extras: Dict[str, Any]
- filter: List[QueryObjectFilterClause]
- from_dttm: Optional[datetime]
- granularity: Optional[str]
- inner_from_dttm: Optional[datetime]
- inner_to_dttm: Optional[datetime]
+ columns: list[Column]
+ datasource: BaseDatasource | None
+ extras: dict[str, Any]
+ filter: list[QueryObjectFilterClause]
+ from_dttm: datetime | None
+ granularity: str | None
+ inner_from_dttm: datetime | None
+ inner_to_dttm: datetime | None
is_rowcount: bool
is_timeseries: bool
- metrics: Optional[List[Metric]]
+ metrics: list[Metric] | None
order_desc: bool
- orderby: List[OrderBy]
- post_processing: List[Dict[str, Any]]
- result_type: Optional[ChartDataResultType]
- row_limit: Optional[int]
+ orderby: list[OrderBy]
+ post_processing: list[dict[str, Any]]
+ result_type: ChartDataResultType | None
+ row_limit: int | None
row_offset: int
- series_columns: List[Column]
+ series_columns: list[Column]
series_limit: int
- series_limit_metric: Optional[Metric]
- time_offsets: List[str]
- time_shift: Optional[str]
- time_range: Optional[str]
- to_dttm: Optional[datetime]
+ series_limit_metric: Metric | None
+ time_offsets: list[str]
+ time_shift: str | None
+ time_range: str | None
+ to_dttm: datetime | None
def __init__( # pylint: disable=too-many-locals
self,
*,
- annotation_layers: Optional[List[Dict[str, Any]]] = None,
- applied_time_extras: Optional[Dict[str, str]] = None,
+ annotation_layers: list[dict[str, Any]] | None = None,
+ applied_time_extras: dict[str, str] | None = None,
apply_fetch_values_predicate: bool = False,
- columns: Optional[List[Column]] = None,
- datasource: Optional[BaseDatasource] = None,
- extras: Optional[Dict[str, Any]] = None,
- filters: Optional[List[QueryObjectFilterClause]] = None,
- granularity: Optional[str] = None,
+ columns: list[Column] | None = None,
+ datasource: BaseDatasource | None = None,
+ extras: dict[str, Any] | None = None,
+ filters: list[QueryObjectFilterClause] | None = None,
+ granularity: str | None = None,
is_rowcount: bool = False,
- is_timeseries: Optional[bool] = None,
- metrics: Optional[List[Metric]] = None,
+ is_timeseries: bool | None = None,
+ metrics: list[Metric] | None = None,
order_desc: bool = True,
- orderby: Optional[List[OrderBy]] = None,
- post_processing: Optional[List[Optional[Dict[str, Any]]]] = None,
- row_limit: Optional[int],
- row_offset: Optional[int] = None,
- series_columns: Optional[List[Column]] = None,
+ orderby: list[OrderBy] | None = None,
+ post_processing: list[dict[str, Any] | None] | None = None,
+ row_limit: int | None,
+ row_offset: int | None = None,
+ series_columns: list[Column] | None = None,
series_limit: int = 0,
- series_limit_metric: Optional[Metric] = None,
- time_range: Optional[str] = None,
- time_shift: Optional[str] = None,
+ series_limit_metric: Metric | None = None,
+ time_range: str | None = None,
+ time_shift: str | None = None,
**kwargs: Any,
):
self._set_annotation_layers(annotation_layers)
@@ -166,7 +166,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
self._move_deprecated_extra_fields(kwargs)
def _set_annotation_layers(
- self, annotation_layers: Optional[List[Dict[str, Any]]]
+ self, annotation_layers: list[dict[str, Any]] | None
) -> None:
self.annotation_layers = [
layer
@@ -175,14 +175,14 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
if layer["annotationType"] != "FORMULA"
]
- def _set_is_timeseries(self, is_timeseries: Optional[bool]) -> None:
+ def _set_is_timeseries(self, is_timeseries: bool | None) -> None:
# is_timeseries is True if time column is in either columns or groupby
# (both are dimensions)
self.is_timeseries = (
is_timeseries if is_timeseries is not None else DTTM_ALIAS in self.columns
)
- def _set_metrics(self, metrics: Optional[List[Metric]] = None) -> None:
+ def _set_metrics(self, metrics: list[Metric] | None = None) -> None:
# Support metric reference/definition in the format of
# 1. 'metric_name' - name of predefined metric
# 2. { label: 'label_name' } - legacy format for a predefined metric
@@ -195,16 +195,16 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
]
def _set_post_processing(
- self, post_processing: Optional[List[Optional[Dict[str, Any]]]]
+ self, post_processing: list[dict[str, Any] | None] | None
) -> None:
post_processing = post_processing or []
self.post_processing = [post_proc for post_proc in post_processing if post_proc]
def _init_series_columns(
self,
- series_columns: Optional[List[Column]],
- metrics: Optional[List[Metric]],
- is_timeseries: Optional[bool],
+ series_columns: list[Column] | None,
+ metrics: list[Metric] | None,
+ is_timeseries: bool | None,
) -> None:
if series_columns:
self.series_columns = series_columns
@@ -213,7 +213,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
else:
self.series_columns = []
- def _rename_deprecated_fields(self, kwargs: Dict[str, Any]) -> None:
+ def _rename_deprecated_fields(self, kwargs: dict[str, Any]) -> None:
# rename deprecated fields
for field in DEPRECATED_FIELDS:
if field.old_name in kwargs:
@@ -233,7 +233,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
)
setattr(self, field.new_name, value)
- def _move_deprecated_extra_fields(self, kwargs: Dict[str, Any]) -> None:
+ def _move_deprecated_extra_fields(self, kwargs: dict[str, Any]) -> None:
# move deprecated extras fields to extras
for field in DEPRECATED_EXTRAS_FIELDS:
if field.old_name in kwargs:
@@ -256,19 +256,19 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
self.extras[field.new_name] = value
@property
- def metric_names(self) -> List[str]:
+ def metric_names(self) -> list[str]:
"""Return metrics names (labels), coerce adhoc metrics to strings."""
return get_metric_names(self.metrics or [])
@property
- def column_names(self) -> List[str]:
+ def column_names(self) -> list[str]:
"""Return column names (labels). Gives priority to groupbys if both groupbys
and metrics are non-empty, otherwise returns column labels."""
return get_column_names(self.columns)
def validate(
- self, raise_exceptions: Optional[bool] = True
- ) -> Optional[QueryObjectValidationError]:
+ self, raise_exceptions: bool | None = True
+ ) -> QueryObjectValidationError | None:
"""Validate query object"""
try:
self._validate_there_are_no_missing_series()
@@ -314,7 +314,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
)
)
- def to_dict(self) -> Dict[str, Any]:
+ def to_dict(self) -> dict[str, Any]:
query_object_dict = {
"apply_fetch_values_predicate": self.apply_fetch_values_predicate,
"columns": self.columns,
diff --git a/superset/common/query_object_factory.py b/superset/common/query_object_factory.py
index 88cc7ca1b4..5676dc9eda 100644
--- a/superset/common/query_object_factory.py
+++ b/superset/common/query_object_factory.py
@@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations
-from typing import Any, Dict, Optional, TYPE_CHECKING
+from typing import Any, TYPE_CHECKING
from superset.common.chart_data import ChartDataResultType
from superset.common.query_object import QueryObject
@@ -31,13 +31,13 @@ if TYPE_CHECKING:
class QueryObjectFactory: # pylint: disable=too-few-public-methods
- _config: Dict[str, Any]
+ _config: dict[str, Any]
_datasource_dao: DatasourceDAO
_session_maker: sessionmaker
def __init__(
self,
- app_configurations: Dict[str, Any],
+ app_configurations: dict[str, Any],
_datasource_dao: DatasourceDAO,
session_maker: sessionmaker,
):
@@ -48,11 +48,11 @@ class QueryObjectFactory: # pylint: disable=too-few-public-methods
def create( # pylint: disable=too-many-arguments
self,
parent_result_type: ChartDataResultType,
- datasource: Optional[DatasourceDict] = None,
- extras: Optional[Dict[str, Any]] = None,
- row_limit: Optional[int] = None,
- time_range: Optional[str] = None,
- time_shift: Optional[str] = None,
+ datasource: DatasourceDict | None = None,
+ extras: dict[str, Any] | None = None,
+ row_limit: int | None = None,
+ time_range: str | None = None,
+ time_shift: str | None = None,
**kwargs: Any,
) -> QueryObject:
datasource_model_instance = None
@@ -84,13 +84,13 @@ class QueryObjectFactory: # pylint: disable=too-few-public-methods
def _process_extras( # pylint: disable=no-self-use
self,
- extras: Optional[Dict[str, Any]],
- ) -> Dict[str, Any]:
+ extras: dict[str, Any] | None,
+ ) -> dict[str, Any]:
extras = extras or {}
return extras
def _process_row_limit(
- self, row_limit: Optional[int], result_type: ChartDataResultType
+ self, row_limit: int | None, result_type: ChartDataResultType
) -> int:
default_row_limit = (
self._config["SAMPLES_ROW_LIMIT"]
diff --git a/superset/common/tags.py b/superset/common/tags.py
index 706192913a..6066d0eec7 100644
--- a/superset/common/tags.py
+++ b/superset/common/tags.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, List
+from typing import Any
from sqlalchemy import MetaData
from sqlalchemy.exc import IntegrityError
@@ -25,7 +25,7 @@ from superset.tags.models import ObjectTypes, TagTypes
def add_types_to_charts(
- metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
+ metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
) -> None:
slices = metadata.tables["slices"]
@@ -57,7 +57,7 @@ def add_types_to_charts(
def add_types_to_dashboards(
- metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
+ metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
) -> None:
dashboard_table = metadata.tables["dashboards"]
@@ -89,7 +89,7 @@ def add_types_to_dashboards(
def add_types_to_saved_queries(
- metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
+ metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
) -> None:
saved_query = metadata.tables["saved_query"]
@@ -121,7 +121,7 @@ def add_types_to_saved_queries(
def add_types_to_datasets(
- metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
+ metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
) -> None:
tables = metadata.tables["tables"]
@@ -237,7 +237,7 @@ def add_types(metadata: MetaData) -> None:
def add_owners_to_charts(
- metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
+ metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
) -> None:
slices = metadata.tables["slices"]
@@ -273,7 +273,7 @@ def add_owners_to_charts(
def add_owners_to_dashboards(
- metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
+ metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
) -> None:
dashboard_table = metadata.tables["dashboards"]
@@ -309,7 +309,7 @@ def add_owners_to_dashboards(
def add_owners_to_saved_queries(
- metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
+ metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
) -> None:
saved_query = metadata.tables["saved_query"]
@@ -345,7 +345,7 @@ def add_owners_to_saved_queries(
def add_owners_to_datasets(
- metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
+ metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
) -> None:
tables = metadata.tables["tables"]
diff --git a/superset/common/utils/dataframe_utils.py b/superset/common/utils/dataframe_utils.py
index 4dd62e3b5d..a3421f6431 100644
--- a/superset/common/utils/dataframe_utils.py
+++ b/superset/common/utils/dataframe_utils.py
@@ -17,7 +17,7 @@
from __future__ import annotations
import datetime
-from typing import Any, List, TYPE_CHECKING
+from typing import Any, TYPE_CHECKING
import numpy as np
import pandas as pd
@@ -29,7 +29,7 @@ if TYPE_CHECKING:
def left_join_df(
left_df: pd.DataFrame,
right_df: pd.DataFrame,
- join_keys: List[str],
+ join_keys: list[str],
) -> pd.DataFrame:
df = left_df.set_index(join_keys).join(right_df.set_index(join_keys))
df.reset_index(inplace=True)
diff --git a/superset/common/utils/query_cache_manager.py b/superset/common/utils/query_cache_manager.py
index 6c1b268f46..a0fb65b20d 100644
--- a/superset/common/utils/query_cache_manager.py
+++ b/superset/common/utils/query_cache_manager.py
@@ -17,7 +17,7 @@
from __future__ import annotations
import logging
-from typing import Any, Dict, List, Optional
+from typing import Any
from flask_caching import Cache
from pandas import DataFrame
@@ -37,7 +37,7 @@ config = app.config
stats_logger: BaseStatsLogger = config["STATS_LOGGER"]
logger = logging.getLogger(__name__)
-_cache: Dict[CacheRegion, Cache] = {
+_cache: dict[CacheRegion, Cache] = {
CacheRegion.DEFAULT: cache_manager.cache,
CacheRegion.DATA: cache_manager.data_cache,
}
@@ -53,17 +53,17 @@ class QueryCacheManager:
self,
df: DataFrame = DataFrame(),
query: str = "",
- annotation_data: Optional[Dict[str, Any]] = None,
- applied_template_filters: Optional[List[str]] = None,
- applied_filter_columns: Optional[List[Column]] = None,
- rejected_filter_columns: Optional[List[Column]] = None,
- status: Optional[str] = None,
- error_message: Optional[str] = None,
+ annotation_data: dict[str, Any] | None = None,
+ applied_template_filters: list[str] | None = None,
+ applied_filter_columns: list[Column] | None = None,
+ rejected_filter_columns: list[Column] | None = None,
+ status: str | None = None,
+ error_message: str | None = None,
is_loaded: bool = False,
- stacktrace: Optional[str] = None,
- is_cached: Optional[bool] = None,
- cache_dttm: Optional[str] = None,
- cache_value: Optional[Dict[str, Any]] = None,
+ stacktrace: str | None = None,
+ is_cached: bool | None = None,
+ cache_dttm: str | None = None,
+ cache_value: dict[str, Any] | None = None,
) -> None:
self.df = df
self.query = query
@@ -85,10 +85,10 @@ class QueryCacheManager:
self,
key: str,
query_result: QueryResult,
- annotation_data: Optional[Dict[str, Any]] = None,
- force_query: Optional[bool] = False,
- timeout: Optional[int] = None,
- datasource_uid: Optional[str] = None,
+ annotation_data: dict[str, Any] | None = None,
+ force_query: bool | None = False,
+ timeout: int | None = None,
+ datasource_uid: str | None = None,
region: CacheRegion = CacheRegion.DEFAULT,
) -> None:
"""
@@ -136,11 +136,11 @@ class QueryCacheManager:
@classmethod
def get(
cls,
- key: Optional[str],
+ key: str | None,
region: CacheRegion = CacheRegion.DEFAULT,
- force_query: Optional[bool] = False,
- force_cached: Optional[bool] = False,
- ) -> "QueryCacheManager":
+ force_query: bool | None = False,
+ force_cached: bool | None = False,
+ ) -> QueryCacheManager:
"""
Initialize QueryCacheManager by query-cache key
"""
@@ -190,10 +190,10 @@ class QueryCacheManager:
@staticmethod
def set(
- key: Optional[str],
- value: Dict[str, Any],
- timeout: Optional[int] = None,
- datasource_uid: Optional[str] = None,
+ key: str | None,
+ value: dict[str, Any],
+ timeout: int | None = None,
+ datasource_uid: str | None = None,
region: CacheRegion = CacheRegion.DEFAULT,
) -> None:
"""
@@ -204,7 +204,7 @@ class QueryCacheManager:
@staticmethod
def delete(
- key: Optional[str],
+ key: str | None,
region: CacheRegion = CacheRegion.DEFAULT,
) -> None:
if key:
@@ -212,7 +212,7 @@ class QueryCacheManager:
@staticmethod
def has(
- key: Optional[str],
+ key: str | None,
region: CacheRegion = CacheRegion.DEFAULT,
) -> bool:
return bool(_cache[region].get(key)) if key else False
diff --git a/superset/common/utils/time_range_utils.py b/superset/common/utils/time_range_utils.py
index fa6a5244b2..5f9139c047 100644
--- a/superset/common/utils/time_range_utils.py
+++ b/superset/common/utils/time_range_utils.py
@@ -17,7 +17,7 @@
from __future__ import annotations
from datetime import datetime
-from typing import Any, cast, Dict, Optional, Tuple
+from typing import Any, cast
from superset import app
from superset.common.query_object import QueryObject
@@ -26,10 +26,10 @@ from superset.utils.date_parser import get_since_until
def get_since_until_from_time_range(
- time_range: Optional[str] = None,
- time_shift: Optional[str] = None,
- extras: Optional[Dict[str, Any]] = None,
-) -> Tuple[Optional[datetime], Optional[datetime]]:
+ time_range: str | None = None,
+ time_shift: str | None = None,
+ extras: dict[str, Any] | None = None,
+) -> tuple[datetime | None, datetime | None]:
return get_since_until(
relative_start=(extras or {}).get(
"relative_start", app.config["DEFAULT_RELATIVE_START_TIME"]
@@ -45,7 +45,7 @@ def get_since_until_from_time_range(
# pylint: disable=invalid-name
def get_since_until_from_query_object(
query_object: QueryObject,
-) -> Tuple[Optional[datetime], Optional[datetime]]:
+) -> tuple[datetime | None, datetime | None]:
"""
this function will return since and until by tuple if
1) the time_range is in the query object.
diff --git a/superset/config.py b/superset/config.py
index 7d9359d14f..434456386d 100644
--- a/superset/config.py
+++ b/superset/config.py
@@ -33,20 +33,7 @@ import sys
from collections import OrderedDict
from datetime import timedelta
from email.mime.multipart import MIMEMultipart
-from typing import (
- Any,
- Callable,
- Dict,
- List,
- Literal,
- Optional,
- Set,
- Tuple,
- Type,
- TYPE_CHECKING,
- TypedDict,
- Union,
-)
+from typing import Any, Callable, Literal, TYPE_CHECKING, TypedDict
import pkg_resources
from cachelib.base import BaseCache
@@ -114,17 +101,17 @@ PACKAGE_JSON_FILE = pkg_resources.resource_filename(
FAVICONS = [{"href": "/static/assets/images/favicon.png"}]
-def _try_json_readversion(filepath: str) -> Optional[str]:
+def _try_json_readversion(filepath: str) -> str | None:
try:
- with open(filepath, "r") as f:
+ with open(filepath) as f:
return json.load(f).get("version")
except Exception: # pylint: disable=broad-except
return None
-def _try_json_readsha(filepath: str, length: int) -> Optional[str]:
+def _try_json_readsha(filepath: str, length: int) -> str | None:
try:
- with open(filepath, "r") as f:
+ with open(filepath) as f:
return json.load(f).get("GIT_SHA")[:length]
except Exception: # pylint: disable=broad-except
return None
@@ -275,7 +262,7 @@ ENABLE_PROXY_FIX = False
PROXY_FIX_CONFIG = {"x_for": 1, "x_proto": 1, "x_host": 1, "x_port": 1, "x_prefix": 1}
# Configuration for scheduling queries from SQL Lab.
-SCHEDULED_QUERIES: Dict[str, Any] = {}
+SCHEDULED_QUERIES: dict[str, Any] = {}
# ------------------------------
# GLOBALS FOR APP Builder
@@ -294,7 +281,7 @@ LOGO_TARGET_PATH = None
LOGO_TOOLTIP = ""
# Specify any text that should appear to the right of the logo
-LOGO_RIGHT_TEXT: Union[Callable[[], str], str] = ""
+LOGO_RIGHT_TEXT: Callable[[], str] | str = ""
# Enables SWAGGER UI for superset openapi spec
# ex: http://localhost:8080/swagger/v1
@@ -347,7 +334,7 @@ AUTH_TYPE = AUTH_DB
# Grant public role the same set of permissions as for a selected builtin role.
# This is useful if one wants to enable anonymous users to view
# dashboards. Explicit grant on specific datasets is still required.
-PUBLIC_ROLE_LIKE: Optional[str] = None
+PUBLIC_ROLE_LIKE: str | None = None
# ---------------------------------------------------
# Babel config for translations
@@ -390,8 +377,8 @@ LANGUAGES = {}
class D3Format(TypedDict, total=False):
decimal: str
thousands: str
- grouping: List[int]
- currency: List[str]
+ grouping: list[int]
+ currency: list[str]
D3_FORMAT: D3Format = {}
@@ -404,7 +391,7 @@ D3_FORMAT: D3Format = {}
# For example, DEFAULT_FEATURE_FLAGS = { 'FOO': True, 'BAR': False } here
# and FEATURE_FLAGS = { 'BAR': True, 'BAZ': True } in superset_config.py
# will result in combined feature flags of { 'FOO': True, 'BAR': True, 'BAZ': True }
-DEFAULT_FEATURE_FLAGS: Dict[str, bool] = {
+DEFAULT_FEATURE_FLAGS: dict[str, bool] = {
# Experimental feature introducing a client (browser) cache
"CLIENT_CACHE": False, # deprecated
"DISABLE_DATASET_SOURCE_EDIT": False, # deprecated
@@ -527,7 +514,7 @@ DEFAULT_FEATURE_FLAGS.update(
)
# This is merely a default.
-FEATURE_FLAGS: Dict[str, bool] = {}
+FEATURE_FLAGS: dict[str, bool] = {}
# A function that receives a dict of all feature flags
# (DEFAULT_FEATURE_FLAGS merged with FEATURE_FLAGS)
@@ -543,7 +530,7 @@ FEATURE_FLAGS: Dict[str, bool] = {}
# if hasattr(g, "user") and g.user.is_active:
# feature_flags_dict['some_feature'] = g.user and g.user.get_id() == 5
# return feature_flags_dict
-GET_FEATURE_FLAGS_FUNC: Optional[Callable[[Dict[str, bool]], Dict[str, bool]]] = None
+GET_FEATURE_FLAGS_FUNC: Callable[[dict[str, bool]], dict[str, bool]] | None = None
# A function that receives a feature flag name and an optional default value.
# Has a similar utility to GET_FEATURE_FLAGS_FUNC but it's useful to not force the
# evaluation of all feature flags when just evaluating a single one.
@@ -551,7 +538,7 @@ GET_FEATURE_FLAGS_FUNC: Optional[Callable[[Dict[str, bool]], Dict[str, bool]]] =
# Note that the default `get_feature_flags` will evaluate each feature with this
# callable when the config key is set, so don't use both GET_FEATURE_FLAGS_FUNC
# and IS_FEATURE_ENABLED_FUNC in conjunction.
-IS_FEATURE_ENABLED_FUNC: Optional[Callable[[str, Optional[bool]], bool]] = None
+IS_FEATURE_ENABLED_FUNC: Callable[[str, bool | None], bool] | None = None
# A function that expands/overrides the frontend `bootstrap_data.common` object.
# Can be used to implement custom frontend functionality,
# or dynamically change certain configs.
@@ -563,7 +550,7 @@ IS_FEATURE_ENABLED_FUNC: Optional[Callable[[str, Optional[bool]], bool]] = None
# Takes as a parameter the common bootstrap payload before transformations.
# Returns a dict containing data that should be added or overridden to the payload.
COMMON_BOOTSTRAP_OVERRIDES_FUNC: Callable[
- [Dict[str, Any]], Dict[str, Any]
+ [dict[str, Any]], dict[str, Any]
] = lambda data: {} # default: empty dict
# EXTRA_CATEGORICAL_COLOR_SCHEMES is used for adding custom categorical color schemes
@@ -580,7 +567,7 @@ COMMON_BOOTSTRAP_OVERRIDES_FUNC: Callable[
# }]
# This is merely a default
-EXTRA_CATEGORICAL_COLOR_SCHEMES: List[Dict[str, Any]] = []
+EXTRA_CATEGORICAL_COLOR_SCHEMES: list[dict[str, Any]] = []
# THEME_OVERRIDES is used for adding custom theme to superset
# example code for "My theme" custom scheme
@@ -599,7 +586,7 @@ EXTRA_CATEGORICAL_COLOR_SCHEMES: List[Dict[str, Any]] = []
# }
# }
-THEME_OVERRIDES: Dict[str, Any] = {}
+THEME_OVERRIDES: dict[str, Any] = {}
# EXTRA_SEQUENTIAL_COLOR_SCHEMES is used for adding custom sequential color schemes
# EXTRA_SEQUENTIAL_COLOR_SCHEMES = [
@@ -615,7 +602,7 @@ THEME_OVERRIDES: Dict[str, Any] = {}
# }]
# This is merely a default
-EXTRA_SEQUENTIAL_COLOR_SCHEMES: List[Dict[str, Any]] = []
+EXTRA_SEQUENTIAL_COLOR_SCHEMES: list[dict[str, Any]] = []
# ---------------------------------------------------
# Thumbnail config (behind feature flag)
@@ -626,7 +613,7 @@ EXTRA_SEQUENTIAL_COLOR_SCHEMES: List[Dict[str, Any]] = []
# `superset.tasks.types.ExecutorType` for a full list of executor options.
# To always use a fixed user account, use the following configuration:
# THUMBNAIL_EXECUTE_AS = [ExecutorType.SELENIUM]
-THUMBNAIL_SELENIUM_USER: Optional[str] = "admin"
+THUMBNAIL_SELENIUM_USER: str | None = "admin"
THUMBNAIL_EXECUTE_AS = [ExecutorType.CURRENT_USER, ExecutorType.SELENIUM]
# By default, thumbnail digests are calculated based on various parameters in the
@@ -639,10 +626,10 @@ THUMBNAIL_EXECUTE_AS = [ExecutorType.CURRENT_USER, ExecutorType.SELENIUM]
# `THUMBNAIL_EXECUTE_AS`; the executor is only equal to the currently logged in
# user if the executor type is equal to `ExecutorType.CURRENT_USER`)
# and return the final digest string:
-THUMBNAIL_DASHBOARD_DIGEST_FUNC: Optional[
+THUMBNAIL_DASHBOARD_DIGEST_FUNC: None | (
Callable[[Dashboard, ExecutorType, str], str]
-] = None
-THUMBNAIL_CHART_DIGEST_FUNC: Optional[Callable[[Slice, ExecutorType, str], str]] = None
+) = None
+THUMBNAIL_CHART_DIGEST_FUNC: Callable[[Slice, ExecutorType, str], str] | None = None
THUMBNAIL_CACHE_CONFIG: CacheConfig = {
"CACHE_TYPE": "NullCache",
@@ -714,7 +701,7 @@ STORE_CACHE_KEYS_IN_METADATA_DB = False
# CORS Options
ENABLE_CORS = False
-CORS_OPTIONS: Dict[Any, Any] = {}
+CORS_OPTIONS: dict[Any, Any] = {}
# Sanitizes the HTML content used in markdowns to allow its rendering in a safe manner.
# Disabling this option is not recommended for security reasons. If you wish to allow
@@ -736,7 +723,7 @@ HTML_SANITIZATION = True
# }
# }
# Be careful when extending the default schema to avoid XSS attacks.
-HTML_SANITIZATION_SCHEMA_EXTENSIONS: Dict[str, Any] = {}
+HTML_SANITIZATION_SCHEMA_EXTENSIONS: dict[str, Any] = {}
# Chrome allows up to 6 open connections per domain at a time. When there are more
# than 6 slices in dashboard, a lot of time fetch requests are queued up and wait for
@@ -768,13 +755,13 @@ EXCEL_EXPORT = {"encoding": "utf-8"}
# time grains in superset/db_engine_specs/base.py).
# For example: to disable 1 second time grain:
# TIME_GRAIN_DENYLIST = ['PT1S']
-TIME_GRAIN_DENYLIST: List[str] = []
+TIME_GRAIN_DENYLIST: list[str] = []
# Additional time grains to be supported using similar definitions as in
# superset/db_engine_specs/base.py.
# For example: To add a new 2 second time grain:
# TIME_GRAIN_ADDONS = {'PT2S': '2 second'}
-TIME_GRAIN_ADDONS: Dict[str, str] = {}
+TIME_GRAIN_ADDONS: dict[str, str] = {}
# Implementation of additional time grains per engine.
# The column to be truncated is denoted `{col}` in the expression.
@@ -784,7 +771,7 @@ TIME_GRAIN_ADDONS: Dict[str, str] = {}
# 'PT2S': 'toDateTime(intDiv(toUInt32(toDateTime({col})), 2)*2)'
# }
# }
-TIME_GRAIN_ADDON_EXPRESSIONS: Dict[str, Dict[str, str]] = {}
+TIME_GRAIN_ADDON_EXPRESSIONS: dict[str, dict[str, str]] = {}
# ---------------------------------------------------
# List of viz_types not allowed in your environment
@@ -792,7 +779,7 @@ TIME_GRAIN_ADDON_EXPRESSIONS: Dict[str, Dict[str, str]] = {}
# VIZ_TYPE_DENYLIST = ['pivot_table', 'treemap']
# ---------------------------------------------------
-VIZ_TYPE_DENYLIST: List[str] = []
+VIZ_TYPE_DENYLIST: list[str] = []
# --------------------------------------------------
# Modules, datasources and middleware to be registered
@@ -802,8 +789,8 @@ DEFAULT_MODULE_DS_MAP = OrderedDict(
("superset.connectors.sqla.models", ["SqlaTable"]),
]
)
-ADDITIONAL_MODULE_DS_MAP: Dict[str, List[str]] = {}
-ADDITIONAL_MIDDLEWARE: List[Callable[..., Any]] = []
+ADDITIONAL_MODULE_DS_MAP: dict[str, list[str]] = {}
+ADDITIONAL_MIDDLEWARE: list[Callable[..., Any]] = []
# 1) https://docs.python-guide.org/writing/logging/
# 2) https://docs.python.org/2/library/logging.config.html
@@ -925,9 +912,9 @@ CELERY_CONFIG = CeleryConfig # pylint: disable=invalid-name
# within the app
# OVERRIDE_HTTP_HEADERS: sets override values for HTTP headers. These values will
# override anything set within the app
-DEFAULT_HTTP_HEADERS: Dict[str, Any] = {}
-OVERRIDE_HTTP_HEADERS: Dict[str, Any] = {}
-HTTP_HEADERS: Dict[str, Any] = {}
+DEFAULT_HTTP_HEADERS: dict[str, Any] = {}
+OVERRIDE_HTTP_HEADERS: dict[str, Any] = {}
+HTTP_HEADERS: dict[str, Any] = {}
# The db id here results in selecting this one as a default in SQL Lab
DEFAULT_DB_ID = None
@@ -974,8 +961,8 @@ SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT = int(timedelta(seconds=10).total_seconds())
# return out
#
# QUERY_COST_FORMATTERS_BY_ENGINE: {"postgresql": postgres_query_cost_formatter}
-QUERY_COST_FORMATTERS_BY_ENGINE: Dict[
- str, Callable[[List[Dict[str, Any]]], List[Dict[str, Any]]]
+QUERY_COST_FORMATTERS_BY_ENGINE: dict[
+ str, Callable[[list[dict[str, Any]]], list[dict[str, Any]]]
] = {}
# Flag that controls if limit should be enforced on the CTA (create table as queries).
@@ -1000,13 +987,13 @@ SQLLAB_CTAS_NO_LIMIT = False
# else:
# return f'tmp_{schema}'
# Function accepts database object, user object, schema name and sql that will be run.
-SQLLAB_CTAS_SCHEMA_NAME_FUNC: Optional[
+SQLLAB_CTAS_SCHEMA_NAME_FUNC: None | (
Callable[[Database, models.User, str, str], str]
-] = None
+) = None
# If enabled, it can be used to store the results of long-running queries
# in SQL Lab by using the "Run Async" button/feature
-RESULTS_BACKEND: Optional[BaseCache] = None
+RESULTS_BACKEND: BaseCache | None = None
# Use PyArrow and MessagePack for async query results serialization,
# rather than JSON. This feature requires additional testing from the
@@ -1028,7 +1015,7 @@ CSV_TO_HIVE_UPLOAD_DIRECTORY = "EXTERNAL_HIVE_TABLES/"
def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name
database: Database,
user: models.User, # pylint: disable=unused-argument
- schema: Optional[str],
+ schema: str | None,
) -> str:
# Note the final empty path enforces a trailing slash.
return os.path.join(
@@ -1038,14 +1025,14 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name
# The namespace within hive where the tables created from
# uploading CSVs will be stored.
-UPLOADED_CSV_HIVE_NAMESPACE: Optional[str] = None
+UPLOADED_CSV_HIVE_NAMESPACE: str | None = None
# Function that computes the allowed schemas for the CSV uploads.
# Allowed schemas will be a union of schemas_allowed_for_file_upload
# db configuration and a result of this function.
# mypy doesn't catch that if case ensures list content being always str
-ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[[Database, models.User], List[str]] = (
+ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[[Database, models.User], list[str]] = (
lambda database, user: [UPLOADED_CSV_HIVE_NAMESPACE]
if UPLOADED_CSV_HIVE_NAMESPACE
else []
@@ -1062,7 +1049,7 @@ CSV_DEFAULT_NA_NAMES = list(STR_NA_VALUES)
# It's important to make sure that the objects exposed (as well as objects attached
# to those objets) are harmless. We recommend only exposing simple/pure functions that
# return native types.
-JINJA_CONTEXT_ADDONS: Dict[str, Callable[..., Any]] = {}
+JINJA_CONTEXT_ADDONS: dict[str, Callable[..., Any]] = {}
# A dictionary of macro template processors (by engine) that gets merged into global
# template processors. The existing template processors get updated with this
@@ -1070,7 +1057,7 @@ JINJA_CONTEXT_ADDONS: Dict[str, Callable[..., Any]] = {}
# dictionary. The customized addons don't necessarily need to use Jinja templating
# language. This allows you to define custom logic to process templates on a per-engine
# basis. Example value = `{"presto": CustomPrestoTemplateProcessor}`
-CUSTOM_TEMPLATE_PROCESSORS: Dict[str, Type[BaseTemplateProcessor]] = {}
+CUSTOM_TEMPLATE_PROCESSORS: dict[str, type[BaseTemplateProcessor]] = {}
# Roles that are controlled by the API / Superset and should not be changes
# by humans.
@@ -1125,7 +1112,7 @@ PERMISSION_INSTRUCTIONS_LINK = ""
# Integrate external Blueprints to the app by passing them to your
# configuration. These blueprints will get integrated in the app
-BLUEPRINTS: List[Blueprint] = []
+BLUEPRINTS: list[Blueprint] = []
# Provide a callable that receives a tracking_url and returns another
# URL. This is used to translate internal Hadoop job tracker URL
@@ -1142,7 +1129,7 @@ TRACKING_URL_TRANSFORMER = lambda url: url
# customize the polling time of each engine
-DB_POLL_INTERVAL_SECONDS: Dict[str, int] = {}
+DB_POLL_INTERVAL_SECONDS: dict[str, int] = {}
# Interval between consecutive polls when using Presto Engine
# See here: https://github.com/dropbox/PyHive/blob/8eb0aeab8ca300f3024655419b93dad926c1a351/pyhive/presto.py#L93 # pylint: disable=line-too-long,useless-suppression
@@ -1159,7 +1146,7 @@ PRESTO_POLL_INTERVAL = int(timedelta(seconds=1).total_seconds())
# "another_auth_method": auth_method,
# },
# }
-ALLOWED_EXTRA_AUTHENTICATIONS: Dict[str, Dict[str, Callable[..., Any]]] = {}
+ALLOWED_EXTRA_AUTHENTICATIONS: dict[str, dict[str, Callable[..., Any]]] = {}
# The id of a template dashboard that should be copied to every new user
DASHBOARD_TEMPLATE_ID = None
@@ -1224,14 +1211,14 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument
# Owners, filters for created_by, etc.
# The users can also be excluded by overriding the get_exclude_users_from_lists method
# in security manager
-EXCLUDE_USERS_FROM_LISTS: Optional[List[str]] = None
+EXCLUDE_USERS_FROM_LISTS: list[str] | None = None
# For database connections, this dictionary will remove engines from the available
# list/dropdown if you do not want these dbs to show as available.
# The available list is generated by driver installed, and some engines have multiple
# drivers.
# e.g., DBS_AVAILABLE_DENYLIST: Dict[str, Set[str]] = {"databricks": {"pyhive", "pyodbc"}}
-DBS_AVAILABLE_DENYLIST: Dict[str, Set[str]] = {}
+DBS_AVAILABLE_DENYLIST: dict[str, set[str]] = {}
# This auth provider is used by background (offline) tasks that need to access
# protected resources. Can be overridden by end users in order to support
@@ -1261,7 +1248,7 @@ ALERT_REPORTS_WORKING_TIME_OUT_KILL = True
# ExecutorType.OWNER,
# ExecutorType.SELENIUM,
# ]
-ALERT_REPORTS_EXECUTE_AS: List[ExecutorType] = [ExecutorType.OWNER]
+ALERT_REPORTS_EXECUTE_AS: list[ExecutorType] = [ExecutorType.OWNER]
# if ALERT_REPORTS_WORKING_TIME_OUT_KILL is True, set a celery hard timeout
# Equal to working timeout + ALERT_REPORTS_WORKING_TIME_OUT_LAG
ALERT_REPORTS_WORKING_TIME_OUT_LAG = int(timedelta(seconds=10).total_seconds())
@@ -1286,7 +1273,7 @@ EMAIL_REPORTS_SUBJECT_PREFIX = "[Report] "
EMAIL_REPORTS_CTA = "Explore in Superset"
# Slack API token for the superset reports, either string or callable
-SLACK_API_TOKEN: Optional[Union[Callable[[], str], str]] = None
+SLACK_API_TOKEN: Callable[[], str] | str | None = None
SLACK_PROXY = None
# The webdriver to use for generating reports. Use one of the following
@@ -1310,7 +1297,7 @@ WEBDRIVER_WINDOW = {
WEBDRIVER_AUTH_FUNC = None
# Any config options to be passed as-is to the webdriver
-WEBDRIVER_CONFIGURATION: Dict[Any, Any] = {"service_log_path": "/dev/null"}
+WEBDRIVER_CONFIGURATION: dict[Any, Any] = {"service_log_path": "/dev/null"}
# Additional args to be passed as arguments to the config object
# Note: If using Chrome, you'll want to add the "--marionette" arg.
@@ -1353,7 +1340,7 @@ SQL_VALIDATORS_BY_ENGINE = {
# displayed prominently in the "Add Database" dialog. You should
# use the "engine_name" attribute of the corresponding DB engine spec
# in `superset/db_engine_specs/`.
-PREFERRED_DATABASES: List[str] = [
+PREFERRED_DATABASES: list[str] = [
"PostgreSQL",
"Presto",
"MySQL",
@@ -1386,7 +1373,7 @@ TALISMAN_CONFIG = {
#
SESSION_COOKIE_HTTPONLY = True # Prevent cookie from being read by frontend JS?
SESSION_COOKIE_SECURE = False # Prevent cookie from being transmitted over non-tls?
-SESSION_COOKIE_SAMESITE: Optional[Literal["None", "Lax", "Strict"]] = "Lax"
+SESSION_COOKIE_SAMESITE: Literal["None", "Lax", "Strict"] | None = "Lax"
# Accepts None, "basic" and "strong", more details on: https://flask-login.readthedocs.io/en/latest/#session-protection
SESSION_PROTECTION = "strong"
@@ -1418,7 +1405,7 @@ DATASET_IMPORT_ALLOWED_DATA_URLS = [r".*"]
# Path used to store SSL certificates that are generated when using custom certs.
# Defaults to temporary directory.
# Example: SSL_CERT_PATH = "/certs"
-SSL_CERT_PATH: Optional[str] = None
+SSL_CERT_PATH: str | None = None
# SQLA table mutator, every time we fetch the metadata for a certain table
# (superset.connectors.sqla.models.SqlaTable), we call this hook
@@ -1443,9 +1430,9 @@ GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT = 1000
GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT_FIREHOSE = 1000000
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME = "async-token"
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SECURE = False
-GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SAMESITE: Optional[
+GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SAMESITE: None | (
Literal["None", "Lax", "Strict"]
-] = None
+) = None
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_DOMAIN = None
GLOBAL_ASYNC_QUERIES_JWT_SECRET = "test-secret-change-me"
GLOBAL_ASYNC_QUERIES_TRANSPORT = "polling"
@@ -1461,7 +1448,7 @@ GUEST_TOKEN_JWT_ALGO = "HS256"
GUEST_TOKEN_HEADER_NAME = "X-GuestToken"
GUEST_TOKEN_JWT_EXP_SECONDS = 300 # 5 minutes
# Guest token audience for the embedded superset, either string or callable
-GUEST_TOKEN_JWT_AUDIENCE: Optional[Union[Callable[[], str], str]] = None
+GUEST_TOKEN_JWT_AUDIENCE: Callable[[], str] | str | None = None
# A SQL dataset health check. Note if enabled it is strongly advised that the callable
# be memoized to aid with performance, i.e.,
@@ -1492,7 +1479,7 @@ GUEST_TOKEN_JWT_AUDIENCE: Optional[Union[Callable[[], str], str]] = None
# cache_manager.cache.delete_memoized(func)
# cache_manager.cache.set(name, code, timeout=0)
#
-DATASET_HEALTH_CHECK: Optional[Callable[["SqlaTable"], str]] = None
+DATASET_HEALTH_CHECK: Callable[[SqlaTable], str] | None = None
# Do not show user info or profile in the menu
MENU_HIDE_USER_INFO = False
@@ -1502,7 +1489,7 @@ MENU_HIDE_USER_INFO = False
ENABLE_BROAD_ACTIVITY_ACCESS = True
# the advanced data type key should correspond to that set in the column metadata
-ADVANCED_DATA_TYPES: Dict[str, AdvancedDataType] = {
+ADVANCED_DATA_TYPES: dict[str, AdvancedDataType] = {
"internet_address": internet_address,
"port": internet_port,
}
@@ -1514,9 +1501,9 @@ ADVANCED_DATA_TYPES: Dict[str, AdvancedDataType] = {
# "Xyz",
# [{"col": 'created_by', "opr": 'rel_o_m', "value": 10}],
# )
-WELCOME_PAGE_LAST_TAB: Union[
- Literal["examples", "all"], Tuple[str, List[Dict[str, Any]]]
-] = "all"
+WELCOME_PAGE_LAST_TAB: (
+ Literal["examples", "all"] | tuple[str, list[dict[str, Any]]]
+) = "all"
# Configuration for environment tag shown on the navbar. Setting 'text' to '' will hide the tag.
# 'color' can either be a hex color code, or a dot-indexed theme color (e.g. error.base)
diff --git a/superset/connectors/base/models.py b/superset/connectors/base/models.py
index 2cb0d54c51..d43d078639 100644
--- a/superset/connectors/base/models.py
+++ b/superset/connectors/base/models.py
@@ -16,21 +16,12 @@
# under the License.
from __future__ import annotations
+import builtins
import json
+from collections.abc import Hashable
from datetime import datetime
from enum import Enum
-from typing import (
- Any,
- Dict,
- Hashable,
- List,
- Optional,
- Set,
- Tuple,
- Type,
- TYPE_CHECKING,
- Union,
-)
+from typing import Any, TYPE_CHECKING
from flask_appbuilder.security.sqla.models import User
from flask_babel import gettext as __
@@ -89,23 +80,23 @@ class BaseDatasource(
# ---------------------------------------------------------------
# class attributes to define when deriving BaseDatasource
# ---------------------------------------------------------------
- __tablename__: Optional[str] = None # {connector_name}_datasource
- baselink: Optional[str] = None # url portion pointing to ModelView endpoint
+ __tablename__: str | None = None # {connector_name}_datasource
+ baselink: str | None = None # url portion pointing to ModelView endpoint
@property
- def column_class(self) -> Type["BaseColumn"]:
+ def column_class(self) -> type[BaseColumn]:
# link to derivative of BaseColumn
raise NotImplementedError()
@property
- def metric_class(self) -> Type["BaseMetric"]:
+ def metric_class(self) -> type[BaseMetric]:
# link to derivative of BaseMetric
raise NotImplementedError()
- owner_class: Optional[User] = None
+ owner_class: User | None = None
# Used to do code highlighting when displaying the query in the UI
- query_language: Optional[str] = None
+ query_language: str | None = None
# Only some datasources support Row Level Security
is_rls_supported: bool = False
@@ -131,9 +122,9 @@ class BaseDatasource(
is_managed_externally = Column(Boolean, nullable=False, default=False)
external_url = Column(Text, nullable=True)
- sql: Optional[str] = None
- owners: List[User]
- update_from_object_fields: List[str]
+ sql: str | None = None
+ owners: list[User]
+ update_from_object_fields: list[str]
extra_import_fields = ["is_managed_externally", "external_url"]
@@ -142,7 +133,7 @@ class BaseDatasource(
return DatasourceKind.VIRTUAL if self.sql else DatasourceKind.PHYSICAL
@property
- def owners_data(self) -> List[Dict[str, Any]]:
+ def owners_data(self) -> list[dict[str, Any]]:
return [
{
"first_name": o.first_name,
@@ -167,8 +158,8 @@ class BaseDatasource(
),
)
- columns: List["BaseColumn"] = []
- metrics: List["BaseMetric"] = []
+ columns: list[BaseColumn] = []
+ metrics: list[BaseMetric] = []
@property
def type(self) -> str:
@@ -180,11 +171,11 @@ class BaseDatasource(
return f"{self.id}__{self.type}"
@property
- def column_names(self) -> List[str]:
+ def column_names(self) -> list[str]:
return sorted([c.column_name for c in self.columns], key=lambda x: x or "")
@property
- def columns_types(self) -> Dict[str, str]:
+ def columns_types(self) -> dict[str, str]:
return {c.column_name: c.type for c in self.columns}
@property
@@ -196,26 +187,26 @@ class BaseDatasource(
raise NotImplementedError()
@property
- def connection(self) -> Optional[str]:
+ def connection(self) -> str | None:
"""String representing the context of the Datasource"""
return None
@property
- def schema(self) -> Optional[str]:
+ def schema(self) -> str | None:
"""String representing the schema of the Datasource (if it applies)"""
return None
@property
- def filterable_column_names(self) -> List[str]:
+ def filterable_column_names(self) -> list[str]:
return sorted([c.column_name for c in self.columns if c.filterable])
@property
- def dttm_cols(self) -> List[str]:
+ def dttm_cols(self) -> list[str]:
return []
@property
def url(self) -> str:
- return "/{}/edit/{}".format(self.baselink, self.id)
+ return f"/{self.baselink}/edit/{self.id}"
@property
def explore_url(self) -> str:
@@ -224,10 +215,10 @@ class BaseDatasource(
return f"/explore/?datasource_type={self.type}&datasource_id={self.id}"
@property
- def column_formats(self) -> Dict[str, Optional[str]]:
+ def column_formats(self) -> dict[str, str | None]:
return {m.metric_name: m.d3format for m in self.metrics if m.d3format}
- def add_missing_metrics(self, metrics: List["BaseMetric"]) -> None:
+ def add_missing_metrics(self, metrics: list[BaseMetric]) -> None:
existing_metrics = {m.metric_name for m in self.metrics}
for metric in metrics:
if metric.metric_name not in existing_metrics:
@@ -235,7 +226,7 @@ class BaseDatasource(
self.metrics.append(metric)
@property
- def short_data(self) -> Dict[str, Any]:
+ def short_data(self) -> dict[str, Any]:
"""Data representation of the datasource sent to the frontend"""
return {
"edit_url": self.url,
@@ -249,11 +240,11 @@ class BaseDatasource(
}
@property
- def select_star(self) -> Optional[str]:
+ def select_star(self) -> str | None:
pass
@property
- def order_by_choices(self) -> List[Tuple[str, str]]:
+ def order_by_choices(self) -> list[tuple[str, str]]:
choices = []
# self.column_names return sorted column_names
for column_name in self.column_names:
@@ -267,7 +258,7 @@ class BaseDatasource(
return choices
@property
- def verbose_map(self) -> Dict[str, str]:
+ def verbose_map(self) -> dict[str, str]:
verb_map = {"__timestamp": "Time"}
verb_map.update(
{o.metric_name: o.verbose_name or o.metric_name for o in self.metrics}
@@ -278,7 +269,7 @@ class BaseDatasource(
return verb_map
@property
- def data(self) -> Dict[str, Any]:
+ def data(self) -> dict[str, Any]:
"""Data representation of the datasource sent to the frontend"""
return {
# simple fields
@@ -313,8 +304,8 @@ class BaseDatasource(
}
def data_for_slices( # pylint: disable=too-many-locals
- self, slices: List[Slice]
- ) -> Dict[str, Any]:
+ self, slices: list[Slice]
+ ) -> dict[str, Any]:
"""
The representation of the datasource containing only the required data
to render the provided slices.
@@ -381,8 +372,8 @@ class BaseDatasource(
if metric["metric_name"] in metric_names
]
- filtered_columns: List[Column] = []
- column_types: Set[GenericDataType] = set()
+ filtered_columns: list[Column] = []
+ column_types: set[GenericDataType] = set()
for column in data["columns"]:
generic_type = column.get("type_generic")
if generic_type is not None:
@@ -413,18 +404,18 @@ class BaseDatasource(
@staticmethod
def filter_values_handler( # pylint: disable=too-many-arguments
- values: Optional[FilterValues],
+ values: FilterValues | None,
operator: str,
target_generic_type: GenericDataType,
- target_native_type: Optional[str] = None,
+ target_native_type: str | None = None,
is_list_target: bool = False,
- db_engine_spec: Optional[Type[BaseEngineSpec]] = None,
- db_extra: Optional[Dict[str, Any]] = None,
- ) -> Optional[FilterValues]:
+ db_engine_spec: builtins.type[BaseEngineSpec] | None = None,
+ db_extra: dict[str, Any] | None = None,
+ ) -> FilterValues | None:
if values is None:
return None
- def handle_single_value(value: Optional[FilterValue]) -> Optional[FilterValue]:
+ def handle_single_value(value: FilterValue | None) -> FilterValue | None:
if operator == utils.FilterOperator.TEMPORAL_RANGE:
return value
if (
@@ -464,7 +455,7 @@ class BaseDatasource(
values = values[0] if values else None
return values
- def external_metadata(self) -> List[Dict[str, str]]:
+ def external_metadata(self) -> list[dict[str, str]]:
"""Returns column information from the external system"""
raise NotImplementedError()
@@ -483,7 +474,7 @@ class BaseDatasource(
"""
raise NotImplementedError()
- def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
+ def values_for_column(self, column_name: str, limit: int = 10000) -> list[Any]:
"""Given a column, returns an iterable of distinct values
This is used to populate the dropdown showing a list of
@@ -494,7 +485,7 @@ class BaseDatasource(
def default_query(qry: Query) -> Query:
return qry
- def get_column(self, column_name: Optional[str]) -> Optional["BaseColumn"]:
+ def get_column(self, column_name: str | None) -> BaseColumn | None:
if not column_name:
return None
for col in self.columns:
@@ -504,11 +495,11 @@ class BaseDatasource(
@staticmethod
def get_fk_many_from_list(
- object_list: List[Any],
- fkmany: List[Column],
- fkmany_class: Type[Union["BaseColumn", "BaseMetric"]],
+ object_list: list[Any],
+ fkmany: list[Column],
+ fkmany_class: builtins.type[BaseColumn | BaseMetric],
key_attr: str,
- ) -> List[Column]:
+ ) -> list[Column]:
"""Update ORM one-to-many list from object list
Used for syncing metrics and columns using the same code"""
@@ -541,7 +532,7 @@ class BaseDatasource(
fkmany += new_fks
return fkmany
- def update_from_object(self, obj: Dict[str, Any]) -> None:
+ def update_from_object(self, obj: dict[str, Any]) -> None:
"""Update datasource from a data structure
The UI's table editor crafts a complex data structure that
@@ -578,7 +569,7 @@ class BaseDatasource(
def get_extra_cache_keys( # pylint: disable=no-self-use
self, query_obj: QueryObjectDict # pylint: disable=unused-argument
- ) -> List[Hashable]:
+ ) -> list[Hashable]:
"""If a datasource needs to provide additional keys for calculation of
cache keys, those can be provided via this method
@@ -607,14 +598,14 @@ class BaseDatasource(
@classmethod
def get_datasource_by_name(
cls, session: Session, datasource_name: str, schema: str, database_name: str
- ) -> Optional["BaseDatasource"]:
+ ) -> BaseDatasource | None:
raise NotImplementedError()
class BaseColumn(AuditMixinNullable, ImportExportMixin):
"""Interface for column"""
- __tablename__: Optional[str] = None # {connector_name}_column
+ __tablename__: str | None = None # {connector_name}_column
id = Column(Integer, primary_key=True)
column_name = Column(String(255), nullable=False)
@@ -628,7 +619,7 @@ class BaseColumn(AuditMixinNullable, ImportExportMixin):
is_dttm = None
# [optional] Set this to support import/export functionality
- export_fields: List[Any] = []
+ export_fields: list[Any] = []
def __repr__(self) -> str:
return str(self.column_name)
@@ -666,7 +657,7 @@ class BaseColumn(AuditMixinNullable, ImportExportMixin):
return self.type and any(map(lambda t: t in self.type.upper(), self.bool_types))
@property
- def type_generic(self) -> Optional[utils.GenericDataType]:
+ def type_generic(self) -> utils.GenericDataType | None:
if self.is_string:
return utils.GenericDataType.STRING
if self.is_boolean:
@@ -686,7 +677,7 @@ class BaseColumn(AuditMixinNullable, ImportExportMixin):
raise NotImplementedError()
@property
- def data(self) -> Dict[str, Any]:
+ def data(self) -> dict[str, Any]:
attrs = (
"id",
"column_name",
@@ -705,7 +696,7 @@ class BaseColumn(AuditMixinNullable, ImportExportMixin):
class BaseMetric(AuditMixinNullable, ImportExportMixin):
"""Interface for Metrics"""
- __tablename__: Optional[str] = None # {connector_name}_metric
+ __tablename__: str | None = None # {connector_name}_metric
id = Column(Integer, primary_key=True)
metric_name = Column(String(255), nullable=False)
@@ -730,7 +721,7 @@ class BaseMetric(AuditMixinNullable, ImportExportMixin):
"""
@property
- def perm(self) -> Optional[str]:
+ def perm(self) -> str | None:
raise NotImplementedError()
@property
@@ -738,7 +729,7 @@ class BaseMetric(AuditMixinNullable, ImportExportMixin):
raise NotImplementedError()
@property
- def data(self) -> Dict[str, Any]:
+ def data(self) -> dict[str, Any]:
attrs = (
"id",
"metric_name",
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index 8833d6f6cb..41a9c89757 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -22,21 +22,10 @@ import json
import logging
import re
from collections import defaultdict
+from collections.abc import Hashable
from dataclasses import dataclass, field
from datetime import datetime, timedelta
-from typing import (
- Any,
- Callable,
- cast,
- Dict,
- Hashable,
- List,
- Optional,
- Set,
- Tuple,
- Type,
- Union,
-)
+from typing import Any, Callable, cast
import dateutil.parser
import numpy as np
@@ -136,9 +125,9 @@ ADDITIVE_METRIC_TYPES_LOWER = {op.lower() for op in ADDITIVE_METRIC_TYPES}
@dataclass
class MetadataResult:
- added: List[str] = field(default_factory=list)
- removed: List[str] = field(default_factory=list)
- modified: List[str] = field(default_factory=list)
+ added: list[str] = field(default_factory=list)
+ removed: list[str] = field(default_factory=list)
+ modified: list[str] = field(default_factory=list)
class AnnotationDatasource(BaseDatasource):
@@ -190,7 +179,7 @@ class AnnotationDatasource(BaseDatasource):
def get_query_str(self, query_obj: QueryObjectDict) -> str:
raise NotImplementedError()
- def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
+ def values_for_column(self, column_name: str, limit: int = 10000) -> list[Any]:
raise NotImplementedError()
@@ -201,7 +190,7 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
__tablename__ = "table_columns"
__table_args__ = (UniqueConstraint("table_id", "column_name"),)
table_id = Column(Integer, ForeignKey("tables.id"))
- table: Mapped["SqlaTable"] = relationship(
+ table: Mapped[SqlaTable] = relationship(
"SqlaTable",
back_populates="columns",
)
@@ -263,15 +252,15 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
return self.type_generic == GenericDataType.TEMPORAL
@property
- def db_engine_spec(self) -> Type[BaseEngineSpec]:
+ def db_engine_spec(self) -> type[BaseEngineSpec]:
return self.table.db_engine_spec
@property
- def db_extra(self) -> Dict[str, Any]:
+ def db_extra(self) -> dict[str, Any]:
return self.table.database.get_extra()
@property
- def type_generic(self) -> Optional[utils.GenericDataType]:
+ def type_generic(self) -> utils.GenericDataType | None:
if self.is_dttm:
return GenericDataType.TEMPORAL
@@ -310,8 +299,8 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
def get_sqla_col(
self,
- label: Optional[str] = None,
- template_processor: Optional[BaseTemplateProcessor] = None,
+ label: str | None = None,
+ template_processor: BaseTemplateProcessor | None = None,
) -> Column:
label = label or self.column_name
db_engine_spec = self.db_engine_spec
@@ -332,10 +321,10 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
def get_timestamp_expression(
self,
- time_grain: Optional[str],
- label: Optional[str] = None,
- template_processor: Optional[BaseTemplateProcessor] = None,
- ) -> Union[TimestampExpression, Label]:
+ time_grain: str | None,
+ label: str | None = None,
+ template_processor: BaseTemplateProcessor | None = None,
+ ) -> TimestampExpression | Label:
"""
Return a SQLAlchemy Core element representation of self to be used in a query.
@@ -365,7 +354,7 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
return self.table.make_sqla_column_compatible(time_expr, label)
@property
- def data(self) -> Dict[str, Any]:
+ def data(self) -> dict[str, Any]:
attrs = (
"id",
"column_name",
@@ -399,7 +388,7 @@ class SqlMetric(Model, BaseMetric, CertificationMixin):
__tablename__ = "sql_metrics"
__table_args__ = (UniqueConstraint("table_id", "metric_name"),)
table_id = Column(Integer, ForeignKey("tables.id"))
- table: Mapped["SqlaTable"] = relationship(
+ table: Mapped[SqlaTable] = relationship(
"SqlaTable",
back_populates="metrics",
)
@@ -425,8 +414,8 @@ class SqlMetric(Model, BaseMetric, CertificationMixin):
def get_sqla_col(
self,
- label: Optional[str] = None,
- template_processor: Optional[BaseTemplateProcessor] = None,
+ label: str | None = None,
+ template_processor: BaseTemplateProcessor | None = None,
) -> Column:
label = label or self.metric_name
expression = self.expression
@@ -437,7 +426,7 @@ class SqlMetric(Model, BaseMetric, CertificationMixin):
return self.table.make_sqla_column_compatible(sqla_col, label)
@property
- def perm(self) -> Optional[str]:
+ def perm(self) -> str | None:
return (
("{parent_name}.[{obj.metric_name}](id:{obj.id})").format(
obj=self, parent_name=self.table.full_name
@@ -446,11 +435,11 @@ class SqlMetric(Model, BaseMetric, CertificationMixin):
else None
)
- def get_perm(self) -> Optional[str]:
+ def get_perm(self) -> str | None:
return self.perm
@property
- def data(self) -> Dict[str, Any]:
+ def data(self) -> dict[str, Any]:
attrs = (
"is_certified",
"certified_by",
@@ -473,11 +462,11 @@ sqlatable_user = Table(
def _process_sql_expression(
- expression: Optional[str],
+ expression: str | None,
database_id: int,
schema: str,
- template_processor: Optional[BaseTemplateProcessor] = None,
-) -> Optional[str]:
+ template_processor: BaseTemplateProcessor | None = None,
+) -> str | None:
if template_processor and expression:
expression = template_processor.process_template(expression)
if expression:
@@ -501,12 +490,12 @@ class SqlaTable(
type = "table"
query_language = "sql"
is_rls_supported = True
- columns: Mapped[List[TableColumn]] = relationship(
+ columns: Mapped[list[TableColumn]] = relationship(
TableColumn,
back_populates="table",
cascade="all, delete-orphan",
)
- metrics: Mapped[List[SqlMetric]] = relationship(
+ metrics: Mapped[list[SqlMetric]] = relationship(
SqlMetric,
back_populates="table",
cascade="all, delete-orphan",
@@ -577,11 +566,11 @@ class SqlaTable(
return self.name
@property
- def db_extra(self) -> Dict[str, Any]:
+ def db_extra(self) -> dict[str, Any]:
return self.database.get_extra()
@staticmethod
- def _apply_cte(sql: str, cte: Optional[str]) -> str:
+ def _apply_cte(sql: str, cte: str | None) -> str:
"""
Append a CTE before the SELECT statement if defined
@@ -594,7 +583,7 @@ class SqlaTable(
return sql
@property
- def db_engine_spec(self) -> Type[BaseEngineSpec]:
+ def db_engine_spec(self) -> __builtins__.type[BaseEngineSpec]:
return self.database.db_engine_spec
@property
@@ -637,9 +626,9 @@ class SqlaTable(
cls,
session: Session,
datasource_name: str,
- schema: Optional[str],
+ schema: str | None,
database_name: str,
- ) -> Optional[SqlaTable]:
+ ) -> SqlaTable | None:
schema = schema or None
query = (
session.query(cls)
@@ -660,7 +649,7 @@ class SqlaTable(
anchor = f'{name}'
return Markup(anchor)
- def get_schema_perm(self) -> Optional[str]:
+ def get_schema_perm(self) -> str | None:
"""Returns schema permission if present, database one otherwise."""
return security_manager.get_schema_perm(self.database, self.schema)
@@ -685,18 +674,18 @@ class SqlaTable(
)
@property
- def dttm_cols(self) -> List[str]:
+ def dttm_cols(self) -> list[str]:
l = [c.column_name for c in self.columns if c.is_dttm]
if self.main_dttm_col and self.main_dttm_col not in l:
l.append(self.main_dttm_col)
return l
@property
- def num_cols(self) -> List[str]:
+ def num_cols(self) -> list[str]:
return [c.column_name for c in self.columns if c.is_numeric]
@property
- def any_dttm_col(self) -> Optional[str]:
+ def any_dttm_col(self) -> str | None:
cols = self.dttm_cols
return cols[0] if cols else None
@@ -713,7 +702,7 @@ class SqlaTable(
def sql_url(self) -> str:
return self.database.sql_url + "?table_name=" + str(self.table_name)
- def external_metadata(self) -> List[Dict[str, str]]:
+ def external_metadata(self) -> list[dict[str, str]]:
# todo(yongjie): create a physical table column type in a separate PR
if self.sql:
return get_virtual_table_metadata(dataset=self) # type: ignore
@@ -724,14 +713,14 @@ class SqlaTable(
)
@property
- def time_column_grains(self) -> Dict[str, Any]:
+ def time_column_grains(self) -> dict[str, Any]:
return {
"time_columns": self.dttm_cols,
"time_grains": [grain.name for grain in self.database.grains()],
}
@property
- def select_star(self) -> Optional[str]:
+ def select_star(self) -> str | None:
# show_cols and latest_partition set to false to avoid
# the expensive cost of inspecting the DB
return self.database.select_star(
@@ -739,20 +728,20 @@ class SqlaTable(
)
@property
- def health_check_message(self) -> Optional[str]:
+ def health_check_message(self) -> str | None:
check = config["DATASET_HEALTH_CHECK"]
return check(self) if check else None
@property
- def granularity_sqla(self) -> List[Tuple[Any, Any]]:
+ def granularity_sqla(self) -> list[tuple[Any, Any]]:
return utils.choicify(self.dttm_cols)
@property
- def time_grain_sqla(self) -> List[Tuple[Any, Any]]:
+ def time_grain_sqla(self) -> list[tuple[Any, Any]]:
return [(g.duration, g.name) for g in self.database.grains() or []]
@property
- def data(self) -> Dict[str, Any]:
+ def data(self) -> dict[str, Any]:
data_ = super().data
if self.type == "table":
data_["granularity_sqla"] = self.granularity_sqla
@@ -767,7 +756,7 @@ class SqlaTable(
return data_
@property
- def extra_dict(self) -> Dict[str, Any]:
+ def extra_dict(self) -> dict[str, Any]:
try:
return json.loads(self.extra)
except (TypeError, json.JSONDecodeError):
@@ -775,7 +764,7 @@ class SqlaTable(
def get_fetch_values_predicate(
self,
- template_processor: Optional[BaseTemplateProcessor] = None,
+ template_processor: BaseTemplateProcessor | None = None,
) -> TextClause:
fetch_values_predicate = self.fetch_values_predicate
if template_processor:
@@ -792,7 +781,7 @@ class SqlaTable(
)
) from ex
- def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
+ def values_for_column(self, column_name: str, limit: int = 10000) -> list[Any]:
"""Runs query against sqla to retrieve some
sample values for the given column.
"""
@@ -869,8 +858,8 @@ class SqlaTable(
return tbl
def get_from_clause(
- self, template_processor: Optional[BaseTemplateProcessor] = None
- ) -> Tuple[Union[TableClause, Alias], Optional[str]]:
+ self, template_processor: BaseTemplateProcessor | None = None
+ ) -> tuple[TableClause | Alias, str | None]:
"""
Return where to select the columns and metrics from. Either a physical table
or a virtual table with it's own subquery. If the FROM is referencing a
@@ -899,7 +888,7 @@ class SqlaTable(
return from_clause, cte
def get_rendered_sql(
- self, template_processor: Optional[BaseTemplateProcessor] = None
+ self, template_processor: BaseTemplateProcessor | None = None
) -> str:
"""
Render sql with template engine (Jinja).
@@ -928,8 +917,8 @@ class SqlaTable(
def adhoc_metric_to_sqla(
self,
metric: AdhocMetric,
- columns_by_name: Dict[str, TableColumn],
- template_processor: Optional[BaseTemplateProcessor] = None,
+ columns_by_name: dict[str, TableColumn],
+ template_processor: BaseTemplateProcessor | None = None,
) -> ColumnElement:
"""
Turn an adhoc metric into a sqlalchemy column.
@@ -946,7 +935,7 @@ class SqlaTable(
if expression_type == utils.AdhocMetricExpressionType.SIMPLE:
metric_column = metric.get("column") or {}
column_name = cast(str, metric_column.get("column_name"))
- table_column: Optional[TableColumn] = columns_by_name.get(column_name)
+ table_column: TableColumn | None = columns_by_name.get(column_name)
if table_column:
sqla_column = table_column.get_sqla_col(
template_processor=template_processor
@@ -971,7 +960,7 @@ class SqlaTable(
self,
col: AdhocColumn,
force_type_check: bool = False,
- template_processor: Optional[BaseTemplateProcessor] = None,
+ template_processor: BaseTemplateProcessor | None = None,
) -> ColumnElement:
"""
Turn an adhoc column into a sqlalchemy column.
@@ -1021,7 +1010,7 @@ class SqlaTable(
return self.make_sqla_column_compatible(sqla_column, label)
def make_sqla_column_compatible(
- self, sqla_col: ColumnElement, label: Optional[str] = None
+ self, sqla_col: ColumnElement, label: str | None = None
) -> ColumnElement:
"""Takes a sqlalchemy column object and adds label info if supported by engine.
:param sqla_col: sqlalchemy column instance
@@ -1038,7 +1027,7 @@ class SqlaTable(
return sqla_col
def make_orderby_compatible(
- self, select_exprs: List[ColumnElement], orderby_exprs: List[ColumnElement]
+ self, select_exprs: list[ColumnElement], orderby_exprs: list[ColumnElement]
) -> None:
"""
If needed, make sure aliases for selected columns are not used in
@@ -1069,7 +1058,7 @@ class SqlaTable(
def get_sqla_row_level_filters(
self,
template_processor: BaseTemplateProcessor,
- ) -> List[TextClause]:
+ ) -> list[TextClause]:
"""
Return the appropriate row level security filters for this table and the
current user. A custom username can be passed when the user is not present in the
@@ -1078,8 +1067,8 @@ class SqlaTable(
:param template_processor: The template processor to apply to the filters.
:returns: A list of SQL clauses to be ANDed together.
"""
- all_filters: List[TextClause] = []
- filter_groups: Dict[Union[int, str], List[TextClause]] = defaultdict(list)
+ all_filters: list[TextClause] = []
+ filter_groups: dict[int | str, list[TextClause]] = defaultdict(list)
try:
for filter_ in security_manager.get_rls_filters(self):
clause = self.text(
@@ -1114,9 +1103,9 @@ class SqlaTable(
def _get_series_orderby(
self,
series_limit_metric: Metric,
- metrics_by_name: Dict[str, SqlMetric],
- columns_by_name: Dict[str, TableColumn],
- template_processor: Optional[BaseTemplateProcessor] = None,
+ metrics_by_name: dict[str, SqlMetric],
+ columns_by_name: dict[str, TableColumn],
+ template_processor: BaseTemplateProcessor | None = None,
) -> Column:
if utils.is_adhoc_metric(series_limit_metric):
assert isinstance(series_limit_metric, dict)
@@ -1138,8 +1127,8 @@ class SqlaTable(
self,
row: pd.Series,
dimension: str,
- columns_by_name: Dict[str, TableColumn],
- ) -> Union[str, int, float, bool, Text]:
+ columns_by_name: dict[str, TableColumn],
+ ) -> str | int | float | bool | Text:
"""
Convert a prequery result type to its equivalent Python type.
@@ -1159,7 +1148,7 @@ class SqlaTable(
value = value.item()
column_ = columns_by_name[dimension]
- db_extra: Dict[str, Any] = self.database.get_extra()
+ db_extra: dict[str, Any] = self.database.get_extra()
if column_.type and column_.is_temporal and isinstance(value, str):
sql = self.db_engine_spec.convert_dttm(
@@ -1174,9 +1163,9 @@ class SqlaTable(
def _get_top_groups(
self,
df: pd.DataFrame,
- dimensions: List[str],
- groupby_exprs: Dict[str, Any],
- columns_by_name: Dict[str, TableColumn],
+ dimensions: list[str],
+ groupby_exprs: dict[str, Any],
+ columns_by_name: dict[str, TableColumn],
) -> ColumnElement:
groups = []
for _unused, row in df.iterrows():
@@ -1201,7 +1190,7 @@ class SqlaTable(
errors = None
error_message = None
- def assign_column_label(df: pd.DataFrame) -> Optional[pd.DataFrame]:
+ def assign_column_label(df: pd.DataFrame) -> pd.DataFrame | None:
"""
Some engines change the case or generate bespoke column names, either by
default or due to lack of support for aliasing. This function ensures that
@@ -1283,7 +1272,7 @@ class SqlaTable(
else self.columns
)
- old_columns_by_name: Dict[str, TableColumn] = {
+ old_columns_by_name: dict[str, TableColumn] = {
col.column_name: col for col in old_columns
}
results = MetadataResult(
@@ -1341,8 +1330,8 @@ class SqlaTable(
session: Session,
database: Database,
datasource_name: str,
- schema: Optional[str] = None,
- ) -> List[SqlaTable]:
+ schema: str | None = None,
+ ) -> list[SqlaTable]:
query = (
session.query(cls)
.filter_by(database_id=database.id)
@@ -1357,9 +1346,9 @@ class SqlaTable(
cls,
session: Session,
database: Database,
- permissions: Set[str],
- schema_perms: Set[str],
- ) -> List[SqlaTable]:
+ permissions: set[str],
+ schema_perms: set[str],
+ ) -> list[SqlaTable]:
# TODO(hughhhh): add unit test
return (
session.query(cls)
@@ -1389,7 +1378,7 @@ class SqlaTable(
)
@classmethod
- def get_all_datasources(cls, session: Session) -> List[SqlaTable]:
+ def get_all_datasources(cls, session: Session) -> list[SqlaTable]:
qry = session.query(cls)
qry = cls.default_query(qry)
return qry.all()
@@ -1409,7 +1398,7 @@ class SqlaTable(
:param query_obj: query object to analyze
:return: True if there are call(s) to an `ExtraCache` method, False otherwise
"""
- templatable_statements: List[str] = []
+ templatable_statements: list[str] = []
if self.sql:
templatable_statements.append(self.sql)
if self.fetch_values_predicate:
@@ -1428,7 +1417,7 @@ class SqlaTable(
return True
return False
- def get_extra_cache_keys(self, query_obj: QueryObjectDict) -> List[Hashable]:
+ def get_extra_cache_keys(self, query_obj: QueryObjectDict) -> list[Hashable]:
"""
The cache key of a SqlaTable needs to consider any keys added by the parent
class and any keys added via `ExtraCache`.
@@ -1489,7 +1478,7 @@ class SqlaTable(
@staticmethod
def update_column( # pylint: disable=unused-argument
- mapper: Mapper, connection: Connection, target: Union[SqlMetric, TableColumn]
+ mapper: Mapper, connection: Connection, target: SqlMetric | TableColumn
) -> None:
"""
:param mapper: Unused.
diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py
index 698311dab6..d41c0555d3 100644
--- a/superset/connectors/sqla/utils.py
+++ b/superset/connectors/sqla/utils.py
@@ -17,19 +17,9 @@
from __future__ import annotations
import logging
+from collections.abc import Iterable, Iterator
from functools import lru_cache
-from typing import (
- Any,
- Callable,
- Dict,
- Iterable,
- Iterator,
- List,
- Optional,
- Type,
- TYPE_CHECKING,
- TypeVar,
-)
+from typing import Any, Callable, TYPE_CHECKING, TypeVar
from uuid import UUID
from flask_babel import lazy_gettext as _
@@ -58,8 +48,8 @@ if TYPE_CHECKING:
def get_physical_table_metadata(
database: Database,
table_name: str,
- schema_name: Optional[str] = None,
-) -> List[Dict[str, Any]]:
+ schema_name: str | None = None,
+) -> list[dict[str, Any]]:
"""Use SQLAlchemy inspector to get table metadata"""
db_engine_spec = database.db_engine_spec
db_dialect = database.get_dialect()
@@ -103,7 +93,7 @@ def get_physical_table_metadata(
return cols
-def get_virtual_table_metadata(dataset: SqlaTable) -> List[ResultSetColumnType]:
+def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]:
"""Use SQLparser to get virtual dataset metadata"""
if not dataset.sql:
raise SupersetGenericDBErrorException(
@@ -150,7 +140,7 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> List[ResultSetColumnType]:
def get_columns_description(
database: Database,
query: str,
-) -> List[ResultSetColumnType]:
+) -> list[ResultSetColumnType]:
db_engine_spec = database.db_engine_spec
try:
with database.get_raw_connection() as conn:
@@ -171,7 +161,7 @@ def get_dialect_name(drivername: str) -> str:
@lru_cache(maxsize=LRU_CACHE_MAX_SIZE)
-def get_identifier_quoter(drivername: str) -> Dict[str, Callable[[str], str]]:
+def get_identifier_quoter(drivername: str) -> dict[str, Callable[[str], str]]:
return SqlaURL.create(drivername).get_dialect()().identifier_preparer.quote
@@ -181,9 +171,9 @@ logger = logging.getLogger(__name__)
def find_cached_objects_in_session(
session: Session,
- cls: Type[DeclarativeModel],
- ids: Optional[Iterable[int]] = None,
- uuids: Optional[Iterable[UUID]] = None,
+ cls: type[DeclarativeModel],
+ ids: Iterable[int] | None = None,
+ uuids: Iterable[UUID] | None = None,
) -> Iterator[DeclarativeModel]:
"""Find known ORM instances in cached SQLA session states.
diff --git a/superset/connectors/sqla/views.py b/superset/connectors/sqla/views.py
index 0989a545fd..9116b9636e 100644
--- a/superset/connectors/sqla/views.py
+++ b/superset/connectors/sqla/views.py
@@ -447,7 +447,7 @@ class TableModelView( # pylint: disable=too-many-ancestors
resp = super().edit(pk)
if isinstance(resp, str):
return resp
- return redirect("/explore/?datasource_type=table&datasource_id={}".format(pk))
+ return redirect(f"/explore/?datasource_type=table&datasource_id={pk}")
@expose("/list/")
@has_access
diff --git a/superset/css_templates/commands/bulk_delete.py b/superset/css_templates/commands/bulk_delete.py
index 93564208c4..57612d9048 100644
--- a/superset/css_templates/commands/bulk_delete.py
+++ b/superset/css_templates/commands/bulk_delete.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import List, Optional
+from typing import Optional
from superset.commands.base import BaseCommand
from superset.css_templates.commands.exceptions import (
@@ -30,9 +30,9 @@ logger = logging.getLogger(__name__)
class BulkDeleteCssTemplateCommand(BaseCommand):
- def __init__(self, model_ids: List[int]):
+ def __init__(self, model_ids: list[int]):
self._model_ids = model_ids
- self._models: Optional[List[CssTemplate]] = None
+ self._models: Optional[list[CssTemplate]] = None
def run(self) -> None:
self.validate()
diff --git a/superset/css_templates/dao.py b/superset/css_templates/dao.py
index 1862fb7aaf..bc1a796269 100644
--- a/superset/css_templates/dao.py
+++ b/superset/css_templates/dao.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import List, Optional
+from typing import Optional
from sqlalchemy.exc import SQLAlchemyError
@@ -31,7 +31,7 @@ class CssTemplateDAO(BaseDAO):
model_cls = CssTemplate
@staticmethod
- def bulk_delete(models: Optional[List[CssTemplate]], commit: bool = True) -> None:
+ def bulk_delete(models: Optional[list[CssTemplate]], commit: bool = True) -> None:
item_ids = [model.id for model in models] if models else []
try:
db.session.query(CssTemplate).filter(CssTemplate.id.in_(item_ids)).delete(
diff --git a/superset/dao/base.py b/superset/dao/base.py
index d3675a0e17..539dbab2d5 100644
--- a/superset/dao/base.py
+++ b/superset/dao/base.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=isinstance-second-argument-not-valid-type
-from typing import Any, Dict, List, Optional, Type, Union
+from typing import Any, Optional, Union
from flask_appbuilder.models.filters import BaseFilter
from flask_appbuilder.models.sqla import Model
@@ -37,7 +37,7 @@ class BaseDAO:
Base DAO, implement base CRUD sqlalchemy operations
"""
- model_cls: Optional[Type[Model]] = None
+ model_cls: Optional[type[Model]] = None
"""
Child classes need to state the Model class so they don't need to implement basic
create, update and delete methods
@@ -75,10 +75,10 @@ class BaseDAO:
@classmethod
def find_by_ids(
cls,
- model_ids: Union[List[str], List[int]],
+ model_ids: Union[list[str], list[int]],
session: Session = None,
skip_base_filter: bool = False,
- ) -> List[Model]:
+ ) -> list[Model]:
"""
Find a List of models by a list of ids, if defined applies `base_filter`
"""
@@ -95,7 +95,7 @@ class BaseDAO:
return query.all()
@classmethod
- def find_all(cls) -> List[Model]:
+ def find_all(cls) -> list[Model]:
"""
Get all that fit the `base_filter`
"""
@@ -121,7 +121,7 @@ class BaseDAO:
return query.filter_by(**filter_by).one_or_none()
@classmethod
- def create(cls, properties: Dict[str, Any], commit: bool = True) -> Model:
+ def create(cls, properties: dict[str, Any], commit: bool = True) -> Model:
"""
Generic for creating models
:raises: DAOCreateFailedError
@@ -163,7 +163,7 @@ class BaseDAO:
@classmethod
def update(
- cls, model: Model, properties: Dict[str, Any], commit: bool = True
+ cls, model: Model, properties: dict[str, Any], commit: bool = True
) -> Model:
"""
Generic update a model
@@ -196,7 +196,7 @@ class BaseDAO:
return model
@classmethod
- def bulk_delete(cls, models: List[Model], commit: bool = True) -> None:
+ def bulk_delete(cls, models: list[Model], commit: bool = True) -> None:
try:
for model in models:
cls.delete(model, False)
diff --git a/superset/dashboards/commands/bulk_delete.py b/superset/dashboards/commands/bulk_delete.py
index 13541cd946..385f1fbc6d 100644
--- a/superset/dashboards/commands/bulk_delete.py
+++ b/superset/dashboards/commands/bulk_delete.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import List, Optional
+from typing import Optional
from flask_babel import lazy_gettext as _
@@ -37,9 +37,9 @@ logger = logging.getLogger(__name__)
class BulkDeleteDashboardCommand(BaseCommand):
- def __init__(self, model_ids: List[int]):
+ def __init__(self, model_ids: list[int]):
self._model_ids = model_ids
- self._models: Optional[List[Dashboard]] = None
+ self._models: Optional[list[Dashboard]] = None
def run(self) -> None:
self.validate()
diff --git a/superset/dashboards/commands/create.py b/superset/dashboards/commands/create.py
index 0ad8ddee7c..58acc379ba 100644
--- a/superset/dashboards/commands/create.py
+++ b/superset/dashboards/commands/create.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
@@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
class CreateDashboardCommand(CreateMixin, BaseCommand):
- def __init__(self, data: Dict[str, Any]):
+ def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
def run(self) -> Model:
@@ -48,9 +48,9 @@ class CreateDashboardCommand(CreateMixin, BaseCommand):
return dashboard
def validate(self) -> None:
- exceptions: List[ValidationError] = []
- owner_ids: Optional[List[int]] = self._properties.get("owners")
- role_ids: Optional[List[int]] = self._properties.get("roles")
+ exceptions: list[ValidationError] = []
+ owner_ids: Optional[list[int]] = self._properties.get("owners")
+ role_ids: Optional[list[int]] = self._properties.get("roles")
slug: str = self._properties.get("slug", "")
# Validate slug uniqueness
diff --git a/superset/dashboards/commands/export.py b/superset/dashboards/commands/export.py
index 886b84ffa6..2e70e29bb0 100644
--- a/superset/dashboards/commands/export.py
+++ b/superset/dashboards/commands/export.py
@@ -20,7 +20,8 @@ import json
import logging
import random
import string
-from typing import Any, Dict, Iterator, Optional, Set, Tuple
+from typing import Any, Optional
+from collections.abc import Iterator
import yaml
@@ -52,7 +53,7 @@ def suffix(length: int = 8) -> str:
)
-def get_default_position(title: str) -> Dict[str, Any]:
+def get_default_position(title: str) -> dict[str, Any]:
return {
"DASHBOARD_VERSION_KEY": "v2",
"ROOT_ID": {"children": ["GRID_ID"], "id": "ROOT_ID", "type": "ROOT"},
@@ -66,7 +67,7 @@ def get_default_position(title: str) -> Dict[str, Any]:
}
-def append_charts(position: Dict[str, Any], charts: Set[Slice]) -> Dict[str, Any]:
+def append_charts(position: dict[str, Any], charts: set[Slice]) -> dict[str, Any]:
chart_hashes = [f"CHART-{suffix()}" for _ in charts]
# if we have ROOT_ID/GRID_ID, append orphan charts to a new row inside the grid
@@ -109,7 +110,7 @@ class ExportDashboardsCommand(ExportModelsCommand):
@staticmethod
def _export(
model: Dashboard, export_related: bool = True
- ) -> Iterator[Tuple[str, str]]:
+ ) -> Iterator[tuple[str, str]]:
file_name = get_filename(model.dashboard_title, model.id)
file_path = f"dashboards/{file_name}.yaml"
diff --git a/superset/dashboards/commands/importers/dispatcher.py b/superset/dashboards/commands/importers/dispatcher.py
index dd0121f3e3..d5323b4fe4 100644
--- a/superset/dashboards/commands/importers/dispatcher.py
+++ b/superset/dashboards/commands/importers/dispatcher.py
@@ -16,7 +16,7 @@
# under the License.
import logging
-from typing import Any, Dict
+from typing import Any
from marshmallow.exceptions import ValidationError
@@ -43,7 +43,7 @@ class ImportDashboardsCommand(BaseCommand):
until it finds one that matches.
"""
- def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
+ def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any):
self.contents = contents
self.args = args
self.kwargs = kwargs
diff --git a/superset/dashboards/commands/importers/v0.py b/superset/dashboards/commands/importers/v0.py
index e49c931896..012dbbc5c9 100644
--- a/superset/dashboards/commands/importers/v0.py
+++ b/superset/dashboards/commands/importers/v0.py
@@ -19,7 +19,7 @@ import logging
import time
from copy import copy
from datetime import datetime
-from typing import Any, Dict, Optional
+from typing import Any, Optional
from flask_babel import lazy_gettext as _
from sqlalchemy.orm import make_transient, Session
@@ -83,7 +83,7 @@ def import_chart(
def import_dashboard(
# pylint: disable=too-many-locals,too-many-statements
dashboard_to_import: Dashboard,
- dataset_id_mapping: Optional[Dict[int, int]] = None,
+ dataset_id_mapping: Optional[dict[int, int]] = None,
import_time: Optional[int] = None,
) -> int:
"""Imports the dashboard from the object to the database.
@@ -97,7 +97,7 @@ def import_dashboard(
"""
def alter_positions(
- dashboard: Dashboard, old_to_new_slc_id_dict: Dict[int, int]
+ dashboard: Dashboard, old_to_new_slc_id_dict: dict[int, int]
) -> None:
"""Updates slice_ids in the position json.
@@ -166,7 +166,7 @@ def import_dashboard(
dashboard_to_import.slug = None
old_json_metadata = json.loads(dashboard_to_import.json_metadata or "{}")
- old_to_new_slc_id_dict: Dict[int, int] = {}
+ old_to_new_slc_id_dict: dict[int, int] = {}
new_timed_refresh_immune_slices = []
new_expanded_slices = {}
new_filter_scopes = {}
@@ -268,7 +268,7 @@ def import_dashboard(
return dashboard_to_import.id # type: ignore
-def decode_dashboards(o: Dict[str, Any]) -> Any:
+def decode_dashboards(o: dict[str, Any]) -> Any:
"""
Function to be passed into json.loads obj_hook parameter
Recreates the dashboard object from a json representation.
@@ -302,7 +302,7 @@ def import_dashboards(
data = json.loads(content, object_hook=decode_dashboards)
if not data:
raise DashboardImportException(_("No data in file"))
- dataset_id_mapping: Dict[int, int] = {}
+ dataset_id_mapping: dict[int, int] = {}
for table in data["datasources"]:
new_dataset_id = import_dataset(table, database_id, import_time=import_time)
params = json.loads(table.params)
@@ -324,7 +324,7 @@ class ImportDashboardsCommand(BaseCommand):
# pylint: disable=unused-argument
def __init__(
- self, contents: Dict[str, str], database_id: Optional[int] = None, **kwargs: Any
+ self, contents: dict[str, str], database_id: Optional[int] = None, **kwargs: Any
):
self.contents = contents
self.database_id = database_id
diff --git a/superset/dashboards/commands/importers/v1/__init__.py b/superset/dashboards/commands/importers/v1/__init__.py
index 5d83a580bd..597adba6d9 100644
--- a/superset/dashboards/commands/importers/v1/__init__.py
+++ b/superset/dashboards/commands/importers/v1/__init__.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, List, Set, Tuple
+from typing import Any
from marshmallow import Schema
from sqlalchemy.orm import Session
@@ -47,7 +47,7 @@ class ImportDashboardsCommand(ImportModelsCommand):
dao = DashboardDAO
model_name = "dashboard"
prefix = "dashboards/"
- schemas: Dict[str, Schema] = {
+ schemas: dict[str, Schema] = {
"charts/": ImportV1ChartSchema(),
"dashboards/": ImportV1DashboardSchema(),
"datasets/": ImportV1DatasetSchema(),
@@ -59,11 +59,11 @@ class ImportDashboardsCommand(ImportModelsCommand):
# pylint: disable=too-many-branches, too-many-locals
@staticmethod
def _import(
- session: Session, configs: Dict[str, Any], overwrite: bool = False
+ session: Session, configs: dict[str, Any], overwrite: bool = False
) -> None:
# discover charts and datasets associated with dashboards
- chart_uuids: Set[str] = set()
- dataset_uuids: Set[str] = set()
+ chart_uuids: set[str] = set()
+ dataset_uuids: set[str] = set()
for file_name, config in configs.items():
if file_name.startswith("dashboards/"):
chart_uuids.update(find_chart_uuids(config["position"]))
@@ -77,20 +77,20 @@ class ImportDashboardsCommand(ImportModelsCommand):
dataset_uuids.add(config["dataset_uuid"])
# discover databases associated with datasets
- database_uuids: Set[str] = set()
+ database_uuids: set[str] = set()
for file_name, config in configs.items():
if file_name.startswith("datasets/") and config["uuid"] in dataset_uuids:
database_uuids.add(config["database_uuid"])
# import related databases
- database_ids: Dict[str, int] = {}
+ database_ids: dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/") and config["uuid"] in database_uuids:
database = import_database(session, config, overwrite=False)
database_ids[str(database.uuid)] = database.id
# import datasets with the correct parent ref
- dataset_info: Dict[str, Dict[str, Any]] = {}
+ dataset_info: dict[str, dict[str, Any]] = {}
for file_name, config in configs.items():
if (
file_name.startswith("datasets/")
@@ -105,7 +105,7 @@ class ImportDashboardsCommand(ImportModelsCommand):
}
# import charts with the correct parent ref
- chart_ids: Dict[str, int] = {}
+ chart_ids: dict[str, int] = {}
for file_name, config in configs.items():
if (
file_name.startswith("charts/")
@@ -129,7 +129,7 @@ class ImportDashboardsCommand(ImportModelsCommand):
).fetchall()
# import dashboards
- dashboard_chart_ids: List[Tuple[int, int]] = []
+ dashboard_chart_ids: list[tuple[int, int]] = []
for file_name, config in configs.items():
if file_name.startswith("dashboards/"):
config = update_id_refs(config, chart_ids, dataset_info)
diff --git a/superset/dashboards/commands/importers/v1/utils.py b/superset/dashboards/commands/importers/v1/utils.py
index 9f0ffc36a1..1deb44949a 100644
--- a/superset/dashboards/commands/importers/v1/utils.py
+++ b/superset/dashboards/commands/importers/v1/utils.py
@@ -17,7 +17,7 @@
import json
import logging
-from typing import Any, Dict, Set
+from typing import Any
from flask import g
from sqlalchemy.orm import Session
@@ -32,12 +32,12 @@ logger = logging.getLogger(__name__)
JSON_KEYS = {"position": "position_json", "metadata": "json_metadata"}
-def find_chart_uuids(position: Dict[str, Any]) -> Set[str]:
+def find_chart_uuids(position: dict[str, Any]) -> set[str]:
return set(build_uuid_to_id_map(position))
-def find_native_filter_datasets(metadata: Dict[str, Any]) -> Set[str]:
- uuids: Set[str] = set()
+def find_native_filter_datasets(metadata: dict[str, Any]) -> set[str]:
+ uuids: set[str] = set()
for native_filter in metadata.get("native_filter_configuration", []):
targets = native_filter.get("targets", [])
for target in targets:
@@ -47,7 +47,7 @@ def find_native_filter_datasets(metadata: Dict[str, Any]) -> Set[str]:
return uuids
-def build_uuid_to_id_map(position: Dict[str, Any]) -> Dict[str, int]:
+def build_uuid_to_id_map(position: dict[str, Any]) -> dict[str, int]:
return {
child["meta"]["uuid"]: child["meta"]["chartId"]
for child in position.values()
@@ -60,10 +60,10 @@ def build_uuid_to_id_map(position: Dict[str, Any]) -> Dict[str, int]:
def update_id_refs( # pylint: disable=too-many-locals
- config: Dict[str, Any],
- chart_ids: Dict[str, int],
- dataset_info: Dict[str, Dict[str, Any]],
-) -> Dict[str, Any]:
+ config: dict[str, Any],
+ chart_ids: dict[str, int],
+ dataset_info: dict[str, dict[str, Any]],
+) -> dict[str, Any]:
"""Update dashboard metadata to use new IDs"""
fixed = config.copy()
@@ -147,7 +147,7 @@ def update_id_refs( # pylint: disable=too-many-locals
def import_dashboard(
session: Session,
- config: Dict[str, Any],
+ config: dict[str, Any],
overwrite: bool = False,
ignore_permissions: bool = False,
) -> Dashboard:
diff --git a/superset/dashboards/commands/update.py b/superset/dashboards/commands/update.py
index 11833a64be..fefa65e3f6 100644
--- a/superset/dashboards/commands/update.py
+++ b/superset/dashboards/commands/update.py
@@ -16,7 +16,7 @@
# under the License.
import json
import logging
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
@@ -41,7 +41,7 @@ logger = logging.getLogger(__name__)
class UpdateDashboardCommand(UpdateMixin, BaseCommand):
- def __init__(self, model_id: int, data: Dict[str, Any]):
+ def __init__(self, model_id: int, data: dict[str, Any]):
self._model_id = model_id
self._properties = data.copy()
self._model: Optional[Dashboard] = None
@@ -64,9 +64,9 @@ class UpdateDashboardCommand(UpdateMixin, BaseCommand):
return dashboard
def validate(self) -> None:
- exceptions: List[ValidationError] = []
- owners_ids: Optional[List[int]] = self._properties.get("owners")
- roles_ids: Optional[List[int]] = self._properties.get("roles")
+ exceptions: list[ValidationError] = []
+ owners_ids: Optional[list[int]] = self._properties.get("owners")
+ roles_ids: Optional[list[int]] = self._properties.get("roles")
slug: Optional[str] = self._properties.get("slug")
# Validate/populate model exists
diff --git a/superset/dashboards/dao.py b/superset/dashboards/dao.py
index 5355d602be..d88fb431b7 100644
--- a/superset/dashboards/dao.py
+++ b/superset/dashboards/dao.py
@@ -17,7 +17,7 @@
import json
import logging
from datetime import datetime
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, Optional, Union
from flask import g
from flask_appbuilder.models.sqla.interface import SQLAInterface
@@ -68,12 +68,12 @@ class DashboardDAO(BaseDAO):
return dashboard
@staticmethod
- def get_datasets_for_dashboard(id_or_slug: str) -> List[Any]:
+ def get_datasets_for_dashboard(id_or_slug: str) -> list[Any]:
dashboard = DashboardDAO.get_by_id_or_slug(id_or_slug)
return dashboard.datasets_trimmed_for_slices()
@staticmethod
- def get_charts_for_dashboard(id_or_slug: str) -> List[Slice]:
+ def get_charts_for_dashboard(id_or_slug: str) -> list[Slice]:
return DashboardDAO.get_by_id_or_slug(id_or_slug).slices
@staticmethod
@@ -173,7 +173,7 @@ class DashboardDAO(BaseDAO):
return model
@staticmethod
- def bulk_delete(models: Optional[List[Dashboard]], commit: bool = True) -> None:
+ def bulk_delete(models: Optional[list[Dashboard]], commit: bool = True) -> None:
item_ids = [model.id for model in models] if models else []
# bulk delete, first delete related data
if models:
@@ -196,8 +196,8 @@ class DashboardDAO(BaseDAO):
@staticmethod
def set_dash_metadata( # pylint: disable=too-many-locals
dashboard: Dashboard,
- data: Dict[Any, Any],
- old_to_new_slice_ids: Optional[Dict[int, int]] = None,
+ data: dict[Any, Any],
+ old_to_new_slice_ids: Optional[dict[int, int]] = None,
commit: bool = False,
) -> Dashboard:
new_filter_scopes = {}
@@ -235,7 +235,7 @@ class DashboardDAO(BaseDAO):
if "filter_scopes" in data:
# replace filter_id and immune ids from old slice id to new slice id:
# and remove slice ids that are not in dash anymore
- slc_id_dict: Dict[int, int] = {}
+ slc_id_dict: dict[int, int] = {}
if old_to_new_slice_ids:
slc_id_dict = {
old: new
@@ -288,7 +288,7 @@ class DashboardDAO(BaseDAO):
return dashboard
@staticmethod
- def favorited_ids(dashboards: List[Dashboard]) -> List[FavStar]:
+ def favorited_ids(dashboards: list[Dashboard]) -> list[FavStar]:
ids = [dash.id for dash in dashboards]
return [
star.obj_id
@@ -303,7 +303,7 @@ class DashboardDAO(BaseDAO):
@classmethod
def copy_dashboard(
- cls, original_dash: Dashboard, data: Dict[str, Any]
+ cls, original_dash: Dashboard, data: dict[str, Any]
) -> Dashboard:
dash = Dashboard()
dash.owners = [g.user] if g.user else []
@@ -311,7 +311,7 @@ class DashboardDAO(BaseDAO):
dash.css = data.get("css")
metadata = json.loads(data["json_metadata"])
- old_to_new_slice_ids: Dict[int, int] = {}
+ old_to_new_slice_ids: dict[int, int] = {}
if data.get("duplicate_slices"):
# Duplicating slices as well, mapping old ids to new ones
for slc in original_dash.slices:
diff --git a/superset/dashboards/filter_sets/commands/create.py b/superset/dashboards/filter_sets/commands/create.py
index de1d70daf7..63c4534786 100644
--- a/superset/dashboards/filter_sets/commands/create.py
+++ b/superset/dashboards/filter_sets/commands/create.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict
+from typing import Any
from flask_appbuilder.models.sqla import Model
@@ -40,7 +40,7 @@ logger = logging.getLogger(__name__)
class CreateFilterSetCommand(BaseFilterSetCommand):
# pylint: disable=C0103
- def __init__(self, dashboard_id: int, data: Dict[str, Any]):
+ def __init__(self, dashboard_id: int, data: dict[str, Any]):
super().__init__(dashboard_id)
self._properties = data.copy()
diff --git a/superset/dashboards/filter_sets/commands/update.py b/superset/dashboards/filter_sets/commands/update.py
index 07d59f93ae..722672d668 100644
--- a/superset/dashboards/filter_sets/commands/update.py
+++ b/superset/dashboards/filter_sets/commands/update.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict
+from typing import Any
from flask_appbuilder.models.sqla import Model
@@ -31,7 +31,7 @@ logger = logging.getLogger(__name__)
class UpdateFilterSetCommand(BaseFilterSetCommand):
- def __init__(self, dashboard_id: int, filter_set_id: int, data: Dict[str, Any]):
+ def __init__(self, dashboard_id: int, filter_set_id: int, data: dict[str, Any]):
super().__init__(dashboard_id)
self._filter_set_id = filter_set_id
self._properties = data.copy()
diff --git a/superset/dashboards/filter_sets/dao.py b/superset/dashboards/filter_sets/dao.py
index 949aa6d3fd..5f2b0ba418 100644
--- a/superset/dashboards/filter_sets/dao.py
+++ b/superset/dashboards/filter_sets/dao.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict
+from typing import Any
from flask_appbuilder.models.sqla import Model
from sqlalchemy.exc import SQLAlchemyError
@@ -40,7 +40,7 @@ class FilterSetDAO(BaseDAO):
model_cls = FilterSet
@classmethod
- def create(cls, properties: Dict[str, Any], commit: bool = True) -> Model:
+ def create(cls, properties: dict[str, Any], commit: bool = True) -> Model:
if cls.model_cls is None:
raise DAOConfigError()
model = FilterSet()
diff --git a/superset/dashboards/filter_sets/schemas.py b/superset/dashboards/filter_sets/schemas.py
index c1a13b424e..2309eea99f 100644
--- a/superset/dashboards/filter_sets/schemas.py
+++ b/superset/dashboards/filter_sets/schemas.py
@@ -14,7 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, cast, Dict, Mapping
+from collections.abc import Mapping
+from typing import Any, cast
from marshmallow import fields, post_load, Schema, ValidationError
from marshmallow.validate import Length, OneOf
@@ -64,11 +65,11 @@ class FilterSetPostSchema(FilterSetSchema):
@post_load
def validate(
self, data: Mapping[Any, Any], *, many: Any, partial: Any
- ) -> Dict[str, Any]:
+ ) -> dict[str, Any]:
self._validate_json_meta_data(data[JSON_METADATA_FIELD])
if data[OWNER_TYPE_FIELD] == USER_OWNER_TYPE and OWNER_ID_FIELD not in data:
raise ValidationError("owner_id is mandatory when owner_type is User")
- return cast(Dict[str, Any], data)
+ return cast(dict[str, Any], data)
class FilterSetPutSchema(FilterSetSchema):
@@ -84,14 +85,14 @@ class FilterSetPutSchema(FilterSetSchema):
@post_load
def validate( # pylint: disable=unused-argument
self, data: Mapping[Any, Any], *, many: Any, partial: Any
- ) -> Dict[str, Any]:
+ ) -> dict[str, Any]:
if JSON_METADATA_FIELD in data:
self._validate_json_meta_data(data[JSON_METADATA_FIELD])
- return cast(Dict[str, Any], data)
+ return cast(dict[str, Any], data)
-def validate_pair(first_field: str, second_field: str, data: Dict[str, Any]) -> None:
+def validate_pair(first_field: str, second_field: str, data: dict[str, Any]) -> None:
if first_field in data and second_field not in data:
raise ValidationError(
- "{} must be included alongside {}".format(first_field, second_field)
+ f"{first_field} must be included alongside {second_field}"
)
diff --git a/superset/dashboards/filter_state/api.py b/superset/dashboards/filter_state/api.py
index 7a771d6b54..a1b855ca9e 100644
--- a/superset/dashboards/filter_state/api.py
+++ b/superset/dashboards/filter_state/api.py
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Type
from flask import Response
from flask_appbuilder.api import expose, protect, safe
@@ -35,16 +34,16 @@ class DashboardFilterStateRestApi(TemporaryCacheRestApi):
resource_name = "dashboard"
openapi_spec_tag = "Dashboard Filter State"
- def get_create_command(self) -> Type[CreateFilterStateCommand]:
+ def get_create_command(self) -> type[CreateFilterStateCommand]:
return CreateFilterStateCommand
- def get_update_command(self) -> Type[UpdateFilterStateCommand]:
+ def get_update_command(self) -> type[UpdateFilterStateCommand]:
return UpdateFilterStateCommand
- def get_get_command(self) -> Type[GetFilterStateCommand]:
+ def get_get_command(self) -> type[GetFilterStateCommand]:
return GetFilterStateCommand
- def get_delete_command(self) -> Type[DeleteFilterStateCommand]:
+ def get_delete_command(self) -> type[DeleteFilterStateCommand]:
return DeleteFilterStateCommand
@expose("//filter_state", methods=("POST",))
diff --git a/superset/dashboards/permalink/types.py b/superset/dashboards/permalink/types.py
index 91c5a9620c..4961d2a17b 100644
--- a/superset/dashboards/permalink/types.py
+++ b/superset/dashboards/permalink/types.py
@@ -14,14 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, List, Optional, Tuple, TypedDict
+from typing import Any, Optional, TypedDict
class DashboardPermalinkState(TypedDict):
- dataMask: Optional[Dict[str, Any]]
- activeTabs: Optional[List[str]]
+ dataMask: Optional[dict[str, Any]]
+ activeTabs: Optional[list[str]]
anchor: Optional[str]
- urlParams: Optional[List[Tuple[str, str]]]
+ urlParams: Optional[list[tuple[str, str]]]
class DashboardPermalinkValue(TypedDict):
diff --git a/superset/dashboards/schemas.py b/superset/dashboards/schemas.py
index ab93e4130f..846ed39e82 100644
--- a/superset/dashboards/schemas.py
+++ b/superset/dashboards/schemas.py
@@ -16,7 +16,7 @@
# under the License.
import json
import re
-from typing import Any, Dict, Union
+from typing import Any, Union
from marshmallow import fields, post_load, pre_load, Schema
from marshmallow.validate import Length, ValidationError
@@ -144,9 +144,9 @@ class DashboardJSONMetadataSchema(Schema):
@pre_load
def remove_show_native_filters( # pylint: disable=unused-argument, no-self-use
self,
- data: Dict[str, Any],
+ data: dict[str, Any],
**kwargs: Any,
- ) -> Dict[str, Any]:
+ ) -> dict[str, Any]:
"""
Remove ``show_native_filters`` from the JSON metadata.
@@ -254,7 +254,7 @@ class DashboardDatasetSchema(Schema):
class BaseDashboardSchema(Schema):
# pylint: disable=no-self-use,unused-argument
@post_load
- def post_load(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
+ def post_load(self, data: dict[str, Any], **kwargs: Any) -> dict[str, Any]:
if data.get("slug"):
data["slug"] = data["slug"].strip()
data["slug"] = data["slug"].replace(" ", "-")
diff --git a/superset/databases/api.py b/superset/databases/api.py
index 77f9596182..c214065a27 100644
--- a/superset/databases/api.py
+++ b/superset/databases/api.py
@@ -19,7 +19,7 @@ import json
import logging
from datetime import datetime
from io import BytesIO
-from typing import Any, cast, Dict, List, Optional
+from typing import Any, cast, Optional
from zipfile import is_zipfile, ZipFile
from flask import request, Response, send_file
@@ -1328,13 +1328,13 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
500:
$ref: '#/components/responses/500'
"""
- preferred_databases: List[str] = app.config.get("PREFERRED_DATABASES", [])
+ preferred_databases: list[str] = app.config.get("PREFERRED_DATABASES", [])
available_databases = []
for engine_spec, drivers in get_available_engine_specs().items():
if not drivers:
continue
- payload: Dict[str, Any] = {
+ payload: dict[str, Any] = {
"name": engine_spec.engine_name,
"engine": engine_spec.engine,
"available_drivers": sorted(drivers),
diff --git a/superset/databases/commands/create.py b/superset/databases/commands/create.py
index 16d27835b3..e3fd667130 100644
--- a/superset/databases/commands/create.py
+++ b/superset/databases/commands/create.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from flask import current_app
from flask_appbuilder.models.sqla import Model
@@ -47,7 +47,7 @@ stats_logger = current_app.config["STATS_LOGGER"]
class CreateDatabaseCommand(BaseCommand):
- def __init__(self, data: Dict[str, Any]):
+ def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
def run(self) -> Model:
@@ -128,7 +128,7 @@ class CreateDatabaseCommand(BaseCommand):
return database
def validate(self) -> None:
- exceptions: List[ValidationError] = []
+ exceptions: list[ValidationError] = []
sqlalchemy_uri: Optional[str] = self._properties.get("sqlalchemy_uri")
database_name: Optional[str] = self._properties.get("database_name")
if not sqlalchemy_uri:
diff --git a/superset/databases/commands/export.py b/superset/databases/commands/export.py
index e1f8fc2a25..889cb86c8f 100644
--- a/superset/databases/commands/export.py
+++ b/superset/databases/commands/export.py
@@ -18,7 +18,8 @@
import json
import logging
-from typing import Any, Dict, Iterator, Tuple
+from typing import Any
+from collections.abc import Iterator
import yaml
@@ -33,7 +34,7 @@ from superset.utils.ssh_tunnel import mask_password_info
logger = logging.getLogger(__name__)
-def parse_extra(extra_payload: str) -> Dict[str, Any]:
+def parse_extra(extra_payload: str) -> dict[str, Any]:
try:
extra = json.loads(extra_payload)
except json.decoder.JSONDecodeError:
@@ -57,7 +58,7 @@ class ExportDatabasesCommand(ExportModelsCommand):
@staticmethod
def _export(
model: Database, export_related: bool = True
- ) -> Iterator[Tuple[str, str]]:
+ ) -> Iterator[tuple[str, str]]:
db_file_name = get_filename(model.database_name, model.id, skip_id=True)
file_path = f"databases/{db_file_name}.yaml"
diff --git a/superset/databases/commands/importers/dispatcher.py b/superset/databases/commands/importers/dispatcher.py
index 88d38bf13b..70031b09e4 100644
--- a/superset/databases/commands/importers/dispatcher.py
+++ b/superset/databases/commands/importers/dispatcher.py
@@ -16,7 +16,7 @@
# under the License.
import logging
-from typing import Any, Dict
+from typing import Any
from marshmallow.exceptions import ValidationError
@@ -38,7 +38,7 @@ class ImportDatabasesCommand(BaseCommand):
until it finds one that matches.
"""
- def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
+ def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any):
self.contents = contents
self.args = args
self.kwargs = kwargs
diff --git a/superset/databases/commands/importers/v1/__init__.py b/superset/databases/commands/importers/v1/__init__.py
index 239bd0977f..ba119beaaa 100644
--- a/superset/databases/commands/importers/v1/__init__.py
+++ b/superset/databases/commands/importers/v1/__init__.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict
+from typing import Any
from marshmallow import Schema
from sqlalchemy.orm import Session
@@ -36,7 +36,7 @@ class ImportDatabasesCommand(ImportModelsCommand):
dao = DatabaseDAO
model_name = "database"
prefix = "databases/"
- schemas: Dict[str, Schema] = {
+ schemas: dict[str, Schema] = {
"databases/": ImportV1DatabaseSchema(),
"datasets/": ImportV1DatasetSchema(),
}
@@ -44,10 +44,10 @@ class ImportDatabasesCommand(ImportModelsCommand):
@staticmethod
def _import(
- session: Session, configs: Dict[str, Any], overwrite: bool = False
+ session: Session, configs: dict[str, Any], overwrite: bool = False
) -> None:
# first import databases
- database_ids: Dict[str, int] = {}
+ database_ids: dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/"):
database = import_database(session, config, overwrite=overwrite)
diff --git a/superset/databases/commands/importers/v1/utils.py b/superset/databases/commands/importers/v1/utils.py
index c0c0ee60d9..8881f78a9c 100644
--- a/superset/databases/commands/importers/v1/utils.py
+++ b/superset/databases/commands/importers/v1/utils.py
@@ -16,7 +16,7 @@
# under the License.
import json
-from typing import Any, Dict
+from typing import Any
from sqlalchemy.orm import Session
@@ -28,7 +28,7 @@ from superset.models.core import Database
def import_database(
session: Session,
- config: Dict[str, Any],
+ config: dict[str, Any],
overwrite: bool = False,
ignore_permissions: bool = False,
) -> Database:
diff --git a/superset/databases/commands/tables.py b/superset/databases/commands/tables.py
index 48e9227dea..b7dbb4d461 100644
--- a/superset/databases/commands/tables.py
+++ b/superset/databases/commands/tables.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, cast, Dict
+from typing import Any, cast
from superset.commands.base import BaseCommand
from superset.connectors.sqla.models import SqlaTable
@@ -40,7 +40,7 @@ class TablesDatabaseCommand(BaseCommand):
self._schema_name = schema_name
self._force = force
- def run(self) -> Dict[str, Any]:
+ def run(self) -> dict[str, Any]:
self.validate()
try:
tables = security_manager.get_datasources_accessible_by_user(
diff --git a/superset/databases/commands/test_connection.py b/superset/databases/commands/test_connection.py
index 9809641d5c..2680c5e8c1 100644
--- a/superset/databases/commands/test_connection.py
+++ b/superset/databases/commands/test_connection.py
@@ -17,7 +17,7 @@
import logging
import sqlite3
from contextlib import closing
-from typing import Any, Dict, Optional
+from typing import Any, Optional
from flask import current_app as app
from flask_babel import gettext as _
@@ -64,7 +64,7 @@ def get_log_connection_action(
class TestConnectionDatabaseCommand(BaseCommand):
- def __init__(self, data: Dict[str, Any]):
+ def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
self._model: Optional[Database] = None
diff --git a/superset/databases/commands/update.py b/superset/databases/commands/update.py
index 746f7a8152..f12706fa1d 100644
--- a/superset/databases/commands/update.py
+++ b/superset/databases/commands/update.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
@@ -47,7 +47,7 @@ logger = logging.getLogger(__name__)
class UpdateDatabaseCommand(BaseCommand):
- def __init__(self, model_id: int, data: Dict[str, Any]):
+ def __init__(self, model_id: int, data: dict[str, Any]):
self._properties = data.copy()
self._model_id = model_id
self._model: Optional[Database] = None
@@ -78,7 +78,7 @@ class UpdateDatabaseCommand(BaseCommand):
raise DatabaseConnectionFailedError() from ex
# Update database schema permissions
- new_schemas: List[str] = []
+ new_schemas: list[str] = []
for schema in schemas:
old_view_menu_name = security_manager.get_schema_perm(
@@ -164,7 +164,7 @@ class UpdateDatabaseCommand(BaseCommand):
chart.schema_perm = new_view_menu_name
def validate(self) -> None:
- exceptions: List[ValidationError] = []
+ exceptions: list[ValidationError] = []
# Validate/populate model exists
self._model = DatabaseDAO.find_by_id(self._model_id)
if not self._model:
diff --git a/superset/databases/commands/validate.py b/superset/databases/commands/validate.py
index 2a624e32c7..d97ad33af9 100644
--- a/superset/databases/commands/validate.py
+++ b/superset/databases/commands/validate.py
@@ -16,7 +16,7 @@
# under the License.
import json
from contextlib import closing
-from typing import Any, Dict, Optional
+from typing import Any, Optional
from flask_babel import gettext as __
@@ -38,7 +38,7 @@ BYPASS_VALIDATION_ENGINES = {"bigquery"}
class ValidateDatabaseParametersCommand(BaseCommand):
- def __init__(self, properties: Dict[str, Any]):
+ def __init__(self, properties: dict[str, Any]):
self._properties = properties.copy()
self._model: Optional[Database] = None
diff --git a/superset/databases/commands/validate_sql.py b/superset/databases/commands/validate_sql.py
index 346d684a0d..40d88af745 100644
--- a/superset/databases/commands/validate_sql.py
+++ b/superset/databases/commands/validate_sql.py
@@ -16,7 +16,7 @@
# under the License.
import logging
import re
-from typing import Any, Dict, List, Optional, Type
+from typing import Any, Optional
from flask import current_app
from flask_babel import gettext as __
@@ -41,13 +41,13 @@ logger = logging.getLogger(__name__)
class ValidateSQLCommand(BaseCommand):
- def __init__(self, model_id: int, data: Dict[str, Any]):
+ def __init__(self, model_id: int, data: dict[str, Any]):
self._properties = data.copy()
self._model_id = model_id
self._model: Optional[Database] = None
- self._validator: Optional[Type[BaseSQLValidator]] = None
+ self._validator: Optional[type[BaseSQLValidator]] = None
- def run(self) -> List[Dict[str, Any]]:
+ def run(self) -> list[dict[str, Any]]:
"""
Validates a SQL statement
@@ -97,9 +97,7 @@ class ValidateSQLCommand(BaseCommand):
if not validators_by_engine or spec.engine not in validators_by_engine:
raise NoValidatorConfigFoundError(
SupersetError(
- message=__(
- "no SQL validator is configured for {}".format(spec.engine)
- ),
+ message=__(f"no SQL validator is configured for {spec.engine}"),
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
level=ErrorLevel.ERROR,
),
diff --git a/superset/databases/dao.py b/superset/databases/dao.py
index c82f0db574..9ce3b5e73e 100644
--- a/superset/databases/dao.py
+++ b/superset/databases/dao.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict, Optional
+from typing import Any, Optional
from superset.dao.base import BaseDAO
from superset.databases.filters import DatabaseFilter
@@ -38,7 +38,7 @@ class DatabaseDAO(BaseDAO):
def update(
cls,
model: Database,
- properties: Dict[str, Any],
+ properties: dict[str, Any],
commit: bool = True,
) -> Database:
"""
@@ -93,7 +93,7 @@ class DatabaseDAO(BaseDAO):
)
@classmethod
- def get_related_objects(cls, database_id: int) -> Dict[str, Any]:
+ def get_related_objects(cls, database_id: int) -> dict[str, Any]:
database: Any = cls.find_by_id(database_id)
datasets = database.tables
dataset_ids = [dataset.id for dataset in datasets]
diff --git a/superset/databases/filters.py b/superset/databases/filters.py
index 86564e8f15..2ca77b77d1 100644
--- a/superset/databases/filters.py
+++ b/superset/databases/filters.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Set
+from typing import Any
from flask import g
from flask_babel import lazy_gettext as _
@@ -30,7 +30,7 @@ from superset.views.base import BaseFilter
def can_access_databases(
view_menu_name: str,
-) -> Set[str]:
+) -> set[str]:
return {
security_manager.unpack_database_and_schema(vm).database
for vm in security_manager.user_view_menu_names(view_menu_name)
diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py
index 00e8c3ca53..01a00e8b80 100644
--- a/superset/databases/schemas.py
+++ b/superset/databases/schemas.py
@@ -19,7 +19,7 @@
import inspect
import json
-from typing import Any, Dict, List
+from typing import Any
from flask import current_app
from flask_babel import lazy_gettext as _
@@ -263,8 +263,8 @@ class DatabaseParametersSchemaMixin: # pylint: disable=too-few-public-methods
@pre_load
def build_sqlalchemy_uri(
- self, data: Dict[str, Any], **kwargs: Any
- ) -> Dict[str, Any]:
+ self, data: dict[str, Any], **kwargs: Any
+ ) -> dict[str, Any]:
"""
Build SQLAlchemy URI from separate parameters.
@@ -325,9 +325,9 @@ class DatabaseParametersSchemaMixin: # pylint: disable=too-few-public-methods
def rename_encrypted_extra(
self: Schema,
- data: Dict[str, Any],
+ data: dict[str, Any],
**kwargs: Any,
-) -> Dict[str, Any]:
+) -> dict[str, Any]:
"""
Rename ``encrypted_extra`` to ``masked_encrypted_extra``.
@@ -707,8 +707,8 @@ class DatabaseFunctionNamesResponse(Schema):
class ImportV1DatabaseExtraSchema(Schema):
@pre_load
def fix_schemas_allowed_for_csv_upload( # pylint: disable=invalid-name
- self, data: Dict[str, Any], **kwargs: Any
- ) -> Dict[str, Any]:
+ self, data: dict[str, Any], **kwargs: Any
+ ) -> dict[str, Any]:
"""
Fixes for ``schemas_allowed_for_csv_upload``.
"""
@@ -744,8 +744,8 @@ class ImportV1DatabaseExtraSchema(Schema):
class ImportV1DatabaseSchema(Schema):
@pre_load
def fix_allow_csv_upload(
- self, data: Dict[str, Any], **kwargs: Any
- ) -> Dict[str, Any]:
+ self, data: dict[str, Any], **kwargs: Any
+ ) -> dict[str, Any]:
"""
Fix for ``allow_csv_upload`` .
"""
@@ -775,7 +775,7 @@ class ImportV1DatabaseSchema(Schema):
ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True)
@validates_schema
- def validate_password(self, data: Dict[str, Any], **kwargs: Any) -> None:
+ def validate_password(self, data: dict[str, Any], **kwargs: Any) -> None:
"""If sqlalchemy_uri has a masked password, password is required"""
uuid = data["uuid"]
existing = db.session.query(Database).filter_by(uuid=uuid).first()
@@ -789,7 +789,7 @@ class ImportV1DatabaseSchema(Schema):
@validates_schema
def validate_ssh_tunnel_credentials(
- self, data: Dict[str, Any], **kwargs: Any
+ self, data: dict[str, Any], **kwargs: Any
) -> None:
"""If ssh_tunnel has a masked credentials, credentials are required"""
uuid = data["uuid"]
@@ -829,7 +829,7 @@ class ImportV1DatabaseSchema(Schema):
# or there're times where it's masked.
# If both are masked, we need to return a list of errors
# so the UI ask for both fields at the same time if needed
- exception_messages: List[str] = []
+ exception_messages: list[str] = []
if private_key is None or private_key == PASSWORD_MASK:
# If we get here we need to ask for the private key
exception_messages.append(
@@ -864,7 +864,7 @@ class EncryptedDict(EncryptedField, fields.Dict):
pass
-def encrypted_field_properties(self, field: Any, **_) -> Dict[str, Any]: # type: ignore
+def encrypted_field_properties(self, field: Any, **_) -> dict[str, Any]: # type: ignore
ret = {}
if isinstance(field, EncryptedField):
if self.openapi_version.major > 2:
diff --git a/superset/databases/ssh_tunnel/commands/create.py b/superset/databases/ssh_tunnel/commands/create.py
index 45e5af5f44..9c41b83392 100644
--- a/superset/databases/ssh_tunnel/commands/create.py
+++ b/superset/databases/ssh_tunnel/commands/create.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
@@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
class CreateSSHTunnelCommand(BaseCommand):
- def __init__(self, database_id: int, data: Dict[str, Any]):
+ def __init__(self, database_id: int, data: dict[str, Any]):
self._properties = data.copy()
self._properties["database_id"] = database_id
@@ -61,7 +61,7 @@ class CreateSSHTunnelCommand(BaseCommand):
def validate(self) -> None:
# TODO(hughhh): check to make sure the server port is not localhost
# using the config.SSH_TUNNEL_MANAGER
- exceptions: List[ValidationError] = []
+ exceptions: list[ValidationError] = []
database_id: Optional[int] = self._properties.get("database_id")
server_address: Optional[str] = self._properties.get("server_address")
server_port: Optional[int] = self._properties.get("server_port")
diff --git a/superset/databases/ssh_tunnel/commands/update.py b/superset/databases/ssh_tunnel/commands/update.py
index 42925d1caa..37fd4a94b9 100644
--- a/superset/databases/ssh_tunnel/commands/update.py
+++ b/superset/databases/ssh_tunnel/commands/update.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict, Optional
+from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
@@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
class UpdateSSHTunnelCommand(BaseCommand):
- def __init__(self, model_id: int, data: Dict[str, Any]):
+ def __init__(self, model_id: int, data: dict[str, Any]):
self._properties = data.copy()
self._model_id = model_id
self._model: Optional[SSHTunnel] = None
diff --git a/superset/databases/ssh_tunnel/dao.py b/superset/databases/ssh_tunnel/dao.py
index 89562fc05d..731f9183b3 100644
--- a/superset/databases/ssh_tunnel/dao.py
+++ b/superset/databases/ssh_tunnel/dao.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict
+from typing import Any
from superset.dao.base import BaseDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
@@ -31,7 +31,7 @@ class SSHTunnelDAO(BaseDAO):
def update(
cls,
model: SSHTunnel,
- properties: Dict[str, Any],
+ properties: dict[str, Any],
commit: bool = True,
) -> SSHTunnel:
"""
diff --git a/superset/databases/ssh_tunnel/models.py b/superset/databases/ssh_tunnel/models.py
index 3384679cb7..d9462a63db 100644
--- a/superset/databases/ssh_tunnel/models.py
+++ b/superset/databases/ssh_tunnel/models.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict
+from typing import Any
import sqlalchemy as sa
from flask import current_app
@@ -82,7 +82,7 @@ class SSHTunnel(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
]
@property
- def data(self) -> Dict[str, Any]:
+ def data(self) -> dict[str, Any]:
output = {
"id": self.id,
"server_address": self.server_address,
diff --git a/superset/databases/utils.py b/superset/databases/utils.py
index 9229bb8cba..74943f4747 100644
--- a/superset/databases/utils.py
+++ b/superset/databases/utils.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, Optional, Union
from sqlalchemy.engine.url import make_url, URL
@@ -25,7 +25,7 @@ def get_foreign_keys_metadata(
database: Any,
table_name: str,
schema_name: Optional[str],
-) -> List[Dict[str, Any]]:
+) -> list[dict[str, Any]]:
foreign_keys = database.get_foreign_keys(table_name, schema_name)
for fk in foreign_keys:
fk["column_names"] = fk.pop("constrained_columns")
@@ -35,14 +35,14 @@ def get_foreign_keys_metadata(
def get_indexes_metadata(
database: Any, table_name: str, schema_name: Optional[str]
-) -> List[Dict[str, Any]]:
+) -> list[dict[str, Any]]:
indexes = database.get_indexes(table_name, schema_name)
for idx in indexes:
idx["type"] = "index"
return indexes
-def get_col_type(col: Dict[Any, Any]) -> str:
+def get_col_type(col: dict[Any, Any]) -> str:
try:
dtype = f"{col['type']}"
except Exception: # pylint: disable=broad-except
@@ -53,7 +53,7 @@ def get_col_type(col: Dict[Any, Any]) -> str:
def get_table_metadata(
database: Any, table_name: str, schema_name: Optional[str]
-) -> Dict[str, Any]:
+) -> dict[str, Any]:
"""
Get table metadata information, including type, pk, fks.
This function raises SQLAlchemyError when a schema is not found.
@@ -73,7 +73,7 @@ def get_table_metadata(
foreign_keys = get_foreign_keys_metadata(database, table_name, schema_name)
indexes = get_indexes_metadata(database, table_name, schema_name)
keys += foreign_keys + indexes
- payload_columns: List[Dict[str, Any]] = []
+ payload_columns: list[dict[str, Any]] = []
table_comment = database.get_table_comment(table_name, schema_name)
for col in columns:
dtype = get_col_type(col)
diff --git a/superset/dataframe.py b/superset/dataframe.py
index 8abeedc095..8083993294 100644
--- a/superset/dataframe.py
+++ b/superset/dataframe.py
@@ -17,7 +17,7 @@
""" Superset utilities for pandas.DataFrame.
"""
import logging
-from typing import Any, Dict, List
+from typing import Any
import pandas as pd
@@ -37,7 +37,7 @@ def _convert_big_integers(val: Any) -> Any:
return str(val) if isinstance(val, int) and abs(val) > JS_MAX_INTEGER else val
-def df_to_records(dframe: pd.DataFrame) -> List[Dict[str, Any]]:
+def df_to_records(dframe: pd.DataFrame) -> list[dict[str, Any]]:
"""
Convert a DataFrame to a set of records.
diff --git a/superset/datasets/commands/bulk_delete.py b/superset/datasets/commands/bulk_delete.py
index 643ac784ec..fd13351809 100644
--- a/superset/datasets/commands/bulk_delete.py
+++ b/superset/datasets/commands/bulk_delete.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import List, Optional
+from typing import Optional
from superset import security_manager
from superset.commands.base import BaseCommand
@@ -34,9 +34,9 @@ logger = logging.getLogger(__name__)
class BulkDeleteDatasetCommand(BaseCommand):
- def __init__(self, model_ids: List[int]):
+ def __init__(self, model_ids: list[int]):
self._model_ids = model_ids
- self._models: Optional[List[SqlaTable]] = None
+ self._models: Optional[list[SqlaTable]] = None
def run(self) -> None:
self.validate()
diff --git a/superset/datasets/commands/create.py b/superset/datasets/commands/create.py
index 04f54339d0..1c864ad196 100644
--- a/superset/datasets/commands/create.py
+++ b/superset/datasets/commands/create.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
@@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
class CreateDatasetCommand(CreateMixin, BaseCommand):
- def __init__(self, data: Dict[str, Any]):
+ def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
def run(self) -> Model:
@@ -55,12 +55,12 @@ class CreateDatasetCommand(CreateMixin, BaseCommand):
return dataset
def validate(self) -> None:
- exceptions: List[ValidationError] = []
+ exceptions: list[ValidationError] = []
database_id = self._properties["database"]
table_name = self._properties["table_name"]
schema = self._properties.get("schema", None)
sql = self._properties.get("sql", None)
- owner_ids: Optional[List[int]] = self._properties.get("owners")
+ owner_ids: Optional[list[int]] = self._properties.get("owners")
# Validate uniqueness
if not DatasetDAO.validate_uniqueness(database_id, schema, table_name):
diff --git a/superset/datasets/commands/duplicate.py b/superset/datasets/commands/duplicate.py
index 5fc642cbe3..5a4a84fdf9 100644
--- a/superset/datasets/commands/duplicate.py
+++ b/superset/datasets/commands/duplicate.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict, List
+from typing import Any
from flask_appbuilder.models.sqla import Model
from flask_babel import gettext as __
@@ -43,7 +43,7 @@ logger = logging.getLogger(__name__)
class DuplicateDatasetCommand(CreateMixin, BaseCommand):
- def __init__(self, data: Dict[str, Any]) -> None:
+ def __init__(self, data: dict[str, Any]) -> None:
self._base_model: SqlaTable = SqlaTable()
self._properties = data.copy()
@@ -105,7 +105,7 @@ class DuplicateDatasetCommand(CreateMixin, BaseCommand):
return table
def validate(self) -> None:
- exceptions: List[ValidationError] = []
+ exceptions: list[ValidationError] = []
base_model_id = self._properties["base_model_id"]
duplicate_name = self._properties["table_name"]
diff --git a/superset/datasets/commands/export.py b/superset/datasets/commands/export.py
index c6fe43c89d..8c02a23f29 100644
--- a/superset/datasets/commands/export.py
+++ b/superset/datasets/commands/export.py
@@ -18,7 +18,7 @@
import json
import logging
-from typing import Iterator, Tuple
+from collections.abc import Iterator
import yaml
@@ -43,7 +43,7 @@ class ExportDatasetsCommand(ExportModelsCommand):
@staticmethod
def _export(
model: SqlaTable, export_related: bool = True
- ) -> Iterator[Tuple[str, str]]:
+ ) -> Iterator[tuple[str, str]]:
db_file_name = get_filename(
model.database.database_name, model.database.id, skip_id=True
)
diff --git a/superset/datasets/commands/importers/dispatcher.py b/superset/datasets/commands/importers/dispatcher.py
index 74f1129d23..6be8635da2 100644
--- a/superset/datasets/commands/importers/dispatcher.py
+++ b/superset/datasets/commands/importers/dispatcher.py
@@ -16,7 +16,7 @@
# under the License.
import logging
-from typing import Any, Dict
+from typing import Any
from marshmallow.exceptions import ValidationError
@@ -43,7 +43,7 @@ class ImportDatasetsCommand(BaseCommand):
until it finds one that matches.
"""
- def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
+ def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any):
self.contents = contents
self.args = args
self.kwargs = kwargs
diff --git a/superset/datasets/commands/importers/v0.py b/superset/datasets/commands/importers/v0.py
index f706ecf38b..c530be3c14 100644
--- a/superset/datasets/commands/importers/v0.py
+++ b/superset/datasets/commands/importers/v0.py
@@ -16,7 +16,7 @@
# under the License.
import json
import logging
-from typing import Any, Callable, Dict, List, Optional
+from typing import Any, Callable, Optional
import yaml
from flask_appbuilder import Model
@@ -213,7 +213,7 @@ def import_simple_obj(
def import_from_dict(
- session: Session, data: Dict[str, Any], sync: Optional[List[str]] = None
+ session: Session, data: dict[str, Any], sync: Optional[list[str]] = None
) -> None:
"""Imports databases from dictionary"""
if not sync:
@@ -238,12 +238,12 @@ class ImportDatasetsCommand(BaseCommand):
# pylint: disable=unused-argument
def __init__(
self,
- contents: Dict[str, str],
+ contents: dict[str, str],
*args: Any,
**kwargs: Any,
):
self.contents = contents
- self._configs: Dict[str, Any] = {}
+ self._configs: dict[str, Any] = {}
self.sync = []
if kwargs.get("sync_columns"):
diff --git a/superset/datasets/commands/importers/v1/__init__.py b/superset/datasets/commands/importers/v1/__init__.py
index e73213319d..e753138ab8 100644
--- a/superset/datasets/commands/importers/v1/__init__.py
+++ b/superset/datasets/commands/importers/v1/__init__.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, Set
+from typing import Any
from marshmallow import Schema
from sqlalchemy.orm import Session
@@ -36,7 +36,7 @@ class ImportDatasetsCommand(ImportModelsCommand):
dao = DatasetDAO
model_name = "dataset"
prefix = "datasets/"
- schemas: Dict[str, Schema] = {
+ schemas: dict[str, Schema] = {
"databases/": ImportV1DatabaseSchema(),
"datasets/": ImportV1DatasetSchema(),
}
@@ -44,16 +44,16 @@ class ImportDatasetsCommand(ImportModelsCommand):
@staticmethod
def _import(
- session: Session, configs: Dict[str, Any], overwrite: bool = False
+ session: Session, configs: dict[str, Any], overwrite: bool = False
) -> None:
# discover databases associated with datasets
- database_uuids: Set[str] = set()
+ database_uuids: set[str] = set()
for file_name, config in configs.items():
if file_name.startswith("datasets/"):
database_uuids.add(config["database_uuid"])
# import related databases
- database_ids: Dict[str, int] = {}
+ database_ids: dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/") and config["uuid"] in database_uuids:
database = import_database(session, config, overwrite=False)
diff --git a/superset/datasets/commands/importers/v1/utils.py b/superset/datasets/commands/importers/v1/utils.py
index 52f46829b5..ae47fc411a 100644
--- a/superset/datasets/commands/importers/v1/utils.py
+++ b/superset/datasets/commands/importers/v1/utils.py
@@ -18,7 +18,7 @@ import gzip
import json
import logging
import re
-from typing import Any, Dict
+from typing import Any
from urllib import request
import pandas as pd
@@ -69,7 +69,7 @@ def get_sqla_type(native_type: str) -> VisitableType:
raise Exception(f"Unknown type: {native_type}")
-def get_dtype(df: pd.DataFrame, dataset: SqlaTable) -> Dict[str, VisitableType]:
+def get_dtype(df: pd.DataFrame, dataset: SqlaTable) -> dict[str, VisitableType]:
return {
column.column_name: get_sqla_type(column.type)
for column in dataset.columns
@@ -101,7 +101,7 @@ def validate_data_uri(data_uri: str) -> None:
def import_dataset(
session: Session,
- config: Dict[str, Any],
+ config: dict[str, Any],
overwrite: bool = False,
force_data: bool = False,
ignore_permissions: bool = False,
diff --git a/superset/datasets/commands/update.py b/superset/datasets/commands/update.py
index cc9f480a41..be9625709f 100644
--- a/superset/datasets/commands/update.py
+++ b/superset/datasets/commands/update.py
@@ -16,7 +16,7 @@
# under the License.
import logging
from collections import Counter
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from flask import current_app
from flask_appbuilder.models.sqla import Model
@@ -52,7 +52,7 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
def __init__(
self,
model_id: int,
- data: Dict[str, Any],
+ data: dict[str, Any],
override_columns: Optional[bool] = False,
):
self._model_id = model_id
@@ -76,8 +76,8 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
raise DatasetUpdateFailedError()
def validate(self) -> None:
- exceptions: List[ValidationError] = []
- owner_ids: Optional[List[int]] = self._properties.get("owners")
+ exceptions: list[ValidationError] = []
+ owner_ids: Optional[list[int]] = self._properties.get("owners")
# Validate/populate model exists
self._model = DatasetDAO.find_by_id(self._model_id)
if not self._model:
@@ -125,14 +125,14 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
raise DatasetInvalidError(exceptions=exceptions)
def _validate_columns(
- self, columns: List[Dict[str, Any]], exceptions: List[ValidationError]
+ self, columns: list[dict[str, Any]], exceptions: list[ValidationError]
) -> None:
# Validate duplicates on data
if self._get_duplicates(columns, "column_name"):
exceptions.append(DatasetColumnsDuplicateValidationError())
else:
# validate invalid id's
- columns_ids: List[int] = [
+ columns_ids: list[int] = [
column["id"] for column in columns if "id" in column
]
if not DatasetDAO.validate_columns_exist(self._model_id, columns_ids):
@@ -140,7 +140,7 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
# validate new column names uniqueness
if not self.override_columns:
- columns_names: List[str] = [
+ columns_names: list[str] = [
column["column_name"] for column in columns if "id" not in column
]
if not DatasetDAO.validate_columns_uniqueness(
@@ -149,26 +149,26 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand):
exceptions.append(DatasetColumnsExistsValidationError())
def _validate_metrics(
- self, metrics: List[Dict[str, Any]], exceptions: List[ValidationError]
+ self, metrics: list[dict[str, Any]], exceptions: list[ValidationError]
) -> None:
if self._get_duplicates(metrics, "metric_name"):
exceptions.append(DatasetMetricsDuplicateValidationError())
else:
# validate invalid id's
- metrics_ids: List[int] = [
+ metrics_ids: list[int] = [
metric["id"] for metric in metrics if "id" in metric
]
if not DatasetDAO.validate_metrics_exist(self._model_id, metrics_ids):
exceptions.append(DatasetMetricsNotFoundValidationError())
# validate new metric names uniqueness
- metric_names: List[str] = [
+ metric_names: list[str] = [
metric["metric_name"] for metric in metrics if "id" not in metric
]
if not DatasetDAO.validate_metrics_uniqueness(self._model_id, metric_names):
exceptions.append(DatasetMetricsExistsValidationError())
@staticmethod
- def _get_duplicates(data: List[Dict[str, Any]], key: str) -> List[str]:
+ def _get_duplicates(data: list[dict[str, Any]], key: str) -> list[str]:
duplicates = [
name
for name, count in Counter([item[key] for item in data]).items()
diff --git a/superset/datasets/dao.py b/superset/datasets/dao.py
index b158fce1fe..f4d46be109 100644
--- a/superset/datasets/dao.py
+++ b/superset/datasets/dao.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from sqlalchemy.exc import SQLAlchemyError
@@ -44,7 +44,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
return None
@staticmethod
- def get_related_objects(database_id: int) -> Dict[str, Any]:
+ def get_related_objects(database_id: int) -> dict[str, Any]:
charts = (
db.session.query(Slice)
.filter(
@@ -108,7 +108,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
return not db.session.query(dataset_query.exists()).scalar()
@staticmethod
- def validate_columns_exist(dataset_id: int, columns_ids: List[int]) -> bool:
+ def validate_columns_exist(dataset_id: int, columns_ids: list[int]) -> bool:
dataset_query = (
db.session.query(TableColumn.id).filter(
TableColumn.table_id == dataset_id, TableColumn.id.in_(columns_ids)
@@ -117,7 +117,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
return len(columns_ids) == len(dataset_query)
@staticmethod
- def validate_columns_uniqueness(dataset_id: int, columns_names: List[str]) -> bool:
+ def validate_columns_uniqueness(dataset_id: int, columns_names: list[str]) -> bool:
dataset_query = (
db.session.query(TableColumn.id).filter(
TableColumn.table_id == dataset_id,
@@ -127,7 +127,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
return len(dataset_query) == 0
@staticmethod
- def validate_metrics_exist(dataset_id: int, metrics_ids: List[int]) -> bool:
+ def validate_metrics_exist(dataset_id: int, metrics_ids: list[int]) -> bool:
dataset_query = (
db.session.query(SqlMetric.id).filter(
SqlMetric.table_id == dataset_id, SqlMetric.id.in_(metrics_ids)
@@ -136,7 +136,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
return len(metrics_ids) == len(dataset_query)
@staticmethod
- def validate_metrics_uniqueness(dataset_id: int, metrics_names: List[str]) -> bool:
+ def validate_metrics_uniqueness(dataset_id: int, metrics_names: list[str]) -> bool:
dataset_query = (
db.session.query(SqlMetric.id).filter(
SqlMetric.table_id == dataset_id,
@@ -149,7 +149,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
def update(
cls,
model: SqlaTable,
- properties: Dict[str, Any],
+ properties: dict[str, Any],
commit: bool = True,
) -> Optional[SqlaTable]:
"""
@@ -173,7 +173,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
def update_columns(
cls,
model: SqlaTable,
- property_columns: List[Dict[str, Any]],
+ property_columns: list[dict[str, Any]],
commit: bool = True,
override_columns: bool = False,
) -> None:
@@ -239,7 +239,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
def update_metrics(
cls,
model: SqlaTable,
- property_metrics: List[Dict[str, Any]],
+ property_metrics: list[dict[str, Any]],
commit: bool = True,
) -> None:
"""
@@ -304,14 +304,14 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
def update_column(
cls,
model: TableColumn,
- properties: Dict[str, Any],
+ properties: dict[str, Any],
commit: bool = True,
) -> TableColumn:
return DatasetColumnDAO.update(model, properties, commit=commit)
@classmethod
def create_column(
- cls, properties: Dict[str, Any], commit: bool = True
+ cls, properties: dict[str, Any], commit: bool = True
) -> TableColumn:
"""
Creates a Dataset model on the metadata DB
@@ -346,7 +346,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
def update_metric(
cls,
model: SqlMetric,
- properties: Dict[str, Any],
+ properties: dict[str, Any],
commit: bool = True,
) -> SqlMetric:
return DatasetMetricDAO.update(model, properties, commit=commit)
@@ -354,7 +354,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
@classmethod
def create_metric(
cls,
- properties: Dict[str, Any],
+ properties: dict[str, Any],
commit: bool = True,
) -> SqlMetric:
"""
@@ -363,7 +363,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
return DatasetMetricDAO.create(properties, commit=commit)
@staticmethod
- def bulk_delete(models: Optional[List[SqlaTable]], commit: bool = True) -> None:
+ def bulk_delete(models: Optional[list[SqlaTable]], commit: bool = True) -> None:
item_ids = [model.id for model in models] if models else []
# bulk delete, first delete related data
if models:
diff --git a/superset/datasets/models.py b/superset/datasets/models.py
index b433709f2c..50aeea7b51 100644
--- a/superset/datasets/models.py
+++ b/superset/datasets/models.py
@@ -24,7 +24,6 @@ dataset, new models for columns, metrics, and tables were also introduced.
These models are not fully implemented, and shouldn't be used yet.
"""
-from typing import List
import sqlalchemy as sa
from flask_appbuilder import Model
@@ -87,7 +86,7 @@ class Dataset(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
# The relationship between datasets and columns is 1:n, but we use a
# many-to-many association table to avoid adding two mutually exclusive
# columns(dataset_id and table_id) to Column
- columns: List[Column] = relationship(
+ columns: list[Column] = relationship(
"Column",
secondary=dataset_column_association_table,
cascade="all, delete-orphan",
@@ -97,7 +96,7 @@ class Dataset(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
owners = relationship(
security_manager.user_model, secondary=dataset_user_association_table
)
- tables: List[Table] = relationship(
+ tables: list[Table] = relationship(
"Table", secondary=dataset_table_association_table, backref="datasets"
)
diff --git a/superset/datasets/schemas.py b/superset/datasets/schemas.py
index f248fc70ff..9a2af98066 100644
--- a/superset/datasets/schemas.py
+++ b/superset/datasets/schemas.py
@@ -16,7 +16,7 @@
# under the License.
import json
import re
-from typing import Any, Dict
+from typing import Any
from flask_babel import lazy_gettext as _
from marshmallow import fields, pre_load, Schema, ValidationError
@@ -150,7 +150,7 @@ class DatasetRelatedObjectsResponse(Schema):
class ImportV1ColumnSchema(Schema):
# pylint: disable=no-self-use, unused-argument
@pre_load
- def fix_extra(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
+ def fix_extra(self, data: dict[str, Any], **kwargs: Any) -> dict[str, Any]:
"""
Fix for extra initially being exported as a string.
"""
@@ -176,7 +176,7 @@ class ImportV1ColumnSchema(Schema):
class ImportV1MetricSchema(Schema):
# pylint: disable=no-self-use, unused-argument
@pre_load
- def fix_extra(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
+ def fix_extra(self, data: dict[str, Any], **kwargs: Any) -> dict[str, Any]:
"""
Fix for extra initially being exported as a string.
"""
@@ -198,7 +198,7 @@ class ImportV1MetricSchema(Schema):
class ImportV1DatasetSchema(Schema):
# pylint: disable=no-self-use, unused-argument
@pre_load
- def fix_extra(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
+ def fix_extra(self, data: dict[str, Any], **kwargs: Any) -> dict[str, Any]:
"""
Fix for extra initially being exported as a string.
"""
diff --git a/superset/datasource/dao.py b/superset/datasource/dao.py
index 158a32c7fd..4682f070e2 100644
--- a/superset/datasource/dao.py
+++ b/superset/datasource/dao.py
@@ -16,7 +16,7 @@
# under the License.
import logging
-from typing import Dict, Type, Union
+from typing import Union
from sqlalchemy.orm import Session
@@ -34,7 +34,7 @@ Datasource = Union[Dataset, SqlaTable, Table, Query, SavedQuery]
class DatasourceDAO(BaseDAO):
- sources: Dict[Union[DatasourceType, str], Type[Datasource]] = {
+ sources: dict[Union[DatasourceType, str], type[Datasource]] = {
DatasourceType.TABLE: SqlaTable,
DatasourceType.QUERY: Query,
DatasourceType.SAVEDQUERY: SavedQuery,
diff --git a/superset/db_engine_specs/__init__.py b/superset/db_engine_specs/__init__.py
index f19dffd4a3..20cdfcc51f 100644
--- a/superset/db_engine_specs/__init__.py
+++ b/superset/db_engine_specs/__init__.py
@@ -33,7 +33,7 @@ import pkgutil
from collections import defaultdict
from importlib import import_module
from pathlib import Path
-from typing import Any, Dict, List, Optional, Set, Type
+from typing import Any, Optional
import sqlalchemy.databases
import sqlalchemy.dialects
@@ -58,11 +58,11 @@ def is_engine_spec(obj: Any) -> bool:
)
-def load_engine_specs() -> List[Type[BaseEngineSpec]]:
+def load_engine_specs() -> list[type[BaseEngineSpec]]:
"""
Load all engine specs, native and 3rd party.
"""
- engine_specs: List[Type[BaseEngineSpec]] = []
+ engine_specs: list[type[BaseEngineSpec]] = []
# load standard engines
db_engine_spec_dir = str(Path(__file__).parent)
@@ -85,7 +85,7 @@ def load_engine_specs() -> List[Type[BaseEngineSpec]]:
return engine_specs
-def get_engine_spec(backend: str, driver: Optional[str] = None) -> Type[BaseEngineSpec]:
+def get_engine_spec(backend: str, driver: Optional[str] = None) -> type[BaseEngineSpec]:
"""
Return the DB engine spec associated with a given SQLAlchemy URL.
@@ -120,11 +120,11 @@ backend_replacements = {
}
-def get_available_engine_specs() -> Dict[Type[BaseEngineSpec], Set[str]]:
+def get_available_engine_specs() -> dict[type[BaseEngineSpec], set[str]]:
"""
Return available engine specs and installed drivers for them.
"""
- drivers: Dict[str, Set[str]] = defaultdict(set)
+ drivers: dict[str, set[str]] = defaultdict(set)
# native SQLAlchemy dialects
for attr in sqlalchemy.databases.__all__:
diff --git a/superset/db_engine_specs/athena.py b/superset/db_engine_specs/athena.py
index 047952402d..ad6bed113d 100644
--- a/superset/db_engine_specs/athena.py
+++ b/superset/db_engine_specs/athena.py
@@ -16,7 +16,8 @@
# under the License.
import re
from datetime import datetime
-from typing import Any, Dict, Optional, Pattern, Tuple
+from re import Pattern
+from typing import Any, Optional
from flask_babel import gettext as __
from sqlalchemy import types
@@ -51,7 +52,7 @@ class AthenaEngineSpec(BaseEngineSpec):
date_add('day', 1, CAST({col} AS TIMESTAMP))))",
}
- custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = {
+ custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = {
SYNTAX_ERROR_REGEX: (
__(
"Please check your query for syntax errors at or "
@@ -64,7 +65,7 @@ class AthenaEngineSpec(BaseEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index a7ff862272..ef922a5e63 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -22,22 +22,8 @@ import json
import logging
import re
from datetime import datetime
-from typing import (
- Any,
- Callable,
- ContextManager,
- Dict,
- List,
- Match,
- NamedTuple,
- Optional,
- Pattern,
- Set,
- Tuple,
- Type,
- TYPE_CHECKING,
- Union,
-)
+from re import Match, Pattern
+from typing import Any, Callable, ContextManager, NamedTuple, TYPE_CHECKING, Union
import pandas as pd
import sqlparse
@@ -77,7 +63,7 @@ if TYPE_CHECKING:
from superset.models.core import Database
from superset.models.sql_lab import Query
-ColumnTypeMapping = Tuple[
+ColumnTypeMapping = tuple[
Pattern[str],
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
GenericDataType,
@@ -90,10 +76,10 @@ class TimeGrain(NamedTuple):
name: str # TODO: redundant field, remove
label: str
function: str
- duration: Optional[str]
+ duration: str | None
-builtin_time_grains: Dict[Optional[str], str] = {
+builtin_time_grains: dict[str | None, str] = {
"PT1S": __("Second"),
"PT5S": __("5 second"),
"PT30S": __("30 second"),
@@ -160,12 +146,12 @@ class MetricType(TypedDict, total=False):
metric_name: str
expression: str
- verbose_name: Optional[str]
- metric_type: Optional[str]
- description: Optional[str]
- d3format: Optional[str]
- warning_text: Optional[str]
- extra: Optional[str]
+ verbose_name: str | None
+ metric_type: str | None
+ description: str | None
+ d3format: str | None
+ warning_text: str | None
+ extra: str | None
class BaseEngineSpec: # pylint: disable=too-many-public-methods
@@ -182,19 +168,19 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
having to add the same aggregation in SELECT.
"""
- engine_name: Optional[str] = None # for user messages, overridden in child classes
+ engine_name: str | None = None # for user messages, overridden in child classes
# These attributes map the DB engine spec to one or more SQLAlchemy dialects/drivers;
# see the ``supports_url`` and ``supports_backend`` methods below.
engine = "base" # str as defined in sqlalchemy.engine.engine
- engine_aliases: Set[str] = set()
- drivers: Dict[str, str] = {}
- default_driver: Optional[str] = None
+ engine_aliases: set[str] = set()
+ drivers: dict[str, str] = {}
+ default_driver: str | None = None
disable_ssh_tunneling = False
- _date_trunc_functions: Dict[str, str] = {}
- _time_grain_expressions: Dict[Optional[str], str] = {}
- _default_column_type_mappings: Tuple[ColumnTypeMapping, ...] = (
+ _date_trunc_functions: dict[str, str] = {}
+ _time_grain_expressions: dict[str | None, str] = {}
+ _default_column_type_mappings: tuple[ColumnTypeMapping, ...] = (
(
re.compile(r"^string", re.IGNORECASE),
types.String(),
@@ -312,7 +298,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
),
)
# engine-specific type mappings to check prior to the defaults
- column_type_mappings: Tuple[ColumnTypeMapping, ...] = ()
+ column_type_mappings: tuple[ColumnTypeMapping, ...] = ()
# Does database support join-free timeslot grouping
time_groupby_inline = False
@@ -351,23 +337,23 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
allow_limit_clause = True
# This set will give keywords for select statements
# to consider for the engines with TOP SQL parsing
- select_keywords: Set[str] = {"SELECT"}
+ select_keywords: set[str] = {"SELECT"}
# This set will give the keywords for data limit statements
# to consider for the engines with TOP SQL parsing
- top_keywords: Set[str] = {"TOP"}
+ top_keywords: set[str] = {"TOP"}
# A set of disallowed connection query parameters by driver name
- disallow_uri_query_params: Dict[str, Set[str]] = {}
+ disallow_uri_query_params: dict[str, set[str]] = {}
# A Dict of query parameters that will always be used on every connection
# by driver name
- enforce_uri_query_params: Dict[str, Dict[str, Any]] = {}
+ enforce_uri_query_params: dict[str, dict[str, Any]] = {}
force_column_alias_quotes = False
arraysize = 0
max_column_name_length = 0
try_remove_schema_from_table_name = True # pylint: disable=invalid-name
run_multiple_statements_as_one = False
- custom_errors: Dict[
- Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]
+ custom_errors: dict[
+ Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]
] = {}
# Whether the engine supports file uploads
@@ -422,7 +408,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return cls.supports_backend(backend, driver)
@classmethod
- def supports_backend(cls, backend: str, driver: Optional[str] = None) -> bool:
+ def supports_backend(cls, backend: str, driver: str | None = None) -> bool:
"""
Returns true if the DB engine spec supports a given SQLAlchemy backend/driver.
"""
@@ -439,7 +425,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return driver in cls.drivers
@classmethod
- def get_default_schema(cls, database: Database) -> Optional[str]:
+ def get_default_schema(cls, database: Database) -> str | None:
"""
Return the default schema in a given database.
"""
@@ -450,8 +436,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
def get_schema_from_engine_params( # pylint: disable=unused-argument
cls,
sqlalchemy_uri: URL,
- connect_args: Dict[str, Any],
- ) -> Optional[str]:
+ connect_args: dict[str, Any],
+ ) -> str | None:
"""
Return the schema configured in a SQLALchemy URI and connection argments, if any.
"""
@@ -462,7 +448,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
cls,
database: Database,
query: Query,
- ) -> Optional[str]:
+ ) -> str | None:
"""
Return the default schema for a given query.
@@ -501,7 +487,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return cls.get_default_schema(database)
@classmethod
- def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
+ def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]:
"""
Each engine can implement and converge its own specific exceptions into
Superset DBAPI exceptions
@@ -541,7 +527,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def get_allow_cost_estimate( # pylint: disable=unused-argument
cls,
- extra: Dict[str, Any],
+ extra: dict[str, Any],
) -> bool:
return False
@@ -561,8 +547,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
def get_engine(
cls,
database: Database,
- schema: Optional[str] = None,
- source: Optional[utils.QuerySource] = None,
+ schema: str | None = None,
+ source: utils.QuerySource | None = None,
) -> ContextManager[Engine]:
"""
Return an engine context manager.
@@ -578,8 +564,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
def get_timestamp_expr(
cls,
col: ColumnClause,
- pdf: Optional[str],
- time_grain: Optional[str],
+ pdf: str | None,
+ time_grain: str | None,
) -> TimestampExpression:
"""
Construct a TimestampExpression to be used in a SQLAlchemy query.
@@ -616,7 +602,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return TimestampExpression(time_expr, col, type_=col.type)
@classmethod
- def get_time_grains(cls) -> Tuple[TimeGrain, ...]:
+ def get_time_grains(cls) -> tuple[TimeGrain, ...]:
"""
Generate a tuple of supported time grains.
@@ -634,8 +620,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def _sort_time_grains(
- cls, val: Tuple[Optional[str], str], index: int
- ) -> Union[float, int, str]:
+ cls, val: tuple[str | None, str], index: int
+ ) -> float | int | str:
"""
Return an ordered time-based value of a portion of a time grain
for sorting
@@ -695,7 +681,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return plist.get(index, 0)
@classmethod
- def get_time_grain_expressions(cls) -> Dict[Optional[str], str]:
+ def get_time_grain_expressions(cls) -> dict[str | None, str]:
"""
Return a dict of all supported time grains including any potential added grains
but excluding any potentially disabled grains in the config file.
@@ -706,7 +692,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
time_grain_expressions = cls._time_grain_expressions.copy()
grain_addon_expressions = current_app.config["TIME_GRAIN_ADDON_EXPRESSIONS"]
time_grain_expressions.update(grain_addon_expressions.get(cls.engine, {}))
- denylist: List[str] = current_app.config["TIME_GRAIN_DENYLIST"]
+ denylist: list[str] = current_app.config["TIME_GRAIN_DENYLIST"]
for key in denylist:
time_grain_expressions.pop(key, None)
@@ -723,9 +709,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
)
@classmethod
- def fetch_data(
- cls, cursor: Any, limit: Optional[int] = None
- ) -> List[Tuple[Any, ...]]:
+ def fetch_data(cls, cursor: Any, limit: int | None = None) -> list[tuple[Any, ...]]:
"""
:param cursor: Cursor instance
@@ -743,9 +727,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def expand_data(
- cls, columns: List[ResultSetColumnType], data: List[Dict[Any, Any]]
- ) -> Tuple[
- List[ResultSetColumnType], List[Dict[Any, Any]], List[ResultSetColumnType]
+ cls, columns: list[ResultSetColumnType], data: list[dict[Any, Any]]
+ ) -> tuple[
+ list[ResultSetColumnType], list[dict[Any, Any]], list[ResultSetColumnType]
]:
"""
Some engines support expanding nested fields. See implementation in Presto
@@ -759,7 +743,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return columns, data, []
@classmethod
- def alter_new_orm_column(cls, orm_col: "TableColumn") -> None:
+ def alter_new_orm_column(cls, orm_col: TableColumn) -> None:
"""Allow altering default column attributes when first detected/added
For instance special column like `__time` for Druid can be
@@ -789,7 +773,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return cls.epoch_to_dttm().replace("{col}", "({col}/1000)")
@classmethod
- def get_datatype(cls, type_code: Any) -> Optional[str]:
+ def get_datatype(cls, type_code: Any) -> str | None:
"""
Change column type code from cursor description to string representation.
@@ -802,7 +786,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
@deprecated(deprecated_in="3.0")
- def normalize_indexes(cls, indexes: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ def normalize_indexes(cls, indexes: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Normalizes indexes for more consistency across db engines
@@ -818,8 +802,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
cls,
database: Database,
table_name: str,
- schema_name: Optional[str],
- ) -> Dict[str, Any]:
+ schema_name: str | None,
+ ) -> dict[str, Any]:
"""
Returns engine-specific table metadata
@@ -872,7 +856,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
sql_remainder = None
sql = sql.strip(" \t\n;")
sql_statement = sqlparse.format(sql, strip_comments=True)
- query_limit: Optional[int] = sql_parse.extract_top_from_query(
+ query_limit: int | None = sql_parse.extract_top_from_query(
sql_statement, cls.top_keywords
)
if not limit:
@@ -928,7 +912,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return True
@classmethod
- def get_limit_from_sql(cls, sql: str) -> Optional[int]:
+ def get_limit_from_sql(cls, sql: str) -> int | None:
"""
Extract limit from SQL query
@@ -951,7 +935,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return parsed_query.set_or_update_query_limit(limit)
@classmethod
- def get_cte_query(cls, sql: str) -> Optional[str]:
+ def get_cte_query(cls, sql: str) -> str | None:
"""
Convert the input CTE based SQL to the SQL for virtual table conversion
@@ -981,7 +965,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
database: Database,
table: Table,
df: pd.DataFrame,
- to_sql_kwargs: Dict[str, Any],
+ to_sql_kwargs: dict[str, Any],
) -> None:
"""
Upload data from a Pandas DataFrame to a database.
@@ -1012,8 +996,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def convert_dttm( # pylint: disable=unused-argument
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
- ) -> Optional[str]:
+ cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None
+ ) -> str | None:
"""
Convert a Python `datetime` object to a SQL expression.
@@ -1044,8 +1028,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def extract_errors(
- cls, ex: Exception, context: Optional[Dict[str, Any]] = None
- ) -> List[SupersetError]:
+ cls, ex: Exception, context: dict[str, Any] | None = None
+ ) -> list[SupersetError]:
raw_message = cls._extract_error_message(ex)
context = context or {}
@@ -1076,10 +1060,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
def adjust_engine_params( # pylint: disable=unused-argument
cls,
uri: URL,
- connect_args: Dict[str, Any],
- catalog: Optional[str] = None,
- schema: Optional[str] = None,
- ) -> Tuple[URL, Dict[str, Any]]:
+ connect_args: dict[str, Any],
+ catalog: str | None = None,
+ schema: str | None = None,
+ ) -> tuple[URL, dict[str, Any]]:
"""
Return a new URL and ``connect_args`` for a specific catalog/schema.
@@ -1116,7 +1100,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
cls,
database: Database,
inspector: Inspector,
- ) -> List[str]:
+ ) -> list[str]:
"""
Get all catalogs from database.
@@ -1126,7 +1110,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return []
@classmethod
- def get_schema_names(cls, inspector: Inspector) -> List[str]:
+ def get_schema_names(cls, inspector: Inspector) -> list[str]:
"""
Get all schemas from database
@@ -1140,8 +1124,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
cls,
database: Database,
inspector: Inspector,
- schema: Optional[str],
- ) -> Set[str]:
+ schema: str | None,
+ ) -> set[str]:
"""
Get all the real table names within the specified schema.
@@ -1168,8 +1152,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
cls,
database: Database,
inspector: Inspector,
- schema: Optional[str],
- ) -> Set[str]:
+ schema: str | None,
+ ) -> set[str]:
"""
Get all the view names within the specified schema.
@@ -1197,8 +1181,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
database: Database, # pylint: disable=unused-argument
inspector: Inspector,
table_name: str,
- schema: Optional[str],
- ) -> List[Dict[str, Any]]:
+ schema: str | None,
+ ) -> list[dict[str, Any]]:
"""
Get the indexes associated with the specified schema/table.
@@ -1213,8 +1197,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def get_table_comment(
- cls, inspector: Inspector, table_name: str, schema: Optional[str]
- ) -> Optional[str]:
+ cls, inspector: Inspector, table_name: str, schema: str | None
+ ) -> str | None:
"""
Get comment of table from a given schema and table
@@ -1237,8 +1221,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def get_columns(
- cls, inspector: Inspector, table_name: str, schema: Optional[str]
- ) -> List[Dict[str, Any]]:
+ cls, inspector: Inspector, table_name: str, schema: str | None
+ ) -> list[dict[str, Any]]:
"""
Get all columns from a given schema and table
@@ -1255,8 +1239,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
database: Database,
inspector: Inspector,
table_name: str,
- schema: Optional[str],
- ) -> List[MetricType]:
+ schema: str | None,
+ ) -> list[MetricType]:
"""
Get all metrics from a given schema and table.
"""
@@ -1273,11 +1257,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
def where_latest_partition( # pylint: disable=too-many-arguments,unused-argument
cls,
table_name: str,
- schema: Optional[str],
+ schema: str | None,
database: Database,
query: Select,
- columns: Optional[List[Dict[str, Any]]] = None,
- ) -> Optional[Select]:
+ columns: list[dict[str, Any]] | None = None,
+ ) -> Select | None:
"""
Add a where clause to a query to reference only the most recent partition
@@ -1293,7 +1277,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return None
@classmethod
- def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[Any]:
+ def _get_fields(cls, cols: list[dict[str, Any]]) -> list[Any]:
return [column(c["name"]) for c in cols]
@classmethod
@@ -1302,12 +1286,12 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
database: Database,
table_name: str,
engine: Engine,
- schema: Optional[str] = None,
+ schema: str | None = None,
limit: int = 100,
show_cols: bool = False,
indent: bool = True,
latest_partition: bool = True,
- cols: Optional[List[Dict[str, Any]]] = None,
+ cols: list[dict[str, Any]] | None = None,
) -> str:
"""
Generate a "SELECT * from [schema.]table_name" query with appropriate limit.
@@ -1326,7 +1310,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:return: SQL query
"""
# pylint: disable=redefined-outer-name
- fields: Union[str, List[Any]] = "*"
+ fields: str | list[Any] = "*"
cols = cols or []
if (show_cols or latest_partition) and not cols:
cols = database.get_columns(table_name, schema)
@@ -1355,7 +1339,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return sql
@classmethod
- def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]:
+ def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]:
"""
Generate a SQL query that estimates the cost of a given statement.
@@ -1367,8 +1351,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def query_cost_formatter(
- cls, raw_cost: List[Dict[str, Any]]
- ) -> List[Dict[str, str]]:
+ cls, raw_cost: list[dict[str, Any]]
+ ) -> list[dict[str, str]]:
"""
Format cost estimate.
@@ -1405,8 +1389,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
database: Database,
schema: str,
sql: str,
- source: Optional[utils.QuerySource] = None,
- ) -> List[Dict[str, Any]]:
+ source: utils.QuerySource | None = None,
+ ) -> list[dict[str, Any]]:
"""
Estimate the cost of a multiple statement SQL query.
@@ -1433,7 +1417,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def get_url_for_impersonation(
- cls, url: URL, impersonate_user: bool, username: Optional[str]
+ cls, url: URL, impersonate_user: bool, username: str | None
) -> URL:
"""
Return a modified URL with the username set.
@@ -1450,9 +1434,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def update_impersonation_config(
cls,
- connect_args: Dict[str, Any],
+ connect_args: dict[str, Any],
uri: str,
- username: Optional[str],
+ username: str | None,
) -> None:
"""
Update a configuration dictionary
@@ -1490,7 +1474,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
raise cls.get_dbapi_mapped_exception(ex) from ex
@classmethod
- def make_label_compatible(cls, label: str) -> Union[str, quoted_name]:
+ def make_label_compatible(cls, label: str) -> str | quoted_name:
"""
Conditionally mutate and/or quote a sqlalchemy expression label. If
force_column_alias_quotes is set to True, return the label as a
@@ -1515,8 +1499,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def get_column_types(
cls,
- column_type: Optional[str],
- ) -> Optional[Tuple[TypeEngine, GenericDataType]]:
+ column_type: str | None,
+ ) -> tuple[TypeEngine, GenericDataType] | None:
"""
Return a sqlalchemy native column type and generic data type that corresponds
to the column type defined in the data source (return None to use default type
@@ -1598,7 +1582,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
def get_function_names( # pylint: disable=unused-argument
cls,
database: Database,
- ) -> List[str]:
+ ) -> list[str]:
"""
Get a list of function names that are able to be called on the database.
Used for SQL Lab autocomplete.
@@ -1609,7 +1593,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return []
@staticmethod
- def pyodbc_rows_to_tuples(data: List[Any]) -> List[Tuple[Any, ...]]:
+ def pyodbc_rows_to_tuples(data: list[Any]) -> list[tuple[Any, ...]]:
"""
Convert pyodbc.Row objects from `fetch_data` to tuples.
@@ -1634,7 +1618,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return None
@staticmethod
- def get_extra_params(database: Database) -> Dict[str, Any]:
+ def get_extra_params(database: Database) -> dict[str, Any]:
"""
Some databases require adding elements to connection parameters,
like passing certificates to `extra`. This can be done here.
@@ -1642,7 +1626,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
:param database: database instance from which to extract extras
:raises CertificateException: If certificate is not valid/unparseable
"""
- extra: Dict[str, Any] = {}
+ extra: dict[str, Any] = {}
if database.extra:
try:
extra = json.loads(database.extra)
@@ -1653,7 +1637,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@staticmethod
def update_params_from_encrypted_extra( # pylint: disable=invalid-name
- database: Database, params: Dict[str, Any]
+ database: Database, params: dict[str, Any]
) -> None:
"""
Some databases require some sensitive information which do not conform to
@@ -1691,10 +1675,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def get_column_spec( # pylint: disable=unused-argument
cls,
- native_type: Optional[str],
- db_extra: Optional[Dict[str, Any]] = None,
+ native_type: str | None,
+ db_extra: dict[str, Any] | None = None,
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
- ) -> Optional[ColumnSpec]:
+ ) -> ColumnSpec | None:
"""
Get generic type related specs regarding a native column type.
@@ -1714,10 +1698,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
@classmethod
def get_sqla_column_type(
cls,
- native_type: Optional[str],
- db_extra: Optional[Dict[str, Any]] = None,
+ native_type: str | None,
+ db_extra: dict[str, Any] | None = None,
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
- ) -> Optional[TypeEngine]:
+ ) -> TypeEngine | None:
"""
Converts native database type to sqlalchemy column type.
@@ -1761,7 +1745,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
cls,
cursor: Any,
query: Query,
- ) -> Optional[str]:
+ ) -> str | None:
"""
Select identifiers from the database engine that uniquely identifies the
queries to cancel. The identifier is typically a session id, process id
@@ -1794,11 +1778,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return False
@classmethod
- def parse_sql(cls, sql: str) -> List[str]:
+ def parse_sql(cls, sql: str) -> list[str]:
return [str(s).strip(" ;") for s in sqlparse.parse(sql)]
@classmethod
- def get_impersonation_key(cls, user: Optional[User]) -> Any:
+ def get_impersonation_key(cls, user: User | None) -> Any:
"""
Construct an impersonation key, by default it's the given username.
@@ -1809,7 +1793,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return user.username if user else None
@classmethod
- def mask_encrypted_extra(cls, encrypted_extra: Optional[str]) -> Optional[str]:
+ def mask_encrypted_extra(cls, encrypted_extra: str | None) -> str | None:
"""
Mask ``encrypted_extra``.
@@ -1822,9 +1806,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
# pylint: disable=unused-argument
@classmethod
- def unmask_encrypted_extra(
- cls, old: Optional[str], new: Optional[str]
- ) -> Optional[str]:
+ def unmask_encrypted_extra(cls, old: str | None, new: str | None) -> str | None:
"""
Remove masks from ``encrypted_extra``.
@@ -1835,7 +1817,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return new
@classmethod
- def get_public_information(cls) -> Dict[str, Any]:
+ def get_public_information(cls) -> dict[str, Any]:
"""
Construct a Dict with properties we want to expose.
@@ -1891,12 +1873,12 @@ class BasicParametersSchema(Schema):
class BasicParametersType(TypedDict, total=False):
- username: Optional[str]
- password: Optional[str]
+ username: str | None
+ password: str | None
host: str
port: int
database: str
- query: Dict[str, Any]
+ query: dict[str, Any]
encryption: bool
@@ -1929,13 +1911,13 @@ class BasicParametersMixin:
# query parameter to enable encryption in the database connection
# for Postgres this would be `{"sslmode": "verify-ca"}`, eg.
- encryption_parameters: Dict[str, str] = {}
+ encryption_parameters: dict[str, str] = {}
@classmethod
def build_sqlalchemy_uri( # pylint: disable=unused-argument
cls,
parameters: BasicParametersType,
- encrypted_extra: Optional[Dict[str, str]] = None,
+ encrypted_extra: dict[str, str] | None = None,
) -> str:
# make a copy so that we don't update the original
query = parameters.get("query", {}).copy()
@@ -1958,7 +1940,7 @@ class BasicParametersMixin:
@classmethod
def get_parameters_from_uri( # pylint: disable=unused-argument
- cls, uri: str, encrypted_extra: Optional[Dict[str, Any]] = None
+ cls, uri: str, encrypted_extra: dict[str, Any] | None = None
) -> BasicParametersType:
url = make_url_safe(uri)
query = {
@@ -1982,14 +1964,14 @@ class BasicParametersMixin:
@classmethod
def validate_parameters(
cls, properties: BasicPropertiesType
- ) -> List[SupersetError]:
+ ) -> list[SupersetError]:
"""
Validates any number of parameters, for progressive validation.
If only the hostname is present it will check if the name is resolvable. As more
parameters are present in the request, more validation is done.
"""
- errors: List[SupersetError] = []
+ errors: list[SupersetError] = []
required = {"host", "port", "username", "database"}
parameters = properties.get("parameters", {})
diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py
index 1f5068ad04..3b62f4bbb8 100644
--- a/superset/db_engine_specs/bigquery.py
+++ b/superset/db_engine_specs/bigquery.py
@@ -18,7 +18,8 @@ import json
import re
import urllib
from datetime import datetime
-from typing import Any, Dict, List, Optional, Pattern, Tuple, Type, TYPE_CHECKING
+from re import Pattern
+from typing import Any, Optional, TYPE_CHECKING
import pandas as pd
from apispec import APISpec
@@ -99,8 +100,8 @@ class BigQueryParametersSchema(Schema):
class BigQueryParametersType(TypedDict):
- credentials_info: Dict[str, Any]
- query: Dict[str, Any]
+ credentials_info: dict[str, Any]
+ query: dict[str, Any]
class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-methods
@@ -173,7 +174,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
"P1Y": "{func}({col}, YEAR)",
}
- custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = {
+ custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = {
CONNECTION_DATABASE_PERMISSIONS_REGEX: (
__(
"Unable to connect. Verify that the following roles are set "
@@ -219,7 +220,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.Date):
@@ -235,7 +236,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
@classmethod
def fetch_data(
cls, cursor: Any, limit: Optional[int] = None
- ) -> List[Tuple[Any, ...]]:
+ ) -> list[tuple[Any, ...]]:
data = super().fetch_data(cursor, limit)
# Support type BigQuery Row, introduced here PR #4071
# google.cloud.bigquery.table.Row
@@ -280,7 +281,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
@classmethod
@deprecated(deprecated_in="3.0")
- def normalize_indexes(cls, indexes: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ def normalize_indexes(cls, indexes: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""
Normalizes indexes for more consistency across db engines
@@ -305,7 +306,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
inspector: Inspector,
table_name: str,
schema: Optional[str],
- ) -> List[Dict[str, Any]]:
+ ) -> list[dict[str, Any]]:
"""
Get the indexes associated with the specified schema/table.
@@ -321,7 +322,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
@classmethod
def extra_table_metadata(
cls, database: "Database", table_name: str, schema_name: Optional[str]
- ) -> Dict[str, Any]:
+ ) -> dict[str, Any]:
indexes = database.get_indexes(table_name, schema_name)
if not indexes:
return {}
@@ -354,7 +355,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
database: "Database",
table: Table,
df: pd.DataFrame,
- to_sql_kwargs: Dict[str, Any],
+ to_sql_kwargs: dict[str, Any],
) -> None:
"""
Upload data from a Pandas DataFrame to a database.
@@ -421,7 +422,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
schema: str,
sql: str,
source: Optional[utils.QuerySource] = None,
- ) -> List[Dict[str, Any]]:
+ ) -> list[dict[str, Any]]:
"""
Estimate the cost of a multiple statement SQL query.
@@ -448,7 +449,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
cls,
database: "Database",
inspector: Inspector,
- ) -> List[str]:
+ ) -> list[str]:
"""
Get all catalogs.
@@ -462,11 +463,11 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
return sorted(project.project_id for project in projects)
@classmethod
- def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool:
+ def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool:
return True
@classmethod
- def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]:
+ def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]:
with cls.get_engine(cursor) as engine:
client = cls._get_client(engine)
job_config = bigquery.QueryJobConfig(dry_run=True)
@@ -503,15 +504,15 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
@classmethod
def query_cost_formatter(
- cls, raw_cost: List[Dict[str, Any]]
- ) -> List[Dict[str, str]]:
+ cls, raw_cost: list[dict[str, Any]]
+ ) -> list[dict[str, str]]:
return [{k: str(v) for k, v in row.items()} for row in raw_cost]
@classmethod
def build_sqlalchemy_uri(
cls,
parameters: BigQueryParametersType,
- encrypted_extra: Optional[Dict[str, Any]] = None,
+ encrypted_extra: Optional[dict[str, Any]] = None,
) -> str:
query = parameters.get("query", {})
query_params = urllib.parse.urlencode(query)
@@ -533,7 +534,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
def get_parameters_from_uri(
cls,
uri: str,
- encrypted_extra: Optional[Dict[str, Any]] = None,
+ encrypted_extra: Optional[dict[str, Any]] = None,
) -> Any:
value = make_url_safe(uri)
@@ -592,7 +593,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
return json.dumps(new_config)
@classmethod
- def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
+ def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]:
# pylint: disable=import-outside-toplevel
from google.auth.exceptions import DefaultCredentialsError
@@ -602,7 +603,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
def validate_parameters(
cls,
properties: BasicPropertiesType, # pylint: disable=unused-argument
- ) -> List[SupersetError]:
+ ) -> list[SupersetError]:
return []
@classmethod
@@ -636,7 +637,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
show_cols: bool = False,
indent: bool = True,
latest_partition: bool = True,
- cols: Optional[List[Dict[str, Any]]] = None,
+ cols: Optional[list[dict[str, Any]]] = None,
) -> str:
"""
Remove array structures from `SELECT *`.
@@ -699,7 +700,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met
)
@classmethod
- def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[Any]:
+ def _get_fields(cls, cols: list[dict[str, Any]]) -> list[Any]:
"""
Label columns using their fully qualified name.
diff --git a/superset/db_engine_specs/clickhouse.py b/superset/db_engine_specs/clickhouse.py
index a62087bc6a..af38c15e0b 100644
--- a/superset/db_engine_specs/clickhouse.py
+++ b/superset/db_engine_specs/clickhouse.py
@@ -19,7 +19,7 @@ from __future__ import annotations
import logging
import re
from datetime import datetime
-from typing import Any, cast, Dict, List, Optional, Type, TYPE_CHECKING
+from typing import Any, cast, TYPE_CHECKING
from flask import current_app
from flask_babel import gettext as __
@@ -124,8 +124,8 @@ class ClickHouseBaseEngineSpec(BaseEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
- ) -> Optional[str]:
+ cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None
+ ) -> str | None:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.Date):
@@ -145,7 +145,7 @@ class ClickHouseEngineSpec(ClickHouseBaseEngineSpec):
supports_file_upload = False
@classmethod
- def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
+ def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]:
return {NewConnectionError: SupersetDBAPIDatabaseError}
@classmethod
@@ -159,7 +159,7 @@ class ClickHouseEngineSpec(ClickHouseBaseEngineSpec):
@classmethod
@cache_manager.cache.memoize()
- def get_function_names(cls, database: Database) -> List[str]:
+ def get_function_names(cls, database: Database) -> list[str]:
"""
Get a list of function names that are able to be called on the database.
Used for SQL Lab autocomplete.
@@ -256,7 +256,7 @@ class ClickHouseConnectEngineSpec(ClickHouseEngineSpec, BasicParametersMixin):
engine_name = "ClickHouse Connect (Superset)"
default_driver = "connect"
- _function_names: List[str] = []
+ _function_names: list[str] = []
sqlalchemy_uri_placeholder = (
"clickhousedb://user:password@host[:port][/dbname][?secure=value&=value...]"
@@ -265,7 +265,7 @@ class ClickHouseConnectEngineSpec(ClickHouseEngineSpec, BasicParametersMixin):
encryption_parameters = {"secure": "true"}
@classmethod
- def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
+ def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]:
return {}
@classmethod
@@ -278,7 +278,7 @@ class ClickHouseConnectEngineSpec(ClickHouseEngineSpec, BasicParametersMixin):
return new_exception(str(exception))
@classmethod
- def get_function_names(cls, database: Database) -> List[str]:
+ def get_function_names(cls, database: Database) -> list[str]:
# pylint: disable=import-outside-toplevel,import-error
from clickhouse_connect.driver.exceptions import ClickHouseError
@@ -304,7 +304,7 @@ class ClickHouseConnectEngineSpec(ClickHouseEngineSpec, BasicParametersMixin):
def build_sqlalchemy_uri(
cls,
parameters: BasicParametersType,
- encrypted_extra: Optional[Dict[str, str]] = None,
+ encrypted_extra: dict[str, str] | None = None,
) -> str:
url_params = parameters.copy()
if url_params.get("encryption"):
@@ -318,7 +318,7 @@ class ClickHouseConnectEngineSpec(ClickHouseEngineSpec, BasicParametersMixin):
@classmethod
def get_parameters_from_uri(
- cls, uri: str, encrypted_extra: Optional[Dict[str, Any]] = None
+ cls, uri: str, encrypted_extra: dict[str, Any] | None = None
) -> BasicParametersType:
url = make_url_safe(uri)
query = url.query
@@ -340,7 +340,7 @@ class ClickHouseConnectEngineSpec(ClickHouseEngineSpec, BasicParametersMixin):
@classmethod
def validate_parameters(
cls, properties: BasicPropertiesType
- ) -> List[SupersetError]:
+ ) -> list[SupersetError]:
# pylint: disable=import-outside-toplevel,import-error
from clickhouse_connect.driver import default_port
diff --git a/superset/db_engine_specs/crate.py b/superset/db_engine_specs/crate.py
index 6eafae829e..d8d91c6796 100644
--- a/superset/db_engine_specs/crate.py
+++ b/superset/db_engine_specs/crate.py
@@ -17,7 +17,7 @@
from __future__ import annotations
from datetime import datetime
-from typing import Any, Dict, Optional, TYPE_CHECKING
+from typing import Any, TYPE_CHECKING
from sqlalchemy import types
@@ -53,8 +53,8 @@ class CrateEngineSpec(BaseEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
- ) -> Optional[str]:
+ cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None
+ ) -> str | None:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.TIMESTAMP):
diff --git a/superset/db_engine_specs/databricks.py b/superset/db_engine_specs/databricks.py
index 5f12f3174d..5df24be65d 100644
--- a/superset/db_engine_specs/databricks.py
+++ b/superset/db_engine_specs/databricks.py
@@ -17,7 +17,7 @@
import json
from datetime import datetime
-from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING
+from typing import Any, Optional, TYPE_CHECKING
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
@@ -135,7 +135,7 @@ class DatabricksODBCEngineSpec(BaseEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
return HiveEngineSpec.convert_dttm(target_type, dttm, db_extra=db_extra)
@@ -160,14 +160,14 @@ class DatabricksNativeEngineSpec(DatabricksODBCEngineSpec, BasicParametersMixin)
encryption_parameters = {"ssl": "1"}
@staticmethod
- def get_extra_params(database: "Database") -> Dict[str, Any]:
+ def get_extra_params(database: "Database") -> dict[str, Any]:
"""
Add a user agent to be used in the requests.
Trim whitespace from connect_args to avoid databricks driver errors
"""
- extra: Dict[str, Any] = BaseEngineSpec.get_extra_params(database)
- engine_params: Dict[str, Any] = extra.setdefault("engine_params", {})
- connect_args: Dict[str, Any] = engine_params.setdefault("connect_args", {})
+ extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database)
+ engine_params: dict[str, Any] = extra.setdefault("engine_params", {})
+ connect_args: dict[str, Any] = engine_params.setdefault("connect_args", {})
connect_args.setdefault("http_headers", [("User-Agent", USER_AGENT)])
connect_args.setdefault("_user_agent_entry", USER_AGENT)
@@ -184,7 +184,7 @@ class DatabricksNativeEngineSpec(DatabricksODBCEngineSpec, BasicParametersMixin)
database: "Database",
inspector: Inspector,
schema: Optional[str],
- ) -> Set[str]:
+ ) -> set[str]:
return super().get_table_names(
database, inspector, schema
) - cls.get_view_names(database, inspector, schema)
@@ -213,8 +213,8 @@ class DatabricksNativeEngineSpec(DatabricksODBCEngineSpec, BasicParametersMixin)
@classmethod
def extract_errors(
- cls, ex: Exception, context: Optional[Dict[str, Any]] = None
- ) -> List[SupersetError]:
+ cls, ex: Exception, context: Optional[dict[str, Any]] = None
+ ) -> list[SupersetError]:
raw_message = cls._extract_error_message(ex)
context = context or {}
@@ -271,8 +271,8 @@ class DatabricksNativeEngineSpec(DatabricksODBCEngineSpec, BasicParametersMixin)
def validate_parameters( # type: ignore
cls,
properties: DatabricksPropertiesType,
- ) -> List[SupersetError]:
- errors: List[SupersetError] = []
+ ) -> list[SupersetError]:
+ errors: list[SupersetError] = []
required = {"access_token", "host", "port", "database", "extra"}
extra = json.loads(properties.get("extra", "{}"))
engine_params = extra.get("engine_params", {})
diff --git a/superset/db_engine_specs/dremio.py b/superset/db_engine_specs/dremio.py
index 7fae3014d6..7b4c0458cd 100644
--- a/superset/db_engine_specs/dremio.py
+++ b/superset/db_engine_specs/dremio.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
-from typing import Any, Dict, Optional
+from typing import Any, Optional
from sqlalchemy import types
@@ -46,7 +46,7 @@ class DremioEngineSpec(BaseEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
diff --git a/superset/db_engine_specs/drill.py b/superset/db_engine_specs/drill.py
index 16ac89212a..946544863d 100644
--- a/superset/db_engine_specs/drill.py
+++ b/superset/db_engine_specs/drill.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
-from typing import Any, Dict, Optional, Tuple
+from typing import Any, Optional
from urllib import parse
from sqlalchemy import types
@@ -59,7 +59,7 @@ class DrillEngineSpec(BaseEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
@@ -74,10 +74,10 @@ class DrillEngineSpec(BaseEngineSpec):
def adjust_engine_params(
cls,
uri: URL,
- connect_args: Dict[str, Any],
+ connect_args: dict[str, Any],
catalog: Optional[str] = None,
schema: Optional[str] = None,
- ) -> Tuple[URL, Dict[str, Any]]:
+ ) -> tuple[URL, dict[str, Any]]:
if schema:
uri = uri.set(database=parse.quote(schema.replace(".", "/"), safe=""))
@@ -87,7 +87,7 @@ class DrillEngineSpec(BaseEngineSpec):
def get_schema_from_engine_params(
cls,
sqlalchemy_uri: URL,
- connect_args: Dict[str, Any],
+ connect_args: dict[str, Any],
) -> Optional[str]:
"""
Return the configured schema.
diff --git a/superset/db_engine_specs/druid.py b/superset/db_engine_specs/druid.py
index 83829ec22a..43ce310a40 100644
--- a/superset/db_engine_specs/druid.py
+++ b/superset/db_engine_specs/druid.py
@@ -20,7 +20,7 @@ from __future__ import annotations
import json
import logging
from datetime import datetime
-from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING
+from typing import Any, TYPE_CHECKING
from sqlalchemy import types
from sqlalchemy.engine.reflection import Inspector
@@ -79,7 +79,7 @@ class DruidEngineSpec(BaseEngineSpec):
orm_col.is_dttm = True
@staticmethod
- def get_extra_params(database: Database) -> Dict[str, Any]:
+ def get_extra_params(database: Database) -> dict[str, Any]:
"""
For Druid, the path to a SSL certificate is placed in `connect_args`.
@@ -104,8 +104,8 @@ class DruidEngineSpec(BaseEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
- ) -> Optional[str]:
+ cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None
+ ) -> str | None:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.Date):
@@ -130,15 +130,15 @@ class DruidEngineSpec(BaseEngineSpec):
@classmethod
def get_columns(
- cls, inspector: Inspector, table_name: str, schema: Optional[str]
- ) -> List[Dict[str, Any]]:
+ cls, inspector: Inspector, table_name: str, schema: str | None
+ ) -> list[dict[str, Any]]:
"""
Update the Druid type map.
"""
return super().get_columns(inspector, table_name, schema)
@classmethod
- def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
+ def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]:
# pylint: disable=import-outside-toplevel
from requests import exceptions as requests_exceptions
diff --git a/superset/db_engine_specs/duckdb.py b/superset/db_engine_specs/duckdb.py
index 1248287b84..3bbf9ecc38 100644
--- a/superset/db_engine_specs/duckdb.py
+++ b/superset/db_engine_specs/duckdb.py
@@ -18,7 +18,8 @@ from __future__ import annotations
import re
from datetime import datetime
-from typing import Any, Dict, Optional, Pattern, Set, Tuple, TYPE_CHECKING
+from re import Pattern
+from typing import Any, TYPE_CHECKING
from flask_babel import gettext as __
from sqlalchemy import types
@@ -51,7 +52,7 @@ class DuckDBEngineSpec(BaseEngineSpec):
"P1Y": "DATE_TRUNC('year', {col})",
}
- custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = {
+ custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = {
COLUMN_DOES_NOT_EXIST_REGEX: (
__('We can\'t seem to resolve the column "%(column_name)s"'),
SupersetErrorType.COLUMN_DOES_NOT_EXIST_ERROR,
@@ -65,8 +66,8 @@ class DuckDBEngineSpec(BaseEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
- ) -> Optional[str]:
+ cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None
+ ) -> str | None:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, (types.String, types.DateTime)):
@@ -75,6 +76,6 @@ class DuckDBEngineSpec(BaseEngineSpec):
@classmethod
def get_table_names(
- cls, database: Database, inspector: Inspector, schema: Optional[str]
- ) -> Set[str]:
+ cls, database: Database, inspector: Inspector, schema: str | None
+ ) -> set[str]:
return set(inspector.get_table_names(schema))
diff --git a/superset/db_engine_specs/dynamodb.py b/superset/db_engine_specs/dynamodb.py
index c398a9c1df..5f7a9e2b71 100644
--- a/superset/db_engine_specs/dynamodb.py
+++ b/superset/db_engine_specs/dynamodb.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
-from typing import Any, Dict, Optional
+from typing import Any, Optional
from sqlalchemy import types
@@ -55,7 +55,7 @@ class DynamoDBEngineSpec(BaseEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
diff --git a/superset/db_engine_specs/elasticsearch.py b/superset/db_engine_specs/elasticsearch.py
index 934aa0bb03..d717c52bf5 100644
--- a/superset/db_engine_specs/elasticsearch.py
+++ b/superset/db_engine_specs/elasticsearch.py
@@ -16,7 +16,7 @@
# under the License.
import logging
from datetime import datetime
-from typing import Any, Dict, Optional, Type
+from typing import Any, Optional
from packaging.version import Version
from sqlalchemy import types
@@ -50,10 +50,10 @@ class ElasticSearchEngineSpec(BaseEngineSpec): # pylint: disable=abstract-metho
"P1Y": "HISTOGRAM({col}, INTERVAL 1 YEAR)",
}
- type_code_map: Dict[int, str] = {} # loaded from get_datatype only if needed
+ type_code_map: dict[int, str] = {} # loaded from get_datatype only if needed
@classmethod
- def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
+ def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]:
# pylint: disable=import-error,import-outside-toplevel
import es.exceptions as es_exceptions
@@ -65,7 +65,7 @@ class ElasticSearchEngineSpec(BaseEngineSpec): # pylint: disable=abstract-metho
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
db_extra = db_extra or {}
@@ -117,7 +117,7 @@ class OpenDistroEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
diff --git a/superset/db_engine_specs/exasol.py b/superset/db_engine_specs/exasol.py
index c06fbd826d..6da56e2fee 100644
--- a/superset/db_engine_specs/exasol.py
+++ b/superset/db_engine_specs/exasol.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, List, Optional, Tuple
+from typing import Any, Optional
from superset.db_engine_specs.base import BaseEngineSpec
@@ -42,7 +42,7 @@ class ExasolEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
@classmethod
def fetch_data(
cls, cursor: Any, limit: Optional[int] = None
- ) -> List[Tuple[Any, ...]]:
+ ) -> list[tuple[Any, ...]]:
data = super().fetch_data(cursor, limit)
# Lists of `pyodbc.Row` need to be unpacked further
return cls.pyodbc_rows_to_tuples(data)
diff --git a/superset/db_engine_specs/firebird.py b/superset/db_engine_specs/firebird.py
index 306a642dc3..4448074157 100644
--- a/superset/db_engine_specs/firebird.py
+++ b/superset/db_engine_specs/firebird.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
-from typing import Any, Dict, Optional
+from typing import Any, Optional
from sqlalchemy import types
@@ -72,7 +72,7 @@ class FirebirdEngineSpec(BaseEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
diff --git a/superset/db_engine_specs/firebolt.py b/superset/db_engine_specs/firebolt.py
index 65cd714352..ace3d6b3b2 100644
--- a/superset/db_engine_specs/firebolt.py
+++ b/superset/db_engine_specs/firebolt.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
-from typing import Any, Dict, Optional
+from typing import Any, Optional
from sqlalchemy import types
@@ -43,7 +43,7 @@ class FireboltEngineSpec(BaseEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
diff --git a/superset/db_engine_specs/gsheets.py b/superset/db_engine_specs/gsheets.py
index 73a66c464f..abf5bac48f 100644
--- a/superset/db_engine_specs/gsheets.py
+++ b/superset/db_engine_specs/gsheets.py
@@ -16,7 +16,8 @@
# under the License.
import json
import re
-from typing import Any, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING
+from re import Pattern
+from typing import Any, Optional, TYPE_CHECKING
from apispec import APISpec
from apispec.ext.marshmallow import MarshmallowPlugin
@@ -56,12 +57,12 @@ class GSheetsParametersSchema(Schema):
class GSheetsParametersType(TypedDict):
service_account_info: str
- catalog: Optional[Dict[str, str]]
+ catalog: Optional[dict[str, str]]
class GSheetsPropertiesType(TypedDict):
parameters: GSheetsParametersType
- catalog: Dict[str, str]
+ catalog: dict[str, str]
class GSheetsEngineSpec(SqliteEngineSpec):
@@ -77,7 +78,7 @@ class GSheetsEngineSpec(SqliteEngineSpec):
default_driver = "apsw"
sqlalchemy_uri_placeholder = "gsheets://"
- custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = {
+ custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = {
SYNTAX_ERROR_REGEX: (
__(
'Please check your query for syntax errors near "%(server_error)s". '
@@ -110,7 +111,7 @@ class GSheetsEngineSpec(SqliteEngineSpec):
database: "Database",
table_name: str,
schema_name: Optional[str],
- ) -> Dict[str, Any]:
+ ) -> dict[str, Any]:
with database.get_raw_connection(schema=schema_name) as conn:
cursor = conn.cursor()
cursor.execute(f'SELECT GET_METADATA("{table_name}")')
@@ -127,7 +128,7 @@ class GSheetsEngineSpec(SqliteEngineSpec):
cls,
_: GSheetsParametersType,
encrypted_extra: Optional[ # pylint: disable=unused-argument
- Dict[str, Any]
+ dict[str, Any]
] = None,
) -> str:
return "gsheets://"
@@ -136,7 +137,7 @@ class GSheetsEngineSpec(SqliteEngineSpec):
def get_parameters_from_uri(
cls,
uri: str, # pylint: disable=unused-argument
- encrypted_extra: Optional[Dict[str, Any]] = None,
+ encrypted_extra: Optional[dict[str, Any]] = None,
) -> Any:
# Building parameters from encrypted_extra and uri
if encrypted_extra:
@@ -214,8 +215,8 @@ class GSheetsEngineSpec(SqliteEngineSpec):
def validate_parameters(
cls,
properties: GSheetsPropertiesType,
- ) -> List[SupersetError]:
- errors: List[SupersetError] = []
+ ) -> list[SupersetError]:
+ errors: list[SupersetError] = []
# backwards compatible just incase people are send data
# via parameters for validation
diff --git a/superset/db_engine_specs/hana.py b/superset/db_engine_specs/hana.py
index e579550b2e..108838f9d2 100644
--- a/superset/db_engine_specs/hana.py
+++ b/superset/db_engine_specs/hana.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
-from typing import Any, Dict, Optional
+from typing import Any, Optional
from sqlalchemy import types
@@ -45,7 +45,7 @@ class HanaEngineSpec(PostgresBaseEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py
index 6d8986c1c7..7601ebb2cd 100644
--- a/superset/db_engine_specs/hive.py
+++ b/superset/db_engine_specs/hive.py
@@ -22,7 +22,7 @@ import re
import tempfile
import time
from datetime import datetime
-from typing import Any, Dict, List, Optional, Set, Tuple, TYPE_CHECKING
+from typing import Any, TYPE_CHECKING
from urllib import parse
import numpy as np
@@ -150,9 +150,7 @@ class HiveEngineSpec(PrestoEngineSpec):
hive.Cursor.fetch_logs = fetch_logs
@classmethod
- def fetch_data(
- cls, cursor: Any, limit: Optional[int] = None
- ) -> List[Tuple[Any, ...]]:
+ def fetch_data(cls, cursor: Any, limit: int | None = None) -> list[tuple[Any, ...]]:
# pylint: disable=import-outside-toplevel
import pyhive
from TCLIService import ttypes
@@ -168,10 +166,10 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod
def df_to_sql(
cls,
- database: "Database",
+ database: Database,
table: Table,
df: pd.DataFrame,
- to_sql_kwargs: Dict[str, Any],
+ to_sql_kwargs: dict[str, Any],
) -> None:
"""
Upload data from a Pandas DataFrame to a database.
@@ -248,8 +246,8 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
- ) -> Optional[str]:
+ cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None
+ ) -> str | None:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.Date):
@@ -263,10 +261,10 @@ class HiveEngineSpec(PrestoEngineSpec):
def adjust_engine_params(
cls,
uri: URL,
- connect_args: Dict[str, Any],
- catalog: Optional[str] = None,
- schema: Optional[str] = None,
- ) -> Tuple[URL, Dict[str, Any]]:
+ connect_args: dict[str, Any],
+ catalog: str | None = None,
+ schema: str | None = None,
+ ) -> tuple[URL, dict[str, Any]]:
if schema:
uri = uri.set(database=parse.quote(schema, safe=""))
@@ -276,8 +274,8 @@ class HiveEngineSpec(PrestoEngineSpec):
def get_schema_from_engine_params(
cls,
sqlalchemy_uri: URL,
- connect_args: Dict[str, Any],
- ) -> Optional[str]:
+ connect_args: dict[str, Any],
+ ) -> str | None:
"""
Return the configured schema.
"""
@@ -292,10 +290,10 @@ class HiveEngineSpec(PrestoEngineSpec):
return msg
@classmethod
- def progress(cls, log_lines: List[str]) -> int:
+ def progress(cls, log_lines: list[str]) -> int:
total_jobs = 1 # assuming there's at least 1 job
current_job = 1
- stages: Dict[int, float] = {}
+ stages: dict[int, float] = {}
for line in log_lines:
match = cls.jobs_stats_r.match(line)
if match:
@@ -323,7 +321,7 @@ class HiveEngineSpec(PrestoEngineSpec):
return int(progress)
@classmethod
- def get_tracking_url_from_logs(cls, log_lines: List[str]) -> Optional[str]:
+ def get_tracking_url_from_logs(cls, log_lines: list[str]) -> str | None:
lkp = "Tracking URL = "
for line in log_lines:
if lkp in line:
@@ -407,19 +405,19 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod
def get_columns(
- cls, inspector: Inspector, table_name: str, schema: Optional[str]
- ) -> List[Dict[str, Any]]:
+ cls, inspector: Inspector, table_name: str, schema: str | None
+ ) -> list[dict[str, Any]]:
return inspector.get_columns(table_name, schema)
@classmethod
def where_latest_partition( # pylint: disable=too-many-arguments
cls,
table_name: str,
- schema: Optional[str],
- database: "Database",
+ schema: str | None,
+ database: Database,
query: Select,
- columns: Optional[List[Dict[str, Any]]] = None,
- ) -> Optional[Select]:
+ columns: list[dict[str, Any]] | None = None,
+ ) -> Select | None:
try:
col_names, values = cls.latest_partition(
table_name, schema, database, show_first=True
@@ -437,18 +435,18 @@ class HiveEngineSpec(PrestoEngineSpec):
return None
@classmethod
- def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[ColumnClause]:
+ def _get_fields(cls, cols: list[dict[str, Any]]) -> list[ColumnClause]:
return BaseEngineSpec._get_fields(cols) # pylint: disable=protected-access
@classmethod
def latest_sub_partition( # type: ignore
- cls, table_name: str, schema: Optional[str], database: "Database", **kwargs: Any
+ cls, table_name: str, schema: str | None, database: Database, **kwargs: Any
) -> str:
# TODO(bogdan): implement`
pass
@classmethod
- def _latest_partition_from_df(cls, df: pd.DataFrame) -> Optional[List[str]]:
+ def _latest_partition_from_df(cls, df: pd.DataFrame) -> list[str] | None:
"""Hive partitions look like ds={partition name}/ds={partition name}"""
if not df.empty:
return [
@@ -461,12 +459,12 @@ class HiveEngineSpec(PrestoEngineSpec):
def _partition_query( # pylint: disable=too-many-arguments
cls,
table_name: str,
- schema: Optional[str],
- indexes: List[Dict[str, Any]],
- database: "Database",
+ schema: str | None,
+ indexes: list[dict[str, Any]],
+ database: Database,
limit: int = 0,
- order_by: Optional[List[Tuple[str, bool]]] = None,
- filters: Optional[Dict[Any, Any]] = None,
+ order_by: list[tuple[str, bool]] | None = None,
+ filters: dict[Any, Any] | None = None,
) -> str:
full_table_name = f"{schema}.{table_name}" if schema else table_name
return f"SHOW PARTITIONS {full_table_name}"
@@ -474,15 +472,15 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod
def select_star( # pylint: disable=too-many-arguments
cls,
- database: "Database",
+ database: Database,
table_name: str,
engine: Engine,
- schema: Optional[str] = None,
+ schema: str | None = None,
limit: int = 100,
show_cols: bool = False,
indent: bool = True,
latest_partition: bool = True,
- cols: Optional[List[Dict[str, Any]]] = None,
+ cols: list[dict[str, Any]] | None = None,
) -> str:
return super( # pylint: disable=bad-super-call
PrestoEngineSpec, cls
@@ -500,7 +498,7 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod
def get_url_for_impersonation(
- cls, url: URL, impersonate_user: bool, username: Optional[str]
+ cls, url: URL, impersonate_user: bool, username: str | None
) -> URL:
"""
Return a modified URL with the username set.
@@ -516,9 +514,9 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod
def update_impersonation_config(
cls,
- connect_args: Dict[str, Any],
+ connect_args: dict[str, Any],
uri: str,
- username: Optional[str],
+ username: str | None,
) -> None:
"""
Update a configuration dictionary
@@ -549,7 +547,7 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod
@cache_manager.cache.memoize()
- def get_function_names(cls, database: "Database") -> List[str]:
+ def get_function_names(cls, database: Database) -> list[str]:
"""
Get a list of function names that are able to be called on the database.
Used for SQL Lab autocomplete.
@@ -600,10 +598,10 @@ class HiveEngineSpec(PrestoEngineSpec):
@classmethod
def get_view_names(
cls,
- database: "Database",
+ database: Database,
inspector: Inspector,
- schema: Optional[str],
- ) -> Set[str]:
+ schema: str | None,
+ ) -> set[str]:
"""
Get all the view names within the specified schema.
@@ -635,9 +633,9 @@ class HiveEngineSpec(PrestoEngineSpec):
# TODO: contribute back to pyhive.
def fetch_logs( # pylint: disable=protected-access
- self: "Cursor",
+ self: Cursor,
_max_rows: int = 1024,
- orientation: Optional["TFetchOrientation"] = None,
+ orientation: TFetchOrientation | None = None,
) -> str:
"""Mocked. Retrieve the logs produced by the execution of the query.
Can be called multiple times to fetch the logs produced after
diff --git a/superset/db_engine_specs/impala.py b/superset/db_engine_specs/impala.py
index e59c2b74fb..cd1c9e4732 100644
--- a/superset/db_engine_specs/impala.py
+++ b/superset/db_engine_specs/impala.py
@@ -18,7 +18,7 @@ import logging
import re
import time
from datetime import datetime
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from flask import current_app
from sqlalchemy import types
@@ -57,7 +57,7 @@ class ImpalaEngineSpec(BaseEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
@@ -68,7 +68,7 @@ class ImpalaEngineSpec(BaseEngineSpec):
return None
@classmethod
- def get_schema_names(cls, inspector: Inspector) -> List[str]:
+ def get_schema_names(cls, inspector: Inspector) -> list[str]:
schemas = [
row[0]
for row in inspector.engine.execute("SHOW SCHEMAS")
diff --git a/superset/db_engine_specs/kusto.py b/superset/db_engine_specs/kusto.py
index 9fddb23d26..17147d5cc0 100644
--- a/superset/db_engine_specs/kusto.py
+++ b/superset/db_engine_specs/kusto.py
@@ -16,7 +16,7 @@
# under the License.
import re
from datetime import datetime
-from typing import Any, Dict, List, Optional, Type
+from typing import Any, Optional
from sqlalchemy import types
from sqlalchemy.dialects.mssql.base import SMALLDATETIME
@@ -61,7 +61,7 @@ class KustoSqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
" DATEDIFF(week, 0, DATEADD(day, -1, {col})), 0)",
}
- type_code_map: Dict[int, str] = {} # loaded from get_datatype only if needed
+ type_code_map: dict[int, str] = {} # loaded from get_datatype only if needed
column_type_mappings = (
(
@@ -72,7 +72,7 @@ class KustoSqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
)
@classmethod
- def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
+ def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]:
# pylint: disable=import-outside-toplevel,import-error
import sqlalchemy_kusto.errors as kusto_exceptions
@@ -84,7 +84,7 @@ class KustoSqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
@@ -128,10 +128,10 @@ class KustoKqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
"P1Y": "datetime_diff('year',CreateDate, datetime(0001-01-01 00:00:00))+1",
}
- type_code_map: Dict[int, str] = {} # loaded from get_datatype only if needed
+ type_code_map: dict[int, str] = {} # loaded from get_datatype only if needed
@classmethod
- def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
+ def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]:
# pylint: disable=import-outside-toplevel,import-error
import sqlalchemy_kusto.errors as kusto_exceptions
@@ -143,7 +143,7 @@ class KustoKqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
@@ -168,7 +168,7 @@ class KustoKqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
return not parsed_query.sql.startswith(".")
@classmethod
- def parse_sql(cls, sql: str) -> List[str]:
+ def parse_sql(cls, sql: str) -> list[str]:
"""
Kusto supports a single query statement, but it could include sub queries
and variables declared via let keyword.
diff --git a/superset/db_engine_specs/kylin.py b/superset/db_engine_specs/kylin.py
index e340daea51..f522602a48 100644
--- a/superset/db_engine_specs/kylin.py
+++ b/superset/db_engine_specs/kylin.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
-from typing import Any, Dict, Optional
+from typing import Any, Optional
from sqlalchemy import types
@@ -42,7 +42,7 @@ class KylinEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
diff --git a/superset/db_engine_specs/mssql.py b/superset/db_engine_specs/mssql.py
index 8b38ec7421..3e0879b904 100644
--- a/superset/db_engine_specs/mssql.py
+++ b/superset/db_engine_specs/mssql.py
@@ -17,7 +17,8 @@
import logging
import re
from datetime import datetime
-from typing import Any, Dict, List, Optional, Pattern, Tuple
+from re import Pattern
+from typing import Any, Optional
from flask_babel import gettext as __
from sqlalchemy import types
@@ -80,7 +81,7 @@ class MssqlEngineSpec(BaseEngineSpec):
),
)
- custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = {
+ custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = {
CONNECTION_ACCESS_DENIED_REGEX: (
__(
'Either the username "%(username)s", password, '
@@ -115,7 +116,7 @@ class MssqlEngineSpec(BaseEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
@@ -132,7 +133,7 @@ class MssqlEngineSpec(BaseEngineSpec):
@classmethod
def fetch_data(
cls, cursor: Any, limit: Optional[int] = None
- ) -> List[Tuple[Any, ...]]:
+ ) -> list[tuple[Any, ...]]:
data = super().fetch_data(cursor, limit)
# Lists of `pyodbc.Row` need to be unpacked further
return cls.pyodbc_rows_to_tuples(data)
diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py
index 6258f6b21a..9f853d577c 100644
--- a/superset/db_engine_specs/mysql.py
+++ b/superset/db_engine_specs/mysql.py
@@ -16,7 +16,8 @@
# under the License.
import re
from datetime import datetime
-from typing import Any, Dict, Optional, Pattern, Tuple
+from re import Pattern
+from typing import Any, Optional
from urllib import parse
from flask_babel import gettext as __
@@ -143,9 +144,9 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
"INTERVAL 1 DAY)) - 1 DAY))",
}
- type_code_map: Dict[int, str] = {} # loaded from get_datatype only if needed
+ type_code_map: dict[int, str] = {} # loaded from get_datatype only if needed
- custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = {
+ custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = {
CONNECTION_ACCESS_DENIED_REGEX: (
__('Either the username "%(username)s" or the password is incorrect.'),
SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR,
@@ -186,7 +187,7 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
@@ -201,10 +202,10 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
def adjust_engine_params(
cls,
uri: URL,
- connect_args: Dict[str, Any],
+ connect_args: dict[str, Any],
catalog: Optional[str] = None,
schema: Optional[str] = None,
- ) -> Tuple[URL, Dict[str, Any]]:
+ ) -> tuple[URL, dict[str, Any]]:
uri, new_connect_args = super().adjust_engine_params(
uri,
connect_args,
@@ -221,7 +222,7 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
def get_schema_from_engine_params(
cls,
sqlalchemy_uri: URL,
- connect_args: Dict[str, Any],
+ connect_args: dict[str, Any],
) -> Optional[str]:
"""
Return the configured schema.
diff --git a/superset/db_engine_specs/ocient.py b/superset/db_engine_specs/ocient.py
index 4b8a59117e..59fa52a656 100644
--- a/superset/db_engine_specs/ocient.py
+++ b/superset/db_engine_specs/ocient.py
@@ -17,7 +17,8 @@
import re
import threading
-from typing import Any, Callable, Dict, List, NamedTuple, Optional, Pattern, Set, Tuple
+from re import Pattern
+from typing import Any, Callable, List, NamedTuple, Optional
from flask_babel import gettext as __
from sqlalchemy.engine.reflection import Inspector
@@ -178,15 +179,13 @@ def _polygon_to_geo_json(
# Sanitization function for column values
SanitizeFunc = Callable[[Any], Any]
+
# Represents a pair of a column index and the sanitization function
# to apply to its values.
-PlacedSanitizeFunc = NamedTuple(
- "PlacedSanitizeFunc",
- [
- ("column_index", int),
- ("sanitize_func", SanitizeFunc),
- ],
-)
+class PlacedSanitizeFunc(NamedTuple):
+ column_index: int
+ sanitize_func: SanitizeFunc
+
# This map contains functions used to sanitize values for column types
# that cannot be processed natively by Superset.
@@ -199,7 +198,7 @@ PlacedSanitizeFunc = NamedTuple(
try:
from pyocient import TypeCodes
- _sanitized_ocient_type_codes: Dict[int, SanitizeFunc] = {
+ _sanitized_ocient_type_codes: dict[int, SanitizeFunc] = {
TypeCodes.BINARY: _to_hex,
TypeCodes.ST_POINT: _point_to_geo_json,
TypeCodes.IP: str,
@@ -211,7 +210,7 @@ except ImportError as e:
_sanitized_ocient_type_codes = {}
-def _find_columns_to_sanitize(cursor: Any) -> List[PlacedSanitizeFunc]:
+def _find_columns_to_sanitize(cursor: Any) -> list[PlacedSanitizeFunc]:
"""
Cleans the column value for consumption by Superset.
@@ -238,10 +237,10 @@ class OcientEngineSpec(BaseEngineSpec):
# Store mapping of superset Query id -> Ocient ID
# These are inserted into the cache when executing the query
# They are then removed, either upon cancellation or query completion
- query_id_mapping: Dict[str, str] = dict()
+ query_id_mapping: dict[str, str] = dict()
query_id_mapping_lock = threading.Lock()
- custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = {
+ custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = {
CONNECTION_INVALID_USERNAME_REGEX: (
__('The username "%(username)s" does not exist.'),
SupersetErrorType.CONNECTION_INVALID_USERNAME_ERROR,
@@ -309,15 +308,15 @@ class OcientEngineSpec(BaseEngineSpec):
@classmethod
def get_table_names(
cls, database: Database, inspector: Inspector, schema: Optional[str]
- ) -> Set[str]:
+ ) -> set[str]:
return inspector.get_table_names(schema)
@classmethod
def fetch_data(
cls, cursor: Any, limit: Optional[int] = None
- ) -> List[Tuple[Any, ...]]:
+ ) -> list[tuple[Any, ...]]:
try:
- rows: List[Tuple[Any, ...]] = super().fetch_data(cursor, limit)
+ rows: list[tuple[Any, ...]] = super().fetch_data(cursor, limit)
except Exception as exception:
with OcientEngineSpec.query_id_mapping_lock:
del OcientEngineSpec.query_id_mapping[
@@ -329,7 +328,7 @@ class OcientEngineSpec(BaseEngineSpec):
if len(rows) > 0 and type(rows[0]).__name__ == "Row":
# Peek at the schema to determine which column values, if any,
# require sanitization.
- columns_to_sanitize: List[PlacedSanitizeFunc] = _find_columns_to_sanitize(
+ columns_to_sanitize: list[PlacedSanitizeFunc] = _find_columns_to_sanitize(
cursor
)
@@ -341,7 +340,7 @@ class OcientEngineSpec(BaseEngineSpec):
# Use the identity function if the column type doesn't need to be
# sanitized.
- sanitization_functions: List[SanitizeFunc] = [
+ sanitization_functions: list[SanitizeFunc] = [
identity for _ in range(len(cursor.description))
]
for info in columns_to_sanitize:
diff --git a/superset/db_engine_specs/oracle.py b/superset/db_engine_specs/oracle.py
index 4a219919bb..1199b74406 100644
--- a/superset/db_engine_specs/oracle.py
+++ b/superset/db_engine_specs/oracle.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Optional
from sqlalchemy import types
@@ -43,7 +43,7 @@ class OracleEngineSpec(BaseEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
@@ -68,7 +68,7 @@ class OracleEngineSpec(BaseEngineSpec):
@classmethod
def fetch_data(
cls, cursor: Any, limit: Optional[int] = None
- ) -> List[Tuple[Any, ...]]:
+ ) -> list[tuple[Any, ...]]:
"""
:param cursor: Cursor instance
:param limit: Maximum number of rows to be returned by the cursor
diff --git a/superset/db_engine_specs/pinot.py b/superset/db_engine_specs/pinot.py
index cebdd693a4..bfec8b2947 100644
--- a/superset/db_engine_specs/pinot.py
+++ b/superset/db_engine_specs/pinot.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Dict, Optional
+from typing import Optional
from sqlalchemy.sql.expression import ColumnClause
@@ -30,7 +30,7 @@ class PinotEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
allows_alias_in_orderby = False
# Pinot does its own conversion below
- _time_grain_expressions: Dict[Optional[str], str] = {
+ _time_grain_expressions: dict[Optional[str], str] = {
"PT1S": "1:SECONDS",
"PT1M": "1:MINUTES",
"PT5M": "5:MINUTES",
@@ -45,7 +45,7 @@ class PinotEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
"P1Y": "year",
}
- _python_to_java_time_patterns: Dict[str, str] = {
+ _python_to_java_time_patterns: dict[str, str] = {
"%Y": "yyyy",
"%m": "MM",
"%d": "dd",
@@ -54,7 +54,7 @@ class PinotEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
"%S": "ss",
}
- _use_date_trunc_function: Dict[str, bool] = {
+ _use_date_trunc_function: dict[str, bool] = {
"PT1S": False,
"PT1M": False,
"PT5M": False,
diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py
index e809187af6..2088782f83 100644
--- a/superset/db_engine_specs/postgres.py
+++ b/superset/db_engine_specs/postgres.py
@@ -18,7 +18,8 @@ import json
import logging
import re
from datetime import datetime
-from typing import Any, Dict, List, Optional, Pattern, Set, Tuple, TYPE_CHECKING
+from re import Pattern
+from typing import Any, Optional, TYPE_CHECKING
from flask_babel import gettext as __
from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, ENUM, JSON
@@ -73,7 +74,7 @@ COLUMN_DOES_NOT_EXIST_REGEX = re.compile(
SYNTAX_ERROR_REGEX = re.compile('syntax error at or near "(?P.*?)"')
-def parse_options(connect_args: Dict[str, Any]) -> Dict[str, str]:
+def parse_options(connect_args: dict[str, Any]) -> dict[str, str]:
"""
Parse ``options`` from ``connect_args`` into a dictionary.
"""
@@ -109,7 +110,7 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
"P1Y": "DATE_TRUNC('year', {col})",
}
- custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = {
+ custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = {
CONNECTION_INVALID_USERNAME_REGEX: (
__('The username "%(username)s" does not exist.'),
SupersetErrorType.CONNECTION_INVALID_USERNAME_ERROR,
@@ -169,7 +170,7 @@ class PostgresBaseEngineSpec(BaseEngineSpec):
@classmethod
def fetch_data(
cls, cursor: Any, limit: Optional[int] = None
- ) -> List[Tuple[Any, ...]]:
+ ) -> list[tuple[Any, ...]]:
if not cursor.description:
return []
return super().fetch_data(cursor, limit)
@@ -221,7 +222,7 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
def get_schema_from_engine_params(
cls,
sqlalchemy_uri: URL,
- connect_args: Dict[str, Any],
+ connect_args: dict[str, Any],
) -> Optional[str]:
"""
Return the configured schema.
@@ -253,10 +254,10 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
def adjust_engine_params(
cls,
uri: URL,
- connect_args: Dict[str, Any],
+ connect_args: dict[str, Any],
catalog: Optional[str] = None,
schema: Optional[str] = None,
- ) -> Tuple[URL, Dict[str, Any]]:
+ ) -> tuple[URL, dict[str, Any]]:
if not schema:
return uri, connect_args
@@ -269,11 +270,11 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
return uri, connect_args
@classmethod
- def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool:
+ def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool:
return True
@classmethod
- def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]:
+ def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]:
sql = f"EXPLAIN {statement}"
cursor.execute(sql)
@@ -289,8 +290,8 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
@classmethod
def query_cost_formatter(
- cls, raw_cost: List[Dict[str, Any]]
- ) -> List[Dict[str, str]]:
+ cls, raw_cost: list[dict[str, Any]]
+ ) -> list[dict[str, str]]:
return [{k: str(v) for k, v in row.items()} for row in raw_cost]
@classmethod
@@ -298,7 +299,7 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
cls,
database: "Database",
inspector: Inspector,
- ) -> List[str]:
+ ) -> list[str]:
"""
Return all catalogs.
@@ -317,7 +318,7 @@ WHERE datistemplate = false;
@classmethod
def get_table_names(
cls, database: "Database", inspector: PGInspector, schema: Optional[str]
- ) -> Set[str]:
+ ) -> set[str]:
"""Need to consider foreign tables for PostgreSQL"""
return set(inspector.get_table_names(schema)) | set(
inspector.get_foreign_table_names(schema)
@@ -325,7 +326,7 @@ WHERE datistemplate = false;
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
@@ -337,7 +338,7 @@ WHERE datistemplate = false;
return None
@staticmethod
- def get_extra_params(database: "Database") -> Dict[str, Any]:
+ def get_extra_params(database: "Database") -> dict[str, Any]:
"""
For Postgres, the path to a SSL certificate is placed in `connect_args`.
diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py
index 82b05e53e3..d5a2ab7605 100644
--- a/superset/db_engine_specs/presto.py
+++ b/superset/db_engine_specs/presto.py
@@ -23,19 +23,9 @@ import time
from abc import ABCMeta
from collections import defaultdict, deque
from datetime import datetime
+from re import Pattern
from textwrap import dedent
-from typing import (
- Any,
- cast,
- Dict,
- List,
- Optional,
- Pattern,
- Set,
- Tuple,
- TYPE_CHECKING,
- Union,
-)
+from typing import Any, cast, Optional, TYPE_CHECKING
from urllib import parse
import pandas as pd
@@ -78,7 +68,7 @@ if TYPE_CHECKING:
# need try/catch because pyhive may not be installed
try:
- from pyhive.presto import Cursor # pylint: disable=unused-import
+ from pyhive.presto import Cursor
except ImportError:
pass
@@ -107,7 +97,7 @@ CONNECTION_UNKNOWN_DATABASE_ERROR = re.compile(
logger = logging.getLogger(__name__)
-def get_children(column: ResultSetColumnType) -> List[ResultSetColumnType]:
+def get_children(column: ResultSetColumnType) -> list[ResultSetColumnType]:
"""
Get the children of a complex Presto type (row or array).
@@ -276,8 +266,8 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
- ) -> Optional[str]:
+ cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None
+ ) -> str | None:
"""
Convert a Python `datetime` object to a SQL expression.
:param target_type: The target type of expression
@@ -304,10 +294,10 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
def adjust_engine_params(
cls,
uri: URL,
- connect_args: Dict[str, Any],
- catalog: Optional[str] = None,
- schema: Optional[str] = None,
- ) -> Tuple[URL, Dict[str, Any]]:
+ connect_args: dict[str, Any],
+ catalog: str | None = None,
+ schema: str | None = None,
+ ) -> tuple[URL, dict[str, Any]]:
database = uri.database
if schema and database:
schema = parse.quote(schema, safe="")
@@ -323,8 +313,8 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
def get_schema_from_engine_params(
cls,
sqlalchemy_uri: URL,
- connect_args: Dict[str, Any],
- ) -> Optional[str]:
+ connect_args: dict[str, Any],
+ ) -> str | None:
"""
Return the configured schema.
@@ -341,7 +331,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
return parse.unquote(database.split("/")[1])
@classmethod
- def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]:
+ def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]:
"""
Run a SQL query that estimates the cost of a given statement.
:param statement: A single SQL statement
@@ -369,8 +359,8 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
@classmethod
def query_cost_formatter(
- cls, raw_cost: List[Dict[str, Any]]
- ) -> List[Dict[str, str]]:
+ cls, raw_cost: list[dict[str, Any]]
+ ) -> list[dict[str, str]]:
"""
Format cost estimate.
:param raw_cost: JSON estimate from Trino
@@ -401,7 +391,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
("networkCost", "Network cost", ""),
]
for row in raw_cost:
- estimate: Dict[str, float] = row.get("estimate", {})
+ estimate: dict[str, float] = row.get("estimate", {})
statement_cost = {}
for key, label, suffix in columns:
if key in estimate:
@@ -412,7 +402,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
@classmethod
@cache_manager.data_cache.memoize()
- def get_function_names(cls, database: Database) -> List[str]:
+ def get_function_names(cls, database: Database) -> list[str]:
"""
Get a list of function names that are able to be called on the database.
Used for SQL Lab autocomplete.
@@ -426,12 +416,12 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
def _partition_query( # pylint: disable=too-many-arguments,too-many-locals,unused-argument
cls,
table_name: str,
- schema: Optional[str],
- indexes: List[Dict[str, Any]],
+ schema: str | None,
+ indexes: list[dict[str, Any]],
database: Database,
limit: int = 0,
- order_by: Optional[List[Tuple[str, bool]]] = None,
- filters: Optional[Dict[Any, Any]] = None,
+ order_by: list[tuple[str, bool]] | None = None,
+ filters: dict[Any, Any] | None = None,
) -> str:
"""
Return a partition query.
@@ -449,7 +439,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
order
:param filters: dict of field name and filter value combinations
"""
- limit_clause = "LIMIT {}".format(limit) if limit else ""
+ limit_clause = f"LIMIT {limit}" if limit else ""
order_by_clause = ""
if order_by:
l = []
@@ -492,11 +482,11 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
def where_latest_partition( # pylint: disable=too-many-arguments
cls,
table_name: str,
- schema: Optional[str],
+ schema: str | None,
database: Database,
query: Select,
- columns: Optional[List[Dict[str, Any]]] = None,
- ) -> Optional[Select]:
+ columns: list[dict[str, Any]] | None = None,
+ ) -> Select | None:
try:
col_names, values = cls.latest_partition(
table_name, schema, database, show_first=True
@@ -525,7 +515,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
return query
@classmethod
- def _latest_partition_from_df(cls, df: pd.DataFrame) -> Optional[List[str]]:
+ def _latest_partition_from_df(cls, df: pd.DataFrame) -> list[str] | None:
if not df.empty:
return df.to_records(index=False)[0].item()
return None
@@ -535,10 +525,10 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
def latest_partition(
cls,
table_name: str,
- schema: Optional[str],
+ schema: str | None,
database: Database,
show_first: bool = False,
- ) -> Tuple[List[str], Optional[List[str]]]:
+ ) -> tuple[list[str], list[str] | None]:
"""Returns col name and the latest (max) partition value for a table
:param table_name: the name of the table
@@ -589,7 +579,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta):
@classmethod
def latest_sub_partition(
- cls, table_name: str, schema: Optional[str], database: Database, **kwargs: Any
+ cls, table_name: str, schema: str | None, database: Database, **kwargs: Any
) -> Any:
"""Returns the latest (max) partition value for a table
@@ -652,7 +642,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
engine_name = "Presto"
allows_alias_to_source_column = False
- custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = {
+ custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = {
COLUMN_DOES_NOT_EXIST_REGEX: (
__(
'We can\'t seem to resolve the column "%(column_name)s" at '
@@ -708,16 +698,16 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
}
@classmethod
- def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool:
+ def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool:
version = extra.get("version")
return version is not None and Version(version) >= Version("0.319")
@classmethod
def update_impersonation_config(
cls,
- connect_args: Dict[str, Any],
+ connect_args: dict[str, Any],
uri: str,
- username: Optional[str],
+ username: str | None,
) -> None:
"""
Update a configuration dictionary
@@ -741,8 +731,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
cls,
database: Database,
inspector: Inspector,
- schema: Optional[str],
- ) -> Set[str]:
+ schema: str | None,
+ ) -> set[str]:
"""
Get all the real table names within the specified schema.
@@ -769,8 +759,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
cls,
database: Database,
inspector: Inspector,
- schema: Optional[str],
- ) -> Set[str]:
+ schema: str | None,
+ ) -> set[str]:
"""
Get all the view names within the specified schema.
@@ -817,7 +807,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
cls,
database: Database,
inspector: Inspector,
- ) -> List[str]:
+ ) -> list[str]:
"""
Get all catalogs.
"""
@@ -826,7 +816,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
@classmethod
def _create_column_info(
cls, name: str, data_type: types.TypeEngine
- ) -> Dict[str, Any]:
+ ) -> dict[str, Any]:
"""
Create column info object
:param name: column name
@@ -836,7 +826,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
return {"name": name, "type": f"{data_type}"}
@classmethod
- def _get_full_name(cls, names: List[Tuple[str, str]]) -> str:
+ def _get_full_name(cls, names: list[tuple[str, str]]) -> str:
"""
Get the full column name
:param names: list of all individual column names
@@ -860,7 +850,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
)
@classmethod
- def _split_data_type(cls, data_type: str, delimiter: str) -> List[str]:
+ def _split_data_type(cls, data_type: str, delimiter: str) -> list[str]:
"""
Split data type based on given delimiter. Do not split the string if the
delimiter is enclosed in quotes
@@ -869,16 +859,14 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
comma, whitespace)
:return: list of strings after breaking it by the delimiter
"""
- return re.split(
- r"{}(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)".format(delimiter), data_type
- )
+ return re.split(rf"{delimiter}(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)", data_type)
@classmethod
def _parse_structural_column( # pylint: disable=too-many-locals
cls,
parent_column_name: str,
parent_data_type: str,
- result: List[Dict[str, Any]],
+ result: list[dict[str, Any]],
) -> None:
"""
Parse a row or array column
@@ -893,7 +881,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
# split on open parenthesis ( to get the structural
# data type and its component types
data_types = cls._split_data_type(full_data_type, r"\(")
- stack: List[Tuple[str, str]] = []
+ stack: list[tuple[str, str]] = []
for data_type in data_types:
# split on closed parenthesis ) to track which component
# types belong to what structural data type
@@ -962,8 +950,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
@classmethod
def _show_columns(
- cls, inspector: Inspector, table_name: str, schema: Optional[str]
- ) -> List[ResultRow]:
+ cls, inspector: Inspector, table_name: str, schema: str | None
+ ) -> list[ResultRow]:
"""
Show presto column names
:param inspector: object that performs database schema inspection
@@ -974,13 +962,13 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
quote = inspector.engine.dialect.identifier_preparer.quote_identifier
full_table = quote(table_name)
if schema:
- full_table = "{}.{}".format(quote(schema), full_table)
+ full_table = f"{quote(schema)}.{full_table}"
return inspector.bind.execute(f"SHOW COLUMNS FROM {full_table}").fetchall()
@classmethod
def get_columns(
- cls, inspector: Inspector, table_name: str, schema: Optional[str]
- ) -> List[Dict[str, Any]]:
+ cls, inspector: Inspector, table_name: str, schema: str | None
+ ) -> list[dict[str, Any]]:
"""
Get columns from a Presto data source. This includes handling row and
array data types
@@ -991,7 +979,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
(i.e. column name and data type)
"""
columns = cls._show_columns(inspector, table_name, schema)
- result: List[Dict[str, Any]] = []
+ result: list[dict[str, Any]] = []
for column in columns:
# parse column if it is a row or array
if is_feature_enabled("PRESTO_EXPAND_DATA") and (
@@ -1031,7 +1019,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
return column_name.startswith('"') and column_name.endswith('"')
@classmethod
- def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[ColumnClause]:
+ def _get_fields(cls, cols: list[dict[str, Any]]) -> list[ColumnClause]:
"""
Format column clauses where names are in quotes and labels are specified
:param cols: columns
@@ -1053,7 +1041,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
# quote each column name if it is not already quoted
for index, col_name in enumerate(col_names):
if not cls._is_column_name_quoted(col_name):
- col_names[index] = '"{}"'.format(col_name)
+ col_names[index] = f'"{col_name}"'
quoted_col_name = ".".join(
col_name if cls._is_column_name_quoted(col_name) else f'"{col_name}"'
for col_name in col_names
@@ -1069,12 +1057,12 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
database: Database,
table_name: str,
engine: Engine,
- schema: Optional[str] = None,
+ schema: str | None = None,
limit: int = 100,
show_cols: bool = False,
indent: bool = True,
latest_partition: bool = True,
- cols: Optional[List[Dict[str, Any]]] = None,
+ cols: list[dict[str, Any]] | None = None,
) -> str:
"""
Include selecting properties of row objects. We cannot easily break arrays into
@@ -1102,9 +1090,9 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
@classmethod
def expand_data( # pylint: disable=too-many-locals
- cls, columns: List[ResultSetColumnType], data: List[Dict[Any, Any]]
- ) -> Tuple[
- List[ResultSetColumnType], List[Dict[Any, Any]], List[ResultSetColumnType]
+ cls, columns: list[ResultSetColumnType], data: list[dict[Any, Any]]
+ ) -> tuple[
+ list[ResultSetColumnType], list[dict[Any, Any]], list[ResultSetColumnType]
]:
"""
We do not immediately display rows and arrays clearly in the data grid. This
@@ -1133,7 +1121,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
# process each column, unnesting ARRAY types and
# expanding ROW types into new columns
to_process = deque((column, 0) for column in columns)
- all_columns: List[ResultSetColumnType] = []
+ all_columns: list[ResultSetColumnType] = []
expanded_columns = []
current_array_level = None
while to_process:
@@ -1147,11 +1135,11 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
# added by the first. every time we change a level in the nested arrays
# we reinitialize this.
if level != current_array_level:
- unnested_rows: Dict[int, int] = defaultdict(int)
+ unnested_rows: dict[int, int] = defaultdict(int)
current_array_level = level
name = column["name"]
- values: Optional[Union[str, List[Any]]]
+ values: str | list[Any] | None
if column["type"] and column["type"].startswith("ARRAY("):
# keep processing array children; we append to the right so that
@@ -1198,7 +1186,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
for row in data:
values = row.get(name) or []
if isinstance(values, str):
- values = cast(Optional[List[Any]], destringify(values))
+ values = cast(Optional[list[Any]], destringify(values))
row[name] = values
for value, col in zip(values or [], expanded):
row[col["name"]] = value
@@ -1211,8 +1199,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
@classmethod
def extra_table_metadata(
- cls, database: Database, table_name: str, schema_name: Optional[str]
- ) -> Dict[str, Any]:
+ cls, database: Database, table_name: str, schema_name: str | None
+ ) -> dict[str, Any]:
metadata = {}
if indexes := database.get_indexes(table_name, schema_name):
@@ -1243,8 +1231,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
@classmethod
def get_create_view(
- cls, database: Database, schema: Optional[str], table: str
- ) -> Optional[str]:
+ cls, database: Database, schema: str | None, table: str
+ ) -> str | None:
"""
Return a CREATE VIEW statement, or `None` if not a view.
@@ -1267,7 +1255,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
return rows[0][0]
@classmethod
- def get_tracking_url(cls, cursor: "Cursor") -> Optional[str]:
+ def get_tracking_url(cls, cursor: Cursor) -> str | None:
try:
if cursor.last_query_id:
# pylint: disable=protected-access, line-too-long
@@ -1277,7 +1265,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec):
return None
@classmethod
- def handle_cursor(cls, cursor: "Cursor", query: Query, session: Session) -> None:
+ def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None:
"""Updates progress information"""
if tracking_url := cls.get_tracking_url(cursor):
query.tracking_url = tracking_url
diff --git a/superset/db_engine_specs/redshift.py b/superset/db_engine_specs/redshift.py
index 27b749e418..2e746a6349 100644
--- a/superset/db_engine_specs/redshift.py
+++ b/superset/db_engine_specs/redshift.py
@@ -16,7 +16,8 @@
# under the License.
import logging
import re
-from typing import Any, Dict, Optional, Pattern, Tuple
+from re import Pattern
+from typing import Any, Optional
import pandas as pd
from flask_babel import gettext as __
@@ -66,7 +67,7 @@ class RedshiftEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
encryption_parameters = {"sslmode": "verify-ca"}
- custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = {
+ custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = {
CONNECTION_ACCESS_DENIED_REGEX: (
__('Either the username "%(username)s" or the password is incorrect.'),
SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR,
@@ -106,7 +107,7 @@ class RedshiftEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin):
database: Database,
table: Table,
df: pd.DataFrame,
- to_sql_kwargs: Dict[str, Any],
+ to_sql_kwargs: dict[str, Any],
) -> None:
"""
Upload data from a Pandas DataFrame to a database.
diff --git a/superset/db_engine_specs/rockset.py b/superset/db_engine_specs/rockset.py
index cc215054be..71adca0b10 100644
--- a/superset/db_engine_specs/rockset.py
+++ b/superset/db_engine_specs/rockset.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
-from typing import Any, Dict, Optional, TYPE_CHECKING
+from typing import Any, Optional, TYPE_CHECKING
from sqlalchemy import types
@@ -51,7 +51,7 @@ class RocksetEngineSpec(BaseEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py
index 69ccf55931..32ade649b0 100644
--- a/superset/db_engine_specs/snowflake.py
+++ b/superset/db_engine_specs/snowflake.py
@@ -18,7 +18,8 @@ import json
import logging
import re
from datetime import datetime
-from typing import Any, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING
+from re import Pattern
+from typing import Any, Optional, TYPE_CHECKING
from urllib import parse
from apispec import APISpec
@@ -107,7 +108,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
"P1Y": "DATE_TRUNC('YEAR', {col})",
}
- custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = {
+ custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = {
OBJECT_DOES_NOT_EXIST_REGEX: (
__("%(object)s does not exist in this database."),
SupersetErrorType.OBJECT_DOES_NOT_EXIST_ERROR,
@@ -124,13 +125,13 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
}
@staticmethod
- def get_extra_params(database: "Database") -> Dict[str, Any]:
+ def get_extra_params(database: "Database") -> dict[str, Any]:
"""
Add a user agent to be used in the requests.
"""
- extra: Dict[str, Any] = BaseEngineSpec.get_extra_params(database)
- engine_params: Dict[str, Any] = extra.setdefault("engine_params", {})
- connect_args: Dict[str, Any] = engine_params.setdefault("connect_args", {})
+ extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database)
+ engine_params: dict[str, Any] = extra.setdefault("engine_params", {})
+ connect_args: dict[str, Any] = engine_params.setdefault("connect_args", {})
connect_args.setdefault("application", USER_AGENT)
@@ -140,10 +141,10 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
def adjust_engine_params(
cls,
uri: URL,
- connect_args: Dict[str, Any],
+ connect_args: dict[str, Any],
catalog: Optional[str] = None,
schema: Optional[str] = None,
- ) -> Tuple[URL, Dict[str, Any]]:
+ ) -> tuple[URL, dict[str, Any]]:
database = uri.database
if "/" in database:
database = database.split("/")[0]
@@ -157,7 +158,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
def get_schema_from_engine_params(
cls,
sqlalchemy_uri: URL,
- connect_args: Dict[str, Any],
+ connect_args: dict[str, Any],
) -> Optional[str]:
"""
Return the configured schema.
@@ -174,7 +175,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
cls,
database: "Database",
inspector: Inspector,
- ) -> List[str]:
+ ) -> list[str]:
"""
Return all catalogs.
@@ -197,7 +198,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
@@ -261,7 +262,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
cls,
parameters: SnowflakeParametersType,
encrypted_extra: Optional[ # pylint: disable=unused-argument
- Dict[str, Any]
+ dict[str, Any]
] = None,
) -> str:
return str(
@@ -283,7 +284,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
cls,
uri: str,
encrypted_extra: Optional[ # pylint: disable=unused-argument
- Dict[str, str]
+ dict[str, str]
] = None,
) -> Any:
url = make_url_safe(uri)
@@ -300,8 +301,8 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
@classmethod
def validate_parameters(
cls, properties: BasicPropertiesType
- ) -> List[SupersetError]:
- errors: List[SupersetError] = []
+ ) -> list[SupersetError]:
+ errors: list[SupersetError] = []
required = {
"warehouse",
"username",
@@ -346,7 +347,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
@staticmethod
def update_params_from_encrypted_extra(
database: "Database",
- params: Dict[str, Any],
+ params: dict[str, Any],
) -> None:
if not database.encrypted_extra:
return
diff --git a/superset/db_engine_specs/sqlite.py b/superset/db_engine_specs/sqlite.py
index a414143296..767d0a20ad 100644
--- a/superset/db_engine_specs/sqlite.py
+++ b/superset/db_engine_specs/sqlite.py
@@ -16,7 +16,8 @@
# under the License.
import re
from datetime import datetime
-from typing import Any, Dict, Optional, Pattern, Set, Tuple, TYPE_CHECKING
+from re import Pattern
+from typing import Any, Optional, TYPE_CHECKING
from flask_babel import gettext as __
from sqlalchemy import types
@@ -60,7 +61,7 @@ class SqliteEngineSpec(BaseEngineSpec):
),
}
- custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = {
+ custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = {
COLUMN_DOES_NOT_EXIST_REGEX: (
__('We can\'t seem to resolve the column "%(column_name)s"'),
SupersetErrorType.COLUMN_DOES_NOT_EXIST_ERROR,
@@ -74,7 +75,7 @@ class SqliteEngineSpec(BaseEngineSpec):
@classmethod
def convert_dttm(
- cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, (types.String, types.DateTime)):
@@ -84,6 +85,6 @@ class SqliteEngineSpec(BaseEngineSpec):
@classmethod
def get_table_names(
cls, database: "Database", inspector: Inspector, schema: Optional[str]
- ) -> Set[str]:
+ ) -> set[str]:
"""Need to disregard the schema for Sqlite"""
return set(inspector.get_table_names())
diff --git a/superset/db_engine_specs/starrocks.py b/superset/db_engine_specs/starrocks.py
index f687fdbdb3..63269439af 100644
--- a/superset/db_engine_specs/starrocks.py
+++ b/superset/db_engine_specs/starrocks.py
@@ -17,7 +17,8 @@
import logging
import re
-from typing import Any, Dict, List, Optional, Pattern, Tuple, Type
+from re import Pattern
+from typing import Any, Optional
from urllib import parse
from flask_babel import gettext as __
@@ -40,11 +41,11 @@ CONNECTION_UNKNOWN_DATABASE_REGEX = re.compile("Unknown database '(?P.
logger = logging.getLogger(__name__)
-class TINYINT(Integer): # pylint: disable=no-init
+class TINYINT(Integer):
__visit_name__ = "TINYINT"
-class DOUBLE(Numeric): # pylint: disable=no-init
+class DOUBLE(Numeric):
__visit_name__ = "DOUBLE"
@@ -52,7 +53,7 @@ class ARRAY(TypeEngine): # pylint: disable=no-init
__visit_name__ = "ARRAY"
@property
- def python_type(self) -> Optional[Type[List[Any]]]:
+ def python_type(self) -> Optional[type[list[Any]]]:
return list
@@ -60,7 +61,7 @@ class MAP(TypeEngine): # pylint: disable=no-init
__visit_name__ = "MAP"
@property
- def python_type(self) -> Optional[Type[Dict[Any, Any]]]:
+ def python_type(self) -> Optional[type[dict[Any, Any]]]:
return dict
@@ -68,7 +69,7 @@ class STRUCT(TypeEngine): # pylint: disable=no-init
__visit_name__ = "STRUCT"
@property
- def python_type(self) -> Optional[Type[Any]]:
+ def python_type(self) -> Optional[type[Any]]:
return None
@@ -117,7 +118,7 @@ class StarRocksEngineSpec(MySQLEngineSpec):
(re.compile(r"^struct.*", re.IGNORECASE), STRUCT(), GenericDataType.STRING),
)
- custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = {
+ custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = {
CONNECTION_ACCESS_DENIED_REGEX: (
__('Either the username "%(username)s" or the password is incorrect.'),
SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR,
@@ -134,10 +135,10 @@ class StarRocksEngineSpec(MySQLEngineSpec):
def adjust_engine_params(
cls,
uri: URL,
- connect_args: Dict[str, Any],
+ connect_args: dict[str, Any],
catalog: Optional[str] = None,
schema: Optional[str] = None,
- ) -> Tuple[URL, Dict[str, Any]]:
+ ) -> tuple[URL, dict[str, Any]]:
database = uri.database
if schema and database:
schema = parse.quote(schema, safe="")
@@ -152,9 +153,9 @@ class StarRocksEngineSpec(MySQLEngineSpec):
@classmethod
def get_columns(
cls, inspector: Inspector, table_name: str, schema: Optional[str]
- ) -> List[Dict[str, Any]]:
+ ) -> list[dict[str, Any]]:
columns = cls._show_columns(inspector, table_name, schema)
- result: List[Dict[str, Any]] = []
+ result: list[dict[str, Any]] = []
for column in columns:
column_spec = cls.get_column_spec(column.Type)
column_type = column_spec.sqla_type if column_spec else None
@@ -174,7 +175,7 @@ class StarRocksEngineSpec(MySQLEngineSpec):
@classmethod
def _show_columns(
cls, inspector: Inspector, table_name: str, schema: Optional[str]
- ) -> List[ResultRow]:
+ ) -> list[ResultRow]:
"""
Show starrocks column names
:param inspector: object that performs database schema inspection
@@ -185,13 +186,13 @@ class StarRocksEngineSpec(MySQLEngineSpec):
quote = inspector.engine.dialect.identifier_preparer.quote_identifier
full_table = quote(table_name)
if schema:
- full_table = "{}.{}".format(quote(schema), full_table)
+ full_table = f"{quote(schema)}.{full_table}"
return inspector.bind.execute(f"SHOW COLUMNS FROM {full_table}").fetchall()
@classmethod
def _create_column_info(
cls, name: str, data_type: types.TypeEngine
- ) -> Dict[str, Any]:
+ ) -> dict[str, Any]:
"""
Create column info object
:param name: column name
@@ -204,7 +205,7 @@ class StarRocksEngineSpec(MySQLEngineSpec):
def get_schema_from_engine_params(
cls,
sqlalchemy_uri: URL,
- connect_args: Dict[str, Any],
+ connect_args: dict[str, Any],
) -> Optional[str]:
"""
Return the configured schema.
diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py
index 0fa4d05cbc..f05bd67ec3 100644
--- a/superset/db_engine_specs/trino.py
+++ b/superset/db_engine_specs/trino.py
@@ -17,7 +17,7 @@
from __future__ import annotations
import logging
-from typing import Any, Dict, Optional, Type, TYPE_CHECKING
+from typing import Any, TYPE_CHECKING
import simplejson as json
from flask import current_app
@@ -53,8 +53,8 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
cls,
database: Database,
table_name: str,
- schema_name: Optional[str],
- ) -> Dict[str, Any]:
+ schema_name: str | None,
+ ) -> dict[str, Any]:
metadata = {}
if indexes := database.get_indexes(table_name, schema_name):
@@ -68,12 +68,12 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
metadata["partitions"] = {
"cols": sorted(
list(
- set(
+ {
column_name
for index in indexes
if index.get("name") == "partition"
for column_name in index.get("column_names", [])
- )
+ }
)
),
"latest": dict(zip(col_names, latest_parts)),
@@ -95,9 +95,9 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
@classmethod
def update_impersonation_config(
cls,
- connect_args: Dict[str, Any],
+ connect_args: dict[str, Any],
uri: str,
- username: Optional[str],
+ username: str | None,
) -> None:
"""
Update a configuration dictionary
@@ -118,7 +118,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
@classmethod
def get_url_for_impersonation(
- cls, url: URL, impersonate_user: bool, username: Optional[str]
+ cls, url: URL, impersonate_user: bool, username: str | None
) -> URL:
"""
Return a modified URL with the username set.
@@ -131,11 +131,11 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
return url
@classmethod
- def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool:
+ def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool:
return True
@classmethod
- def get_tracking_url(cls, cursor: Cursor) -> Optional[str]:
+ def get_tracking_url(cls, cursor: Cursor) -> str | None:
try:
return cursor.info_uri
except AttributeError:
@@ -199,7 +199,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
return True
@staticmethod
- def get_extra_params(database: Database) -> Dict[str, Any]:
+ def get_extra_params(database: Database) -> dict[str, Any]:
"""
Some databases require adding elements to connection parameters,
like passing certificates to `extra`. This can be done here.
@@ -207,9 +207,9 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
:param database: database instance from which to extract extras
:raises CertificateException: If certificate is not valid/unparseable
"""
- extra: Dict[str, Any] = BaseEngineSpec.get_extra_params(database)
- engine_params: Dict[str, Any] = extra.setdefault("engine_params", {})
- connect_args: Dict[str, Any] = engine_params.setdefault("connect_args", {})
+ extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database)
+ engine_params: dict[str, Any] = extra.setdefault("engine_params", {})
+ connect_args: dict[str, Any] = engine_params.setdefault("connect_args", {})
connect_args.setdefault("source", USER_AGENT)
@@ -222,7 +222,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
@staticmethod
def update_params_from_encrypted_extra(
database: Database,
- params: Dict[str, Any],
+ params: dict[str, Any],
) -> None:
if not database.encrypted_extra:
return
@@ -262,7 +262,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec):
raise ex
@classmethod
- def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
+ def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]:
# pylint: disable=import-outside-toplevel
from requests import exceptions as requests_exceptions
diff --git a/superset/embedded/dao.py b/superset/embedded/dao.py
index 957a7242a7..27ca338502 100644
--- a/superset/embedded/dao.py
+++ b/superset/embedded/dao.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict, List
+from typing import Any
from superset.dao.base import BaseDAO
from superset.extensions import db
@@ -31,7 +31,7 @@ class EmbeddedDAO(BaseDAO):
id_column_name = "uuid"
@staticmethod
- def upsert(dashboard: Dashboard, allowed_domains: List[str]) -> EmbeddedDashboard:
+ def upsert(dashboard: Dashboard, allowed_domains: list[str]) -> EmbeddedDashboard:
"""
Sets up a dashboard to be embeddable.
Upsert is used to preserve the embedded_dashboard uuid across updates.
@@ -45,7 +45,7 @@ class EmbeddedDAO(BaseDAO):
return embedded
@classmethod
- def create(cls, properties: Dict[str, Any], commit: bool = True) -> Any:
+ def create(cls, properties: dict[str, Any], commit: bool = True) -> Any:
"""
Use EmbeddedDAO.upsert() instead.
At least, until we are ok with more than one embedded instance per dashboard.
diff --git a/superset/errors.py b/superset/errors.py
index 5261848687..6f68f2466c 100644
--- a/superset/errors.py
+++ b/superset/errors.py
@@ -16,7 +16,7 @@
# under the License.
from dataclasses import dataclass
from enum import Enum
-from typing import Any, Dict, Optional
+from typing import Any, Optional
from flask_babel import lazy_gettext as _
@@ -204,7 +204,7 @@ class SupersetError:
message: str
error_type: SupersetErrorType
level: ErrorLevel
- extra: Optional[Dict[str, Any]] = None
+ extra: Optional[dict[str, Any]] = None
def __post_init__(self) -> None:
"""
@@ -227,7 +227,7 @@ class SupersetError:
}
)
- def to_dict(self) -> Dict[str, Any]:
+ def to_dict(self) -> dict[str, Any]:
rv = {"message": self.message, "error_type": self.error_type}
if self.extra:
rv["extra"] = self.extra # type: ignore
diff --git a/superset/examples/bart_lines.py b/superset/examples/bart_lines.py
index 5d167b02d0..e18f6e4632 100644
--- a/superset/examples/bart_lines.py
+++ b/superset/examples/bart_lines.py
@@ -55,7 +55,7 @@ def load_bart_lines(only_metadata: bool = False, force: bool = False) -> None:
index=False,
)
- print("Creating table {} reference".format(tbl_name))
+ print(f"Creating table {tbl_name} reference")
table = get_table_connector_registry()
tbl = db.session.query(table).filter_by(table_name=tbl_name).first()
if not tbl:
diff --git a/superset/examples/big_data.py b/superset/examples/big_data.py
index 8c0f2e267c..ed738d2c96 100644
--- a/superset/examples/big_data.py
+++ b/superset/examples/big_data.py
@@ -16,7 +16,6 @@
# under the License.
import random
import string
-from typing import List
import sqlalchemy.sql.sqltypes
@@ -36,7 +35,7 @@ COLUMN_TYPES = [
def load_big_data() -> None:
print("Creating table `wide_table` with 100 columns")
- columns: List[ColumnInfo] = []
+ columns: list[ColumnInfo] = []
for i in range(100):
column: ColumnInfo = {
"name": f"col{i}",
diff --git a/superset/examples/birth_names.py b/superset/examples/birth_names.py
index 8da041550e..45a3b39eb3 100644
--- a/superset/examples/birth_names.py
+++ b/superset/examples/birth_names.py
@@ -16,7 +16,7 @@
# under the License.
import json
import textwrap
-from typing import Dict, List, Tuple, Union
+from typing import Union
import pandas as pd
from sqlalchemy import DateTime, inspect, String
@@ -42,7 +42,7 @@ from .helpers import (
def gen_filter(
subject: str, comparator: str, operator: str = "=="
-) -> Dict[str, Union[bool, str]]:
+) -> dict[str, Union[bool, str]]:
return {
"clause": "WHERE",
"comparator": comparator,
@@ -152,7 +152,7 @@ def _add_table_metrics(datasource: SqlaTable) -> None:
datasource.metrics = metrics
-def create_slices(tbl: SqlaTable) -> Tuple[List[Slice], List[Slice]]:
+def create_slices(tbl: SqlaTable) -> tuple[list[Slice], list[Slice]]:
metrics = [
{
"expressionType": "SIMPLE",
@@ -529,7 +529,7 @@ def create_slices(tbl: SqlaTable) -> Tuple[List[Slice], List[Slice]]:
return slices, misc_slices
-def create_dashboard(slices: List[Slice]) -> Dashboard:
+def create_dashboard(slices: list[Slice]) -> Dashboard:
print("Creating a dashboard")
dash = db.session.query(Dashboard).filter_by(slug="births").first()
if not dash:
diff --git a/superset/examples/countries.py b/superset/examples/countries.py
index 8f1d5466ae..2ea12baae7 100644
--- a/superset/examples/countries.py
+++ b/superset/examples/countries.py
@@ -16,9 +16,9 @@
# under the License.
"""This module contains data related to countries and is used for geo mapping"""
# pylint: disable=too-many-lines
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
-countries: List[Dict[str, Any]] = [
+countries: list[dict[str, Any]] = [
{
"name": "Angola",
"area": 1246700,
@@ -2491,7 +2491,7 @@ countries: List[Dict[str, Any]] = [
},
]
-all_lookups: Dict[str, Dict[str, Dict[str, Any]]] = {}
+all_lookups: dict[str, dict[str, dict[str, Any]]] = {}
lookups = ["cioc", "cca2", "cca3", "name"]
for lookup in lookups:
all_lookups[lookup] = {}
@@ -2499,7 +2499,7 @@ for lookup in lookups:
all_lookups[lookup][country[lookup].lower()] = country
-def get(field: str, symbol: str) -> Optional[Dict[str, Any]]:
+def get(field: str, symbol: str) -> Optional[dict[str, Any]]:
"""
Get country data based on a standard code and a symbol
"""
diff --git a/superset/examples/helpers.py b/superset/examples/helpers.py
index e26e05e497..9f893f1ccc 100644
--- a/superset/examples/helpers.py
+++ b/superset/examples/helpers.py
@@ -17,7 +17,7 @@
"""Loads datasets, dashboards and slices in a new superset instance"""
import json
import os
-from typing import Any, Dict, List, Set
+from typing import Any
from superset import app, db
from superset.connectors.sqla.models import SqlaTable
@@ -25,7 +25,7 @@ from superset.models.slice import Slice
BASE_URL = "https://github.com/apache-superset/examples-data/blob/master/"
-misc_dash_slices: Set[str] = set() # slices assembled in a 'Misc Chart' dashboard
+misc_dash_slices: set[str] = set() # slices assembled in a 'Misc Chart' dashboard
def get_table_connector_registry() -> Any:
@@ -36,7 +36,7 @@ def get_examples_folder() -> str:
return os.path.join(app.config["BASE_DIR"], "examples")
-def update_slice_ids(pos: Dict[Any, Any]) -> List[Slice]:
+def update_slice_ids(pos: dict[Any, Any]) -> list[Slice]:
"""Update slice ids in position_json and return the slices found."""
slice_components = [
component
@@ -44,7 +44,7 @@ def update_slice_ids(pos: Dict[Any, Any]) -> List[Slice]:
if isinstance(component, dict) and component.get("type") == "CHART"
]
slices = {}
- for name in set(component["meta"]["sliceName"] for component in slice_components):
+ for name in {component["meta"]["sliceName"] for component in slice_components}:
slc = db.session.query(Slice).filter_by(slice_name=name).first()
if slc:
slices[name] = slc
@@ -64,7 +64,7 @@ def merge_slice(slc: Slice) -> None:
db.session.commit()
-def get_slice_json(defaults: Dict[Any, Any], **kwargs: Any) -> str:
+def get_slice_json(defaults: dict[Any, Any], **kwargs: Any) -> str:
defaults_copy = defaults.copy()
defaults_copy.update(kwargs)
return json.dumps(defaults_copy, indent=4, sort_keys=True)
diff --git a/superset/examples/multiformat_time_series.py b/superset/examples/multiformat_time_series.py
index de9630ef58..6bad2a7ac2 100644
--- a/superset/examples/multiformat_time_series.py
+++ b/superset/examples/multiformat_time_series.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Dict, Optional, Tuple
+from typing import Optional
import pandas as pd
from sqlalchemy import BigInteger, Date, DateTime, inspect, String
@@ -85,7 +85,7 @@ def load_multiformat_time_series( # pylint: disable=too-many-locals
obj.main_dttm_col = "ds"
obj.database = database
obj.filter_select_enabled = True
- dttm_and_expr_dict: Dict[str, Tuple[Optional[str], None]] = {
+ dttm_and_expr_dict: dict[str, tuple[Optional[str], None]] = {
"ds": (None, None),
"ds2": (None, None),
"epoch_s": ("epoch_s", None),
diff --git a/superset/examples/paris.py b/superset/examples/paris.py
index a54a3706b1..1180c428fe 100644
--- a/superset/examples/paris.py
+++ b/superset/examples/paris.py
@@ -52,7 +52,7 @@ def load_paris_iris_geojson(only_metadata: bool = False, force: bool = False) ->
index=False,
)
- print("Creating table {} reference".format(tbl_name))
+ print(f"Creating table {tbl_name} reference")
table = get_table_connector_registry()
tbl = db.session.query(table).filter_by(table_name=tbl_name).first()
if not tbl:
diff --git a/superset/examples/sf_population_polygons.py b/superset/examples/sf_population_polygons.py
index 6011b82b09..76c039afb8 100644
--- a/superset/examples/sf_population_polygons.py
+++ b/superset/examples/sf_population_polygons.py
@@ -54,7 +54,7 @@ def load_sf_population_polygons(
index=False,
)
- print("Creating table {} reference".format(tbl_name))
+ print(f"Creating table {tbl_name} reference")
table = get_table_connector_registry()
tbl = db.session.query(table).filter_by(table_name=tbl_name).first()
if not tbl:
diff --git a/superset/examples/supported_charts_dashboard.py b/superset/examples/supported_charts_dashboard.py
index e6d7557deb..bce50c854e 100644
--- a/superset/examples/supported_charts_dashboard.py
+++ b/superset/examples/supported_charts_dashboard.py
@@ -19,7 +19,6 @@
import json
import textwrap
-from typing import List
from sqlalchemy import inspect
@@ -40,7 +39,7 @@ from .helpers import (
DASH_SLUG = "supported_charts_dash"
-def create_slices(tbl: SqlaTable) -> List[Slice]:
+def create_slices(tbl: SqlaTable) -> list[Slice]:
slice_kwargs = {
"datasource_id": tbl.id,
"datasource_type": DatasourceType.TABLE,
diff --git a/superset/examples/utils.py b/superset/examples/utils.py
index 8c2cfea23c..52d58e3e4a 100644
--- a/superset/examples/utils.py
+++ b/superset/examples/utils.py
@@ -17,7 +17,7 @@
import logging
import re
from pathlib import Path
-from typing import Any, Dict
+from typing import Any
import yaml
from pkg_resources import resource_isdir, resource_listdir, resource_stream
@@ -42,13 +42,13 @@ def load_examples_from_configs(
command.run()
-def load_contents(load_test_data: bool = False) -> Dict[str, Any]:
+def load_contents(load_test_data: bool = False) -> dict[str, Any]:
"""Traverse configs directory and load contents"""
root = Path("examples/configs")
resource_names = resource_listdir("superset", str(root))
queue = [root / resource_name for resource_name in resource_names]
- contents: Dict[Path, str] = {}
+ contents: dict[Path, str] = {}
while queue:
path_name = queue.pop()
test_re = re.compile(r"\.test\.|metadata\.yaml$")
@@ -74,7 +74,7 @@ def load_configs_from_directory(
"""
Load all the examples from a given directory.
"""
- contents: Dict[str, str] = {}
+ contents: dict[str, str] = {}
queue = [root]
while queue:
path_name = queue.pop()
diff --git a/superset/examples/world_bank.py b/superset/examples/world_bank.py
index 2972188e02..9f9f6bb700 100644
--- a/superset/examples/world_bank.py
+++ b/superset/examples/world_bank.py
@@ -17,7 +17,6 @@
"""Loads datasets, dashboards and slices in a new superset instance"""
import json
import os
-from typing import List
import pandas as pd
from sqlalchemy import DateTime, inspect, String
@@ -139,7 +138,7 @@ def load_world_bank_health_n_pop( # pylint: disable=too-many-locals, too-many-s
db.session.commit()
-def create_slices(tbl: BaseDatasource) -> List[Slice]:
+def create_slices(tbl: BaseDatasource) -> list[Slice]:
metric = "sum__SP_POP_TOTL"
metrics = ["sum__SP_POP_TOTL"]
secondary_metric = {
diff --git a/superset/exceptions.py b/superset/exceptions.py
index 32b06203cd..018d1b6dfb 100644
--- a/superset/exceptions.py
+++ b/superset/exceptions.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from collections import defaultdict
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from flask_babel import gettext as _
from marshmallow import ValidationError
@@ -47,7 +47,7 @@ class SupersetException(Exception):
def error_type(self) -> Optional[SupersetErrorType]:
return self._error_type
- def to_dict(self) -> Dict[str, Any]:
+ def to_dict(self) -> dict[str, Any]:
rv = {}
if hasattr(self, "message"):
rv["message"] = self.message
@@ -67,7 +67,7 @@ class SupersetErrorException(SupersetException):
if status is not None:
self.status = status
- def to_dict(self) -> Dict[str, Any]:
+ def to_dict(self) -> dict[str, Any]:
return self.error.to_dict()
@@ -94,7 +94,7 @@ class SupersetErrorFromParamsException(SupersetErrorException):
error_type: SupersetErrorType,
message: str,
level: ErrorLevel,
- extra: Optional[Dict[str, Any]] = None,
+ extra: Optional[dict[str, Any]] = None,
) -> None:
super().__init__(
SupersetError(
@@ -107,7 +107,7 @@ class SupersetErrorsException(SupersetException):
"""Exceptions with multiple SupersetErrorType associated with them"""
def __init__(
- self, errors: List[SupersetError], status: Optional[int] = None
+ self, errors: list[SupersetError], status: Optional[int] = None
) -> None:
super().__init__(str(errors))
self.errors = errors
@@ -119,7 +119,7 @@ class SupersetSyntaxErrorException(SupersetErrorsException):
status = 422
error_type = SupersetErrorType.SYNTAX_ERROR
- def __init__(self, errors: List[SupersetError]) -> None:
+ def __init__(self, errors: list[SupersetError]) -> None:
super().__init__(errors)
@@ -134,7 +134,7 @@ class SupersetGenericDBErrorException(SupersetErrorFromParamsException):
self,
message: str,
level: ErrorLevel = ErrorLevel.ERROR,
- extra: Optional[Dict[str, Any]] = None,
+ extra: Optional[dict[str, Any]] = None,
) -> None:
super().__init__(
SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
@@ -152,7 +152,7 @@ class SupersetTemplateParamsErrorException(SupersetErrorFromParamsException):
message: str,
error: SupersetErrorType,
level: ErrorLevel = ErrorLevel.ERROR,
- extra: Optional[Dict[str, Any]] = None,
+ extra: Optional[dict[str, Any]] = None,
) -> None:
super().__init__(
error,
@@ -166,7 +166,7 @@ class SupersetSecurityException(SupersetErrorException):
status = 403
def __init__(
- self, error: SupersetError, payload: Optional[Dict[str, Any]] = None
+ self, error: SupersetError, payload: Optional[dict[str, Any]] = None
) -> None:
super().__init__(error)
self.payload = payload
diff --git a/superset/explore/commands/get.py b/superset/explore/commands/get.py
index fb690a9d75..490d198360 100644
--- a/superset/explore/commands/get.py
+++ b/superset/explore/commands/get.py
@@ -16,7 +16,7 @@
# under the License.
import logging
from abc import ABC
-from typing import Any, cast, Dict, Optional
+from typing import Any, cast, Optional
import simplejson as json
from flask import current_app, request
@@ -60,7 +60,7 @@ class GetExploreCommand(BaseCommand, ABC):
self._slice_id = params.slice_id
# pylint: disable=too-many-locals,too-many-branches,too-many-statements
- def run(self) -> Optional[Dict[str, Any]]:
+ def run(self) -> Optional[dict[str, Any]]:
initial_form_data = {}
if self._permalink_key is not None:
@@ -147,7 +147,7 @@ class GetExploreCommand(BaseCommand, ABC):
utils.merge_request_params(form_data, request.args)
# TODO: this is a dummy placeholder - should be refactored to being just `None`
- datasource_data: Dict[str, Any] = {
+ datasource_data: dict[str, Any] = {
"type": self._datasource_type,
"name": datasource_name,
"columns": [],
diff --git a/superset/explore/permalink/commands/create.py b/superset/explore/permalink/commands/create.py
index 90e64f6df7..97a8bcbf09 100644
--- a/superset/explore/permalink/commands/create.py
+++ b/superset/explore/permalink/commands/create.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict, Optional
+from typing import Any, Optional
from sqlalchemy.exc import SQLAlchemyError
@@ -31,7 +31,7 @@ logger = logging.getLogger(__name__)
class CreateExplorePermalinkCommand(BaseExplorePermalinkCommand):
- def __init__(self, state: Dict[str, Any]):
+ def __init__(self, state: dict[str, Any]):
self.chart_id: Optional[int] = state["formData"].get("slice_id")
self.datasource: str = state["formData"]["datasource"]
self.state = state
diff --git a/superset/explore/permalink/types.py b/superset/explore/permalink/types.py
index 393f0ed8d5..7eb4a7cb6b 100644
--- a/superset/explore/permalink/types.py
+++ b/superset/explore/permalink/types.py
@@ -14,12 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, List, Optional, Tuple, TypedDict
+from typing import Any, Optional, TypedDict
class ExplorePermalinkState(TypedDict, total=False):
- formData: Dict[str, Any]
- urlParams: Optional[List[Tuple[str, str]]]
+ formData: dict[str, Any]
+ urlParams: Optional[list[tuple[str, str]]]
class ExplorePermalinkValue(TypedDict):
diff --git a/superset/extensions/__init__.py b/superset/extensions/__init__.py
index f633385972..c2e84f700f 100644
--- a/superset/extensions/__init__.py
+++ b/superset/extensions/__init__.py
@@ -16,7 +16,7 @@
# under the License.
import json
import os
-from typing import Any, Callable, Dict, List, Optional
+from typing import Any, Callable, Optional
import celery
from cachelib.base import BaseCache
@@ -58,7 +58,7 @@ class ResultsBackendManager:
class UIManifestProcessor:
def __init__(self, app_dir: str) -> None:
self.app: Optional[Flask] = None
- self.manifest: Dict[str, Dict[str, List[str]]] = {}
+ self.manifest: dict[str, dict[str, list[str]]] = {}
self.manifest_file = f"{app_dir}/static/assets/manifest.json"
def init_app(self, app: Flask) -> None:
@@ -70,10 +70,10 @@ class UIManifestProcessor:
def register_processor(self, app: Flask) -> None:
app.template_context_processors[None].append(self.get_manifest)
- def get_manifest(self) -> Dict[str, Callable[[str], List[str]]]:
+ def get_manifest(self) -> dict[str, Callable[[str], list[str]]]:
loaded_chunks = set()
- def get_files(bundle: str, asset_type: str = "js") -> List[str]:
+ def get_files(bundle: str, asset_type: str = "js") -> list[str]:
files = self.get_manifest_files(bundle, asset_type)
filtered_files = [f for f in files if f not in loaded_chunks]
for f in filtered_files:
@@ -88,7 +88,7 @@ class UIManifestProcessor:
def parse_manifest_json(self) -> None:
try:
- with open(self.manifest_file, "r") as f:
+ with open(self.manifest_file) as f:
# the manifest includes non-entry files we only need entries in
# templates
full_manifest = json.load(f)
@@ -96,7 +96,7 @@ class UIManifestProcessor:
except Exception: # pylint: disable=broad-except
pass
- def get_manifest_files(self, bundle: str, asset_type: str) -> List[str]:
+ def get_manifest_files(self, bundle: str, asset_type: str) -> list[str]:
if self.app and self.app.debug:
self.parse_manifest_json()
return self.manifest.get(bundle, {}).get(asset_type, [])
@@ -117,7 +117,7 @@ cache_manager = CacheManager()
celery_app = celery.Celery()
csrf = CSRFProtect()
db = SQLA()
-_event_logger: Dict[str, Any] = {}
+_event_logger: dict[str, Any] = {}
encrypted_field_factory = EncryptedFieldFactory()
event_logger = LocalProxy(lambda: _event_logger.get("event_logger"))
feature_flag_manager = FeatureFlagManager()
diff --git a/superset/extensions/metastore_cache.py b/superset/extensions/metastore_cache.py
index f69276c908..6e928c0d5f 100644
--- a/superset/extensions/metastore_cache.py
+++ b/superset/extensions/metastore_cache.py
@@ -16,7 +16,7 @@
# under the License.
from datetime import datetime, timedelta
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from uuid import UUID, uuid3
from flask import Flask
@@ -37,7 +37,7 @@ class SupersetMetastoreCache(BaseCache):
@classmethod
def factory(
- cls, app: Flask, config: Dict[str, Any], args: List[Any], kwargs: Dict[str, Any]
+ cls, app: Flask, config: dict[str, Any], args: list[Any], kwargs: dict[str, Any]
) -> BaseCache:
seed = config.get("CACHE_KEY_PREFIX", "")
kwargs["namespace"] = get_uuid_namespace(seed)
diff --git a/superset/forms.py b/superset/forms.py
index c9b29dfcd0..f1e220ba95 100644
--- a/superset/forms.py
+++ b/superset/forms.py
@@ -16,7 +16,7 @@
# under the License.
"""Contains the logic to create cohesive forms on the explore view"""
import json
-from typing import Any, List, Optional
+from typing import Any, Optional
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
from wtforms import Field
@@ -24,12 +24,12 @@ from wtforms import Field
class JsonListField(Field):
widget = BS3TextFieldWidget()
- data: List[str] = []
+ data: list[str] = []
def _value(self) -> str:
return json.dumps(self.data)
- def process_formdata(self, valuelist: List[str]) -> None:
+ def process_formdata(self, valuelist: list[str]) -> None:
if valuelist and valuelist[0]:
self.data = json.loads(valuelist[0])
else:
@@ -38,7 +38,7 @@ class JsonListField(Field):
class CommaSeparatedListField(Field):
widget = BS3TextFieldWidget()
- data: List[str] = []
+ data: list[str] = []
def _value(self) -> str:
if self.data:
@@ -46,14 +46,14 @@ class CommaSeparatedListField(Field):
return ""
- def process_formdata(self, valuelist: List[str]) -> None:
+ def process_formdata(self, valuelist: list[str]) -> None:
if valuelist:
self.data = [x.strip() for x in valuelist[0].split(",")]
else:
self.data = []
-def filter_not_empty_values(values: Optional[List[Any]]) -> Optional[List[Any]]:
+def filter_not_empty_values(values: Optional[list[Any]]) -> Optional[list[Any]]:
"""Returns a list of non empty values or None"""
if not values:
return None
diff --git a/superset/initialization/__init__.py b/superset/initialization/__init__.py
index c489cc323c..bbe25f498b 100644
--- a/superset/initialization/__init__.py
+++ b/superset/initialization/__init__.py
@@ -19,7 +19,7 @@ from __future__ import annotations
import logging
import os
import sys
-from typing import Any, Callable, Dict, TYPE_CHECKING
+from typing import Any, Callable, TYPE_CHECKING
import wtforms_json
from deprecation import deprecated
@@ -68,7 +68,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
self.superset_app = app
self.config = app.config
- self.manifest: Dict[Any, Any] = {}
+ self.manifest: dict[Any, Any] = {}
@deprecated(details="use self.superset_app instead of self.flask_app") # type: ignore
@property
@@ -597,7 +597,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
self.app = app
def __call__(
- self, environ: Dict[str, Any], start_response: Callable[..., Any]
+ self, environ: dict[str, Any], start_response: Callable[..., Any]
) -> Any:
# Setting wsgi.input_terminated tells werkzeug.wsgi to ignore
# content-length and read the stream till the end.
diff --git a/superset/jinja_context.py b/superset/jinja_context.py
index 360c4fc1f4..f096b65cd1 100644
--- a/superset/jinja_context.py
+++ b/superset/jinja_context.py
@@ -18,17 +18,7 @@
import json
import re
from functools import lru_cache, partial
-from typing import (
- Any,
- Callable,
- cast,
- Dict,
- List,
- Optional,
- Tuple,
- TYPE_CHECKING,
- Union,
-)
+from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union
from flask import current_app, g, has_request_context, request
from flask_babel import gettext as _
@@ -71,14 +61,14 @@ COLLECTION_TYPES = ("list", "dict", "tuple", "set")
@lru_cache(maxsize=LRU_CACHE_MAX_SIZE)
-def context_addons() -> Dict[str, Any]:
+def context_addons() -> dict[str, Any]:
return current_app.config.get("JINJA_CONTEXT_ADDONS", {})
class Filter(TypedDict):
op: str # pylint: disable=C0103
col: str
- val: Union[None, Any, List[Any]]
+ val: Union[None, Any, list[Any]]
class ExtraCache:
@@ -100,9 +90,9 @@ class ExtraCache:
def __init__(
self,
- extra_cache_keys: Optional[List[Any]] = None,
- applied_filters: Optional[List[str]] = None,
- removed_filters: Optional[List[str]] = None,
+ extra_cache_keys: Optional[list[Any]] = None,
+ applied_filters: Optional[list[str]] = None,
+ removed_filters: Optional[list[str]] = None,
dialect: Optional[Dialect] = None,
):
self.extra_cache_keys = extra_cache_keys
@@ -206,7 +196,7 @@ class ExtraCache:
def filter_values(
self, column: str, default: Optional[str] = None, remove_filter: bool = False
- ) -> List[Any]:
+ ) -> list[Any]:
"""Gets a values for a particular filter as a list
This is useful if:
@@ -230,7 +220,7 @@ class ExtraCache:
only apply to the inner query
:return: returns a list of filter values
"""
- return_val: List[Any] = []
+ return_val: list[Any] = []
filters = self.get_filters(column, remove_filter)
for flt in filters:
val = flt.get("val")
@@ -245,7 +235,7 @@ class ExtraCache:
return return_val
- def get_filters(self, column: str, remove_filter: bool = False) -> List[Filter]:
+ def get_filters(self, column: str, remove_filter: bool = False) -> list[Filter]:
"""Get the filters applied to the given column. In addition
to returning values like the filter_values function
the get_filters function returns the operator specified in the explorer UI.
@@ -316,10 +306,10 @@ class ExtraCache:
convert_legacy_filters_into_adhoc(form_data)
merge_extra_filters(form_data)
- filters: List[Filter] = []
+ filters: list[Filter] = []
for flt in form_data.get("adhoc_filters", []):
- val: Union[Any, List[Any]] = flt.get("comparator")
+ val: Union[Any, list[Any]] = flt.get("comparator")
op: str = flt["operator"].upper() if flt.get("operator") else None
# fltOpName: str = flt.get("filterOptionName")
if (
@@ -370,7 +360,7 @@ def safe_proxy(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
return return_value
-def validate_context_types(context: Dict[str, Any]) -> Dict[str, Any]:
+def validate_context_types(context: dict[str, Any]) -> dict[str, Any]:
for key in context:
arg_type = type(context[key]).__name__
if arg_type not in ALLOWED_TYPES and key not in context_addons():
@@ -395,8 +385,8 @@ def validate_context_types(context: Dict[str, Any]) -> Dict[str, Any]:
def validate_template_context(
- engine: Optional[str], context: Dict[str, Any]
-) -> Dict[str, Any]:
+ engine: Optional[str], context: dict[str, Any]
+) -> dict[str, Any]:
if engine and engine in context:
# validate engine context separately to allow for engine-specific methods
engine_context = validate_context_types(context.pop(engine))
@@ -407,7 +397,7 @@ def validate_template_context(
return validate_context_types(context)
-def where_in(values: List[Any], mark: str = "'") -> str:
+def where_in(values: list[Any], mark: str = "'") -> str:
"""
Given a list of values, build a parenthesis list suitable for an IN expression.
@@ -439,9 +429,9 @@ class BaseTemplateProcessor:
database: "Database",
query: Optional["Query"] = None,
table: Optional["SqlaTable"] = None,
- extra_cache_keys: Optional[List[Any]] = None,
- removed_filters: Optional[List[str]] = None,
- applied_filters: Optional[List[str]] = None,
+ extra_cache_keys: Optional[list[Any]] = None,
+ removed_filters: Optional[list[str]] = None,
+ applied_filters: Optional[list[str]] = None,
**kwargs: Any,
) -> None:
self._database = database
@@ -454,7 +444,7 @@ class BaseTemplateProcessor:
self._extra_cache_keys = extra_cache_keys
self._applied_filters = applied_filters
self._removed_filters = removed_filters
- self._context: Dict[str, Any] = {}
+ self._context: dict[str, Any] = {}
self._env = SandboxedEnvironment(undefined=DebugUndefined)
self.set_context(**kwargs)
@@ -530,7 +520,7 @@ class PrestoTemplateProcessor(JinjaTemplateProcessor):
@staticmethod
def _schema_table(
table_name: str, schema: Optional[str]
- ) -> Tuple[str, Optional[str]]:
+ ) -> tuple[str, Optional[str]]:
if "." in table_name:
schema, table_name = table_name.split(".")
return table_name, schema
@@ -547,7 +537,7 @@ class PrestoTemplateProcessor(JinjaTemplateProcessor):
latest_partitions = self.latest_partitions(table_name)
return latest_partitions[0] if latest_partitions else None
- def latest_partitions(self, table_name: str) -> Optional[List[str]]:
+ def latest_partitions(self, table_name: str) -> Optional[list[str]]:
"""
Gets the array of all latest partitions
@@ -603,7 +593,7 @@ DEFAULT_PROCESSORS = {
@lru_cache(maxsize=LRU_CACHE_MAX_SIZE)
-def get_template_processors() -> Dict[str, Any]:
+def get_template_processors() -> dict[str, Any]:
processors = current_app.config.get("CUSTOM_TEMPLATE_PROCESSORS", {})
for engine, processor in DEFAULT_PROCESSORS.items():
# do not overwrite engine-specific CUSTOM_TEMPLATE_PROCESSORS
@@ -631,7 +621,7 @@ def get_template_processor(
def dataset_macro(
dataset_id: int,
include_metrics: bool = False,
- columns: Optional[List[str]] = None,
+ columns: Optional[list[str]] = None,
) -> str:
"""
Given a dataset ID, return the SQL that represents it.
diff --git a/superset/key_value/types.py b/superset/key_value/types.py
index fb9c31899f..b2a47336c3 100644
--- a/superset/key_value/types.py
+++ b/superset/key_value/types.py
@@ -21,7 +21,7 @@ import pickle
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
-from typing import Any, Optional, TypedDict
+from typing import Any, TypedDict
from uuid import UUID
from marshmallow import Schema, ValidationError
@@ -34,14 +34,14 @@ from superset.key_value.exceptions import (
@dataclass
class Key:
- id: Optional[int]
- uuid: Optional[UUID]
+ id: int | None
+ uuid: UUID | None
class KeyValueFilter(TypedDict, total=False):
resource: str
- id: Optional[int]
- uuid: Optional[UUID]
+ id: int | None
+ uuid: UUID | None
class KeyValueResource(str, Enum):
diff --git a/superset/key_value/utils.py b/superset/key_value/utils.py
index 2468618a81..6b487c278c 100644
--- a/superset/key_value/utils.py
+++ b/superset/key_value/utils.py
@@ -18,7 +18,7 @@ from __future__ import annotations
from hashlib import md5
from secrets import token_urlsafe
-from typing import Any, Union
+from typing import Any
from uuid import UUID, uuid3
import hashids
@@ -35,7 +35,7 @@ def random_key() -> str:
return token_urlsafe(48)
-def get_filter(resource: KeyValueResource, key: Union[int, UUID]) -> KeyValueFilter:
+def get_filter(resource: KeyValueResource, key: int | UUID) -> KeyValueFilter:
try:
filter_: KeyValueFilter = {"resource": resource.value}
if isinstance(key, UUID):
diff --git a/superset/legacy.py b/superset/legacy.py
index 168b9c0b60..03c1eff7dd 100644
--- a/superset/legacy.py
+++ b/superset/legacy.py
@@ -15,10 +15,10 @@
# specific language governing permissions and limitations
# under the License.
"""Code related with dealing with legacy / change management"""
-from typing import Any, Dict
+from typing import Any
-def update_time_range(form_data: Dict[str, Any]) -> None:
+def update_time_range(form_data: dict[str, Any]) -> None:
"""Move since and until to time_range."""
if "since" in form_data or "until" in form_data:
form_data["time_range"] = "{} : {}".format(
diff --git a/superset/migrations/env.py b/superset/migrations/env.py
index e3779bb65b..130fb367fb 100755
--- a/superset/migrations/env.py
+++ b/superset/migrations/env.py
@@ -17,7 +17,6 @@
import logging
import urllib.parse
from logging.config import fileConfig
-from typing import List
from alembic import context
from alembic.operations.ops import MigrationScript
@@ -85,7 +84,7 @@ def run_migrations_online() -> None:
# when there are no changes to the schema
# reference: https://alembic.sqlalchemy.org/en/latest/cookbook.html
def process_revision_directives( # pylint: disable=redefined-outer-name, unused-argument
- context: MigrationContext, revision: str, directives: List[MigrationScript]
+ context: MigrationContext, revision: str, directives: list[MigrationScript]
) -> None:
if getattr(config.cmd_opts, "autogenerate", False):
script = directives[0]
diff --git a/superset/migrations/shared/migrate_viz/base.py b/superset/migrations/shared/migrate_viz/base.py
index 5ea23551ea..d3b2efa7a0 100644
--- a/superset/migrations/shared/migrate_viz/base.py
+++ b/superset/migrations/shared/migrate_viz/base.py
@@ -18,7 +18,7 @@ from __future__ import annotations
import copy
import json
-from typing import Any, Dict, Set
+from typing import Any
from alembic import op
from sqlalchemy import and_, Column, Integer, String, Text
@@ -44,8 +44,8 @@ FORM_DATA_BAK_FIELD_NAME = "form_data_bak"
class MigrateViz:
- remove_keys: Set[str] = set()
- rename_keys: Dict[str, str] = {}
+ remove_keys: set[str] = set()
+ rename_keys: dict[str, str] = {}
source_viz_type: str
target_viz_type: str
has_x_axis_control: bool = False
@@ -85,7 +85,7 @@ class MigrateViz:
def _post_action(self) -> None:
"""Some actions after migrate"""
- def _migrate_temporal_filter(self, rv_data: Dict[str, Any]) -> None:
+ def _migrate_temporal_filter(self, rv_data: dict[str, Any]) -> None:
"""Adds a temporal filter."""
granularity_sqla = rv_data.pop("granularity_sqla", None)
time_range = rv_data.pop("time_range", None) or conf.get("DEFAULT_TIME_FILTER")
diff --git a/superset/migrations/shared/security_converge.py b/superset/migrations/shared/security_converge.py
index 19caa3932b..9b1730a2a1 100644
--- a/superset/migrations/shared/security_converge.py
+++ b/superset/migrations/shared/security_converge.py
@@ -16,7 +16,6 @@
# under the License.
import logging
from dataclasses import dataclass
-from typing import Dict, List, Tuple
from sqlalchemy import (
Column,
@@ -41,7 +40,7 @@ class Pvm:
permission: str
-PvmMigrationMapType = Dict[Pvm, Tuple[Pvm, ...]]
+PvmMigrationMapType = dict[Pvm, tuple[Pvm, ...]]
# Partial freeze of the current metadata db schema
@@ -162,8 +161,8 @@ def _find_pvm(session: Session, view_name: str, permission_name: str) -> Permiss
def add_pvms(
- session: Session, pvm_data: Dict[str, Tuple[str, ...]], commit: bool = False
-) -> List[PermissionView]:
+ session: Session, pvm_data: dict[str, tuple[str, ...]], commit: bool = False
+) -> list[PermissionView]:
"""
Checks if exists and adds new Permissions, Views and PermissionView's
"""
@@ -181,7 +180,7 @@ def add_pvms(
def _delete_old_permissions(
- session: Session, pvm_map: Dict[PermissionView, List[PermissionView]]
+ session: Session, pvm_map: dict[PermissionView, list[PermissionView]]
) -> None:
"""
Delete old permissions:
@@ -222,7 +221,7 @@ def migrate_roles(
Migrates all existing roles that have the permissions to be migrated
"""
# Collect a map of PermissionView objects for migration
- pvm_map: Dict[PermissionView, List[PermissionView]] = {}
+ pvm_map: dict[PermissionView, list[PermissionView]] = {}
for old_pvm_key, new_pvms_ in pvm_key_map.items():
old_pvm = _find_pvm(session, old_pvm_key.view, old_pvm_key.permission)
if old_pvm:
@@ -252,8 +251,8 @@ def migrate_roles(
session.commit()
-def get_reversed_new_pvms(pvm_map: PvmMigrationMapType) -> Dict[str, Tuple[str, ...]]:
- reversed_pvms: Dict[str, Tuple[str, ...]] = {}
+def get_reversed_new_pvms(pvm_map: PvmMigrationMapType) -> dict[str, tuple[str, ...]]:
+ reversed_pvms: dict[str, tuple[str, ...]] = {}
for old_pvm, new_pvms in pvm_map.items():
if old_pvm.view not in reversed_pvms:
reversed_pvms[old_pvm.view] = (old_pvm.permission,)
diff --git a/superset/migrations/shared/utils.py b/superset/migrations/shared/utils.py
index e05b1d357f..32e7dc1a39 100644
--- a/superset/migrations/shared/utils.py
+++ b/superset/migrations/shared/utils.py
@@ -18,7 +18,8 @@ import json
import logging
import os
import time
-from typing import Any, Callable, Dict, Iterator, Optional, Union
+from collections.abc import Iterator
+from typing import Any, Callable, Optional, Union
from uuid import uuid4
from alembic import op
@@ -127,7 +128,7 @@ def paginated_update(
print_page_progress(processed, total)
-def try_load_json(data: Optional[str]) -> Dict[str, Any]:
+def try_load_json(data: Optional[str]) -> dict[str, Any]:
try:
return data and json.loads(data) or {}
except json.decoder.JSONDecodeError:
diff --git a/superset/migrations/versions/2017-01-24_12-31_db0c65b146bd_update_slice_model_json.py b/superset/migrations/versions/2017-01-24_12-31_db0c65b146bd_update_slice_model_json.py
index 56d5f887b3..1f3dbab636 100644
--- a/superset/migrations/versions/2017-01-24_12-31_db0c65b146bd_update_slice_model_json.py
+++ b/superset/migrations/versions/2017-01-24_12-31_db0c65b146bd_update_slice_model_json.py
@@ -59,7 +59,7 @@ def upgrade():
slc.params = json.dumps(d, indent=2, sort_keys=True)
session.merge(slc)
session.commit()
- print("Upgraded ({}/{}): {}".format(i, slice_len, slc.slice_name))
+ print(f"Upgraded ({i}/{slice_len}): {slc.slice_name}")
except Exception as ex:
print(slc.slice_name + " error: " + str(ex))
diff --git a/superset/migrations/versions/2017-02-08_14-16_a99f2f7c195a_rewriting_url_from_shortner_with_new_.py b/superset/migrations/versions/2017-02-08_14-16_a99f2f7c195a_rewriting_url_from_shortner_with_new_.py
index 04a39a31f5..8e97ada3cd 100644
--- a/superset/migrations/versions/2017-02-08_14-16_a99f2f7c195a_rewriting_url_from_shortner_with_new_.py
+++ b/superset/migrations/versions/2017-02-08_14-16_a99f2f7c195a_rewriting_url_from_shortner_with_new_.py
@@ -82,7 +82,7 @@ def upgrade():
url.url = newurl
session.merge(url)
session.commit()
- print("Updating url ({}/{})".format(i, urls_len))
+ print(f"Updating url ({i}/{urls_len})")
session.close()
diff --git a/superset/migrations/versions/2017-10-03_14-37_4736ec66ce19_.py b/superset/migrations/versions/2017-10-03_14-37_4736ec66ce19_.py
index 26cfb93b99..f6d5610d97 100644
--- a/superset/migrations/versions/2017-10-03_14-37_4736ec66ce19_.py
+++ b/superset/migrations/versions/2017-10-03_14-37_4736ec66ce19_.py
@@ -69,7 +69,7 @@ def upgrade():
batch_op.add_column(sa.Column("datasource_id", sa.Integer))
batch_op.create_foreign_key(
- "fk_{}_datasource_id_datasources".format(foreign),
+ f"fk_{foreign}_datasource_id_datasources",
"datasources",
["datasource_id"],
["id"],
@@ -102,7 +102,7 @@ def upgrade():
for name in names:
batch_op.drop_constraint(
- name or "fk_{}_datasource_name_datasources".format(foreign),
+ name or f"fk_{foreign}_datasource_name_datasources",
type_="foreignkey",
)
@@ -148,7 +148,7 @@ def downgrade():
batch_op.add_column(sa.Column("datasource_name", sa.String(255)))
batch_op.create_foreign_key(
- "fk_{}_datasource_name_datasources".format(foreign),
+ f"fk_{foreign}_datasource_name_datasources",
"datasources",
["datasource_name"],
["datasource_name"],
@@ -174,7 +174,7 @@ def downgrade():
with op.batch_alter_table(foreign, naming_convention=conv) as batch_op:
# Drop the datasource_id column and associated constraint.
batch_op.drop_constraint(
- "fk_{}_datasource_id_datasources".format(foreign), type_="foreignkey"
+ f"fk_{foreign}_datasource_id_datasources", type_="foreignkey"
)
batch_op.drop_column("datasource_id")
@@ -201,7 +201,7 @@ def downgrade():
# Re-create the foreign key associated with the cluster_name column.
batch_op.create_foreign_key(
- "fk_{}_datasource_id_datasources".format(foreign),
+ f"fk_{foreign}_datasource_id_datasources",
"clusters",
["cluster_name"],
["cluster_name"],
diff --git a/superset/migrations/versions/2017-12-17_11-06_21e88bc06c02_annotation_migration.py b/superset/migrations/versions/2017-12-17_11-06_21e88bc06c02_annotation_migration.py
index 5593af0eb6..4b1b807a6f 100644
--- a/superset/migrations/versions/2017-12-17_11-06_21e88bc06c02_annotation_migration.py
+++ b/superset/migrations/versions/2017-12-17_11-06_21e88bc06c02_annotation_migration.py
@@ -59,7 +59,7 @@ def upgrade():
{
"annotationType": "INTERVAL",
"style": "solid",
- "name": "Layer {}".format(layer),
+ "name": f"Layer {layer}",
"show": True,
"overrides": {"since": None, "until": None},
"value": layer,
diff --git a/superset/migrations/versions/2018-02-13_08-07_e866bd2d4976_smaller_grid.py b/superset/migrations/versions/2018-02-13_08-07_e866bd2d4976_smaller_grid.py
index 286be8a5fc..bf6276d702 100644
--- a/superset/migrations/versions/2018-02-13_08-07_e866bd2d4976_smaller_grid.py
+++ b/superset/migrations/versions/2018-02-13_08-07_e866bd2d4976_smaller_grid.py
@@ -51,7 +51,7 @@ def upgrade():
dashboards = session.query(Dashboard).all()
for i, dashboard in enumerate(dashboards):
- print("Upgrading ({}/{}): {}".format(i, len(dashboards), dashboard.id))
+ print(f"Upgrading ({i}/{len(dashboards)}): {dashboard.id}")
positions = json.loads(dashboard.position_json or "{}")
for pos in positions:
if pos.get("v", 0) == 0:
@@ -74,7 +74,7 @@ def downgrade():
dashboards = session.query(Dashboard).all()
for i, dashboard in enumerate(dashboards):
- print("Downgrading ({}/{}): {}".format(i, len(dashboards), dashboard.id))
+ print(f"Downgrading ({i}/{len(dashboards)}): {dashboard.id}")
positions = json.loads(dashboard.position_json or "{}")
for pos in positions:
if pos.get("v", 0) == 1:
diff --git a/superset/migrations/versions/2018-03-20_19-47_f231d82b9b26_.py b/superset/migrations/versions/2018-03-20_19-47_f231d82b9b26_.py
index dbe3f0ace4..c73399fb92 100644
--- a/superset/migrations/versions/2018-03-20_19-47_f231d82b9b26_.py
+++ b/superset/migrations/versions/2018-03-20_19-47_f231d82b9b26_.py
@@ -49,7 +49,7 @@ def upgrade():
for table, column in names.items():
with op.batch_alter_table(table, naming_convention=conv) as batch_op:
batch_op.create_unique_constraint(
- "uq_{}_{}".format(table, column), [column, "datasource_id"]
+ f"uq_{table}_{column}", [column, "datasource_id"]
)
@@ -71,6 +71,6 @@ def downgrade():
with op.batch_alter_table(table, naming_convention=conv) as batch_op:
batch_op.drop_constraint(
generic_find_uq_constraint_name(table, {column, "datasource_id"}, insp)
- or "uq_{}_{}".format(table, column),
+ or f"uq_{table}_{column}",
type_="unique",
)
diff --git a/superset/migrations/versions/2018-04-10_11-19_bf706ae5eb46_cal_heatmap_metric_to_metrics.py b/superset/migrations/versions/2018-04-10_11-19_bf706ae5eb46_cal_heatmap_metric_to_metrics.py
index 3e2b81c17a..49b19b9c69 100644
--- a/superset/migrations/versions/2018-04-10_11-19_bf706ae5eb46_cal_heatmap_metric_to_metrics.py
+++ b/superset/migrations/versions/2018-04-10_11-19_bf706ae5eb46_cal_heatmap_metric_to_metrics.py
@@ -61,7 +61,7 @@ def upgrade():
slc.params = json.dumps(params, indent=2, sort_keys=True)
session.merge(slc)
session.commit()
- print("Upgraded ({}/{}): {}".format(i, slice_len, slc.slice_name))
+ print(f"Upgraded ({i}/{slice_len}): {slc.slice_name}")
except Exception as ex:
print(slc.slice_name + " error: " + str(ex))
diff --git a/superset/migrations/versions/2018-06-13_14-54_bddc498dd179_adhoc_filters.py b/superset/migrations/versions/2018-06-13_14-54_bddc498dd179_adhoc_filters.py
index ec03328271..6292e2860a 100644
--- a/superset/migrations/versions/2018-06-13_14-54_bddc498dd179_adhoc_filters.py
+++ b/superset/migrations/versions/2018-06-13_14-54_bddc498dd179_adhoc_filters.py
@@ -28,8 +28,6 @@ down_revision = "80a67c5192fa"
import json
-import uuid
-from collections import defaultdict
from alembic import op
from sqlalchemy import Column, Integer, Text
diff --git a/superset/migrations/versions/2018-07-05_15-19_3dda56f1c4c6_migrate_num_period_compare_and_period_.py b/superset/migrations/versions/2018-07-05_15-19_3dda56f1c4c6_migrate_num_period_compare_and_period_.py
index 2e491e9303..a2dd50bf9c 100644
--- a/superset/migrations/versions/2018-07-05_15-19_3dda56f1c4c6_migrate_num_period_compare_and_period_.py
+++ b/superset/migrations/versions/2018-07-05_15-19_3dda56f1c4c6_migrate_num_period_compare_and_period_.py
@@ -77,24 +77,24 @@ def isodate_duration_to_string(obj):
if obj.tdelta:
if not obj.months and not obj.years:
return format_seconds(obj.tdelta.total_seconds())
- raise Exception("Unable to convert: {0}".format(obj))
+ raise Exception(f"Unable to convert: {obj}")
if obj.months % 12 != 0:
months = obj.months + 12 * obj.years
- return "{0} months".format(months)
+ return f"{months} months"
- return "{0} years".format(obj.years + obj.months // 12)
+ return f"{obj.years + obj.months // 12} years"
def timedelta_to_string(obj):
if obj.microseconds:
- raise Exception("Unable to convert: {0}".format(obj))
+ raise Exception(f"Unable to convert: {obj}")
elif obj.seconds:
return format_seconds(obj.total_seconds())
elif obj.days % 7 == 0:
- return "{0} weeks".format(obj.days // 7)
+ return f"{obj.days // 7} weeks"
else:
- return "{0} days".format(obj.days)
+ return f"{obj.days} days"
def format_seconds(value):
@@ -106,7 +106,7 @@ def format_seconds(value):
else:
period = "second"
- return "{0} {1}{2}".format(value, period, "s" if value > 1 else "")
+ return "{} {}{}".format(value, period, "s" if value > 1 else "")
def compute_time_compare(granularity, periods):
@@ -120,11 +120,11 @@ def compute_time_compare(granularity, periods):
obj = isodate.parse_duration(granularity) * periods
except isodate.isoerror.ISO8601Error:
# if parse_human_timedelta can parse it, return it directly
- delta = "{0} {1}{2}".format(periods, granularity, "s" if periods > 1 else "")
+ delta = "{} {}{}".format(periods, granularity, "s" if periods > 1 else "")
obj = parse_human_timedelta(delta)
if obj:
return delta
- raise Exception("Unable to parse: {0}".format(granularity))
+ raise Exception(f"Unable to parse: {granularity}")
if isinstance(obj, isodate.duration.Duration):
return isodate_duration_to_string(obj)
diff --git a/superset/migrations/versions/2018-07-22_11-59_bebcf3fed1fe_convert_dashboard_v1_positions.py b/superset/migrations/versions/2018-07-22_11-59_bebcf3fed1fe_convert_dashboard_v1_positions.py
index 13c4e61718..620e2c5008 100644
--- a/superset/migrations/versions/2018-07-22_11-59_bebcf3fed1fe_convert_dashboard_v1_positions.py
+++ b/superset/migrations/versions/2018-07-22_11-59_bebcf3fed1fe_convert_dashboard_v1_positions.py
@@ -173,7 +173,7 @@ def get_header_component(title):
def get_row_container():
return {
"type": ROW_TYPE,
- "id": "DASHBOARD_ROW_TYPE-{}".format(generate_id()),
+ "id": f"DASHBOARD_ROW_TYPE-{generate_id()}",
"children": [],
"meta": {"background": BACKGROUND_TRANSPARENT},
}
@@ -182,7 +182,7 @@ def get_row_container():
def get_col_container():
return {
"type": COLUMN_TYPE,
- "id": "DASHBOARD_COLUMN_TYPE-{}".format(generate_id()),
+ "id": f"DASHBOARD_COLUMN_TYPE-{generate_id()}",
"children": [],
"meta": {"background": BACKGROUND_TRANSPARENT},
}
@@ -203,18 +203,18 @@ def get_chart_holder(position):
if len(code):
markdown_content = code
elif slice_name.strip():
- markdown_content = "##### {}".format(slice_name)
+ markdown_content = f"##### {slice_name}"
return {
"type": MARKDOWN_TYPE,
- "id": "DASHBOARD_MARKDOWN_TYPE-{}".format(generate_id()),
+ "id": f"DASHBOARD_MARKDOWN_TYPE-{generate_id()}",
"children": [],
"meta": {"width": width, "height": height, "code": markdown_content},
}
return {
"type": CHART_TYPE,
- "id": "DASHBOARD_CHART_TYPE-{}".format(generate_id()),
+ "id": f"DASHBOARD_CHART_TYPE-{generate_id()}",
"children": [],
"meta": {"width": width, "height": height, "chartId": int(slice_id)},
}
@@ -584,10 +584,10 @@ def upgrade():
dashboards = session.query(Dashboard).all()
for i, dashboard in enumerate(dashboards):
- print("scanning dashboard ({}/{}) >>>>".format(i + 1, len(dashboards)))
+ print(f"scanning dashboard ({i + 1}/{len(dashboards)}) >>>>")
position_json = json.loads(dashboard.position_json or "[]")
if not is_v2_dash(position_json):
- print("Converting dashboard... dash_id: {}".format(dashboard.id))
+ print(f"Converting dashboard... dash_id: {dashboard.id}")
position_dict = {}
positions = []
slices = dashboard.slices
@@ -650,7 +650,7 @@ def upgrade():
session.merge(dashboard)
session.commit()
else:
- print("Skip converted dash_id: {}".format(dashboard.id))
+ print(f"Skip converted dash_id: {dashboard.id}")
session.close()
diff --git a/superset/migrations/versions/2018-11-12_13-31_4ce8df208545_migrate_time_range_for_default_filters.py b/superset/migrations/versions/2018-11-12_13-31_4ce8df208545_migrate_time_range_for_default_filters.py
index bfb7a66161..3c6979f961 100644
--- a/superset/migrations/versions/2018-11-12_13-31_4ce8df208545_migrate_time_range_for_default_filters.py
+++ b/superset/migrations/versions/2018-11-12_13-31_4ce8df208545_migrate_time_range_for_default_filters.py
@@ -51,7 +51,7 @@ def upgrade():
dashboards = session.query(Dashboard).all()
for i, dashboard in enumerate(dashboards):
- print("scanning dashboard ({}/{}) >>>>".format(i + 1, len(dashboards)))
+ print(f"scanning dashboard ({i + 1}/{len(dashboards)}) >>>>")
if dashboard.json_metadata:
json_metadata = json.loads(dashboard.json_metadata)
has_update = False
@@ -74,7 +74,7 @@ def upgrade():
# if user already defined __time_range,
# just abandon __from and __to
if "__time_range" not in val:
- val["__time_range"] = "{} : {}".format(__from, __to)
+ val["__time_range"] = f"{__from} : {__to}"
json_metadata["default_filters"] = json.dumps(filters)
has_update = True
except Exception:
diff --git a/superset/migrations/versions/2019-11-06_15-23_78ee127d0d1d_reconvert_legacy_filters_into_adhoc.py b/superset/migrations/versions/2019-11-06_15-23_78ee127d0d1d_reconvert_legacy_filters_into_adhoc.py
index 1d0690c5e0..073bfdc474 100644
--- a/superset/migrations/versions/2019-11-06_15-23_78ee127d0d1d_reconvert_legacy_filters_into_adhoc.py
+++ b/superset/migrations/versions/2019-11-06_15-23_78ee127d0d1d_reconvert_legacy_filters_into_adhoc.py
@@ -29,8 +29,6 @@ down_revision = "c2acd2cf3df2"
import copy
import json
import logging
-import uuid
-from collections import defaultdict
from alembic import op
from sqlalchemy import Column, Integer, Text
diff --git a/superset/migrations/versions/2020-03-25_10-49_b5998378c225_add_certificate_to_dbs.py b/superset/migrations/versions/2020-03-25_10-49_b5998378c225_add_certificate_to_dbs.py
index 404ea96e44..3b7c3951cd 100644
--- a/superset/migrations/versions/2020-03-25_10-49_b5998378c225_add_certificate_to_dbs.py
+++ b/superset/migrations/versions/2020-03-25_10-49_b5998378c225_add_certificate_to_dbs.py
@@ -26,14 +26,13 @@ Create Date: 2020-03-25 10:49:10.883065
revision = "b5998378c225"
down_revision = "72428d1ea401"
-from typing import Dict
import sqlalchemy as sa
from alembic import op
def upgrade():
- kwargs: Dict[str, str] = {}
+ kwargs: dict[str, str] = {}
bind = op.get_bind()
op.add_column(
"dbs",
diff --git a/superset/migrations/versions/2020-08-12_00-24_978245563a02_migrate_iframe_to_dash_markdown.py b/superset/migrations/versions/2020-08-12_00-24_978245563a02_migrate_iframe_to_dash_markdown.py
index 6b63c468ec..4202de4560 100644
--- a/superset/migrations/versions/2020-08-12_00-24_978245563a02_migrate_iframe_to_dash_markdown.py
+++ b/superset/migrations/versions/2020-08-12_00-24_978245563a02_migrate_iframe_to_dash_markdown.py
@@ -21,7 +21,6 @@ Revises: f2672aa8350a
Create Date: 2020-08-12 00:24:39.617899
"""
-import collections
import json
import logging
import uuid
@@ -77,7 +76,7 @@ class Dashboard(Base):
def create_new_markdown_component(chart_position, url):
return {
"type": "MARKDOWN",
- "id": "MARKDOWN-{}".format(uuid.uuid4().hex[:8]),
+ "id": f"MARKDOWN-{uuid.uuid4().hex[:8]}",
"children": [],
"parents": chart_position["parents"],
"meta": {
diff --git a/superset/migrations/versions/2020-09-24_12-04_3fbbc6e8d654_fix_data_access_permissions_for_virtual_.py b/superset/migrations/versions/2020-09-24_12-04_3fbbc6e8d654_fix_data_access_permissions_for_virtual_.py
index a6db4c2cb6..45f091c38e 100644
--- a/superset/migrations/versions/2020-09-24_12-04_3fbbc6e8d654_fix_data_access_permissions_for_virtual_.py
+++ b/superset/migrations/versions/2020-09-24_12-04_3fbbc6e8d654_fix_data_access_permissions_for_virtual_.py
@@ -167,7 +167,7 @@ def upgrade():
orphaned_faulty_view_menus = []
for faulty_view_menu in faulty_view_menus:
# Get the dataset id from the view_menu name
- match_ds_id = re.match("\[None\]\.\[.*\]\(id:(\d+)\)", faulty_view_menu.name)
+ match_ds_id = re.match(r"\[None\]\.\[.*\]\(id:(\d+)\)", faulty_view_menu.name)
if match_ds_id:
dataset_id = int(match_ds_id.group(1))
dataset = session.query(SqlaTable).get(dataset_id)
diff --git a/superset/migrations/versions/2021-04-12_12-38_fc3a3a8ff221_migrate_filter_sets_to_new_format.py b/superset/migrations/versions/2021-04-12_12-38_fc3a3a8ff221_migrate_filter_sets_to_new_format.py
index 79b032894f..64396b6abe 100644
--- a/superset/migrations/versions/2021-04-12_12-38_fc3a3a8ff221_migrate_filter_sets_to_new_format.py
+++ b/superset/migrations/versions/2021-04-12_12-38_fc3a3a8ff221_migrate_filter_sets_to_new_format.py
@@ -27,7 +27,8 @@ revision = "fc3a3a8ff221"
down_revision = "085f06488938"
import json
-from typing import Any, Dict, Iterable
+from collections.abc import Iterable
+from typing import Any
from alembic import op
from sqlalchemy import Column, Integer, Text
@@ -77,7 +78,7 @@ EXTRA_FORM_DATA_OVERRIDE_KEYS = (
)
-def upgrade_select_filters(native_filters: Iterable[Dict[str, Any]]) -> None:
+def upgrade_select_filters(native_filters: Iterable[dict[str, Any]]) -> None:
"""
Add `defaultToFirstItem` to `controlValues` of `select_filter` components
"""
@@ -89,7 +90,7 @@ def upgrade_select_filters(native_filters: Iterable[Dict[str, Any]]) -> None:
control_values["defaultToFirstItem"] = value
-def upgrade_filter_set(filter_set: Dict[str, Any]) -> int:
+def upgrade_filter_set(filter_set: dict[str, Any]) -> int:
changed_filters = 0
upgrade_select_filters(filter_set.get("nativeFilters", {}).values())
data_mask = filter_set.get("dataMask", {})
@@ -124,7 +125,7 @@ def upgrade_filter_set(filter_set: Dict[str, Any]) -> int:
return changed_filters
-def downgrade_filter_set(filter_set: Dict[str, Any]) -> int:
+def downgrade_filter_set(filter_set: dict[str, Any]) -> int:
changed_filters = 0
old_data_mask = filter_set.pop("dataMask", {})
native_filters = {}
diff --git a/superset/migrations/versions/2021-04-29_15-32_f1410ed7ec95_migrate_native_filters_to_new_schema.py b/superset/migrations/versions/2021-04-29_15-32_f1410ed7ec95_migrate_native_filters_to_new_schema.py
index ec8f8e1cc0..42368ce896 100644
--- a/superset/migrations/versions/2021-04-29_15-32_f1410ed7ec95_migrate_native_filters_to_new_schema.py
+++ b/superset/migrations/versions/2021-04-29_15-32_f1410ed7ec95_migrate_native_filters_to_new_schema.py
@@ -27,7 +27,8 @@ revision = "f1410ed7ec95"
down_revision = "d416d0d715cc"
import json
-from typing import Any, Dict, Iterable, Tuple
+from collections.abc import Iterable
+from typing import Any
from alembic import op
from sqlalchemy import Column, Integer, Text
@@ -46,7 +47,7 @@ class Dashboard(Base):
json_metadata = Column(Text)
-def upgrade_filters(native_filters: Iterable[Dict[str, Any]]) -> int:
+def upgrade_filters(native_filters: Iterable[dict[str, Any]]) -> int:
"""
Move `defaultValue` into `defaultDataMask.filterState`
"""
@@ -61,7 +62,7 @@ def upgrade_filters(native_filters: Iterable[Dict[str, Any]]) -> int:
return changed_filters
-def downgrade_filters(native_filters: Iterable[Dict[str, Any]]) -> int:
+def downgrade_filters(native_filters: Iterable[dict[str, Any]]) -> int:
"""
Move `defaultDataMask.filterState` into `defaultValue`
"""
@@ -76,7 +77,7 @@ def downgrade_filters(native_filters: Iterable[Dict[str, Any]]) -> int:
return changed_filters
-def upgrade_dashboard(dashboard: Dict[str, Any]) -> Tuple[int, int]:
+def upgrade_dashboard(dashboard: dict[str, Any]) -> tuple[int, int]:
changed_filters, changed_filter_sets = 0, 0
# upgrade native select filter metadata
# upgrade native select filter metadata
@@ -119,7 +120,7 @@ def upgrade():
print(f"Upgraded {changed_filters} filters and {changed_filter_sets} filter sets.")
-def downgrade_dashboard(dashboard: Dict[str, Any]) -> Tuple[int, int]:
+def downgrade_dashboard(dashboard: dict[str, Any]) -> tuple[int, int]:
changed_filters, changed_filter_sets = 0, 0
# upgrade native select filter metadata
if native_filters := dashboard.get("native_filter_configuration"):
diff --git a/superset/migrations/versions/2021-08-03_15-36_143b6f2815da_migrate_pivot_table_v2_heatmaps_to_new_.py b/superset/migrations/versions/2021-08-03_15-36_143b6f2815da_migrate_pivot_table_v2_heatmaps_to_new_.py
index 888925a888..8be11d3cf6 100644
--- a/superset/migrations/versions/2021-08-03_15-36_143b6f2815da_migrate_pivot_table_v2_heatmaps_to_new_.py
+++ b/superset/migrations/versions/2021-08-03_15-36_143b6f2815da_migrate_pivot_table_v2_heatmaps_to_new_.py
@@ -27,7 +27,6 @@ revision = "143b6f2815da"
down_revision = "e323605f370a"
import json
-from typing import Any, Dict, List, Tuple
from alembic import op
from sqlalchemy import and_, Column, Integer, String, Text
diff --git a/superset/migrations/versions/2021-09-27_11-31_60dc453f4e2e_migrate_timeseries_limit_metric_to_.py b/superset/migrations/versions/2021-09-27_11-31_60dc453f4e2e_migrate_timeseries_limit_metric_to_.py
index e44bdae782..ab852c324b 100644
--- a/superset/migrations/versions/2021-09-27_11-31_60dc453f4e2e_migrate_timeseries_limit_metric_to_.py
+++ b/superset/migrations/versions/2021-09-27_11-31_60dc453f4e2e_migrate_timeseries_limit_metric_to_.py
@@ -27,7 +27,6 @@ revision = "60dc453f4e2e"
down_revision = "3ebe0993c770"
import json
-import re
from alembic import op
from sqlalchemy import and_, Column, Integer, String, Text
diff --git a/superset/migrations/versions/2021-10-12_11-15_32646df09c64_update_time_grain_sqla.py b/superset/migrations/versions/2021-10-12_11-15_32646df09c64_update_time_grain_sqla.py
index db1b87e546..b85e9397e6 100644
--- a/superset/migrations/versions/2021-10-12_11-15_32646df09c64_update_time_grain_sqla.py
+++ b/superset/migrations/versions/2021-10-12_11-15_32646df09c64_update_time_grain_sqla.py
@@ -27,7 +27,6 @@ revision = "32646df09c64"
down_revision = "60dc453f4e2e"
import json
-from typing import Dict
from alembic import op
from sqlalchemy import Column, Integer, Text
@@ -45,7 +44,7 @@ class Slice(Base):
params = Column(Text)
-def migrate(mapping: Dict[str, str]) -> None:
+def migrate(mapping: dict[str, str]) -> None:
bind = op.get_bind()
session = db.Session(bind=bind)
diff --git a/superset/migrations/versions/2022-04-01_14-38_a9422eeaae74_new_dataset_models_take_2.py b/superset/migrations/versions/2022-04-01_14-38_a9422eeaae74_new_dataset_models_take_2.py
index 286a0731fc..b51f6c78ac 100644
--- a/superset/migrations/versions/2022-04-01_14-38_a9422eeaae74_new_dataset_models_take_2.py
+++ b/superset/migrations/versions/2022-04-01_14-38_a9422eeaae74_new_dataset_models_take_2.py
@@ -29,7 +29,7 @@ down_revision = "ad07e4fdbaba"
import json
import os
from datetime import datetime
-from typing import List, Optional, Set, Type, Union
+from typing import Optional, Union
from uuid import uuid4
import sqlalchemy as sa
@@ -86,7 +86,7 @@ class AuxiliaryColumnsMixin(UUIDMixin):
def insert_from_select(
- target: Union[str, sa.Table, Type[Base]], source: sa.sql.expression.Select
+ target: Union[str, sa.Table, type[Base]], source: sa.sql.expression.Select
) -> None:
"""
Execute INSERT FROM SELECT to copy data from a SELECT query to the target table.
@@ -274,8 +274,8 @@ def find_tables(
session: Session,
database_id: int,
default_schema: Optional[str],
- tables: Set[Table],
-) -> List[int]:
+ tables: set[Table],
+) -> list[int]:
"""
Look for NewTable's of from a specific database
"""
diff --git a/superset/models/annotations.py b/superset/models/annotations.py
index 3185460bf5..54de94e7f6 100644
--- a/superset/models/annotations.py
+++ b/superset/models/annotations.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""a collection of Annotation-related models"""
-from typing import Any, Dict
+from typing import Any
from flask_appbuilder import Model
from sqlalchemy import Column, DateTime, ForeignKey, Index, Integer, String, Text
@@ -54,7 +54,7 @@ class Annotation(Model, AuditMixinNullable):
__table_args__ = (Index("ti_dag_state", layer_id, start_dttm, end_dttm),)
@property
- def data(self) -> Dict[str, Any]:
+ def data(self) -> dict[str, Any]:
return {
"layer_id": self.layer_id,
"start_dttm": self.start_dttm,
diff --git a/superset/models/core.py b/superset/models/core.py
index ee50f06345..3c2b12d378 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -16,6 +16,7 @@
# under the License.
# pylint: disable=line-too-long
"""A collection of ORM sqlalchemy models for Superset"""
+import builtins
import enum
import json
import logging
@@ -25,7 +26,7 @@ from contextlib import closing, contextmanager, nullcontext
from copy import deepcopy
from datetime import datetime
from functools import lru_cache
-from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TYPE_CHECKING
+from typing import Any, Callable, Optional, TYPE_CHECKING
import numpy
import pandas as pd
@@ -194,7 +195,7 @@ class Database(
return self.db_engine_spec.allows_subqueries
@property
- def function_names(self) -> List[str]:
+ def function_names(self) -> list[str]:
try:
return self.db_engine_spec.get_function_names(self)
except Exception as ex: # pylint: disable=broad-except
@@ -234,7 +235,7 @@ class Database(
return True
@property
- def data(self) -> Dict[str, Any]:
+ def data(self) -> dict[str, Any]:
return {
"id": self.id,
"name": self.database_name,
@@ -271,7 +272,7 @@ class Database(
return self.db_engine_spec.mask_encrypted_extra(self.encrypted_extra)
@property
- def parameters(self) -> Dict[str, Any]:
+ def parameters(self) -> dict[str, Any]:
# Database parameters are a dictionary of values that are used to make up
# the sqlalchemy_uri
# When returning the parameters we should use the masked SQLAlchemy URI and the
@@ -296,7 +297,7 @@ class Database(
return parameters
@property
- def parameters_schema(self) -> Dict[str, Any]:
+ def parameters_schema(self) -> dict[str, Any]:
try:
parameters_schema = self.db_engine_spec.parameters_json_schema() # type: ignore
except Exception: # pylint: disable=broad-except
@@ -304,7 +305,7 @@ class Database(
return parameters_schema
@property
- def metadata_cache_timeout(self) -> Dict[str, Any]:
+ def metadata_cache_timeout(self) -> dict[str, Any]:
return self.get_extra().get("metadata_cache_timeout", {})
@property
@@ -324,15 +325,15 @@ class Database(
return self.metadata_cache_timeout.get("table_cache_timeout")
@property
- def default_schemas(self) -> List[str]:
+ def default_schemas(self) -> list[str]:
return self.get_extra().get("default_schemas", [])
@property
- def connect_args(self) -> Dict[str, Any]:
+ def connect_args(self) -> dict[str, Any]:
return self.get_extra().get("engine_params", {}).get("connect_args", {})
@property
- def engine_information(self) -> Dict[str, Any]:
+ def engine_information(self) -> dict[str, Any]:
try:
engine_information = self.db_engine_spec.get_public_information()
except Exception: # pylint: disable=broad-except
@@ -540,7 +541,7 @@ class Database(
"""Add quotes to potential identifiter expressions if needed"""
return self.get_dialect().identifier_preparer.quote
- def get_reserved_words(self) -> Set[str]:
+ def get_reserved_words(self) -> set[str]:
return self.get_dialect().preparer.reserved_words
def get_df( # pylint: disable=too-many-locals
@@ -629,7 +630,7 @@ class Database(
show_cols: bool = False,
indent: bool = True,
latest_partition: bool = False,
- cols: Optional[List[Dict[str, Any]]] = None,
+ cols: Optional[list[dict[str, Any]]] = None,
) -> str:
"""Generates a ``select *`` statement in the proper dialect"""
eng = self._get_sqla_engine(schema=schema, source=utils.QuerySource.SQL_LAB)
@@ -670,7 +671,7 @@ class Database(
cache: bool = False,
cache_timeout: Optional[int] = None,
force: bool = False,
- ) -> Set[Tuple[str, str]]:
+ ) -> set[tuple[str, str]]:
"""Parameters need to be passed as keyword arguments.
For unused parameters, they are referenced in
@@ -706,7 +707,7 @@ class Database(
cache: bool = False,
cache_timeout: Optional[int] = None,
force: bool = False,
- ) -> Set[Tuple[str, str]]:
+ ) -> set[tuple[str, str]]:
"""Parameters need to be passed as keyword arguments.
For unused parameters, they are referenced in
@@ -750,7 +751,7 @@ class Database(
cache_timeout: Optional[int] = None,
force: bool = False,
ssh_tunnel: Optional["SSHTunnel"] = None,
- ) -> List[str]:
+ ) -> list[str]:
"""Parameters need to be passed as keyword arguments.
For unused parameters, they are referenced in
@@ -768,13 +769,15 @@ class Database(
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
@property
- def db_engine_spec(self) -> Type[db_engine_specs.BaseEngineSpec]:
+ def db_engine_spec(self) -> builtins.type[db_engine_specs.BaseEngineSpec]:
url = make_url_safe(self.sqlalchemy_uri_decrypted)
return self.get_db_engine_spec(url)
@classmethod
@lru_cache(maxsize=LRU_CACHE_MAX_SIZE)
- def get_db_engine_spec(cls, url: URL) -> Type[db_engine_specs.BaseEngineSpec]:
+ def get_db_engine_spec(
+ cls, url: URL
+ ) -> builtins.type[db_engine_specs.BaseEngineSpec]:
backend = url.get_backend_name()
try:
driver = url.get_driver_name()
@@ -784,7 +787,7 @@ class Database(
return db_engine_specs.get_engine_spec(backend, driver)
- def grains(self) -> Tuple[TimeGrain, ...]:
+ def grains(self) -> tuple[TimeGrain, ...]:
"""Defines time granularity database-specific expressions.
The idea here is to make it easy for users to change the time grain
@@ -795,10 +798,10 @@ class Database(
"""
return self.db_engine_spec.get_time_grains()
- def get_extra(self) -> Dict[str, Any]:
+ def get_extra(self) -> dict[str, Any]:
return self.db_engine_spec.get_extra_params(self)
- def get_encrypted_extra(self) -> Dict[str, Any]:
+ def get_encrypted_extra(self) -> dict[str, Any]:
encrypted_extra = {}
if self.encrypted_extra:
try:
@@ -809,7 +812,7 @@ class Database(
return encrypted_extra
# pylint: disable=invalid-name
- def update_params_from_encrypted_extra(self, params: Dict[str, Any]) -> None:
+ def update_params_from_encrypted_extra(self, params: dict[str, Any]) -> None:
self.db_engine_spec.update_params_from_encrypted_extra(self, params)
def get_table(self, table_name: str, schema: Optional[str] = None) -> Table:
@@ -832,7 +835,7 @@ class Database(
def get_columns(
self, table_name: str, schema: Optional[str] = None
- ) -> List[Dict[str, Any]]:
+ ) -> list[dict[str, Any]]:
with self.get_inspector_with_context() as inspector:
return self.db_engine_spec.get_columns(inspector, table_name, schema)
@@ -840,19 +843,19 @@ class Database(
self,
table_name: str,
schema: Optional[str] = None,
- ) -> List[MetricType]:
+ ) -> list[MetricType]:
with self.get_inspector_with_context() as inspector:
return self.db_engine_spec.get_metrics(self, inspector, table_name, schema)
def get_indexes(
self, table_name: str, schema: Optional[str] = None
- ) -> List[Dict[str, Any]]:
+ ) -> list[dict[str, Any]]:
with self.get_inspector_with_context() as inspector:
return self.db_engine_spec.get_indexes(self, inspector, table_name, schema)
def get_pk_constraint(
self, table_name: str, schema: Optional[str] = None
- ) -> Dict[str, Any]:
+ ) -> dict[str, Any]:
with self.get_inspector_with_context() as inspector:
pk_constraint = inspector.get_pk_constraint(table_name, schema) or {}
@@ -866,13 +869,13 @@ class Database(
def get_foreign_keys(
self, table_name: str, schema: Optional[str] = None
- ) -> List[Dict[str, Any]]:
+ ) -> list[dict[str, Any]]:
with self.get_inspector_with_context() as inspector:
return inspector.get_foreign_keys(table_name, schema)
def get_schema_access_for_file_upload( # pylint: disable=invalid-name
self,
- ) -> List[str]:
+ ) -> list[str]:
allowed_databases = self.get_extra().get("schemas_allowed_for_file_upload", [])
if isinstance(allowed_databases, str):
@@ -932,7 +935,7 @@ class Database(
view_name: str,
schema: Optional[str] = None,
) -> bool:
- view_names: List[str] = []
+ view_names: list[str] = []
try:
view_names = dialect.get_view_names(connection=conn, schema=schema)
except Exception: # pylint: disable=broad-except
diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py
index 9afd74f5e3..f3b9c08794 100644
--- a/superset/models/dashboard.py
+++ b/superset/models/dashboard.py
@@ -21,7 +21,7 @@ import logging
import uuid
from collections import defaultdict
from functools import partial
-from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
+from typing import Any, Callable
import sqlalchemy as sqla
from flask import current_app
@@ -68,9 +68,7 @@ config = app.config
logger = logging.getLogger(__name__)
-def copy_dashboard(
- _mapper: Mapper, connection: Connection, target: "Dashboard"
-) -> None:
+def copy_dashboard(_mapper: Mapper, connection: Connection, target: Dashboard) -> None:
dashboard_id = config["DASHBOARD_TEMPLATE_ID"]
if dashboard_id is None:
return
@@ -146,7 +144,7 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin):
certification_details = Column(Text)
json_metadata = Column(Text)
slug = Column(String(255), unique=True)
- slices: List[Slice] = relationship(
+ slices: list[Slice] = relationship(
Slice, secondary=dashboard_slices, backref="dashboards"
)
owners = relationship(security_manager.user_model, secondary=dashboard_user)
@@ -187,14 +185,14 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin):
return f"/superset/dashboard/{self.slug or self.id}/"
@staticmethod
- def get_url(id_: int, slug: Optional[str] = None) -> str:
+ def get_url(id_: int, slug: str | None = None) -> str:
# To be able to generate URL's without instanciating a Dashboard object
return f"/superset/dashboard/{slug or id_}/"
@property
- def datasources(self) -> Set[BaseDatasource]:
+ def datasources(self) -> set[BaseDatasource]:
# Verbose but efficient database enumeration of dashboard datasources.
- datasources_by_cls_model: Dict[Type["BaseDatasource"], Set[int]] = defaultdict(
+ datasources_by_cls_model: dict[type[BaseDatasource], set[int]] = defaultdict(
set
)
@@ -210,14 +208,14 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin):
}
@property
- def filter_sets(self) -> Dict[int, FilterSet]:
+ def filter_sets(self) -> dict[int, FilterSet]:
return {fs.id: fs for fs in self._filter_sets}
@property
- def filter_sets_lst(self) -> Dict[int, FilterSet]:
+ def filter_sets_lst(self) -> dict[int, FilterSet]:
if security_manager.is_admin():
return self._filter_sets
- filter_sets_by_owner_type: Dict[str, List[Any]] = {"Dashboard": [], "User": []}
+ filter_sets_by_owner_type: dict[str, list[Any]] = {"Dashboard": [], "User": []}
for fs in self._filter_sets:
filter_sets_by_owner_type[fs.owner_type].append(fs)
user_filter_sets = list(
@@ -232,7 +230,7 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin):
}
@property
- def charts(self) -> List[str]:
+ def charts(self) -> list[str]:
return [slc.chart for slc in self.slices]
@property
@@ -281,7 +279,7 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin):
return f"/superset/profile/{self.changed_by.username}"
@property
- def data(self) -> Dict[str, Any]:
+ def data(self) -> dict[str, Any]:
positions = self.position_json
if positions:
positions = json.loads(positions)
@@ -305,16 +303,16 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin):
make_name=lambda fname: f"{fname}-v1.0",
unless=lambda: not is_feature_enabled("DASHBOARD_CACHE"),
)
- def datasets_trimmed_for_slices(self) -> List[Dict[str, Any]]:
+ def datasets_trimmed_for_slices(self) -> list[dict[str, Any]]:
# Verbose but efficient database enumeration of dashboard datasources.
- slices_by_datasource: Dict[
- Tuple[Type["BaseDatasource"], int], Set[Slice]
+ slices_by_datasource: dict[
+ tuple[type[BaseDatasource], int], set[Slice]
] = defaultdict(set)
for slc in self.slices:
slices_by_datasource[(slc.cls_model, slc.datasource_id)].add(slc)
- result: List[Dict[str, Any]] = []
+ result: list[dict[str, Any]] = []
for (cls_model, datasource_id), slices in slices_by_datasource.items():
datasource = (
@@ -336,7 +334,7 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin):
self.json_metadata = value
@property
- def position(self) -> Dict[str, Any]:
+ def position(self) -> dict[str, Any]:
if self.position_json:
return json.loads(self.position_json)
return {}
@@ -380,7 +378,7 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin):
@classmethod
def export_dashboards( # pylint: disable=too-many-locals
- cls, dashboard_ids: List[int]
+ cls, dashboard_ids: list[int]
) -> str:
copied_dashboards = []
datasource_ids = set()
@@ -413,7 +411,7 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin):
slices.append(copied_slc)
json_metadata = json.loads(dashboard.json_metadata)
- native_filter_configuration: List[Dict[str, Any]] = json_metadata.get(
+ native_filter_configuration: list[dict[str, Any]] = json_metadata.get(
"native_filter_configuration", []
)
for native_filter in native_filter_configuration:
@@ -449,12 +447,12 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin):
)
@classmethod
- def get(cls, id_or_slug: Union[str, int]) -> Dashboard:
+ def get(cls, id_or_slug: str | int) -> Dashboard:
qry = db.session.query(Dashboard).filter(id_or_slug_filter(id_or_slug))
return qry.one_or_none()
-def is_uuid(value: Union[str, int]) -> bool:
+def is_uuid(value: str | int) -> bool:
try:
uuid.UUID(str(value))
return True
@@ -462,7 +460,7 @@ def is_uuid(value: Union[str, int]) -> bool:
return False
-def is_int(value: Union[str, int]) -> bool:
+def is_int(value: str | int) -> bool:
try:
int(value)
return True
@@ -470,7 +468,7 @@ def is_int(value: Union[str, int]) -> bool:
return False
-def id_or_slug_filter(id_or_slug: Union[int, str]) -> BinaryExpression:
+def id_or_slug_filter(id_or_slug: int | str) -> BinaryExpression:
if is_int(id_or_slug):
return Dashboard.id == int(id_or_slug)
if is_uuid(id_or_slug):
@@ -490,7 +488,7 @@ if is_feature_enabled("DASHBOARD_CACHE"):
def clear_dashboard_cache(
_mapper: Mapper,
_connection: Connection,
- obj: Union[Slice, BaseDatasource, Dashboard],
+ obj: Slice | BaseDatasource | Dashboard,
check_modified: bool = True,
) -> None:
if check_modified and not object_session(obj).is_modified(obj):
diff --git a/superset/models/datasource_access_request.py b/superset/models/datasource_access_request.py
index 1f286f96d8..23df4cffae 100644
--- a/superset/models/datasource_access_request.py
+++ b/superset/models/datasource_access_request.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Optional, Type, TYPE_CHECKING
+from typing import Optional, TYPE_CHECKING
from flask import Markup
from flask_appbuilder import Model
@@ -41,7 +41,7 @@ class DatasourceAccessRequest(Model, AuditMixinNullable):
ROLES_DENYLIST = set(config["ROBOT_PERMISSION_ROLES"])
@property
- def cls_model(self) -> Type["BaseDatasource"]:
+ def cls_model(self) -> type["BaseDatasource"]:
# pylint: disable=import-outside-toplevel
from superset.datasource.dao import DatasourceDAO
@@ -77,7 +77,7 @@ class DatasourceAccessRequest(Model, AuditMixinNullable):
f"datasource_id={self.datasource_id}&"
f"created_by={self.created_by.username}&role_to_grant={role.name}"
)
- link = 'Grant {} Role'.format(href, role.name)
+ link = f'Grant {role.name} Role'
action_list = action_list + "" + link + ""
return ""
@@ -90,8 +90,8 @@ class DatasourceAccessRequest(Model, AuditMixinNullable):
f"datasource_id={self.datasource_id}&"
f"created_by={self.created_by.username}&role_to_extend={role.name}"
)
- link = 'Extend {} Role'.format(href, role.name)
+ link = f'Extend {role.name} Role'
if role.name in self.ROLES_DENYLIST:
- link = "{} Role".format(role.name)
+ link = f"{role.name} Role"
action_list = action_list + "" + link + ""
return ""
diff --git a/superset/models/embedded_dashboard.py b/superset/models/embedded_dashboard.py
index 7718bc886f..32a8e4abce 100644
--- a/superset/models/embedded_dashboard.py
+++ b/superset/models/embedded_dashboard.py
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
import uuid
-from typing import List
from flask_appbuilder import Model
from sqlalchemy import Column, ForeignKey, Integer, Text
@@ -49,7 +48,7 @@ class EmbeddedDashboard(Model, AuditMixinNullable):
)
@property
- def allowed_domains(self) -> List[str]:
+ def allowed_domains(self) -> list[str]:
"""
A list of domains which are allowed to embed the dashboard.
An empty list means any domain can embed.
diff --git a/superset/models/filter_set.py b/superset/models/filter_set.py
index 1ace5bca32..ac25b114ff 100644
--- a/superset/models/filter_set.py
+++ b/superset/models/filter_set.py
@@ -18,7 +18,7 @@ from __future__ import annotations
import json
import logging
-from typing import Any, Dict
+from typing import Any
from flask import current_app
from flask_appbuilder import Model
@@ -75,7 +75,7 @@ class FilterSet(Model, AuditMixinNullable):
return ""
return f"/superset/profile/{self.changed_by.username}"
- def to_dict(self) -> Dict[str, Any]:
+ def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"name": self.name,
@@ -105,7 +105,7 @@ class FilterSet(Model, AuditMixinNullable):
return qry.all()
@property
- def params(self) -> Dict[str, Any]:
+ def params(self) -> dict[str, Any]:
if self.json_metadata:
return json.loads(self.json_metadata)
return {}
diff --git a/superset/models/helpers.py b/superset/models/helpers.py
index 32c6f5ff6a..42d5a24174 100644
--- a/superset/models/helpers.py
+++ b/superset/models/helpers.py
@@ -16,29 +16,17 @@
# under the License.
# pylint: disable=too-many-lines
"""a collection of model-related helper classes and functions"""
+import builtins
import dataclasses
import json
import logging
import re
import uuid
from collections import defaultdict
+from collections.abc import Hashable
from datetime import datetime, timedelta
from json.decoder import JSONDecodeError
-from typing import (
- Any,
- cast,
- Dict,
- Hashable,
- List,
- NamedTuple,
- Optional,
- Set,
- Text,
- Tuple,
- Type,
- TYPE_CHECKING,
- Union,
-)
+from typing import Any, cast, NamedTuple, Optional, TYPE_CHECKING, Union
import dateutil.parser
import humanize
@@ -145,7 +133,7 @@ def validate_adhoc_subquery(
return ";\n".join(str(statement) for statement in statements)
-def json_to_dict(json_str: str) -> Dict[Any, Any]:
+def json_to_dict(json_str: str) -> dict[Any, Any]:
if json_str:
val = re.sub(",[ \t\r\n]+}", "}", json_str)
val = re.sub(",[ \t\r\n]+\\]", "]", val)
@@ -179,22 +167,22 @@ class ImportExportMixin:
# The name of the attribute
# with the SQL Alchemy back reference
- export_children: List[str] = []
+ export_children: list[str] = []
# List of (str) names of attributes
# with the SQL Alchemy forward references
- export_fields: List[str] = []
+ export_fields: list[str] = []
# The names of the attributes
# that are available for import and export
- extra_import_fields: List[str] = []
+ extra_import_fields: list[str] = []
# Additional fields that should be imported,
# even though they were not exported
__mapper__: Mapper
@classmethod
- def _unique_constrains(cls) -> List[Set[str]]:
+ def _unique_constrains(cls) -> list[set[str]]:
"""Get all (single column and multi column) unique constraints"""
unique = [
{c.name for c in u.columns}
@@ -207,7 +195,7 @@ class ImportExportMixin:
return unique
@classmethod
- def parent_foreign_key_mappings(cls) -> Dict[str, str]:
+ def parent_foreign_key_mappings(cls) -> dict[str, str]:
"""Get a mapping of foreign name to the local name of foreign keys"""
parent_rel = cls.__mapper__.relationships.get(cls.export_parent)
if parent_rel:
@@ -217,7 +205,7 @@ class ImportExportMixin:
@classmethod
def export_schema(
cls, recursive: bool = True, include_parent_ref: bool = False
- ) -> Dict[str, Any]:
+ ) -> dict[str, Any]:
"""Export schema as a dictionary"""
parent_excludes = set()
if not include_parent_ref:
@@ -227,12 +215,12 @@ class ImportExportMixin:
def formatter(column: sa.Column) -> str:
return (
- "{0} Default ({1})".format(str(column.type), column.default.arg)
+ f"{str(column.type)} Default ({column.default.arg})"
if column.default
else str(column.type)
)
- schema: Dict[str, Any] = {
+ schema: dict[str, Any] = {
column.name: formatter(column)
for column in cls.__table__.columns # type: ignore
if (column.name in cls.export_fields and column.name not in parent_excludes)
@@ -252,10 +240,10 @@ class ImportExportMixin:
# pylint: disable=too-many-arguments,too-many-branches,too-many-locals
cls,
session: Session,
- dict_rep: Dict[Any, Any],
+ dict_rep: dict[Any, Any],
parent: Optional[Any] = None,
recursive: bool = True,
- sync: Optional[List[str]] = None,
+ sync: Optional[list[str]] = None,
) -> Any:
"""Import obj from a dictionary"""
if sync is None:
@@ -281,9 +269,7 @@ class ImportExportMixin:
if cls.export_parent:
for prnt in parent_refs.keys():
if prnt not in dict_rep:
- raise RuntimeError(
- "{0}: Missing field {1}".format(cls.__name__, prnt)
- )
+ raise RuntimeError(f"{cls.__name__}: Missing field {prnt}")
else:
# Set foreign keys to parent obj
for k, v in parent_refs.items():
@@ -371,7 +357,7 @@ class ImportExportMixin:
include_parent_ref: bool = False,
include_defaults: bool = False,
export_uuids: bool = False,
- ) -> Dict[Any, Any]:
+ ) -> dict[Any, Any]:
"""Export obj to dictionary"""
export_fields = set(self.export_fields)
if export_uuids:
@@ -457,18 +443,18 @@ class ImportExportMixin:
self.owners = [g.user]
@property
- def params_dict(self) -> Dict[Any, Any]:
+ def params_dict(self) -> dict[Any, Any]:
return json_to_dict(self.params)
@property
- def template_params_dict(self) -> Dict[Any, Any]:
+ def template_params_dict(self) -> dict[Any, Any]:
return json_to_dict(self.template_params) # type: ignore
def _user_link(user: User) -> Union[Markup, str]:
if not user:
return ""
- url = "/superset/profile/{}/".format(user.username)
+ url = f"/superset/profile/{user.username}/"
return Markup('{}'.format(url, escape(user) or ""))
@@ -505,13 +491,13 @@ class AuditMixinNullable(AuditMixin):
@property
def created_by_name(self) -> str:
if self.created_by:
- return escape("{}".format(self.created_by))
+ return escape(f"{self.created_by}")
return ""
@property
def changed_by_name(self) -> str:
if self.changed_by:
- return escape("{}".format(self.changed_by))
+ return escape(f"{self.changed_by}")
return ""
@renders("created_by")
@@ -565,12 +551,12 @@ class QueryResult: # pylint: disable=too-few-public-methods
df: pd.DataFrame,
query: str,
duration: timedelta,
- applied_template_filters: Optional[List[str]] = None,
- applied_filter_columns: Optional[List[ColumnTyping]] = None,
- rejected_filter_columns: Optional[List[ColumnTyping]] = None,
+ applied_template_filters: Optional[list[str]] = None,
+ applied_filter_columns: Optional[list[ColumnTyping]] = None,
+ rejected_filter_columns: Optional[list[ColumnTyping]] = None,
status: str = QueryStatus.SUCCESS,
error_message: Optional[str] = None,
- errors: Optional[List[Dict[str, Any]]] = None,
+ errors: Optional[list[dict[str, Any]]] = None,
from_dttm: Optional[datetime] = None,
to_dttm: Optional[datetime] = None,
) -> None:
@@ -593,7 +579,7 @@ class ExtraJSONMixin:
extra_json = sa.Column(sa.Text, default="{}")
@property
- def extra(self) -> Dict[str, Any]:
+ def extra(self) -> dict[str, Any]:
try:
return json.loads(self.extra_json or "{}") or {}
except (TypeError, JSONDecodeError) as exc:
@@ -603,7 +589,7 @@ class ExtraJSONMixin:
return {}
@extra.setter
- def extra(self, extras: Dict[str, Any]) -> None:
+ def extra(self, extras: dict[str, Any]) -> None:
self.extra_json = json.dumps(extras)
def set_extra_json_key(self, key: str, value: Any) -> None:
@@ -615,7 +601,7 @@ class ExtraJSONMixin:
def ensure_extra_json_is_not_none( # pylint: disable=no-self-use
self,
_: str,
- value: Optional[Dict[str, Any]],
+ value: Optional[dict[str, Any]],
) -> Any:
if value is None:
return "{}"
@@ -627,7 +613,7 @@ class CertificationMixin:
extra = sa.Column(sa.Text, default="{}")
- def get_extra_dict(self) -> Dict[str, Any]:
+ def get_extra_dict(self) -> dict[str, Any]:
try:
return json.loads(self.extra)
except (TypeError, json.JSONDecodeError):
@@ -652,8 +638,8 @@ class CertificationMixin:
def clone_model(
target: Model,
- ignore: Optional[List[str]] = None,
- keep_relations: Optional[List[str]] = None,
+ ignore: Optional[list[str]] = None,
+ keep_relations: Optional[list[str]] = None,
**kwargs: Any,
) -> Model:
"""
@@ -676,22 +662,22 @@ def clone_model(
# todo(hugh): centralize where this code lives
class QueryStringExtended(NamedTuple):
- applied_template_filters: Optional[List[str]]
- applied_filter_columns: List[ColumnTyping]
- rejected_filter_columns: List[ColumnTyping]
- labels_expected: List[str]
- prequeries: List[str]
+ applied_template_filters: Optional[list[str]]
+ applied_filter_columns: list[ColumnTyping]
+ rejected_filter_columns: list[ColumnTyping]
+ labels_expected: list[str]
+ prequeries: list[str]
sql: str
class SqlaQuery(NamedTuple):
- applied_template_filters: List[str]
- applied_filter_columns: List[ColumnTyping]
- rejected_filter_columns: List[ColumnTyping]
+ applied_template_filters: list[str]
+ applied_filter_columns: list[ColumnTyping]
+ rejected_filter_columns: list[ColumnTyping]
cte: Optional[str]
- extra_cache_keys: List[Any]
- labels_expected: List[str]
- prequeries: List[str]
+ extra_cache_keys: list[Any]
+ labels_expected: list[str]
+ prequeries: list[str]
sqla_query: Select
@@ -719,7 +705,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
raise NotImplementedError()
@property
- def db_extra(self) -> Optional[Dict[str, Any]]:
+ def db_extra(self) -> Optional[dict[str, Any]]:
raise NotImplementedError()
def query(self, query_obj: QueryObjectDict) -> QueryResult:
@@ -730,11 +716,11 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
raise NotImplementedError()
@property
- def owners_data(self) -> List[Any]:
+ def owners_data(self) -> list[Any]:
raise NotImplementedError()
@property
- def metrics(self) -> List[Any]:
+ def metrics(self) -> list[Any]:
return []
@property
@@ -750,7 +736,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
raise NotImplementedError()
@property
- def column_names(self) -> List[str]:
+ def column_names(self) -> list[str]:
raise NotImplementedError()
@property
@@ -762,15 +748,15 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
raise NotImplementedError()
@property
- def dttm_cols(self) -> List[str]:
+ def dttm_cols(self) -> list[str]:
raise NotImplementedError()
@property
- def db_engine_spec(self) -> Type["BaseEngineSpec"]:
+ def db_engine_spec(self) -> builtins.type["BaseEngineSpec"]:
raise NotImplementedError()
@property
- def database(self) -> Type["Database"]:
+ def database(self) -> builtins.type["Database"]:
raise NotImplementedError()
@property
@@ -782,7 +768,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
raise NotImplementedError()
@property
- def columns(self) -> List[Any]:
+ def columns(self) -> list[Any]:
raise NotImplementedError()
def get_fetch_values_predicate(
@@ -790,7 +776,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
) -> TextClause:
raise NotImplementedError()
- def get_extra_cache_keys(self, query_obj: Dict[str, Any]) -> List[Hashable]:
+ def get_extra_cache_keys(self, query_obj: dict[str, Any]) -> list[Hashable]:
raise NotImplementedError()
def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor:
@@ -799,7 +785,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
def get_sqla_row_level_filters(
self,
template_processor: BaseTemplateProcessor,
- ) -> List[TextClause]:
+ ) -> list[TextClause]:
"""
Return the appropriate row level security filters for this table and the
current user. A custom username can be passed when the user is not present in the
@@ -808,8 +794,8 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
:param template_processor: The template processor to apply to the filters.
:returns: A list of SQL clauses to be ANDed together.
"""
- all_filters: List[TextClause] = []
- filter_groups: Dict[Union[int, str], List[TextClause]] = defaultdict(list)
+ all_filters: list[TextClause] = []
+ filter_groups: dict[Union[int, str], list[TextClause]] = defaultdict(list)
try:
for filter_ in security_manager.get_rls_filters(self):
clause = self.text(
@@ -923,8 +909,8 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
self,
row: pd.Series,
dimension: str,
- columns_by_name: Dict[str, "TableColumn"],
- ) -> Union[str, int, float, bool, Text]:
+ columns_by_name: dict[str, "TableColumn"],
+ ) -> Union[str, int, float, bool, str]:
"""
Convert a prequery result type to its equivalent Python type.
@@ -944,7 +930,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
value = value.item()
column_ = columns_by_name[dimension]
- db_extra: Dict[str, Any] = self.database.get_extra() # type: ignore
+ db_extra: dict[str, Any] = self.database.get_extra() # type: ignore
if isinstance(column_, dict):
if (
@@ -969,7 +955,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
return value
def make_orderby_compatible(
- self, select_exprs: List[ColumnElement], orderby_exprs: List[ColumnElement]
+ self, select_exprs: list[ColumnElement], orderby_exprs: list[ColumnElement]
) -> None:
"""
If needed, make sure aliases for selected columns are not used in
@@ -1088,7 +1074,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
def get_from_clause(
self, template_processor: Optional[BaseTemplateProcessor] = None
- ) -> Tuple[Union[TableClause, Alias], Optional[str]]:
+ ) -> tuple[Union[TableClause, Alias], Optional[str]]:
"""
Return where to select the columns and metrics from. Either a physical table
or a virtual table with it's own subquery. If the FROM is referencing a
@@ -1117,7 +1103,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
def adhoc_metric_to_sqla(
self,
metric: AdhocMetric,
- columns_by_name: Dict[str, "TableColumn"], # pylint: disable=unused-argument
+ columns_by_name: dict[str, "TableColumn"], # pylint: disable=unused-argument
template_processor: Optional[BaseTemplateProcessor] = None,
) -> ColumnElement:
"""
@@ -1151,7 +1137,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
return self.make_sqla_column_compatible(sqla_metric, label)
@property
- def template_params_dict(self) -> Dict[Any, Any]:
+ def template_params_dict(self) -> dict[Any, Any]:
return {}
@staticmethod
@@ -1162,9 +1148,9 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
target_native_type: Optional[str] = None,
is_list_target: bool = False,
db_engine_spec: Optional[
- Type["BaseEngineSpec"]
+ builtins.type["BaseEngineSpec"]
] = None, # fix(hughhh): Optional[Type[BaseEngineSpec]]
- db_extra: Optional[Dict[str, Any]] = None,
+ db_extra: Optional[dict[str, Any]] = None,
) -> Optional[FilterValues]:
if values is None:
return None
@@ -1217,8 +1203,8 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
def _get_series_orderby(
self,
series_limit_metric: Metric,
- metrics_by_name: Dict[str, "SqlMetric"],
- columns_by_name: Dict[str, "TableColumn"],
+ metrics_by_name: dict[str, "SqlMetric"],
+ columns_by_name: dict[str, "TableColumn"],
template_processor: Optional[BaseTemplateProcessor] = None,
) -> Column:
if utils.is_adhoc_metric(series_limit_metric):
@@ -1248,9 +1234,9 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
def _get_top_groups(
self,
df: pd.DataFrame,
- dimensions: List[str],
- groupby_exprs: Dict[str, Any],
- columns_by_name: Dict[str, "TableColumn"],
+ dimensions: list[str],
+ groupby_exprs: dict[str, Any],
+ columns_by_name: dict[str, "TableColumn"],
) -> ColumnElement:
groups = []
for _unused, row in df.iterrows():
@@ -1335,7 +1321,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
)
return and_(*l)
- def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
+ def values_for_column(self, column_name: str, limit: int = 10000) -> list[Any]:
"""Runs query against sqla to retrieve some
sample values for the given column.
"""
@@ -1369,7 +1355,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
def get_timestamp_expression(
self,
- column: Dict[str, Any],
+ column: dict[str, Any],
time_grain: Optional[str],
label: Optional[str] = None,
template_processor: Optional[BaseTemplateProcessor] = None,
@@ -1417,23 +1403,23 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements
self,
apply_fetch_values_predicate: bool = False,
- columns: Optional[List[Column]] = None,
- extras: Optional[Dict[str, Any]] = None,
+ columns: Optional[list[Column]] = None,
+ extras: Optional[dict[str, Any]] = None,
filter: Optional[ # pylint: disable=redefined-builtin
- List[utils.QueryObjectFilterClause]
+ list[utils.QueryObjectFilterClause]
] = None,
from_dttm: Optional[datetime] = None,
granularity: Optional[str] = None,
- groupby: Optional[List[Column]] = None,
+ groupby: Optional[list[Column]] = None,
inner_from_dttm: Optional[datetime] = None,
inner_to_dttm: Optional[datetime] = None,
is_rowcount: bool = False,
is_timeseries: bool = True,
- metrics: Optional[List[Metric]] = None,
- orderby: Optional[List[OrderBy]] = None,
+ metrics: Optional[list[Metric]] = None,
+ orderby: Optional[list[OrderBy]] = None,
order_desc: bool = True,
to_dttm: Optional[datetime] = None,
- series_columns: Optional[List[Column]] = None,
+ series_columns: Optional[list[Column]] = None,
series_limit: Optional[int] = None,
series_limit_metric: Optional[Metric] = None,
row_limit: Optional[int] = None,
@@ -1464,23 +1450,23 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
}
columns = columns or []
groupby = groupby or []
- rejected_adhoc_filters_columns: List[Union[str, ColumnTyping]] = []
- applied_adhoc_filters_columns: List[Union[str, ColumnTyping]] = []
+ rejected_adhoc_filters_columns: list[Union[str, ColumnTyping]] = []
+ applied_adhoc_filters_columns: list[Union[str, ColumnTyping]] = []
series_column_names = utils.get_column_names(series_columns or [])
# deprecated, to be removed in 2.0
if is_timeseries and timeseries_limit:
series_limit = timeseries_limit
series_limit_metric = series_limit_metric or timeseries_limit_metric
template_kwargs.update(self.template_params_dict)
- extra_cache_keys: List[Any] = []
+ extra_cache_keys: list[Any] = []
template_kwargs["extra_cache_keys"] = extra_cache_keys
- removed_filters: List[str] = []
- applied_template_filters: List[str] = []
+ removed_filters: list[str] = []
+ applied_template_filters: list[str] = []
template_kwargs["removed_filters"] = removed_filters
template_kwargs["applied_filters"] = applied_template_filters
template_processor = self.get_template_processor(**template_kwargs)
db_engine_spec = self.db_engine_spec
- prequeries: List[str] = []
+ prequeries: list[str] = []
orderby = orderby or []
need_groupby = bool(metrics is not None or groupby)
metrics = metrics or []
@@ -1489,11 +1475,11 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
if granularity not in self.dttm_cols and granularity is not None:
granularity = self.main_dttm_col
- columns_by_name: Dict[str, "TableColumn"] = {
+ columns_by_name: dict[str, "TableColumn"] = {
col.column_name: col for col in self.columns
}
- metrics_by_name: Dict[str, "SqlMetric"] = {
+ metrics_by_name: dict[str, "SqlMetric"] = {
m.metric_name: m for m in self.metrics
}
@@ -1507,7 +1493,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
if not metrics and not columns and not groupby:
raise QueryObjectValidationError(_("Empty query?"))
- metrics_exprs: List[ColumnElement] = []
+ metrics_exprs: list[ColumnElement] = []
for metric in metrics:
if utils.is_adhoc_metric(metric):
assert isinstance(metric, dict)
@@ -1542,7 +1528,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
metrics_exprs_by_expr = {str(m): m for m in metrics_exprs}
# Since orderby may use adhoc metrics, too; we need to process them first
- orderby_exprs: List[ColumnElement] = []
+ orderby_exprs: list[ColumnElement] = []
for orig_col, ascending in orderby:
col: Union[AdhocMetric, ColumnElement] = orig_col
if isinstance(col, dict):
@@ -1582,7 +1568,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods
_("Unknown column used in orderby: %(col)s", col=orig_col)
)
- select_exprs: List[Union[Column, Label]] = []
+ select_exprs: list[Union[Column, Label]] = []
groupby_all_columns = {}
groupby_series_columns = {}
diff --git a/superset/models/slice.py b/superset/models/slice.py
index 6835215338..15dddfc7e1 100644
--- a/superset/models/slice.py
+++ b/superset/models/slice.py
@@ -18,7 +18,7 @@ from __future__ import annotations
import json
import logging
-from typing import Any, Dict, Optional, Type, TYPE_CHECKING
+from typing import Any, TYPE_CHECKING
from urllib import parse
import sqlalchemy as sqla
@@ -70,7 +70,7 @@ class Slice( # pylint: disable=too-many-public-methods
):
"""A slice is essentially a report or a view on data"""
- query_context_factory: Optional[QueryContextFactory] = None
+ query_context_factory: QueryContextFactory | None = None
__tablename__ = "slices"
id = Column(Integer, primary_key=True)
@@ -134,17 +134,17 @@ class Slice( # pylint: disable=too-many-public-methods
return self.slice_name or str(self.id)
@property
- def cls_model(self) -> Type["BaseDatasource"]:
+ def cls_model(self) -> type[BaseDatasource]:
# pylint: disable=import-outside-toplevel
from superset.datasource.dao import DatasourceDAO
return DatasourceDAO.sources[self.datasource_type]
@property
- def datasource(self) -> Optional["BaseDatasource"]:
+ def datasource(self) -> BaseDatasource | None:
return self.get_datasource
- def clone(self) -> "Slice":
+ def clone(self) -> Slice:
return Slice(
slice_name=self.slice_name,
datasource_id=self.datasource_id,
@@ -158,7 +158,7 @@ class Slice( # pylint: disable=too-many-public-methods
# pylint: disable=using-constant-test
@datasource.getter # type: ignore
- def get_datasource(self) -> Optional["BaseDatasource"]:
+ def get_datasource(self) -> BaseDatasource | None:
return (
db.session.query(self.cls_model)
.filter_by(id=self.datasource_id)
@@ -166,20 +166,20 @@ class Slice( # pylint: disable=too-many-public-methods
)
@renders("datasource_name")
- def datasource_link(self) -> Optional[Markup]:
+ def datasource_link(self) -> Markup | None:
# pylint: disable=no-member
datasource = self.datasource
return datasource.link if datasource else None
@renders("datasource_url")
- def datasource_url(self) -> Optional[str]:
+ def datasource_url(self) -> str | None:
# pylint: disable=no-member
if self.table:
return self.table.explore_url
datasource = self.datasource
return datasource.explore_url if datasource else None
- def datasource_name_text(self) -> Optional[str]:
+ def datasource_name_text(self) -> str | None:
# pylint: disable=no-member
if self.table:
if self.table.schema:
@@ -192,7 +192,7 @@ class Slice( # pylint: disable=too-many-public-methods
return None
@property
- def datasource_edit_url(self) -> Optional[str]:
+ def datasource_edit_url(self) -> str | None:
# pylint: disable=no-member
datasource = self.datasource
return datasource.url if datasource else None
@@ -200,7 +200,7 @@ class Slice( # pylint: disable=too-many-public-methods
# pylint: enable=using-constant-test
@property
- def viz(self) -> Optional[BaseViz]:
+ def viz(self) -> BaseViz | None:
form_data = json.loads(self.params)
viz_class = viz_types.get(self.viz_type)
datasource = self.datasource
@@ -213,9 +213,9 @@ class Slice( # pylint: disable=too-many-public-methods
return utils.markdown(self.description)
@property
- def data(self) -> Dict[str, Any]:
+ def data(self) -> dict[str, Any]:
"""Data used to render slice in templates"""
- data: Dict[str, Any] = {}
+ data: dict[str, Any] = {}
self.token = ""
try:
viz = self.viz
@@ -261,8 +261,8 @@ class Slice( # pylint: disable=too-many-public-methods
return json.dumps(self.data)
@property
- def form_data(self) -> Dict[str, Any]:
- form_data: Dict[str, Any] = {}
+ def form_data(self) -> dict[str, Any]:
+ form_data: dict[str, Any] = {}
try:
form_data = json.loads(self.params)
except Exception as ex: # pylint: disable=broad-except
@@ -272,7 +272,7 @@ class Slice( # pylint: disable=too-many-public-methods
{
"slice_id": self.id,
"viz_type": self.viz_type,
- "datasource": "{}__{}".format(self.datasource_id, self.datasource_type),
+ "datasource": f"{self.datasource_id}__{self.datasource_type}",
}
)
@@ -281,7 +281,7 @@ class Slice( # pylint: disable=too-many-public-methods
update_time_range(form_data)
return form_data
- def get_query_context(self) -> Optional[QueryContext]:
+ def get_query_context(self) -> QueryContext | None:
if self.query_context:
try:
return self.get_query_context_factory().create(
@@ -295,13 +295,13 @@ class Slice( # pylint: disable=too-many-public-methods
def get_explore_url(
self,
base_url: str = "/explore",
- overrides: Optional[Dict[str, Any]] = None,
+ overrides: dict[str, Any] | None = None,
) -> str:
return self.build_explore_url(self.id, base_url, overrides)
@staticmethod
def build_explore_url(
- id_: int, base_url: str = "/explore", overrides: Optional[Dict[str, Any]] = None
+ id_: int, base_url: str = "/explore", overrides: dict[str, Any] | None = None
) -> str:
overrides = overrides or {}
form_data = {"slice_id": id_}
diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py
index b2f0c8c1ed..b9ab153798 100644
--- a/superset/models/sql_lab.py
+++ b/superset/models/sql_lab.py
@@ -15,11 +15,13 @@
# specific language governing permissions and limitations
# under the License.
"""A collection of ORM sqlalchemy models for SQL Lab"""
+import builtins
import inspect
import logging
import re
+from collections.abc import Hashable
from datetime import datetime
-from typing import Any, Dict, Hashable, List, Optional, Type, TYPE_CHECKING
+from typing import Any, Optional, TYPE_CHECKING
import simplejson as json
import sqlalchemy as sqla
@@ -131,7 +133,7 @@ class Query(
def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor:
return get_template_processor(query=self, database=self.database, **kwargs)
- def to_dict(self) -> Dict[str, Any]:
+ def to_dict(self) -> dict[str, Any]:
return {
"changedOn": self.changed_on,
"changed_on": self.changed_on.isoformat(),
@@ -181,11 +183,11 @@ class Query(
return self.user.username
@property
- def sql_tables(self) -> List[Table]:
+ def sql_tables(self) -> list[Table]:
return list(ParsedQuery(self.sql).tables)
@property
- def columns(self) -> List["TableColumn"]:
+ def columns(self) -> list["TableColumn"]:
from superset.connectors.sqla.models import ( # pylint: disable=import-outside-toplevel
TableColumn,
)
@@ -204,11 +206,11 @@ class Query(
return columns
@property
- def db_extra(self) -> Optional[Dict[str, Any]]:
+ def db_extra(self) -> Optional[dict[str, Any]]:
return None
@property
- def data(self) -> Dict[str, Any]:
+ def data(self) -> dict[str, Any]:
order_by_choices = []
for col in self.columns:
column_name = str(col.column_name or "")
@@ -247,11 +249,13 @@ class Query(
security_manager.raise_for_access(query=self)
@property
- def db_engine_spec(self) -> Type["BaseEngineSpec"]:
+ def db_engine_spec(
+ self,
+ ) -> builtins.type["BaseEngineSpec"]: # pylint: disable=unsubscriptable-object
return self.database.db_engine_spec
@property
- def owners_data(self) -> List[Dict[str, Any]]:
+ def owners_data(self) -> list[dict[str, Any]]:
return []
@property
@@ -267,7 +271,7 @@ class Query(
return 0
@property
- def column_names(self) -> List[Any]:
+ def column_names(self) -> list[Any]:
return [col.column_name for col in self.columns]
@property
@@ -282,7 +286,7 @@ class Query(
return None
@property
- def dttm_cols(self) -> List[Any]:
+ def dttm_cols(self) -> list[Any]:
return [col.column_name for col in self.columns if col.is_dttm]
@property
@@ -298,7 +302,7 @@ class Query(
return ""
@staticmethod
- def get_extra_cache_keys(query_obj: Dict[str, Any]) -> List[Hashable]:
+ def get_extra_cache_keys(query_obj: dict[str, Any]) -> list[Hashable]:
return []
@property
@@ -322,7 +326,7 @@ class Query(
def tracking_url(self, value: str) -> None:
self.tracking_url_raw = value
- def get_column(self, column_name: Optional[str]) -> Optional[Dict[str, Any]]:
+ def get_column(self, column_name: Optional[str]) -> Optional[dict[str, Any]]:
if not column_name:
return None
for col in self.columns:
@@ -397,7 +401,7 @@ class SavedQuery(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
def __repr__(self) -> str:
return str(self.label)
- def to_dict(self) -> Dict[str, Any]:
+ def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
}
@@ -421,10 +425,10 @@ class SavedQuery(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
return self.database.sqlalchemy_uri
def url(self) -> str:
- return "/superset/sqllab?savedQueryId={0}".format(self.id)
+ return f"/superset/sqllab?savedQueryId={self.id}"
@property
- def sql_tables(self) -> List[Table]:
+ def sql_tables(self) -> list[Table]:
return list(ParsedQuery(self.sql).tables)
@property
@@ -483,7 +487,7 @@ class TabState(Model, AuditMixinNullable, ExtraJSONMixin):
)
saved_query = relationship("SavedQuery", foreign_keys=[saved_query_id])
- def to_dict(self) -> Dict[str, Any]:
+ def to_dict(self) -> dict[str, Any]:
return {
"id": self.id,
"user_id": self.user_id,
@@ -520,7 +524,7 @@ class TableSchema(Model, AuditMixinNullable, ExtraJSONMixin):
expanded = Column(Boolean, default=False)
- def to_dict(self) -> Dict[str, Any]:
+ def to_dict(self) -> dict[str, Any]:
try:
description = json.loads(self.description)
except json.JSONDecodeError:
diff --git a/superset/models/sql_types/presto_sql_types.py b/superset/models/sql_types/presto_sql_types.py
index c496f75039..234581dfb4 100644
--- a/superset/models/sql_types/presto_sql_types.py
+++ b/superset/models/sql_types/presto_sql_types.py
@@ -16,7 +16,7 @@
# under the License.
# pylint: disable=abstract-method, no-init
-from typing import Any, Dict, List, Optional, Type
+from typing import Any, Optional
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.sql.sqltypes import DATE, Integer, TIMESTAMP
@@ -33,7 +33,7 @@ class TinyInteger(Integer):
"""
@property
- def python_type(self) -> Type[int]:
+ def python_type(self) -> type[int]:
return int
@classmethod
@@ -47,7 +47,7 @@ class Interval(TypeEngine):
"""
@property
- def python_type(self) -> Optional[Type[Any]]:
+ def python_type(self) -> Optional[type[Any]]:
return None
@classmethod
@@ -61,7 +61,7 @@ class Array(TypeEngine):
"""
@property
- def python_type(self) -> Optional[Type[List[Any]]]:
+ def python_type(self) -> Optional[type[list[Any]]]:
return list
@classmethod
@@ -75,7 +75,7 @@ class Map(TypeEngine):
"""
@property
- def python_type(self) -> Optional[Type[Dict[Any, Any]]]:
+ def python_type(self) -> Optional[type[dict[Any, Any]]]:
return dict
@classmethod
@@ -89,7 +89,7 @@ class Row(TypeEngine):
"""
@property
- def python_type(self) -> Optional[Type[Any]]:
+ def python_type(self) -> Optional[type[Any]]:
return None
@classmethod
diff --git a/superset/queries/dao.py b/superset/queries/dao.py
index 642a5dd4cb..e9fe15cac5 100644
--- a/superset/queries/dao.py
+++ b/superset/queries/dao.py
@@ -16,7 +16,7 @@
# under the License.
import logging
from datetime import datetime
-from typing import Any, Dict, List, Union
+from typing import Any, Union
from superset import sql_lab
from superset.common.db_query_status import QueryStatus
@@ -56,14 +56,14 @@ class QueryDAO(BaseDAO):
db.session.commit()
@staticmethod
- def save_metadata(query: Query, payload: Dict[str, Any]) -> None:
+ def save_metadata(query: Query, payload: dict[str, Any]) -> None:
# pull relevant data from payload and store in extra_json
columns = payload.get("columns", {})
db.session.add(query)
query.set_extra_json_key("columns", columns)
@staticmethod
- def get_queries_changed_after(last_updated_ms: Union[float, int]) -> List[Query]:
+ def get_queries_changed_after(last_updated_ms: Union[float, int]) -> list[Query]:
# UTC date time, same that is stored in the DB.
last_updated_dt = datetime.utcfromtimestamp(last_updated_ms / 1000)
diff --git a/superset/queries/saved_queries/commands/bulk_delete.py b/superset/queries/saved_queries/commands/bulk_delete.py
index c96afd31e5..fb230180c8 100644
--- a/superset/queries/saved_queries/commands/bulk_delete.py
+++ b/superset/queries/saved_queries/commands/bulk_delete.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import List, Optional
+from typing import Optional
from superset.commands.base import BaseCommand
from superset.dao.exceptions import DAODeleteFailedError
@@ -30,9 +30,9 @@ logger = logging.getLogger(__name__)
class BulkDeleteSavedQueryCommand(BaseCommand):
- def __init__(self, model_ids: List[int]):
+ def __init__(self, model_ids: list[int]):
self._model_ids = model_ids
- self._models: Optional[List[Dashboard]] = None
+ self._models: Optional[list[Dashboard]] = None
def run(self) -> None:
self.validate()
diff --git a/superset/queries/saved_queries/commands/export.py b/superset/queries/saved_queries/commands/export.py
index 8c5357159e..323256306a 100644
--- a/superset/queries/saved_queries/commands/export.py
+++ b/superset/queries/saved_queries/commands/export.py
@@ -18,7 +18,7 @@
import json
import logging
-from typing import Iterator, Tuple
+from collections.abc import Iterator
import yaml
from werkzeug.utils import secure_filename
@@ -39,7 +39,7 @@ class ExportSavedQueriesCommand(ExportModelsCommand):
@staticmethod
def _export(
model: SavedQuery, export_related: bool = True
- ) -> Iterator[Tuple[str, str]]:
+ ) -> Iterator[tuple[str, str]]:
# build filename based on database, optional schema, and label
database_slug = secure_filename(model.database.database_name)
schema_slug = secure_filename(model.schema)
diff --git a/superset/queries/saved_queries/commands/importers/dispatcher.py b/superset/queries/saved_queries/commands/importers/dispatcher.py
index 8283202225..c2208f0e2a 100644
--- a/superset/queries/saved_queries/commands/importers/dispatcher.py
+++ b/superset/queries/saved_queries/commands/importers/dispatcher.py
@@ -16,7 +16,7 @@
# under the License.
import logging
-from typing import Any, Dict
+from typing import Any
from marshmallow.exceptions import ValidationError
@@ -40,7 +40,7 @@ class ImportSavedQueriesCommand(BaseCommand):
until it finds one that matches.
"""
- def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
+ def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any):
self.contents = contents
self.args = args
self.kwargs = kwargs
diff --git a/superset/queries/saved_queries/commands/importers/v1/__init__.py b/superset/queries/saved_queries/commands/importers/v1/__init__.py
index 1412dbd356..79ec04f54b 100644
--- a/superset/queries/saved_queries/commands/importers/v1/__init__.py
+++ b/superset/queries/saved_queries/commands/importers/v1/__init__.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, Set
+from typing import Any
from marshmallow import Schema
from sqlalchemy.orm import Session
@@ -38,7 +38,7 @@ class ImportSavedQueriesCommand(ImportModelsCommand):
dao = SavedQueryDAO
model_name = "saved_queries"
prefix = "queries/"
- schemas: Dict[str, Schema] = {
+ schemas: dict[str, Schema] = {
"databases/": ImportV1DatabaseSchema(),
"queries/": ImportV1SavedQuerySchema(),
}
@@ -46,16 +46,16 @@ class ImportSavedQueriesCommand(ImportModelsCommand):
@staticmethod
def _import(
- session: Session, configs: Dict[str, Any], overwrite: bool = False
+ session: Session, configs: dict[str, Any], overwrite: bool = False
) -> None:
# discover databases associated with saved queries
- database_uuids: Set[str] = set()
+ database_uuids: set[str] = set()
for file_name, config in configs.items():
if file_name.startswith("queries/"):
database_uuids.add(config["database_uuid"])
# import related databases
- database_ids: Dict[str, int] = {}
+ database_ids: dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/") and config["uuid"] in database_uuids:
database = import_database(session, config, overwrite=False)
diff --git a/superset/queries/saved_queries/commands/importers/v1/utils.py b/superset/queries/saved_queries/commands/importers/v1/utils.py
index f2d090bf11..813f3c2295 100644
--- a/superset/queries/saved_queries/commands/importers/v1/utils.py
+++ b/superset/queries/saved_queries/commands/importers/v1/utils.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict
+from typing import Any
from sqlalchemy.orm import Session
@@ -23,7 +23,7 @@ from superset.models.sql_lab import SavedQuery
def import_saved_query(
- session: Session, config: Dict[str, Any], overwrite: bool = False
+ session: Session, config: dict[str, Any], overwrite: bool = False
) -> SavedQuery:
existing = session.query(SavedQuery).filter_by(uuid=config["uuid"]).first()
if existing:
diff --git a/superset/queries/saved_queries/dao.py b/superset/queries/saved_queries/dao.py
index c6bcfa035c..daae1de8f5 100644
--- a/superset/queries/saved_queries/dao.py
+++ b/superset/queries/saved_queries/dao.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import List, Optional
+from typing import Optional
from sqlalchemy.exc import SQLAlchemyError
@@ -33,7 +33,7 @@ class SavedQueryDAO(BaseDAO):
base_filter = SavedQueryFilter
@staticmethod
- def bulk_delete(models: Optional[List[SavedQuery]], commit: bool = True) -> None:
+ def bulk_delete(models: Optional[list[SavedQuery]], commit: bool = True) -> None:
item_ids = [model.id for model in models] if models else []
try:
db.session.query(SavedQuery).filter(SavedQuery.id.in_(item_ids)).delete(
diff --git a/superset/queries/schemas.py b/superset/queries/schemas.py
index b139784c5b..850664e92f 100644
--- a/superset/queries/schemas.py
+++ b/superset/queries/schemas.py
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import List
from marshmallow import fields, Schema
@@ -73,7 +72,7 @@ class QuerySchema(Schema):
include_relationships = True
# pylint: disable=no-self-use
- def get_sql_tables(self, obj: Query) -> List[Table]:
+ def get_sql_tables(self, obj: Query) -> list[Table]:
return obj.sql_tables
diff --git a/superset/reports/commands/alert.py b/superset/reports/commands/alert.py
index c5b4709447..41163dc064 100644
--- a/superset/reports/commands/alert.py
+++ b/superset/reports/commands/alert.py
@@ -20,7 +20,7 @@ import json
import logging
from operator import eq, ge, gt, le, lt, ne
from timeit import default_timer
-from typing import Any, Optional
+from typing import Any
import numpy as np
import pandas as pd
@@ -54,7 +54,7 @@ OPERATOR_FUNCTIONS = {">=": ge, ">": gt, "<=": le, "<": lt, "==": eq, "!=": ne}
class AlertCommand(BaseCommand):
def __init__(self, report_schedule: ReportSchedule):
self._report_schedule = report_schedule
- self._result: Optional[float] = None
+ self._result: float | None = None
def run(self) -> bool:
"""
diff --git a/superset/reports/commands/base.py b/superset/reports/commands/base.py
index 4fee6a8824..598370576b 100644
--- a/superset/reports/commands/base.py
+++ b/superset/reports/commands/base.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict, List
+from typing import Any
from marshmallow import ValidationError
@@ -36,7 +36,7 @@ logger = logging.getLogger(__name__)
class BaseReportScheduleCommand(BaseCommand):
- _properties: Dict[str, Any]
+ _properties: dict[str, Any]
def run(self) -> Any:
pass
@@ -45,7 +45,7 @@ class BaseReportScheduleCommand(BaseCommand):
pass
def validate_chart_dashboard(
- self, exceptions: List[ValidationError], update: bool = False
+ self, exceptions: list[ValidationError], update: bool = False
) -> None:
"""Validate chart or dashboard relation"""
chart_id = self._properties.get("chart")
diff --git a/superset/reports/commands/bulk_delete.py b/superset/reports/commands/bulk_delete.py
index 28a39a2fb6..7d6e1ed791 100644
--- a/superset/reports/commands/bulk_delete.py
+++ b/superset/reports/commands/bulk_delete.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import List, Optional
+from typing import Optional
from superset import security_manager
from superset.commands.base import BaseCommand
@@ -33,9 +33,9 @@ logger = logging.getLogger(__name__)
class BulkDeleteReportScheduleCommand(BaseCommand):
- def __init__(self, model_ids: List[int]):
+ def __init__(self, model_ids: list[int]):
self._model_ids = model_ids
- self._models: Optional[List[ReportSchedule]] = None
+ self._models: Optional[list[ReportSchedule]] = None
def run(self) -> None:
self.validate()
diff --git a/superset/reports/commands/create.py b/superset/reports/commands/create.py
index 27626170d6..04cf6ef43f 100644
--- a/superset/reports/commands/create.py
+++ b/superset/reports/commands/create.py
@@ -16,7 +16,7 @@
# under the License.
import json
import logging
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from flask_babel import gettext as _
from marshmallow import ValidationError
@@ -46,7 +46,7 @@ logger = logging.getLogger(__name__)
class CreateReportScheduleCommand(CreateMixin, BaseReportScheduleCommand):
- def __init__(self, data: Dict[str, Any]):
+ def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
def run(self) -> ReportSchedule:
@@ -59,8 +59,8 @@ class CreateReportScheduleCommand(CreateMixin, BaseReportScheduleCommand):
return report_schedule
def validate(self) -> None:
- exceptions: List[ValidationError] = []
- owner_ids: Optional[List[int]] = self._properties.get("owners")
+ exceptions: list[ValidationError] = []
+ owner_ids: Optional[list[int]] = self._properties.get("owners")
name = self._properties.get("name", "")
report_type = self._properties.get("type")
creation_method = self._properties.get("creation_method")
@@ -119,7 +119,7 @@ class CreateReportScheduleCommand(CreateMixin, BaseReportScheduleCommand):
if exceptions:
raise ReportScheduleInvalidError(exceptions=exceptions)
- def _validate_report_extra(self, exceptions: List[ValidationError]) -> None:
+ def _validate_report_extra(self, exceptions: list[ValidationError]) -> None:
extra: Optional[ReportScheduleExtra] = self._properties.get("extra")
dashboard = self._properties.get("dashboard")
diff --git a/superset/reports/commands/exceptions.py b/superset/reports/commands/exceptions.py
index 22aff0727d..cba12e0786 100644
--- a/superset/reports/commands/exceptions.py
+++ b/superset/reports/commands/exceptions.py
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import List
from flask_babel import lazy_gettext as _
@@ -263,13 +262,13 @@ class ReportScheduleStateNotFoundError(CommandException):
class ReportScheduleSystemErrorsException(CommandException, SupersetErrorsException):
- errors: List[SupersetError] = []
+ errors: list[SupersetError] = []
message = _("Report schedule system error")
class ReportScheduleClientErrorsException(CommandException, SupersetErrorsException):
status = 400
- errors: List[SupersetError] = []
+ errors: list[SupersetError] = []
message = _("Report schedule client error")
diff --git a/superset/reports/commands/execute.py b/superset/reports/commands/execute.py
index f5f7bf4130..608b2564a2 100644
--- a/superset/reports/commands/execute.py
+++ b/superset/reports/commands/execute.py
@@ -17,7 +17,7 @@
import json
import logging
from datetime import datetime, timedelta
-from typing import Any, List, Optional, Union
+from typing import Any, Optional, Union
from uuid import UUID
import pandas as pd
@@ -80,7 +80,7 @@ logger = logging.getLogger(__name__)
class BaseReportState:
- current_states: List[ReportState] = []
+ current_states: list[ReportState] = []
initial: bool = False
def __init__(
@@ -195,7 +195,7 @@ class BaseReportState:
**kwargs,
)
- def _get_screenshots(self) -> List[bytes]:
+ def _get_screenshots(self) -> list[bytes]:
"""
Get chart or dashboard screenshots
:raises: ReportScheduleScreenshotFailedError
@@ -394,14 +394,14 @@ class BaseReportState:
def _send(
self,
notification_content: NotificationContent,
- recipients: List[ReportRecipients],
+ recipients: list[ReportRecipients],
) -> None:
"""
Sends a notification to all recipients
:raises: CommandException
"""
- notification_errors: List[SupersetError] = []
+ notification_errors: list[SupersetError] = []
for recipient in recipients:
notification = create_notification(recipient, notification_content)
try:
diff --git a/superset/reports/commands/update.py b/superset/reports/commands/update.py
index 0c4f18f1b8..5ca3ac849a 100644
--- a/superset/reports/commands/update.py
+++ b/superset/reports/commands/update.py
@@ -16,7 +16,7 @@
# under the License.
import json
import logging
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
@@ -42,7 +42,7 @@ logger = logging.getLogger(__name__)
class UpdateReportScheduleCommand(UpdateMixin, BaseReportScheduleCommand):
- def __init__(self, model_id: int, data: Dict[str, Any]):
+ def __init__(self, model_id: int, data: dict[str, Any]):
self._model_id = model_id
self._properties = data.copy()
self._model: Optional[ReportSchedule] = None
@@ -57,8 +57,8 @@ class UpdateReportScheduleCommand(UpdateMixin, BaseReportScheduleCommand):
return report_schedule
def validate(self) -> None:
- exceptions: List[ValidationError] = []
- owner_ids: Optional[List[int]] = self._properties.get("owners")
+ exceptions: list[ValidationError] = []
+ owner_ids: Optional[list[int]] = self._properties.get("owners")
report_type = self._properties.get("type", ReportScheduleType.ALERT)
name = self._properties.get("name", "")
diff --git a/superset/reports/dao.py b/superset/reports/dao.py
index be5ee8053c..64777e959a 100644
--- a/superset/reports/dao.py
+++ b/superset/reports/dao.py
@@ -17,7 +17,7 @@
import json
import logging
from datetime import datetime
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from flask_appbuilder import Model
from sqlalchemy.exc import SQLAlchemyError
@@ -47,7 +47,7 @@ class ReportScheduleDAO(BaseDAO):
base_filter = ReportScheduleFilter
@staticmethod
- def find_by_chart_id(chart_id: int) -> List[ReportSchedule]:
+ def find_by_chart_id(chart_id: int) -> list[ReportSchedule]:
return (
db.session.query(ReportSchedule)
.filter(ReportSchedule.chart_id == chart_id)
@@ -55,7 +55,7 @@ class ReportScheduleDAO(BaseDAO):
)
@staticmethod
- def find_by_chart_ids(chart_ids: List[int]) -> List[ReportSchedule]:
+ def find_by_chart_ids(chart_ids: list[int]) -> list[ReportSchedule]:
return (
db.session.query(ReportSchedule)
.filter(ReportSchedule.chart_id.in_(chart_ids))
@@ -63,7 +63,7 @@ class ReportScheduleDAO(BaseDAO):
)
@staticmethod
- def find_by_dashboard_id(dashboard_id: int) -> List[ReportSchedule]:
+ def find_by_dashboard_id(dashboard_id: int) -> list[ReportSchedule]:
return (
db.session.query(ReportSchedule)
.filter(ReportSchedule.dashboard_id == dashboard_id)
@@ -71,7 +71,7 @@ class ReportScheduleDAO(BaseDAO):
)
@staticmethod
- def find_by_dashboard_ids(dashboard_ids: List[int]) -> List[ReportSchedule]:
+ def find_by_dashboard_ids(dashboard_ids: list[int]) -> list[ReportSchedule]:
return (
db.session.query(ReportSchedule)
.filter(ReportSchedule.dashboard_id.in_(dashboard_ids))
@@ -79,7 +79,7 @@ class ReportScheduleDAO(BaseDAO):
)
@staticmethod
- def find_by_database_id(database_id: int) -> List[ReportSchedule]:
+ def find_by_database_id(database_id: int) -> list[ReportSchedule]:
return (
db.session.query(ReportSchedule)
.filter(ReportSchedule.database_id == database_id)
@@ -87,7 +87,7 @@ class ReportScheduleDAO(BaseDAO):
)
@staticmethod
- def find_by_database_ids(database_ids: List[int]) -> List[ReportSchedule]:
+ def find_by_database_ids(database_ids: list[int]) -> list[ReportSchedule]:
return (
db.session.query(ReportSchedule)
.filter(ReportSchedule.database_id.in_(database_ids))
@@ -96,7 +96,7 @@ class ReportScheduleDAO(BaseDAO):
@staticmethod
def bulk_delete(
- models: Optional[List[ReportSchedule]], commit: bool = True
+ models: Optional[list[ReportSchedule]], commit: bool = True
) -> None:
item_ids = [model.id for model in models] if models else []
try:
@@ -156,7 +156,7 @@ class ReportScheduleDAO(BaseDAO):
return found_id is None or found_id == expect_id
@classmethod
- def create(cls, properties: Dict[str, Any], commit: bool = True) -> ReportSchedule:
+ def create(cls, properties: dict[str, Any], commit: bool = True) -> ReportSchedule:
"""
create a report schedule and nested recipients
:raises: DAOCreateFailedError
@@ -187,7 +187,7 @@ class ReportScheduleDAO(BaseDAO):
@classmethod
def update(
- cls, model: Model, properties: Dict[str, Any], commit: bool = True
+ cls, model: Model, properties: dict[str, Any], commit: bool = True
) -> ReportSchedule:
"""
create a report schedule and nested recipients
@@ -219,7 +219,7 @@ class ReportScheduleDAO(BaseDAO):
raise DAOCreateFailedError(str(ex)) from ex
@staticmethod
- def find_active(session: Optional[Session] = None) -> List[ReportSchedule]:
+ def find_active(session: Optional[Session] = None) -> list[ReportSchedule]:
"""
Find all active reports. If session is passed it will be used instead of the
default `db.session`, this is useful when on a celery worker session context
diff --git a/superset/reports/filters.py b/superset/reports/filters.py
index 5fb87e0563..a03238b640 100644
--- a/superset/reports/filters.py
+++ b/superset/reports/filters.py
@@ -52,6 +52,6 @@ class ReportScheduleAllTextFilter(BaseFilter): # pylint: disable=too-few-public
or_(
ReportSchedule.name.ilike(ilike_value),
ReportSchedule.description.ilike(ilike_value),
- ReportSchedule.sql.ilike((ilike_value)),
+ ReportSchedule.sql.ilike(ilike_value),
)
)
diff --git a/superset/reports/logs/api.py b/superset/reports/logs/api.py
index 8ad8455cc7..f0c272caee 100644
--- a/superset/reports/logs/api.py
+++ b/superset/reports/logs/api.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict, Optional
+from typing import Any, Optional
from flask import Response
from flask_appbuilder.api import expose, permission_name, protect, rison, safe
@@ -83,7 +83,7 @@ class ReportExecutionLogRestApi(BaseSupersetModelRestApi):
@staticmethod
def _apply_layered_relation_to_rison( # pylint: disable=invalid-name
- layer_id: int, rison_parameters: Dict[str, Any]
+ layer_id: int, rison_parameters: dict[str, Any]
) -> None:
if "filters" not in rison_parameters:
rison_parameters["filters"] = []
diff --git a/superset/reports/notifications/__init__.py b/superset/reports/notifications/__init__.py
index c466f59abd..f2ac40bb46 100644
--- a/superset/reports/notifications/__init__.py
+++ b/superset/reports/notifications/__init__.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
diff --git a/superset/reports/notifications/base.py b/superset/reports/notifications/base.py
index 6eb2405d0f..640b326fc5 100644
--- a/superset/reports/notifications/base.py
+++ b/superset/reports/notifications/base.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -16,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from dataclasses import dataclass
-from typing import Any, List, Optional, Type
+from typing import Any, Optional
import pandas as pd
@@ -29,7 +28,7 @@ class NotificationContent:
name: str
header_data: HeaderDataType # this is optional to account for error states
csv: Optional[bytes] = None # bytes for csv file
- screenshots: Optional[List[bytes]] = None # bytes for a list of screenshots
+ screenshots: Optional[list[bytes]] = None # bytes for a list of screenshots
text: Optional[str] = None
description: Optional[str] = ""
url: Optional[str] = None # url to chart/dashboard for this screenshot
@@ -44,7 +43,7 @@ class BaseNotification: # pylint: disable=too-few-public-methods
notification type
"""
- plugins: List[Type["BaseNotification"]] = []
+ plugins: list[type["BaseNotification"]] = []
type: Optional[ReportRecipientType] = None
"""
Child classes set their notification type ex: `type = "email"` this string will be
diff --git a/superset/reports/notifications/email.py b/superset/reports/notifications/email.py
index 10a76e7573..1b9e4ade72 100644
--- a/superset/reports/notifications/email.py
+++ b/superset/reports/notifications/email.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -20,7 +19,7 @@ import logging
import textwrap
from dataclasses import dataclass
from email.utils import make_msgid, parseaddr
-from typing import Any, Dict, Optional
+from typing import Any, Optional
import nh3
from flask_babel import gettext as __
@@ -69,8 +68,8 @@ ALLOWED_ATTRIBUTES = {
class EmailContent:
body: str
header_data: Optional[HeaderDataType] = None
- data: Optional[Dict[str, Any]] = None
- images: Optional[Dict[str, bytes]] = None
+ data: Optional[dict[str, Any]] = None
+ images: Optional[dict[str, bytes]] = None
class EmailNotification(BaseNotification): # pylint: disable=too-few-public-methods
diff --git a/superset/reports/notifications/slack.py b/superset/reports/notifications/slack.py
index b89a700ef9..4c3f2ee419 100644
--- a/superset/reports/notifications/slack.py
+++ b/superset/reports/notifications/slack.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -17,8 +16,9 @@
# under the License.
import json
import logging
+from collections.abc import Sequence
from io import IOBase
-from typing import Sequence, Union
+from typing import Union
import backoff
from flask_babel import gettext as __
diff --git a/superset/reports/schemas.py b/superset/reports/schemas.py
index a45ee4cc38..83dea02f8f 100644
--- a/superset/reports/schemas.py
+++ b/superset/reports/schemas.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, Union
+from typing import Any, Union
from croniter import croniter
from flask_babel import gettext as _
@@ -212,7 +212,7 @@ class ReportSchedulePostSchema(Schema):
@validates_schema
def validate_report_references( # pylint: disable=unused-argument,no-self-use
- self, data: Dict[str, Any], **kwargs: Any
+ self, data: dict[str, Any], **kwargs: Any
) -> None:
if data["type"] == ReportScheduleType.REPORT:
if "database" in data:
diff --git a/superset/result_set.py b/superset/result_set.py
index 9aa06bba09..f707b91dce 100644
--- a/superset/result_set.py
+++ b/superset/result_set.py
@@ -19,7 +19,7 @@
import datetime
import json
import logging
-from typing import Any, Dict, List, Optional, Tuple, Type
+from typing import Any, Optional
import numpy as np
import pandas as pd
@@ -33,7 +33,7 @@ from superset.utils import core as utils
logger = logging.getLogger(__name__)
-def dedup(l: List[str], suffix: str = "__", case_sensitive: bool = True) -> List[str]:
+def dedup(l: list[str], suffix: str = "__", case_sensitive: bool = True) -> list[str]:
"""De-duplicates a list of string by suffixing a counter
Always returns the same number of entries as provided, and always returns
@@ -46,8 +46,8 @@ def dedup(l: List[str], suffix: str = "__", case_sensitive: bool = True) -> List
)
foo,bar,bar__1,bar__2,Bar__3
"""
- new_l: List[str] = []
- seen: Dict[str, int] = {}
+ new_l: list[str] = []
+ seen: dict[str, int] = {}
for item in l:
s_fixed_case = item if case_sensitive else item.lower()
if s_fixed_case in seen:
@@ -104,14 +104,14 @@ class SupersetResultSet:
self,
data: DbapiResult,
cursor_description: DbapiDescription,
- db_engine_spec: Type[BaseEngineSpec],
+ db_engine_spec: type[BaseEngineSpec],
):
self.db_engine_spec = db_engine_spec
data = data or []
- column_names: List[str] = []
- pa_data: List[pa.Array] = []
- deduped_cursor_desc: List[Tuple[Any, ...]] = []
- numpy_dtype: List[Tuple[str, ...]] = []
+ column_names: list[str] = []
+ pa_data: list[pa.Array] = []
+ deduped_cursor_desc: list[tuple[Any, ...]] = []
+ numpy_dtype: list[tuple[str, ...]] = []
stringified_arr: NDArray[Any]
if cursor_description:
@@ -181,7 +181,7 @@ class SupersetResultSet:
column_names = []
self.table = pa.Table.from_arrays(pa_data, names=column_names)
- self._type_dict: Dict[str, Any] = {}
+ self._type_dict: dict[str, Any] = {}
try:
# The driver may not be passing a cursor.description
self._type_dict = {
@@ -245,7 +245,7 @@ class SupersetResultSet:
return self.table.num_rows
@property
- def columns(self) -> List[ResultSetColumnType]:
+ def columns(self) -> list[ResultSetColumnType]:
if not self.table.column_names:
return []
diff --git a/superset/row_level_security/commands/bulk_delete.py b/superset/row_level_security/commands/bulk_delete.py
index a6d4625a91..a3703346cc 100644
--- a/superset/row_level_security/commands/bulk_delete.py
+++ b/superset/row_level_security/commands/bulk_delete.py
@@ -16,7 +16,6 @@
# under the License.
import logging
-from typing import List
from superset.commands.base import BaseCommand
from superset.dao.exceptions import DAODeleteFailedError
@@ -31,9 +30,9 @@ logger = logging.getLogger(__name__)
class BulkDeleteRLSRuleCommand(BaseCommand):
- def __init__(self, model_ids: List[int]):
+ def __init__(self, model_ids: list[int]):
self._model_ids = model_ids
- self._models: List[ReportSchedule] = []
+ self._models: list[ReportSchedule] = []
def run(self) -> None:
self.validate()
diff --git a/superset/row_level_security/commands/create.py b/superset/row_level_security/commands/create.py
index 0c348e10c0..5552feeda0 100644
--- a/superset/row_level_security/commands/create.py
+++ b/superset/row_level_security/commands/create.py
@@ -17,7 +17,7 @@
import logging
-from typing import Any, Dict
+from typing import Any
from superset.commands.base import BaseCommand
from superset.commands.exceptions import DatasourceNotFoundValidationError
@@ -31,7 +31,7 @@ logger = logging.getLogger(__name__)
class CreateRLSRuleCommand(BaseCommand):
- def __init__(self, data: Dict[str, Any]):
+ def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
self._tables = self._properties.get("tables", [])
self._roles = self._properties.get("roles", [])
diff --git a/superset/row_level_security/commands/update.py b/superset/row_level_security/commands/update.py
index 8c276ee2c4..a206fc3a39 100644
--- a/superset/row_level_security/commands/update.py
+++ b/superset/row_level_security/commands/update.py
@@ -17,7 +17,7 @@
import logging
-from typing import Any, Dict, Optional
+from typing import Any, Optional
from superset.commands.base import BaseCommand
from superset.commands.exceptions import DatasourceNotFoundValidationError
@@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
class UpdateRLSRuleCommand(BaseCommand):
- def __init__(self, model_id: int, data: Dict[str, Any]):
+ def __init__(self, model_id: int, data: dict[str, Any]):
self._model_id = model_id
self._properties = data.copy()
self._tables = self._properties.get("tables", [])
diff --git a/superset/security/api.py b/superset/security/api.py
index 7aac6ae22b..aff536519d 100644
--- a/superset/security/api.py
+++ b/superset/security/api.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict
+from typing import Any
from flask import request, Response
from flask_appbuilder import expose
@@ -56,8 +56,8 @@ class ResourceSchema(PermissiveSchema):
@post_load
def convert_enum_to_value( # pylint: disable=no-self-use
- self, data: Dict[str, Any], **kwargs: Any # pylint: disable=unused-argument
- ) -> Dict[str, Any]:
+ self, data: dict[str, Any], **kwargs: Any # pylint: disable=unused-argument
+ ) -> dict[str, Any]:
# we don't care about the enum, we want the value inside
data["type"] = data["type"].value
return data
diff --git a/superset/security/guest_token.py b/superset/security/guest_token.py
index 44b59c1dbb..a8dc2e3393 100644
--- a/superset/security/guest_token.py
+++ b/superset/security/guest_token.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from enum import Enum
-from typing import List, Optional, TypedDict, Union
+from typing import Optional, TypedDict, Union
from flask_appbuilder.security.sqla.models import Role
from flask_login import AnonymousUserMixin
@@ -36,7 +36,7 @@ class GuestTokenResource(TypedDict):
id: Union[str, int]
-GuestTokenResources = List[GuestTokenResource]
+GuestTokenResources = list[GuestTokenResource]
class GuestTokenRlsRule(TypedDict):
@@ -49,7 +49,7 @@ class GuestToken(TypedDict):
exp: float
user: GuestTokenUser
resources: GuestTokenResources
- rls_rules: List[GuestTokenRlsRule]
+ rls_rules: list[GuestTokenRlsRule]
class GuestUser(AnonymousUserMixin):
@@ -76,7 +76,7 @@ class GuestUser(AnonymousUserMixin):
"""
return False
- def __init__(self, token: GuestToken, roles: List[Role]):
+ def __init__(self, token: GuestToken, roles: list[Role]):
user = token["user"]
self.guest_token = token
self.username = user.get("username", "guest_user")
diff --git a/superset/security/manager.py b/superset/security/manager.py
index db6e631d91..94a731a3ff 100644
--- a/superset/security/manager.py
+++ b/superset/security/manager.py
@@ -20,18 +20,7 @@ import logging
import re
import time
from collections import defaultdict
-from typing import (
- Any,
- Callable,
- cast,
- Dict,
- List,
- NamedTuple,
- Optional,
- Set,
- TYPE_CHECKING,
- Union,
-)
+from typing import Any, Callable, cast, NamedTuple, Optional, TYPE_CHECKING, Union
from flask import current_app, Flask, g, Request
from flask_appbuilder import Model
@@ -479,7 +468,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
)
def get_table_access_error_msg( # pylint: disable=no-self-use
- self, tables: Set["Table"]
+ self, tables: set["Table"]
) -> str:
"""
Return the error message for the denied SQL tables.
@@ -492,7 +481,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
return f"""You need access to the following tables: {", ".join(quoted_tables)},
`all_database_access` or `all_datasource_access` permission"""
- def get_table_access_error_object(self, tables: Set["Table"]) -> SupersetError:
+ def get_table_access_error_object(self, tables: set["Table"]) -> SupersetError:
"""
Return the error object for the denied SQL tables.
@@ -510,7 +499,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
)
def get_table_access_link( # pylint: disable=unused-argument,no-self-use
- self, tables: Set["Table"]
+ self, tables: set["Table"]
) -> Optional[str]:
"""
Return the access link for the denied SQL tables.
@@ -521,7 +510,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
return current_app.config.get("PERMISSION_INSTRUCTIONS_LINK")
- def get_user_datasources(self) -> List["BaseDatasource"]:
+ def get_user_datasources(self) -> list["BaseDatasource"]:
"""
Collect datasources which the user has explicit permissions to.
@@ -542,7 +531,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
# group all datasources by database
session = self.get_session
all_datasources = SqlaTable.get_all_datasources(session)
- datasources_by_database: Dict["Database", Set["SqlaTable"]] = defaultdict(set)
+ datasources_by_database: dict["Database", set["SqlaTable"]] = defaultdict(set)
for datasource in all_datasources:
datasources_by_database[datasource.database].add(datasource)
@@ -569,7 +558,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
return True
- def user_view_menu_names(self, permission_name: str) -> Set[str]:
+ def user_view_menu_names(self, permission_name: str) -> set[str]:
base_query = (
self.get_session.query(self.viewmenu_model.name)
.join(self.permissionview_model)
@@ -599,7 +588,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
return {s.name for s in view_menu_names}
return set()
- def get_accessible_databases(self) -> List[int]:
+ def get_accessible_databases(self) -> list[int]:
"""
Return the list of databases accessible by the user.
@@ -613,8 +602,8 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
]
def get_schemas_accessible_by_user(
- self, database: "Database", schemas: List[str], hierarchical: bool = True
- ) -> List[str]:
+ self, database: "Database", schemas: list[str], hierarchical: bool = True
+ ) -> list[str]:
"""
Return the list of SQL schemas accessible by the user.
@@ -654,9 +643,9 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
def get_datasources_accessible_by_user( # pylint: disable=invalid-name
self,
database: "Database",
- datasource_names: List[DatasourceName],
+ datasource_names: list[DatasourceName],
schema: Optional[str] = None,
- ) -> List[DatasourceName]:
+ ) -> list[DatasourceName]:
"""
Return the list of SQL tables accessible by the user.
@@ -802,7 +791,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
self.get_session.commit()
self.clean_perms()
- def _get_pvms_from_builtin_role(self, role_name: str) -> List[PermissionView]:
+ def _get_pvms_from_builtin_role(self, role_name: str) -> list[PermissionView]:
"""
Gets a list of model PermissionView permissions inferred from a builtin role
definition
@@ -821,7 +810,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
role_from_permissions.append(pvm)
return role_from_permissions
- def find_roles_by_id(self, role_ids: List[int]) -> List[Role]:
+ def find_roles_by_id(self, role_ids: list[int]) -> list[Role]:
"""
Find a List of models by a list of ids, if defined applies `base_filter`
"""
@@ -1179,7 +1168,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
connection: Connection,
old_database_name: str,
target: "Database",
- ) -> List[ViewMenu]:
+ ) -> list[ViewMenu]:
"""
Helper method that Updates all datasource access permission
when a database name changes.
@@ -1205,7 +1194,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
.filter(SqlaTable.database_id == target.id)
.all()
)
- updated_view_menus: List[ViewMenu] = []
+ updated_view_menus: list[ViewMenu] = []
for dataset in datasets:
old_dataset_vm_name = self.get_dataset_perm(
dataset.id, dataset.table_name, old_database_name
@@ -1768,7 +1757,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
"""
@staticmethod
- def get_exclude_users_from_lists() -> List[str]:
+ def get_exclude_users_from_lists() -> list[str]:
"""
Override to dynamically identify a list of usernames to exclude from
all UI dropdown lists, owners, created_by filters etc...
@@ -1896,7 +1885,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
def get_anonymous_user(self) -> User: # pylint: disable=no-self-use
return AnonymousUserMixin()
- def get_user_roles(self, user: Optional[User] = None) -> List[Role]:
+ def get_user_roles(self, user: Optional[User] = None) -> list[Role]:
if not user:
user = g.user
if user.is_anonymous:
@@ -1906,7 +1895,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
def get_guest_rls_filters(
self, dataset: "BaseDatasource"
- ) -> List[GuestTokenRlsRule]:
+ ) -> list[GuestTokenRlsRule]:
"""
Retrieves the row level security filters for the current user and the dataset,
if the user is authenticated with a guest token.
@@ -1922,7 +1911,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
]
return []
- def get_rls_filters(self, table: "BaseDatasource") -> List[SqlaQuery]:
+ def get_rls_filters(self, table: "BaseDatasource") -> list[SqlaQuery]:
"""
Retrieves the appropriate row level security filters for the current user and
the passed table.
@@ -1990,7 +1979,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
)
return query.all()
- def get_rls_ids(self, table: "BaseDatasource") -> List[int]:
+ def get_rls_ids(self, table: "BaseDatasource") -> list[int]:
"""
Retrieves the appropriate row level security filters IDs for the current user
and the passed table.
@@ -2002,10 +1991,10 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
ids.sort() # Combinations rather than permutations
return ids
- def get_guest_rls_filters_str(self, table: "BaseDatasource") -> List[str]:
+ def get_guest_rls_filters_str(self, table: "BaseDatasource") -> list[str]:
return [f.get("clause", "") for f in self.get_guest_rls_filters(table)]
- def get_rls_cache_key(self, datasource: "BaseDatasource") -> List[str]:
+ def get_rls_cache_key(self, datasource: "BaseDatasource") -> list[str]:
rls_ids = []
if datasource.is_rls_supported:
rls_ids = self.get_rls_ids(datasource)
@@ -2122,7 +2111,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
self,
user: GuestTokenUser,
resources: GuestTokenResources,
- rls: List[GuestTokenRlsRule],
+ rls: list[GuestTokenRlsRule],
) -> bytes:
secret = current_app.config["GUEST_TOKEN_JWT_SECRET"]
algo = current_app.config["GUEST_TOKEN_JWT_ALGO"]
@@ -2183,7 +2172,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
roles=[self.find_role(current_app.config["GUEST_ROLE_NAME"])],
)
- def parse_jwt_guest_token(self, raw_token: str) -> Dict[str, Any]:
+ def parse_jwt_guest_token(self, raw_token: str) -> dict[str, Any]:
"""
Parses a guest token. Raises an error if the jwt fails standard claims checks.
:param raw_token: the token gotten from the request
diff --git a/superset/sql_lab.py b/superset/sql_lab.py
index 9ea881fadf..678da79fa7 100644
--- a/superset/sql_lab.py
+++ b/superset/sql_lab.py
@@ -20,7 +20,7 @@ import uuid
from contextlib import closing
from datetime import datetime
from sys import getsizeof
-from typing import Any, cast, Dict, List, Optional, Tuple, Union
+from typing import Any, cast, Optional, Union
import backoff
import msgpack
@@ -88,9 +88,9 @@ def handle_query_error(
ex: Exception,
query: Query,
session: Session,
- payload: Optional[Dict[str, Any]] = None,
+ payload: Optional[dict[str, Any]] = None,
prefix_message: str = "",
-) -> Dict[str, Any]:
+) -> dict[str, Any]:
"""Local method handling error while processing the SQL"""
payload = payload or {}
msg = f"{prefix_message} {str(ex)}".strip()
@@ -122,7 +122,7 @@ def handle_query_error(
return payload
-def get_query_backoff_handler(details: Dict[Any, Any]) -> None:
+def get_query_backoff_handler(details: dict[Any, Any]) -> None:
query_id = details["kwargs"]["query_id"]
logger.error(
"Query with id `%s` could not be retrieved", str(query_id), exc_info=True
@@ -168,8 +168,8 @@ def get_sql_results( # pylint: disable=too-many-arguments
username: Optional[str] = None,
start_time: Optional[float] = None,
expand_data: bool = False,
- log_params: Optional[Dict[str, Any]] = None,
-) -> Optional[Dict[str, Any]]:
+ log_params: Optional[dict[str, Any]] = None,
+) -> Optional[dict[str, Any]]:
"""Executes the sql query returns the results."""
with session_scope(not ctask.request.called_directly) as session:
with override_user(security_manager.find_user(username)):
@@ -196,7 +196,7 @@ def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-statem
query: Query,
session: Session,
cursor: Any,
- log_params: Optional[Dict[str, Any]],
+ log_params: Optional[dict[str, Any]],
apply_ctas: bool = False,
) -> SupersetResultSet:
"""Executes a single SQL statement"""
@@ -332,7 +332,7 @@ def apply_limit_if_exists(
def _serialize_payload(
- payload: Dict[Any, Any], use_msgpack: Optional[bool] = False
+ payload: dict[Any, Any], use_msgpack: Optional[bool] = False
) -> Union[bytes, str]:
logger.debug("Serializing to msgpack: %r", use_msgpack)
if use_msgpack:
@@ -346,10 +346,10 @@ def _serialize_and_expand_data(
db_engine_spec: BaseEngineSpec,
use_msgpack: Optional[bool] = False,
expand_data: bool = False,
-) -> Tuple[Union[bytes, str], List[Any], List[Any], List[Any]]:
+) -> tuple[Union[bytes, str], list[Any], list[Any], list[Any]]:
selected_columns = result_set.columns
- all_columns: List[Any]
- expanded_columns: List[Any]
+ all_columns: list[Any]
+ expanded_columns: list[Any]
if use_msgpack:
with stats_timing(
@@ -383,15 +383,15 @@ def execute_sql_statements(
session: Session,
start_time: Optional[float],
expand_data: bool,
- log_params: Optional[Dict[str, Any]],
-) -> Optional[Dict[str, Any]]:
+ log_params: Optional[dict[str, Any]],
+) -> Optional[dict[str, Any]]:
"""Executes the sql query returns the results."""
if store_results and start_time:
# only asynchronous queries
stats_logger.timing("sqllab.query.time_pending", now_as_float() - start_time)
query = get_query(query_id, session)
- payload: Dict[str, Any] = dict(query_id=query_id)
+ payload: dict[str, Any] = dict(query_id=query_id)
database = query.database
db_engine_spec = database.db_engine_spec
db_engine_spec.patch()
diff --git a/superset/sql_parse.py b/superset/sql_parse.py
index 034daeb7af..974d7eacd4 100644
--- a/superset/sql_parse.py
+++ b/superset/sql_parse.py
@@ -16,9 +16,10 @@
# under the License.
import logging
import re
+from collections.abc import Iterator
from dataclasses import dataclass
from enum import Enum
-from typing import Any, cast, Iterator, List, Optional, Set, Tuple
+from typing import Any, cast, Optional
from urllib import parse
import sqlparse
@@ -97,7 +98,7 @@ def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
def extract_top_from_query(
- statement: TokenList, top_keywords: Set[str]
+ statement: TokenList, top_keywords: set[str]
) -> Optional[int]:
"""
Extract top clause value from SQL statement.
@@ -122,7 +123,7 @@ def extract_top_from_query(
return top
-def get_cte_remainder_query(sql: str) -> Tuple[Optional[str], str]:
+def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
"""
parse the SQL and return the CTE and rest of the block to the caller
@@ -192,8 +193,8 @@ class ParsedQuery:
sql_statement = sqlparse.format(sql_statement, strip_comments=True)
self.sql: str = sql_statement
- self._tables: Set[Table] = set()
- self._alias_names: Set[str] = set()
+ self._tables: set[Table] = set()
+ self._alias_names: set[str] = set()
self._limit: Optional[int] = None
logger.debug("Parsing with sqlparse statement: %s", self.sql)
@@ -202,7 +203,7 @@ class ParsedQuery:
self._limit = _extract_limit_from_query(statement)
@property
- def tables(self) -> Set[Table]:
+ def tables(self) -> set[Table]:
if not self._tables:
for statement in self._parsed:
self._extract_from_token(statement)
@@ -282,7 +283,7 @@ class ParsedQuery:
def strip_comments(self) -> str:
return sqlparse.format(self.stripped(), strip_comments=True)
- def get_statements(self) -> List[str]:
+ def get_statements(self) -> list[str]:
"""Returns a list of SQL statements as strings, stripped"""
statements = []
for statement in self._parsed:
@@ -737,7 +738,7 @@ RE_JINJA_BLOCK = re.compile(r"\{[%#][^\{\}%#]+[%#]\}")
def extract_table_references(
sql_text: str, sqla_dialect: str, show_warning: bool = True
-) -> Set["Table"]:
+) -> set["Table"]:
"""
Return all the dependencies from a SQL sql_text.
"""
diff --git a/superset/sql_validators/__init__.py b/superset/sql_validators/__init__.py
index c448f696a1..ad048a86a5 100644
--- a/superset/sql_validators/__init__.py
+++ b/superset/sql_validators/__init__.py
@@ -14,13 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Optional, Type
+from typing import Optional
from . import base, postgres, presto_db
from .base import SQLValidationAnnotation
-def get_validator_by_name(name: str) -> Optional[Type[base.BaseSQLValidator]]:
+def get_validator_by_name(name: str) -> Optional[type[base.BaseSQLValidator]]:
return {
"PrestoDBSQLValidator": presto_db.PrestoDBSQLValidator,
"PostgreSQLValidator": postgres.PostgreSQLValidator,
diff --git a/superset/sql_validators/base.py b/superset/sql_validators/base.py
index de29a96e8e..8344fc9264 100644
--- a/superset/sql_validators/base.py
+++ b/superset/sql_validators/base.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from superset.models.core import Database
@@ -34,7 +34,7 @@ class SQLValidationAnnotation: # pylint: disable=too-few-public-methods
self.start_column = start_column
self.end_column = end_column
- def to_dict(self) -> Dict[str, Any]:
+ def to_dict(self) -> dict[str, Any]:
"""Return a dictionary representation of this annotation"""
return {
"line_number": self.line_number,
@@ -53,6 +53,6 @@ class BaseSQLValidator: # pylint: disable=too-few-public-methods
@classmethod
def validate(
cls, sql: str, schema: Optional[str], database: Database
- ) -> List[SQLValidationAnnotation]:
+ ) -> list[SQLValidationAnnotation]:
"""Check that the given SQL querystring is valid for the given engine"""
raise NotImplementedError
diff --git a/superset/sql_validators/postgres.py b/superset/sql_validators/postgres.py
index f62be39f03..60c15ca034 100644
--- a/superset/sql_validators/postgres.py
+++ b/superset/sql_validators/postgres.py
@@ -16,7 +16,7 @@
# under the License.
import re
-from typing import List, Optional
+from typing import Optional
from pgsanity.pgsanity import check_string
@@ -32,8 +32,8 @@ class PostgreSQLValidator(BaseSQLValidator): # pylint: disable=too-few-public-m
@classmethod
def validate(
cls, sql: str, schema: Optional[str], database: Database
- ) -> List[SQLValidationAnnotation]:
- annotations: List[SQLValidationAnnotation] = []
+ ) -> list[SQLValidationAnnotation]:
+ annotations: list[SQLValidationAnnotation] = []
valid, error = check_string(sql, add_semicolon=True)
if valid:
return annotations
diff --git a/superset/sql_validators/presto_db.py b/superset/sql_validators/presto_db.py
index c5ecf4c96e..9d3e7641a6 100644
--- a/superset/sql_validators/presto_db.py
+++ b/superset/sql_validators/presto_db.py
@@ -18,7 +18,7 @@
import logging
import time
from contextlib import closing
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from superset import app, security_manager
from superset.models.core import Database
@@ -109,7 +109,7 @@ class PrestoDBSQLValidator(BaseSQLValidator):
raise PrestoSQLValidationError(
"The pyhive presto client returned an unhandled " "database error."
) from db_error
- error_args: Dict[str, Any] = db_error.args[0]
+ error_args: dict[str, Any] = db_error.args[0]
# Confirm the two fields we need to be able to present an annotation
# are present in the error response -- a message, and a location.
@@ -148,7 +148,7 @@ class PrestoDBSQLValidator(BaseSQLValidator):
@classmethod
def validate(
cls, sql: str, schema: Optional[str], database: Database
- ) -> List[SQLValidationAnnotation]:
+ ) -> list[SQLValidationAnnotation]:
"""
Presto supports query-validation queries by running them with a
prepended explain.
@@ -167,7 +167,7 @@ class PrestoDBSQLValidator(BaseSQLValidator):
) as engine:
# Sharing a single connection and cursor across the
# execution of all statements (if many)
- annotations: List[SQLValidationAnnotation] = []
+ annotations: list[SQLValidationAnnotation] = []
with closing(engine.raw_connection()) as conn:
cursor = conn.cursor()
for statement in parsed_query.get_statements():
diff --git a/superset/sqllab/api.py b/superset/sqllab/api.py
index 3c24bf1c26..35d110d8fc 100644
--- a/superset/sqllab/api.py
+++ b/superset/sqllab/api.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, cast, Dict, Optional
+from typing import Any, cast, Optional
from urllib import parse
import simplejson as json
@@ -326,7 +326,7 @@ class SqlLabRestApi(BaseSupersetApi):
@staticmethod
def _create_sql_json_command(
- execution_context: SqlJsonExecutionContext, log_params: Optional[Dict[str, Any]]
+ execution_context: SqlJsonExecutionContext, log_params: Optional[dict[str, Any]]
) -> ExecuteSqlCommand:
query_dao = QueryDAO()
sql_json_executor = SqlLabRestApi._create_sql_json_executor(
diff --git a/superset/sqllab/commands/estimate.py b/superset/sqllab/commands/estimate.py
index 2b8c5814b9..bf1d6c4fa5 100644
--- a/superset/sqllab/commands/estimate.py
+++ b/superset/sqllab/commands/estimate.py
@@ -17,7 +17,7 @@
from __future__ import annotations
import logging
-from typing import Any, Dict, List
+from typing import Any
from flask_babel import gettext as __
@@ -40,7 +40,7 @@ logger = logging.getLogger(__name__)
class QueryEstimationCommand(BaseCommand):
_database_id: int
_sql: str
- _template_params: Dict[str, Any]
+ _template_params: dict[str, Any]
_schema: str
_database: Database
@@ -64,7 +64,7 @@ class QueryEstimationCommand(BaseCommand):
def run(
self,
- ) -> List[Dict[str, Any]]:
+ ) -> list[dict[str, Any]]:
self.validate()
sql = self._sql
@@ -96,7 +96,7 @@ class QueryEstimationCommand(BaseCommand):
) from ex
spec = self._database.db_engine_spec
- query_cost_formatters: Dict[str, Any] = app.config[
+ query_cost_formatters: dict[str, Any] = app.config[
"QUERY_COST_FORMATTERS_BY_ENGINE"
]
query_cost_formatter = query_cost_formatters.get(
diff --git a/superset/sqllab/commands/execute.py b/superset/sqllab/commands/execute.py
index 97c8514d5d..09b0769ce2 100644
--- a/superset/sqllab/commands/execute.py
+++ b/superset/sqllab/commands/execute.py
@@ -19,7 +19,7 @@ from __future__ import annotations
import copy
import logging
-from typing import Any, Dict, Optional, TYPE_CHECKING
+from typing import Any, TYPE_CHECKING
from flask_babel import gettext as __
@@ -51,7 +51,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-CommandResult = Dict[str, Any]
+CommandResult = dict[str, Any]
class ExecuteSqlCommand(BaseCommand):
@@ -63,7 +63,7 @@ class ExecuteSqlCommand(BaseCommand):
_sql_json_executor: SqlJsonExecutor
_execution_context_convertor: ExecutionContextConvertor
_sqllab_ctas_no_limit: bool
- _log_params: Optional[Dict[str, Any]] = None
+ _log_params: dict[str, Any] | None = None
def __init__(
self,
@@ -75,7 +75,7 @@ class ExecuteSqlCommand(BaseCommand):
sql_json_executor: SqlJsonExecutor,
execution_context_convertor: ExecutionContextConvertor,
sqllab_ctas_no_limit_flag: bool,
- log_params: Optional[Dict[str, Any]] = None,
+ log_params: dict[str, Any] | None = None,
) -> None:
self._execution_context = execution_context
self._query_dao = query_dao
@@ -122,7 +122,7 @@ class ExecuteSqlCommand(BaseCommand):
except Exception as ex:
raise SqlLabException(self._execution_context, exception=ex) from ex
- def _try_get_existing_query(self) -> Optional[Query]:
+ def _try_get_existing_query(self) -> Query | None:
return self._query_dao.find_one_or_none(
client_id=self._execution_context.client_id,
user_id=self._execution_context.user_id,
@@ -130,7 +130,7 @@ class ExecuteSqlCommand(BaseCommand):
)
@classmethod
- def is_query_handled(cls, query: Optional[Query]) -> bool:
+ def is_query_handled(cls, query: Query | None) -> bool:
return query is not None and query.status in [
QueryStatus.RUNNING,
QueryStatus.PENDING,
@@ -166,7 +166,7 @@ class ExecuteSqlCommand(BaseCommand):
return mydb
@classmethod
- def _validate_query_db(cls, database: Optional[Database]) -> None:
+ def _validate_query_db(cls, database: Database | None) -> None:
if not database:
raise SupersetGenericErrorException(
__(
diff --git a/superset/sqllab/commands/export.py b/superset/sqllab/commands/export.py
index e9559be3b9..1b9b0e0344 100644
--- a/superset/sqllab/commands/export.py
+++ b/superset/sqllab/commands/export.py
@@ -17,7 +17,7 @@
from __future__ import annotations
import logging
-from typing import Any, cast, List, TypedDict
+from typing import Any, cast, TypedDict
import pandas as pd
from flask_babel import gettext as __
@@ -40,7 +40,7 @@ logger = logging.getLogger(__name__)
class SqlExportResult(TypedDict):
query: Query
count: int
- data: List[Any]
+ data: list[Any]
class SqlResultExportCommand(BaseCommand):
diff --git a/superset/sqllab/commands/results.py b/superset/sqllab/commands/results.py
index d6c415a09f..83c8aa8f6a 100644
--- a/superset/sqllab/commands/results.py
+++ b/superset/sqllab/commands/results.py
@@ -17,7 +17,7 @@
from __future__ import annotations
import logging
-from typing import Any, cast, Dict, Optional
+from typing import Any, cast
from flask_babel import gettext as __
@@ -40,14 +40,14 @@ logger = logging.getLogger(__name__)
class SqlExecutionResultsCommand(BaseCommand):
_key: str
- _rows: Optional[int]
+ _rows: int | None
_blob: Any
_query: Query
def __init__(
self,
key: str,
- rows: Optional[int] = None,
+ rows: int | None = None,
) -> None:
self._key = key
self._rows = rows
@@ -100,7 +100,7 @@ class SqlExecutionResultsCommand(BaseCommand):
def run(
self,
- ) -> Dict[str, Any]:
+ ) -> dict[str, Any]:
"""Runs arbitrary sql and returns data as json"""
self.validate()
payload = utils.zlib_decompress(
diff --git a/superset/sqllab/exceptions.py b/superset/sqllab/exceptions.py
index 70e4fa9752..f06cc8dd2e 100644
--- a/superset/sqllab/exceptions.py
+++ b/superset/sqllab/exceptions.py
@@ -17,7 +17,7 @@
from __future__ import annotations
import os
-from typing import Optional, TYPE_CHECKING
+from typing import TYPE_CHECKING
from flask_babel import lazy_gettext as _
@@ -31,15 +31,15 @@ if TYPE_CHECKING:
class SqlLabException(SupersetException):
sql_json_execution_context: SqlJsonExecutionContext
failed_reason_msg: str
- suggestion_help_msg: Optional[str]
+ suggestion_help_msg: str | None
def __init__( # pylint: disable=too-many-arguments
self,
sql_json_execution_context: SqlJsonExecutionContext,
- error_type: Optional[SupersetErrorType] = None,
- reason_message: Optional[str] = None,
- exception: Optional[Exception] = None,
- suggestion_help_msg: Optional[str] = None,
+ error_type: SupersetErrorType | None = None,
+ reason_message: str | None = None,
+ exception: Exception | None = None,
+ suggestion_help_msg: str | None = None,
) -> None:
self.sql_json_execution_context = sql_json_execution_context
self.failed_reason_msg = self._get_reason(reason_message, exception)
@@ -68,21 +68,21 @@ class SqlLabException(SupersetException):
if self.failed_reason_msg:
msg = msg + self.failed_reason_msg
if self.suggestion_help_msg is not None:
- msg = "{} {} {}".format(msg, os.linesep, self.suggestion_help_msg)
+ msg = f"{msg} {os.linesep} {self.suggestion_help_msg}"
return msg
@classmethod
def _get_reason(
- cls, reason_message: Optional[str] = None, exception: Optional[Exception] = None
+ cls, reason_message: str | None = None, exception: Exception | None = None
) -> str:
if reason_message is not None:
- return ": {}".format(reason_message)
+ return f": {reason_message}"
if exception is not None:
if hasattr(exception, "get_message"):
- return ": {}".format(exception.get_message())
+ return f": {exception.get_message()}"
if hasattr(exception, "message"):
- return ": {}".format(exception.message)
- return ": {}".format(str(exception))
+ return f": {exception.message}"
+ return f": {str(exception)}"
return ""
@@ -93,7 +93,7 @@ class QueryIsForbiddenToAccessException(SqlLabException):
def __init__(
self,
sql_json_execution_context: SqlJsonExecutionContext,
- exception: Optional[Exception] = None,
+ exception: Exception | None = None,
) -> None:
super().__init__(
sql_json_execution_context,
diff --git a/superset/sqllab/execution_context_convertor.py b/superset/sqllab/execution_context_convertor.py
index f49fbd9a31..430db0d52f 100644
--- a/superset/sqllab/execution_context_convertor.py
+++ b/superset/sqllab/execution_context_convertor.py
@@ -17,7 +17,7 @@
from __future__ import annotations
import logging
-from typing import Any, Dict, TYPE_CHECKING
+from typing import Any, TYPE_CHECKING
import simplejson as json
@@ -36,7 +36,7 @@ if TYPE_CHECKING:
class ExecutionContextConvertor:
_max_row_in_display_configuration: int # pylint: disable=invalid-name
_exc_status: SqlJsonExecutionStatus
- payload: Dict[str, Any]
+ payload: dict[str, Any]
def set_max_row_in_display(self, value: int) -> None:
self._max_row_in_display_configuration = value # pylint: disable=invalid-name
diff --git a/superset/sqllab/query_render.py b/superset/sqllab/query_render.py
index 1369e78db1..db1adf43ba 100644
--- a/superset/sqllab/query_render.py
+++ b/superset/sqllab/query_render.py
@@ -17,7 +17,7 @@
# pylint: disable=invalid-name, no-self-use, too-few-public-methods, too-many-arguments
from __future__ import annotations
-from typing import Any, Callable, Dict, Optional, TYPE_CHECKING
+from typing import Any, Callable, TYPE_CHECKING
from flask_babel import gettext as __, ngettext
from jinja2 import TemplateError
@@ -124,16 +124,16 @@ class SqlQueryRenderImpl(SqlQueryRender):
class SqlQueryRenderException(SqlLabException):
- _extra: Optional[Dict[str, Any]]
+ _extra: dict[str, Any] | None
def __init__(
self,
sql_json_execution_context: SqlJsonExecutionContext,
error_type: SupersetErrorType,
- reason_message: Optional[str] = None,
- exception: Optional[Exception] = None,
- suggestion_help_msg: Optional[str] = None,
- extra: Optional[Dict[str, Any]] = None,
+ reason_message: str | None = None,
+ exception: Exception | None = None,
+ suggestion_help_msg: str | None = None,
+ extra: dict[str, Any] | None = None,
) -> None:
super().__init__(
sql_json_execution_context,
@@ -145,10 +145,10 @@ class SqlQueryRenderException(SqlLabException):
self._extra = extra
@property
- def extra(self) -> Optional[Dict[str, Any]]:
+ def extra(self) -> dict[str, Any] | None:
return self._extra
- def to_dict(self) -> Dict[str, Any]:
+ def to_dict(self) -> dict[str, Any]:
rv = super().to_dict()
if self._extra:
rv["extra"] = self._extra
diff --git a/superset/sqllab/sql_json_executer.py b/superset/sqllab/sql_json_executer.py
index e4e6b60654..124f477e96 100644
--- a/superset/sqllab/sql_json_executer.py
+++ b/superset/sqllab/sql_json_executer.py
@@ -20,7 +20,7 @@ from __future__ import annotations
import dataclasses
import logging
from abc import ABC
-from typing import Any, Callable, Dict, Optional, TYPE_CHECKING
+from typing import Any, Callable, TYPE_CHECKING
from flask_babel import gettext as __
@@ -43,7 +43,7 @@ if TYPE_CHECKING:
QueryStatus = utils.QueryStatus
logger = logging.getLogger(__name__)
-SqlResults = Dict[str, Any]
+SqlResults = dict[str, Any]
GetSqlResultsTask = Callable[..., SqlResults]
@@ -53,7 +53,7 @@ class SqlJsonExecutor:
self,
execution_context: SqlJsonExecutionContext,
rendered_query: str,
- log_params: Optional[Dict[str, Any]],
+ log_params: dict[str, Any] | None,
) -> SqlJsonExecutionStatus:
raise NotImplementedError()
@@ -88,7 +88,7 @@ class SynchronousSqlJsonExecutor(SqlJsonExecutorBase):
self,
execution_context: SqlJsonExecutionContext,
rendered_query: str,
- log_params: Optional[Dict[str, Any]],
+ log_params: dict[str, Any] | None,
) -> SqlJsonExecutionStatus:
query_id = execution_context.query.id
try:
@@ -120,8 +120,8 @@ class SynchronousSqlJsonExecutor(SqlJsonExecutorBase):
self,
execution_context: SqlJsonExecutionContext,
rendered_query: str,
- log_params: Optional[Dict[str, Any]],
- ) -> Optional[SqlResults]:
+ log_params: dict[str, Any] | None,
+ ) -> SqlResults | None:
with utils.timeout(
seconds=self._timeout_duration_in_seconds,
error_message=self._get_timeout_error_msg(),
@@ -132,8 +132,8 @@ class SynchronousSqlJsonExecutor(SqlJsonExecutorBase):
self,
execution_context: SqlJsonExecutionContext,
rendered_query: str,
- log_params: Optional[Dict[str, Any]],
- ) -> Optional[SqlResults]:
+ log_params: dict[str, Any] | None,
+ ) -> SqlResults | None:
return self._get_sql_results_task(
execution_context.query.id,
rendered_query,
@@ -161,7 +161,7 @@ class ASynchronousSqlJsonExecutor(SqlJsonExecutorBase):
self,
execution_context: SqlJsonExecutionContext,
rendered_query: str,
- log_params: Optional[Dict[str, Any]],
+ log_params: dict[str, Any] | None,
) -> SqlJsonExecutionStatus:
query_id = execution_context.query.id
logger.info("Query %i: Running query on a Celery worker", query_id)
diff --git a/superset/sqllab/sqllab_execution_context.py b/superset/sqllab/sqllab_execution_context.py
index 644c978b32..22277804ee 100644
--- a/superset/sqllab/sqllab_execution_context.py
+++ b/superset/sqllab/sqllab_execution_context.py
@@ -19,7 +19,7 @@ from __future__ import annotations
import json
import logging
from dataclasses import dataclass
-from typing import Any, cast, Dict, Optional, TYPE_CHECKING
+from typing import Any, cast, TYPE_CHECKING
from flask import g
from sqlalchemy.orm.exc import DetachedInstanceError
@@ -37,7 +37,7 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-SqlResults = Dict[str, Any]
+SqlResults = dict[str, Any]
@dataclass
@@ -45,7 +45,7 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes
database_id: int
schema: str
sql: str
- template_params: Dict[str, Any]
+ template_params: dict[str, Any]
async_flag: bool
limit: int
status: str
@@ -53,14 +53,14 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes
client_id_or_short_id: str
sql_editor_id: str
tab_name: str
- user_id: Optional[int]
+ user_id: int | None
expand_data: bool
- create_table_as_select: Optional[CreateTableAsSelect]
- database: Optional[Database]
+ create_table_as_select: CreateTableAsSelect | None
+ database: Database | None
query: Query
- _sql_result: Optional[SqlResults]
+ _sql_result: SqlResults | None
- def __init__(self, query_params: Dict[str, Any]):
+ def __init__(self, query_params: dict[str, Any]):
self.create_table_as_select = None
self.database = None
self._init_from_query_params(query_params)
@@ -70,7 +70,7 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes
def set_query(self, query: Query) -> None:
self.query = query
- def _init_from_query_params(self, query_params: Dict[str, Any]) -> None:
+ def _init_from_query_params(self, query_params: dict[str, Any]) -> None:
self.database_id = cast(int, query_params.get("database_id"))
self.schema = cast(str, query_params.get("schema"))
self.sql = cast(str, query_params.get("sql"))
@@ -90,7 +90,7 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes
)
@staticmethod
- def _get_template_params(query_params: Dict[str, Any]) -> Dict[str, Any]:
+ def _get_template_params(query_params: dict[str, Any]) -> dict[str, Any]:
try:
template_params = json.loads(query_params.get("templateParams") or "{}")
except json.JSONDecodeError:
@@ -102,7 +102,7 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes
return template_params
@staticmethod
- def _get_limit_param(query_params: Dict[str, Any]) -> int:
+ def _get_limit_param(query_params: dict[str, Any]) -> int:
limit = apply_max_row_limit(query_params.get("queryLimit") or 0)
if limit < 0:
logger.warning(
@@ -125,7 +125,7 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes
schema_name = self._get_ctas_target_schema_name(database)
self.create_table_as_select.target_schema_name = schema_name # type: ignore
- def _get_ctas_target_schema_name(self, database: Database) -> Optional[str]:
+ def _get_ctas_target_schema_name(self, database: Database) -> str | None:
if database.force_ctas_schema:
return database.force_ctas_schema
return get_cta_schema_name(database, g.user, self.schema, self.sql)
@@ -134,10 +134,10 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes
# TODO validate db.id is equal to self.database_id
pass
- def get_execution_result(self) -> Optional[SqlResults]:
+ def get_execution_result(self) -> SqlResults | None:
return self._sql_result
- def set_execution_result(self, sql_result: Optional[SqlResults]) -> None:
+ def set_execution_result(self, sql_result: SqlResults | None) -> None:
self._sql_result = sql_result
def create_query(self) -> Query:
@@ -178,15 +178,15 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes
try:
if hasattr(self, "query"):
if self.query.id:
- return "query '{}' - '{}'".format(self.query.id, self.query.sql)
+ return f"query '{self.query.id}' - '{self.query.sql}'"
except DetachedInstanceError:
pass
- return "query '{}'".format(self.sql)
+ return f"query '{self.sql}'"
class CreateTableAsSelect: # pylint: disable=too-few-public-methods
ctas_method: CtasMethod
- target_schema_name: Optional[str]
+ target_schema_name: str | None
target_table_name: str
def __init__(
@@ -197,7 +197,7 @@ class CreateTableAsSelect: # pylint: disable=too-few-public-methods
self.target_table_name = target_table_name
@staticmethod
- def create_from(query_params: Dict[str, Any]) -> CreateTableAsSelect:
+ def create_from(query_params: dict[str, Any]) -> CreateTableAsSelect:
ctas_method = query_params.get("ctas_method", CtasMethod.TABLE)
schema = cast(str, query_params.get("schema"))
tmp_table_name = cast(str, query_params.get("tmp_table_name"))
diff --git a/superset/sqllab/utils.py b/superset/sqllab/utils.py
index 3bcd7308a1..abceaaf136 100644
--- a/superset/sqllab/utils.py
+++ b/superset/sqllab/utils.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict
+from typing import Any
import pyarrow as pa
@@ -22,8 +22,8 @@ from superset.common.db_query_status import QueryStatus
def apply_display_max_row_configuration_if_require( # pylint: disable=invalid-name
- sql_results: Dict[str, Any], max_rows_in_result: int
-) -> Dict[str, Any]:
+ sql_results: dict[str, Any], max_rows_in_result: int
+) -> dict[str, Any]:
"""
Given a `sql_results` nested structure, applies a limit to the number of rows
diff --git a/superset/stats_logger.py b/superset/stats_logger.py
index 4b869042a9..fc223f7529 100644
--- a/superset/stats_logger.py
+++ b/superset/stats_logger.py
@@ -54,22 +54,20 @@ class DummyStatsLogger(BaseStatsLogger):
logger.debug(Fore.CYAN + "[stats_logger] (incr) " + key + Style.RESET_ALL)
def decr(self, key: str) -> None:
- logger.debug((Fore.CYAN + "[stats_logger] (decr) " + key + Style.RESET_ALL))
+ logger.debug(Fore.CYAN + "[stats_logger] (decr) " + key + Style.RESET_ALL)
def timing(self, key: str, value: float) -> None:
logger.debug(
- (Fore.CYAN + f"[stats_logger] (timing) {key} | {value} " + Style.RESET_ALL)
+ Fore.CYAN + f"[stats_logger] (timing) {key} | {value} " + Style.RESET_ALL
)
def gauge(self, key: str, value: float) -> None:
logger.debug(
- (
- Fore.CYAN
- + "[stats_logger] (gauge) "
- + f"{key}"
- + f"{value}"
- + Style.RESET_ALL
- )
+ Fore.CYAN
+ + "[stats_logger] (gauge) "
+ + f"{key}"
+ + f"{value}"
+ + Style.RESET_ALL
)
diff --git a/superset/superset_typing.py b/superset/superset_typing.py
index 8eaea54176..7c21df6a88 100644
--- a/superset/superset_typing.py
+++ b/superset/superset_typing.py
@@ -14,10 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from collections.abc import Sequence
from datetime import datetime
-from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
+from typing import Any, Literal, Optional, TYPE_CHECKING, Union
-from typing_extensions import Literal, TypedDict
+from typing_extensions import TypedDict
from werkzeug.wrappers import Response
if TYPE_CHECKING:
@@ -69,8 +70,8 @@ class ResultSetColumnType(TypedDict):
is_dttm: bool
-CacheConfig = Dict[str, Any]
-DbapiDescriptionRow = Tuple[
+CacheConfig = dict[str, Any]
+DbapiDescriptionRow = tuple[
Union[str, bytes],
str,
Optional[str],
@@ -79,27 +80,27 @@ DbapiDescriptionRow = Tuple[
Optional[int],
bool,
]
-DbapiDescription = Union[List[DbapiDescriptionRow], Tuple[DbapiDescriptionRow, ...]]
-DbapiResult = Sequence[Union[List[Any], Tuple[Any, ...]]]
+DbapiDescription = Union[list[DbapiDescriptionRow], tuple[DbapiDescriptionRow, ...]]
+DbapiResult = Sequence[Union[list[Any], tuple[Any, ...]]]
FilterValue = Union[bool, datetime, float, int, str]
-FilterValues = Union[FilterValue, List[FilterValue], Tuple[FilterValue]]
-FormData = Dict[str, Any]
-Granularity = Union[str, Dict[str, Union[str, float]]]
+FilterValues = Union[FilterValue, list[FilterValue], tuple[FilterValue]]
+FormData = dict[str, Any]
+Granularity = Union[str, dict[str, Union[str, float]]]
Column = Union[AdhocColumn, str]
Metric = Union[AdhocMetric, str]
-OrderBy = Tuple[Metric, bool]
-QueryObjectDict = Dict[str, Any]
-VizData = Optional[Union[List[Any], Dict[Any, Any]]]
-VizPayload = Dict[str, Any]
+OrderBy = tuple[Metric, bool]
+QueryObjectDict = dict[str, Any]
+VizData = Optional[Union[list[Any], dict[Any, Any]]]
+VizPayload = dict[str, Any]
# Flask response.
Base = Union[bytes, str]
Status = Union[int, str]
-Headers = Dict[str, Any]
+Headers = dict[str, Any]
FlaskResponse = Union[
Response,
Base,
- Tuple[Base, Status],
- Tuple[Base, Status, Headers],
- Tuple[Response, Status],
+ tuple[Base, Status],
+ tuple[Base, Status, Headers],
+ tuple[Response, Status],
]
diff --git a/superset/tables/models.py b/superset/tables/models.py
index 9a0c07fdcf..a24035fb97 100644
--- a/superset/tables/models.py
+++ b/superset/tables/models.py
@@ -24,7 +24,8 @@ addition to a table, new models for columns, metrics, and datasets were also int
These models are not fully implemented, and shouldn't be used yet.
"""
-from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING
+from collections.abc import Iterable
+from typing import Any, Optional, TYPE_CHECKING
import sqlalchemy as sa
from flask_appbuilder import Model
@@ -87,7 +88,7 @@ class Table(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
# The relationship between datasets and columns is 1:n, but we use a
# many-to-many association table to avoid adding two mutually exclusive
# columns(dataset_id and table_id) to Column
- columns: List[Column] = relationship(
+ columns: list[Column] = relationship(
"Column",
secondary=table_column_association_table,
cascade="all, delete-orphan",
@@ -96,7 +97,7 @@ class Table(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
# is loaded.
backref="tables",
)
- datasets: List["Dataset"] # will be populated by Dataset.tables backref
+ datasets: list["Dataset"] # will be populated by Dataset.tables backref
# We use ``sa.Text`` for these attributes because (1) in modern databases the
# performance is the same as ``VARCHAR``[1] and (2) because some table names can be
@@ -130,7 +131,7 @@ class Table(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
existing_columns = {column.name: column for column in self.columns}
quote_identifier = self.database.quote_identifier
- def update_or_create_column(column_meta: Dict[str, Any]) -> Column:
+ def update_or_create_column(column_meta: dict[str, Any]) -> Column:
column_name: str = column_meta["name"]
if column_name in existing_columns:
column = existing_columns[column_name]
@@ -153,8 +154,8 @@ class Table(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
table_names: Iterable[TableName],
default_schema: Optional[str] = None,
sync_columns: Optional[bool] = False,
- default_props: Optional[Dict[str, Any]] = None,
- ) -> List["Table"]:
+ default_props: Optional[dict[str, Any]] = None,
+ ) -> list["Table"]:
"""
Load or create multiple Table instances.
"""
diff --git a/superset/tags/commands/create.py b/superset/tags/commands/create.py
index 1e886e2af6..20327b54f0 100644
--- a/superset/tags/commands/create.py
+++ b/superset/tags/commands/create.py
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import List
from superset.commands.base import BaseCommand, CreateMixin
from superset.dao.exceptions import DAOCreateFailedError
@@ -28,7 +27,7 @@ logger = logging.getLogger(__name__)
class CreateCustomTagCommand(CreateMixin, BaseCommand):
- def __init__(self, object_type: ObjectTypes, object_id: int, tags: List[str]):
+ def __init__(self, object_type: ObjectTypes, object_id: int, tags: list[str]):
self._object_type = object_type
self._object_id = object_id
self._tags = tags
diff --git a/superset/tags/commands/delete.py b/superset/tags/commands/delete.py
index acec016619..08189b5ac5 100644
--- a/superset/tags/commands/delete.py
+++ b/superset/tags/commands/delete.py
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import List
from superset.commands.base import BaseCommand
from superset.dao.exceptions import DAODeleteFailedError
@@ -90,7 +89,7 @@ class DeleteTaggedObjectCommand(DeleteMixin, BaseCommand):
class DeleteTagsCommand(DeleteMixin, BaseCommand):
- def __init__(self, tags: List[str]):
+ def __init__(self, tags: list[str]):
self._tags = tags
def run(self) -> None:
diff --git a/superset/tags/dao.py b/superset/tags/dao.py
index c676b4ab3c..9ea61f5c90 100644
--- a/superset/tags/dao.py
+++ b/superset/tags/dao.py
@@ -16,7 +16,7 @@
# under the License.
import logging
from operator import and_
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from sqlalchemy.exc import SQLAlchemyError
@@ -45,7 +45,7 @@ class TagDAO(BaseDAO):
@staticmethod
def create_custom_tagged_objects(
- object_type: ObjectTypes, object_id: int, tag_names: List[str]
+ object_type: ObjectTypes, object_id: int, tag_names: list[str]
) -> None:
tagged_objects = []
for name in tag_names:
@@ -95,7 +95,7 @@ class TagDAO(BaseDAO):
raise DAODeleteFailedError(exception=ex) from ex
@staticmethod
- def delete_tags(tag_names: List[str]) -> None:
+ def delete_tags(tag_names: list[str]) -> None:
"""
deletes tags from a list of tag names
"""
@@ -158,8 +158,8 @@ class TagDAO(BaseDAO):
@staticmethod
def get_tagged_objects_for_tags(
- tags: Optional[List[str]] = None, obj_types: Optional[List[str]] = None
- ) -> List[Dict[str, Any]]:
+ tags: Optional[list[str]] = None, obj_types: Optional[list[str]] = None
+ ) -> list[dict[str, Any]]:
"""
returns a list of tagged objects filtered by tag names and object types
if no filters applied returns all tagged objects
@@ -174,7 +174,7 @@ class TagDAO(BaseDAO):
# filter types
- results: List[Dict[str, Any]] = []
+ results: list[dict[str, Any]] = []
# dashboards
if (not obj_types) or ("dashboard" in obj_types):
diff --git a/superset/tags/models.py b/superset/tags/models.py
index 797308c306..bb845303ff 100644
--- a/superset/tags/models.py
+++ b/superset/tags/models.py
@@ -14,16 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from __future__ import (
- absolute_import,
- annotations,
- division,
- print_function,
- unicode_literals,
-)
+from __future__ import annotations
import enum
-from typing import List, Optional, TYPE_CHECKING, Union
+from typing import TYPE_CHECKING
from flask_appbuilder import Model
from sqlalchemy import Column, Enum, ForeignKey, Integer, String
@@ -122,26 +116,26 @@ def get_object_type(class_name: str) -> ObjectTypes:
try:
return mapping[class_name.lower()]
except KeyError as ex:
- raise Exception("No mapping found for {0}".format(class_name)) from ex
+ raise Exception(f"No mapping found for {class_name}") from ex
class ObjectUpdater:
- object_type: Optional[str] = None
+ object_type: str | None = None
@classmethod
def get_owners_ids(
- cls, target: Union[Dashboard, FavStar, Slice, Query, SqlaTable]
- ) -> List[int]:
+ cls, target: Dashboard | FavStar | Slice | Query | SqlaTable
+ ) -> list[int]:
raise NotImplementedError("Subclass should implement `get_owners_ids`")
@classmethod
def _add_owners(
cls,
session: Session,
- target: Union[Dashboard, FavStar, Slice, Query, SqlaTable],
+ target: Dashboard | FavStar | Slice | Query | SqlaTable,
) -> None:
for owner_id in cls.get_owners_ids(target):
- name = "owner:{0}".format(owner_id)
+ name = f"owner:{owner_id}"
tag = get_tag(name, session, TagTypes.owner)
tagged_object = TaggedObject(
tag_id=tag.id, object_id=target.id, object_type=cls.object_type
@@ -153,7 +147,7 @@ class ObjectUpdater:
cls,
_mapper: Mapper,
connection: Connection,
- target: Union[Dashboard, FavStar, Slice, Query, SqlaTable],
+ target: Dashboard | FavStar | Slice | Query | SqlaTable,
) -> None:
session = Session(bind=connection)
@@ -161,7 +155,7 @@ class ObjectUpdater:
cls._add_owners(session, target)
# add `type:` tags
- tag = get_tag("type:{0}".format(cls.object_type), session, TagTypes.type)
+ tag = get_tag(f"type:{cls.object_type}", session, TagTypes.type)
tagged_object = TaggedObject(
tag_id=tag.id, object_id=target.id, object_type=cls.object_type
)
@@ -174,7 +168,7 @@ class ObjectUpdater:
cls,
_mapper: Mapper,
connection: Connection,
- target: Union[Dashboard, FavStar, Slice, Query, SqlaTable],
+ target: Dashboard | FavStar | Slice | Query | SqlaTable,
) -> None:
session = Session(bind=connection)
@@ -203,7 +197,7 @@ class ObjectUpdater:
cls,
_mapper: Mapper,
connection: Connection,
- target: Union[Dashboard, FavStar, Slice, Query, SqlaTable],
+ target: Dashboard | FavStar | Slice | Query | SqlaTable,
) -> None:
session = Session(bind=connection)
@@ -220,7 +214,7 @@ class ChartUpdater(ObjectUpdater):
object_type = "chart"
@classmethod
- def get_owners_ids(cls, target: Slice) -> List[int]:
+ def get_owners_ids(cls, target: Slice) -> list[int]:
return [owner.id for owner in target.owners]
@@ -228,7 +222,7 @@ class DashboardUpdater(ObjectUpdater):
object_type = "dashboard"
@classmethod
- def get_owners_ids(cls, target: Dashboard) -> List[int]:
+ def get_owners_ids(cls, target: Dashboard) -> list[int]:
return [owner.id for owner in target.owners]
@@ -236,7 +230,7 @@ class QueryUpdater(ObjectUpdater):
object_type = "query"
@classmethod
- def get_owners_ids(cls, target: Query) -> List[int]:
+ def get_owners_ids(cls, target: Query) -> list[int]:
return [target.user_id]
@@ -244,7 +238,7 @@ class DatasetUpdater(ObjectUpdater):
object_type = "dataset"
@classmethod
- def get_owners_ids(cls, target: SqlaTable) -> List[int]:
+ def get_owners_ids(cls, target: SqlaTable) -> list[int]:
return [owner.id for owner in target.owners]
@@ -254,7 +248,7 @@ class FavStarUpdater:
cls, _mapper: Mapper, connection: Connection, target: FavStar
) -> None:
session = Session(bind=connection)
- name = "favorited_by:{0}".format(target.user_id)
+ name = f"favorited_by:{target.user_id}"
tag = get_tag(name, session, TagTypes.favorited_by)
tagged_object = TaggedObject(
tag_id=tag.id,
diff --git a/superset/tasks/__init__.py b/superset/tasks/__init__.py
index fd9417fe5c..13a83393a9 100644
--- a/superset/tasks/__init__.py
+++ b/superset/tasks/__init__.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
diff --git a/superset/tasks/async_queries.py b/superset/tasks/async_queries.py
index ffd92c2627..cfcb3e31c6 100644
--- a/superset/tasks/async_queries.py
+++ b/superset/tasks/async_queries.py
@@ -18,7 +18,7 @@ from __future__ import annotations
import copy
import logging
-from typing import Any, cast, Dict, Optional, TYPE_CHECKING
+from typing import Any, cast, TYPE_CHECKING
from celery.exceptions import SoftTimeLimitExceeded
from flask import current_app, g
@@ -45,12 +45,12 @@ query_timeout = current_app.config[
] # TODO: new config key
-def set_form_data(form_data: Dict[str, Any]) -> None:
+def set_form_data(form_data: dict[str, Any]) -> None:
# pylint: disable=assigning-non-slot
g.form_data = form_data
-def _create_query_context_from_form(form_data: Dict[str, Any]) -> QueryContext:
+def _create_query_context_from_form(form_data: dict[str, Any]) -> QueryContext:
try:
return ChartDataQueryContextSchema().load(form_data)
except KeyError as ex:
@@ -61,8 +61,8 @@ def _create_query_context_from_form(form_data: Dict[str, Any]) -> QueryContext:
@celery_app.task(name="load_chart_data_into_cache", soft_time_limit=query_timeout)
def load_chart_data_into_cache(
- job_metadata: Dict[str, Any],
- form_data: Dict[str, Any],
+ job_metadata: dict[str, Any],
+ form_data: dict[str, Any],
) -> None:
# pylint: disable=import-outside-toplevel
from superset.charts.data.commands.get_data_command import ChartDataCommand
@@ -104,9 +104,9 @@ def load_chart_data_into_cache(
@celery_app.task(name="load_explore_json_into_cache", soft_time_limit=query_timeout)
def load_explore_json_into_cache( # pylint: disable=too-many-locals
- job_metadata: Dict[str, Any],
- form_data: Dict[str, Any],
- response_type: Optional[str] = None,
+ job_metadata: dict[str, Any],
+ form_data: dict[str, Any],
+ response_type: str | None = None,
force: bool = False,
) -> None:
cache_key_prefix = "ejr-" # ejr: explore_json request
diff --git a/superset/tasks/cache.py b/superset/tasks/cache.py
index bdbf8add7e..448271269a 100644
--- a/superset/tasks/cache.py
+++ b/superset/tasks/cache.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict, List, Optional, Union
+from typing import Any, Optional, Union
from urllib import request
from urllib.error import URLError
@@ -72,7 +72,7 @@ class Strategy: # pylint: disable=too-few-public-methods
def __init__(self) -> None:
pass
- def get_urls(self) -> List[str]:
+ def get_urls(self) -> list[str]:
raise NotImplementedError("Subclasses must implement get_urls!")
@@ -94,7 +94,7 @@ class DummyStrategy(Strategy): # pylint: disable=too-few-public-methods
name = "dummy"
- def get_urls(self) -> List[str]:
+ def get_urls(self) -> list[str]:
session = db.create_scoped_session()
charts = session.query(Slice).all()
@@ -126,7 +126,7 @@ class TopNDashboardsStrategy(Strategy): # pylint: disable=too-few-public-method
self.top_n = top_n
self.since = parse_human_datetime(since) if since else None
- def get_urls(self) -> List[str]:
+ def get_urls(self) -> list[str]:
urls = []
session = db.create_scoped_session()
@@ -165,11 +165,11 @@ class DashboardTagsStrategy(Strategy): # pylint: disable=too-few-public-methods
name = "dashboard_tags"
- def __init__(self, tags: Optional[List[str]] = None) -> None:
+ def __init__(self, tags: Optional[list[str]] = None) -> None:
super().__init__()
self.tags = tags or []
- def get_urls(self) -> List[str]:
+ def get_urls(self) -> list[str]:
urls = []
session = db.create_scoped_session()
@@ -216,7 +216,7 @@ strategies = [DummyStrategy, TopNDashboardsStrategy, DashboardTagsStrategy]
@celery_app.task(name="fetch_url")
-def fetch_url(url: str, headers: Dict[str, str]) -> Dict[str, str]:
+def fetch_url(url: str, headers: dict[str, str]) -> dict[str, str]:
"""
Celery job to fetch url
"""
@@ -242,7 +242,7 @@ def fetch_url(url: str, headers: Dict[str, str]) -> Dict[str, str]:
@celery_app.task(name="cache-warmup")
def cache_warmup(
strategy_name: str, *args: Any, **kwargs: Any
-) -> Union[Dict[str, List[str]], str]:
+) -> Union[dict[str, list[str]], str]:
"""
Warm up cache.
@@ -272,7 +272,7 @@ def cache_warmup(
cookies = MachineAuthProvider.get_auth_cookies(user)
headers = {"Cookie": f"session={cookies.get('session', '')}"}
- results: Dict[str, List[str]] = {"scheduled": [], "errors": []}
+ results: dict[str, list[str]] = {"scheduled": [], "errors": []}
for url in strategy.get_urls():
try:
logger.info("Scheduling %s", url)
diff --git a/superset/tasks/cron_util.py b/superset/tasks/cron_util.py
index 9c275addf6..19d342ebdc 100644
--- a/superset/tasks/cron_util.py
+++ b/superset/tasks/cron_util.py
@@ -16,8 +16,8 @@
# under the License.
import logging
+from collections.abc import Iterator
from datetime import datetime, timedelta, timezone as dt_timezone
-from typing import Iterator
from croniter import croniter
from pytz import timezone as pytz_timezone, UnknownTimeZoneError
diff --git a/superset/tasks/utils.py b/superset/tasks/utils.py
index 9c1dab8220..5012330bbd 100644
--- a/superset/tasks/utils.py
+++ b/superset/tasks/utils.py
@@ -17,7 +17,7 @@
from __future__ import annotations
-from typing import List, Optional, Tuple, TYPE_CHECKING, Union
+from typing import TYPE_CHECKING
from flask import current_app, g
@@ -32,10 +32,10 @@ if TYPE_CHECKING:
# pylint: disable=too-many-branches
def get_executor(
- executor_types: List[ExecutorType],
- model: Union[Dashboard, ReportSchedule, Slice],
- current_user: Optional[str] = None,
-) -> Tuple[ExecutorType, str]:
+ executor_types: list[ExecutorType],
+ model: Dashboard | ReportSchedule | Slice,
+ current_user: str | None = None,
+) -> tuple[ExecutorType, str]:
"""
Extract the user that should be used to execute a scheduled task. Certain executor
types extract the user from the underlying object (e.g. CREATOR), the constant
@@ -86,7 +86,7 @@ def get_executor(
raise ExecutorNotFoundError()
-def get_current_user() -> Optional[str]:
+def get_current_user() -> str | None:
user = g.user if hasattr(g, "user") and g.user else None
if user and not user.is_anonymous:
return user.username
diff --git a/superset/translations/utils.py b/superset/translations/utils.py
index 79d01539a1..23eca1dd8c 100644
--- a/superset/translations/utils.py
+++ b/superset/translations/utils.py
@@ -16,15 +16,15 @@
# under the License.
import json
import os
-from typing import Any, Dict, Optional
+from typing import Any, Optional
# Global caching for JSON language packs
-ALL_LANGUAGE_PACKS: Dict[str, Dict[str, Any]] = {"en": {}}
+ALL_LANGUAGE_PACKS: dict[str, dict[str, Any]] = {"en": {}}
DIR = os.path.dirname(os.path.abspath(__file__))
-def get_language_pack(locale: str) -> Optional[Dict[str, Any]]:
+def get_language_pack(locale: str) -> Optional[dict[str, Any]]:
"""Get/cache a language pack
Returns the language pack from cache if it exists, caches otherwise
@@ -34,7 +34,7 @@ def get_language_pack(locale: str) -> Optional[Dict[str, Any]]:
"""
pack = ALL_LANGUAGE_PACKS.get(locale)
if not pack:
- filename = DIR + "/{}/LC_MESSAGES/messages.json".format(locale)
+ filename = DIR + f"/{locale}/LC_MESSAGES/messages.json"
try:
with open(filename, encoding="utf8") as f:
pack = json.load(f)
diff --git a/superset/utils/async_query_manager.py b/superset/utils/async_query_manager.py
index 71559aaa3d..1913fc1dec 100644
--- a/superset/utils/async_query_manager.py
+++ b/superset/utils/async_query_manager.py
@@ -17,7 +17,7 @@
import json
import logging
import uuid
-from typing import Any, Dict, List, Literal, Optional, Tuple
+from typing import Any, Literal, Optional
import jwt
import redis
@@ -38,7 +38,7 @@ class AsyncQueryJobException(Exception):
def build_job_metadata(
channel_id: str, job_id: str, user_id: Optional[int], **kwargs: Any
-) -> Dict[str, Any]:
+) -> dict[str, Any]:
return {
"channel_id": channel_id,
"job_id": job_id,
@@ -49,7 +49,7 @@ def build_job_metadata(
}
-def parse_event(event_data: Tuple[str, Dict[str, Any]]) -> Dict[str, Any]:
+def parse_event(event_data: tuple[str, dict[str, Any]]) -> dict[str, Any]:
event_id = event_data[0]
event_payload = event_data[1]["data"]
return {"id": event_id, **json.loads(event_payload)}
@@ -149,7 +149,7 @@ class AsyncQueryManager:
return response
- def parse_jwt_from_request(self, req: Request) -> Dict[str, Any]:
+ def parse_jwt_from_request(self, req: Request) -> dict[str, Any]:
token = req.cookies.get(self._jwt_cookie_name)
if not token:
raise AsyncQueryTokenException("Token not preset")
@@ -160,7 +160,7 @@ class AsyncQueryManager:
logger.warning("Parse jwt failed", exc_info=True)
raise AsyncQueryTokenException("Failed to parse token") from ex
- def init_job(self, channel_id: str, user_id: Optional[int]) -> Dict[str, Any]:
+ def init_job(self, channel_id: str, user_id: Optional[int]) -> dict[str, Any]:
job_id = str(uuid.uuid4())
return build_job_metadata(
channel_id, job_id, user_id, status=self.STATUS_PENDING
@@ -168,14 +168,14 @@ class AsyncQueryManager:
def read_events(
self, channel: str, last_id: Optional[str]
- ) -> List[Optional[Dict[str, Any]]]:
+ ) -> list[Optional[dict[str, Any]]]:
stream_name = f"{self._stream_prefix}{channel}"
start_id = increment_id(last_id) if last_id else "-"
results = self._redis.xrange(stream_name, start_id, "+", self.MAX_EVENT_COUNT)
return [] if not results else list(map(parse_event, results))
def update_job(
- self, job_metadata: Dict[str, Any], status: str, **kwargs: Any
+ self, job_metadata: dict[str, Any], status: str, **kwargs: Any
) -> None:
if "channel_id" not in job_metadata:
raise AsyncQueryJobException("No channel ID specified")
diff --git a/superset/utils/cache.py b/superset/utils/cache.py
index a632b04b37..693f3a73bc 100644
--- a/superset/utils/cache.py
+++ b/superset/utils/cache.py
@@ -20,7 +20,7 @@ import inspect
import logging
from datetime import datetime, timedelta
from functools import wraps
-from typing import Any, Callable, Dict, Optional, TYPE_CHECKING, Union
+from typing import Any, Callable, TYPE_CHECKING
from flask import current_app as app, request
from flask_caching import Cache
@@ -41,7 +41,7 @@ stats_logger: BaseStatsLogger = config["STATS_LOGGER"]
logger = logging.getLogger(__name__)
-def generate_cache_key(values_dict: Dict[str, Any], key_prefix: str = "") -> str:
+def generate_cache_key(values_dict: dict[str, Any], key_prefix: str = "") -> str:
hash_str = md5_sha_from_dict(values_dict, default=json_int_dttm_ser)
return f"{key_prefix}{hash_str}"
@@ -49,9 +49,9 @@ def generate_cache_key(values_dict: Dict[str, Any], key_prefix: str = "") -> str
def set_and_log_cache(
cache_instance: Cache,
cache_key: str,
- cache_value: Dict[str, Any],
- cache_timeout: Optional[int] = None,
- datasource_uid: Optional[str] = None,
+ cache_value: dict[str, Any],
+ cache_timeout: int | None = None,
+ datasource_uid: str | None = None,
) -> None:
if isinstance(cache_instance.cache, NullCache):
return
@@ -91,11 +91,11 @@ logger = logging.getLogger(__name__)
def view_cache_key(*args: Any, **kwargs: Any) -> str: # pylint: disable=unused-argument
args_hash = hash(frozenset(request.args.items()))
- return "view/{}/{}".format(request.path, args_hash)
+ return f"view/{request.path}/{args_hash}"
def memoized_func(
- key: Optional[str] = None, cache: Cache = cache_manager.cache
+ key: str | None = None, cache: Cache = cache_manager.cache
) -> Callable[..., Any]:
"""
Decorator with configurable key and cache backend.
@@ -152,10 +152,10 @@ def memoized_func(
def etag_cache(
cache: Cache = cache_manager.cache,
- get_last_modified: Optional[Callable[..., datetime]] = None,
- max_age: Optional[Union[int, float]] = None,
- raise_for_access: Optional[Callable[..., Any]] = None,
- skip: Optional[Callable[..., bool]] = None,
+ get_last_modified: Callable[..., datetime] | None = None,
+ max_age: int | float | None = None,
+ raise_for_access: Callable[..., Any] | None = None,
+ skip: Callable[..., bool] | None = None,
) -> Callable[..., Any]:
"""
A decorator for caching views and handling etag conditional requests.
diff --git a/superset/utils/celery.py b/superset/utils/celery.py
index 474fc98d94..3577179145 100644
--- a/superset/utils/celery.py
+++ b/superset/utils/celery.py
@@ -15,8 +15,8 @@
# specific language governing permissions and limitations
# under the License.
import logging
+from collections.abc import Iterator
from contextlib import contextmanager
-from typing import Iterator
from sqlalchemy import create_engine
from sqlalchemy.exc import SQLAlchemyError
diff --git a/superset/utils/core.py b/superset/utils/core.py
index c537abf459..24e539b2b6 100644
--- a/superset/utils/core.py
+++ b/superset/utils/core.py
@@ -35,6 +35,7 @@ import threading
import traceback
import uuid
import zlib
+from collections.abc import Iterable, Iterator, Sequence
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import date, datetime, time, timedelta
@@ -47,24 +48,7 @@ from enum import Enum, IntEnum
from io import BytesIO
from timeit import default_timer
from types import TracebackType
-from typing import (
- Any,
- Callable,
- cast,
- Dict,
- Iterable,
- Iterator,
- List,
- NamedTuple,
- Optional,
- Sequence,
- Set,
- Tuple,
- Type,
- TYPE_CHECKING,
- TypeVar,
- Union,
-)
+from typing import Any, Callable, cast, NamedTuple, TYPE_CHECKING, TypeVar
from urllib.parse import unquote_plus
from zipfile import ZipFile
@@ -197,11 +181,11 @@ class LoggerLevel(str, Enum):
class HeaderDataType(TypedDict):
notification_format: str
- owners: List[int]
+ owners: list[int]
notification_type: str
- notification_source: Optional[str]
- chart_id: Optional[int]
- dashboard_id: Optional[int]
+ notification_source: str | None
+ chart_id: int | None
+ dashboard_id: int | None
class DatasourceDict(TypedDict):
@@ -212,20 +196,20 @@ class DatasourceDict(TypedDict):
class AdhocFilterClause(TypedDict, total=False):
clause: str
expressionType: str
- filterOptionName: Optional[str]
- comparator: Optional[FilterValues]
+ filterOptionName: str | None
+ comparator: FilterValues | None
operator: str
subject: str
- isExtra: Optional[bool]
- sqlExpression: Optional[str]
+ isExtra: bool | None
+ sqlExpression: str | None
class QueryObjectFilterClause(TypedDict, total=False):
col: Column
op: str # pylint: disable=invalid-name
- val: Optional[FilterValues]
- grain: Optional[str]
- isExtra: Optional[bool]
+ val: FilterValues | None
+ grain: str | None
+ isExtra: bool | None
class ExtraFiltersTimeColumnType(str, Enum):
@@ -351,9 +335,9 @@ class ReservedUrlParameters(str, Enum):
EDIT_MODE = "edit"
@staticmethod
- def is_standalone_mode() -> Optional[bool]:
+ def is_standalone_mode() -> bool | None:
standalone_param = request.args.get(ReservedUrlParameters.STANDALONE.value)
- standalone: Optional[bool] = bool(
+ standalone: bool | None = bool(
standalone_param and standalone_param != "false" and standalone_param != "0"
)
return standalone
@@ -370,10 +354,10 @@ class ColumnTypeSource(Enum):
class ColumnSpec(NamedTuple):
- sqla_type: Union[TypeEngine, str]
+ sqla_type: TypeEngine | str
generic_type: GenericDataType
is_dttm: bool
- python_date_format: Optional[str] = None
+ python_date_format: str | None = None
try:
@@ -407,8 +391,8 @@ def flasher(msg: str, severity: str = "message") -> None:
def parse_js_uri_path_item(
- item: Optional[str], unquote: bool = True, eval_undefined: bool = False
-) -> Optional[str]:
+ item: str | None, unquote: bool = True, eval_undefined: bool = False
+) -> str | None:
"""Parse a uri path item made with js.
:param item: a uri path component
@@ -421,7 +405,7 @@ def parse_js_uri_path_item(
return unquote_plus(item) if unquote and item else item
-def cast_to_num(value: Optional[Union[float, int, str]]) -> Optional[Union[float, int]]:
+def cast_to_num(value: float | int | str | None) -> float | int | None:
"""Casts a value to an int/float
>>> cast_to_num('1 ')
@@ -457,7 +441,7 @@ def cast_to_num(value: Optional[Union[float, int, str]]) -> Optional[Union[float
return None
-def cast_to_boolean(value: Any) -> Optional[bool]:
+def cast_to_boolean(value: Any) -> bool | None:
"""Casts a value to an int/float
>>> cast_to_boolean(1)
@@ -487,7 +471,7 @@ def cast_to_boolean(value: Any) -> Optional[bool]:
return False
-def list_minus(l: List[Any], minus: List[Any]) -> List[Any]:
+def list_minus(l: list[Any], minus: list[Any]) -> list[Any]:
"""Returns l without what is in minus
>>> list_minus([1, 2, 3], [2])
@@ -501,12 +485,12 @@ class DashboardEncoder(json.JSONEncoder):
super().__init__(*args, **kwargs)
self.sort_keys = True
- def default(self, o: Any) -> Union[Dict[Any, Any], str]:
+ def default(self, o: Any) -> dict[Any, Any] | str:
if isinstance(o, uuid.UUID):
return str(o)
try:
vals = {k: v for k, v in o.__dict__.items() if k != "_sa_instance_state"}
- return {"__{}__".format(o.__class__.__name__): vals}
+ return {f"__{o.__class__.__name__}__": vals}
except Exception: # pylint: disable=broad-except
if isinstance(o, datetime):
return {"__datetime__": o.replace(microsecond=0).isoformat()}
@@ -519,13 +503,13 @@ class JSONEncodedDict(TypeDecorator): # pylint: disable=abstract-method
impl = TEXT
def process_bind_param(
- self, value: Optional[Dict[Any, Any]], dialect: str
- ) -> Optional[str]:
+ self, value: dict[Any, Any] | None, dialect: str
+ ) -> str | None:
return json.dumps(value) if value is not None else None
def process_result_value(
- self, value: Optional[str], dialect: str
- ) -> Optional[Dict[Any, Any]]:
+ self, value: str | None, dialect: str
+ ) -> dict[Any, Any] | None:
return json.loads(value) if value is not None else None
@@ -634,7 +618,7 @@ def json_int_dttm_ser(obj: Any) -> Any:
return base_json_conv(obj)
-def json_dumps_w_dates(payload: Dict[Any, Any], sort_keys: bool = False) -> str:
+def json_dumps_w_dates(payload: dict[Any, Any], sort_keys: bool = False) -> str:
"""Dumps payload to JSON with Datetime objects properly converted"""
return json.dumps(payload, default=json_int_dttm_ser, sort_keys=sort_keys)
@@ -662,7 +646,7 @@ def error_msg_from_exception(ex: Exception) -> str:
return msg or str(ex)
-def markdown(raw: str, markup_wrap: Optional[bool] = False) -> str:
+def markdown(raw: str, markup_wrap: bool | None = False) -> str:
safe_markdown_tags = {
"h1",
"h2",
@@ -709,15 +693,15 @@ def markdown(raw: str, markup_wrap: Optional[bool] = False) -> str:
return safe
-def readfile(file_path: str) -> Optional[str]:
+def readfile(file_path: str) -> str | None:
with open(file_path) as f:
content = f.read()
return content
def generic_find_constraint_name(
- table: str, columns: Set[str], referenced: str, database: SQLA
-) -> Optional[str]:
+ table: str, columns: set[str], referenced: str, database: SQLA
+) -> str | None:
"""Utility to find a constraint name in alembic migrations"""
tbl = sa.Table(
table, database.metadata, autoload=True, autoload_with=database.engine
@@ -731,8 +715,8 @@ def generic_find_constraint_name(
def generic_find_fk_constraint_name(
- table: str, columns: Set[str], referenced: str, insp: Inspector
-) -> Optional[str]:
+ table: str, columns: set[str], referenced: str, insp: Inspector
+) -> str | None:
"""Utility to find a foreign-key constraint name in alembic migrations"""
for fk in insp.get_foreign_keys(table):
if (
@@ -745,8 +729,8 @@ def generic_find_fk_constraint_name(
def generic_find_fk_constraint_names( # pylint: disable=invalid-name
- table: str, columns: Set[str], referenced: str, insp: Inspector
-) -> Set[str]:
+ table: str, columns: set[str], referenced: str, insp: Inspector
+) -> set[str]:
"""Utility to find foreign-key constraint names in alembic migrations"""
names = set()
@@ -761,8 +745,8 @@ def generic_find_fk_constraint_names( # pylint: disable=invalid-name
def generic_find_uq_constraint_name(
- table: str, columns: Set[str], insp: Inspector
-) -> Optional[str]:
+ table: str, columns: set[str], insp: Inspector
+) -> str | None:
"""Utility to find a unique constraint name in alembic migrations"""
for uq in insp.get_unique_constraints(table):
@@ -773,14 +757,14 @@ def generic_find_uq_constraint_name(
def get_datasource_full_name(
- database_name: str, datasource_name: str, schema: Optional[str] = None
+ database_name: str, datasource_name: str, schema: str | None = None
) -> str:
if not schema:
- return "[{}].[{}]".format(database_name, datasource_name)
- return "[{}].[{}].[{}]".format(database_name, schema, datasource_name)
+ return f"[{database_name}].[{datasource_name}]"
+ return f"[{database_name}].[{schema}].[{datasource_name}]"
-def validate_json(obj: Union[bytes, bytearray, str]) -> None:
+def validate_json(obj: bytes | bytearray | str) -> None:
if obj:
try:
json.loads(obj)
@@ -851,7 +835,7 @@ class TimerTimeout:
# Windows has no support for SIGALRM, so we use the timer based timeout
-timeout: Union[Type[TimerTimeout], Type[SigalrmTimeout]] = (
+timeout: type[TimerTimeout] | type[SigalrmTimeout] = (
TimerTimeout if platform.system() == "Windows" else SigalrmTimeout
)
@@ -897,9 +881,9 @@ def notify_user_about_perm_udate( # pylint: disable=too-many-arguments
granter: User,
user: User,
role: Role,
- datasource: "BaseDatasource",
+ datasource: BaseDatasource,
tpl_name: str,
- config: Dict[str, Any],
+ config: dict[str, Any],
) -> None:
msg = render_template(
tpl_name, granter=granter, user=user, role=role, datasource=datasource
@@ -923,15 +907,15 @@ def send_email_smtp( # pylint: disable=invalid-name,too-many-arguments,too-many
to: str,
subject: str,
html_content: str,
- config: Dict[str, Any],
- files: Optional[List[str]] = None,
- data: Optional[Dict[str, str]] = None,
- images: Optional[Dict[str, bytes]] = None,
+ config: dict[str, Any],
+ files: list[str] | None = None,
+ data: dict[str, str] | None = None,
+ images: dict[str, bytes] | None = None,
dryrun: bool = False,
- cc: Optional[str] = None,
- bcc: Optional[str] = None,
+ cc: str | None = None,
+ bcc: str | None = None,
mime_subtype: str = "mixed",
- header_data: Optional[HeaderDataType] = None,
+ header_data: HeaderDataType | None = None,
) -> None:
"""
Send an email with html content, eg:
@@ -1000,9 +984,9 @@ def send_email_smtp( # pylint: disable=invalid-name,too-many-arguments,too-many
def send_mime_email(
e_from: str,
- e_to: List[str],
+ e_to: list[str],
mime_msg: MIMEMultipart,
- config: Dict[str, Any],
+ config: dict[str, Any],
dryrun: bool = False,
) -> None:
smtp_host = config["SMTP_HOST"]
@@ -1035,8 +1019,8 @@ def send_mime_email(
smtp.quit()
-def get_email_address_list(address_string: str) -> List[str]:
- address_string_list: List[str] = []
+def get_email_address_list(address_string: str) -> list[str]:
+ address_string_list: list[str] = []
if isinstance(address_string, str):
address_string_list = re.split(r",|\s|;", address_string)
return [x.strip() for x in address_string_list if x.strip()]
@@ -1049,12 +1033,12 @@ def get_email_address_str(address_string: str) -> str:
return address_list_str
-def choicify(values: Iterable[Any]) -> List[Tuple[Any, Any]]:
+def choicify(values: Iterable[Any]) -> list[tuple[Any, Any]]:
"""Takes an iterable and makes an iterable of tuples with it"""
return [(v, v) for v in values]
-def zlib_compress(data: Union[bytes, str]) -> bytes:
+def zlib_compress(data: bytes | str) -> bytes:
"""
Compress things in a py2/3 safe fashion
>>> json_str = '{"test": 1}'
@@ -1065,7 +1049,7 @@ def zlib_compress(data: Union[bytes, str]) -> bytes:
return zlib.compress(data)
-def zlib_decompress(blob: bytes, decode: Optional[bool] = True) -> Union[bytes, str]:
+def zlib_decompress(blob: bytes, decode: bool | None = True) -> bytes | str:
"""
Decompress things to a string in a py2/3 safe fashion
>>> json_str = '{"test": 1}'
@@ -1094,12 +1078,12 @@ def simple_filter_to_adhoc(
}
if filter_clause.get("isExtra"):
result["isExtra"] = True
- result["filterOptionName"] = md5_sha_from_dict(cast(Dict[Any, Any], result))
+ result["filterOptionName"] = md5_sha_from_dict(cast(dict[Any, Any], result))
return result
-def form_data_to_adhoc(form_data: Dict[str, Any], clause: str) -> AdhocFilterClause:
+def form_data_to_adhoc(form_data: dict[str, Any], clause: str) -> AdhocFilterClause:
if clause not in ("where", "having"):
raise ValueError(__("Unsupported clause type: %(clause)s", clause=clause))
result: AdhocFilterClause = {
@@ -1107,19 +1091,19 @@ def form_data_to_adhoc(form_data: Dict[str, Any], clause: str) -> AdhocFilterCla
"expressionType": "SQL",
"sqlExpression": form_data.get(clause),
}
- result["filterOptionName"] = md5_sha_from_dict(cast(Dict[Any, Any], result))
+ result["filterOptionName"] = md5_sha_from_dict(cast(dict[Any, Any], result))
return result
-def merge_extra_form_data(form_data: Dict[str, Any]) -> None:
+def merge_extra_form_data(form_data: dict[str, Any]) -> None:
"""
Merge extra form data (appends and overrides) into the main payload
and add applied time extras to the payload.
"""
filter_keys = ["filters", "adhoc_filters"]
extra_form_data = form_data.pop("extra_form_data", {})
- append_filters: List[QueryObjectFilterClause] = extra_form_data.get("filters", None)
+ append_filters: list[QueryObjectFilterClause] = extra_form_data.get("filters", None)
# merge append extras
for key in [key for key in EXTRA_FORM_DATA_APPEND_KEYS if key not in filter_keys]:
@@ -1144,9 +1128,9 @@ def merge_extra_form_data(form_data: Dict[str, Any]) -> None:
if extras:
form_data["extras"] = extras
- adhoc_filters: List[AdhocFilterClause] = form_data.get("adhoc_filters", [])
+ adhoc_filters: list[AdhocFilterClause] = form_data.get("adhoc_filters", [])
form_data["adhoc_filters"] = adhoc_filters
- append_adhoc_filters: List[AdhocFilterClause] = extra_form_data.get(
+ append_adhoc_filters: list[AdhocFilterClause] = extra_form_data.get(
"adhoc_filters", []
)
adhoc_filters.extend(
@@ -1170,7 +1154,7 @@ def merge_extra_form_data(form_data: Dict[str, Any]) -> None:
adhoc_filter["comparator"] = form_data["time_range"]
-def merge_extra_filters(form_data: Dict[str, Any]) -> None:
+def merge_extra_filters(form_data: dict[str, Any]) -> None:
# extra_filters are temporary/contextual filters (using the legacy constructs)
# that are external to the slice definition. We use those for dynamic
# interactive filters like the ones emitted by the "Filter Box" visualization.
@@ -1193,7 +1177,7 @@ def merge_extra_filters(form_data: Dict[str, Any]) -> None:
# Grab list of existing filters 'keyed' on the column and operator
- def get_filter_key(f: Dict[str, Any]) -> str:
+ def get_filter_key(f: dict[str, Any]) -> str:
if "expressionType" in f:
return "{}__{}".format(f["subject"], f["operator"])
@@ -1244,7 +1228,7 @@ def merge_extra_filters(form_data: Dict[str, Any]) -> None:
del form_data["extra_filters"]
-def merge_request_params(form_data: Dict[str, Any], params: Dict[str, Any]) -> None:
+def merge_request_params(form_data: dict[str, Any], params: dict[str, Any]) -> None:
"""
Merge request parameters to the key `url_params` in form_data. Only updates
or appends parameters to `form_data` that are defined in `params; pre-existing
@@ -1261,7 +1245,7 @@ def merge_request_params(form_data: Dict[str, Any], params: Dict[str, Any]) -> N
form_data["url_params"] = url_params
-def user_label(user: User) -> Optional[str]:
+def user_label(user: User) -> str | None:
"""Given a user ORM FAB object, returns a label"""
if user:
if user.first_name and user.last_name:
@@ -1272,7 +1256,7 @@ def user_label(user: User) -> Optional[str]:
return None
-def get_example_default_schema() -> Optional[str]:
+def get_example_default_schema() -> str | None:
"""
Return the default schema of the examples database, if any.
"""
@@ -1295,7 +1279,7 @@ def is_adhoc_column(column: Column) -> TypeGuard[AdhocColumn]:
)
-def get_base_axis_labels(columns: Optional[List[Column]]) -> Tuple[str, ...]:
+def get_base_axis_labels(columns: list[Column] | None) -> tuple[str, ...]:
axis_cols = [
col
for col in columns or []
@@ -1304,14 +1288,12 @@ def get_base_axis_labels(columns: Optional[List[Column]]) -> Tuple[str, ...]:
return tuple(get_column_name(col) for col in axis_cols)
-def get_xaxis_label(columns: Optional[List[Column]]) -> Optional[str]:
+def get_xaxis_label(columns: list[Column] | None) -> str | None:
labels = get_base_axis_labels(columns)
return labels[0] if labels else None
-def get_column_name(
- column: Column, verbose_map: Optional[Dict[str, Any]] = None
-) -> str:
+def get_column_name(column: Column, verbose_map: dict[str, Any] | None = None) -> str:
"""
Extract label from column
@@ -1336,9 +1318,7 @@ def get_column_name(
raise ValueError("Missing label")
-def get_metric_name(
- metric: Metric, verbose_map: Optional[Dict[str, Any]] = None
-) -> str:
+def get_metric_name(metric: Metric, verbose_map: dict[str, Any] | None = None) -> str:
"""
Extract label from metric
@@ -1374,9 +1354,9 @@ def get_metric_name(
def get_column_names(
- columns: Optional[Sequence[Column]],
- verbose_map: Optional[Dict[str, Any]] = None,
-) -> List[str]:
+ columns: Sequence[Column] | None,
+ verbose_map: dict[str, Any] | None = None,
+) -> list[str]:
return [
column
for column in [get_column_name(column, verbose_map) for column in columns or []]
@@ -1385,9 +1365,9 @@ def get_column_names(
def get_metric_names(
- metrics: Optional[Sequence[Metric]],
- verbose_map: Optional[Dict[str, Any]] = None,
-) -> List[str]:
+ metrics: Sequence[Metric] | None,
+ verbose_map: dict[str, Any] | None = None,
+) -> list[str]:
return [
metric
for metric in [get_metric_name(metric, verbose_map) for metric in metrics or []]
@@ -1396,9 +1376,9 @@ def get_metric_names(
def get_first_metric_name(
- metrics: Optional[Sequence[Metric]],
- verbose_map: Optional[Dict[str, Any]] = None,
-) -> Optional[str]:
+ metrics: Sequence[Metric] | None,
+ verbose_map: dict[str, Any] | None = None,
+) -> str | None:
metric_labels = get_metric_names(metrics, verbose_map)
return metric_labels[0] if metric_labels else None
@@ -1417,7 +1397,7 @@ def convert_legacy_filters_into_adhoc( # pylint: disable=invalid-name
mapping = {"having": "having_filters", "where": "filters"}
if not form_data.get("adhoc_filters"):
- adhoc_filters: List[AdhocFilterClause] = []
+ adhoc_filters: list[AdhocFilterClause] = []
form_data["adhoc_filters"] = adhoc_filters
for clause, filters in mapping.items():
@@ -1475,17 +1455,13 @@ def split_adhoc_filters_into_base_filters( # pylint: disable=invalid-name
sql_where_filters.append(sql_expression)
elif clause == "HAVING":
sql_having_filters.append(sql_expression)
- form_data["where"] = " AND ".join(
- ["({})".format(sql) for sql in sql_where_filters]
- )
- form_data["having"] = " AND ".join(
- ["({})".format(sql) for sql in sql_having_filters]
- )
+ form_data["where"] = " AND ".join([f"({sql})" for sql in sql_where_filters])
+ form_data["having"] = " AND ".join([f"({sql})" for sql in sql_having_filters])
form_data["having_filters"] = simple_having_filters
form_data["filters"] = simple_where_filters
-def get_username() -> Optional[str]:
+def get_username() -> str | None:
"""
Get username (if defined) associated with the current user.
@@ -1498,7 +1474,7 @@ def get_username() -> Optional[str]:
return None
-def get_user_id() -> Optional[int]:
+def get_user_id() -> int | None:
"""
Get the user identifier (if defined) associated with the current user.
@@ -1517,7 +1493,7 @@ def get_user_id() -> Optional[int]:
@contextmanager
-def override_user(user: Optional[User], force: bool = True) -> Iterator[Any]:
+def override_user(user: User | None, force: bool = True) -> Iterator[Any]:
"""
Temporarily override the current user per `flask.g` with the specified user.
@@ -1583,7 +1559,7 @@ def create_ssl_cert_file(certificate: str) -> str:
def time_function(
func: Callable[..., FlaskResponse], *args: Any, **kwargs: Any
-) -> Tuple[float, Any]:
+) -> tuple[float, Any]:
"""
Measures the amount of time a function takes to execute in ms
@@ -1603,7 +1579,7 @@ def MediumText() -> Variant: # pylint:disable=invalid-name
def shortid() -> str:
- return "{}".format(uuid.uuid4())[-12:]
+ return f"{uuid.uuid4()}"[-12:]
class DatasourceName(NamedTuple):
@@ -1611,7 +1587,7 @@ class DatasourceName(NamedTuple):
schema: str
-def get_stacktrace() -> Optional[str]:
+def get_stacktrace() -> str | None:
if current_app.config["SHOW_STACKTRACE"]:
return traceback.format_exc()
return None
@@ -1649,7 +1625,7 @@ def split(
yield string[i:]
-def get_iterable(x: Any) -> List[Any]:
+def get_iterable(x: Any) -> list[Any]:
"""
Get an iterable (list) representation of the object.
@@ -1659,7 +1635,7 @@ def get_iterable(x: Any) -> List[Any]:
return x if isinstance(x, list) else [x]
-def get_form_data_token(form_data: Dict[str, Any]) -> str:
+def get_form_data_token(form_data: dict[str, Any]) -> str:
"""
Return the token contained within form data or generate a new one.
@@ -1669,7 +1645,7 @@ def get_form_data_token(form_data: Dict[str, Any]) -> str:
return form_data.get("token") or "token_" + uuid.uuid4().hex[:8]
-def get_column_name_from_column(column: Column) -> Optional[str]:
+def get_column_name_from_column(column: Column) -> str | None:
"""
Extract the physical column that a column is referencing. If the column is
an adhoc column, always returns `None`.
@@ -1682,7 +1658,7 @@ def get_column_name_from_column(column: Column) -> Optional[str]:
return column # type: ignore
-def get_column_names_from_columns(columns: List[Column]) -> List[str]:
+def get_column_names_from_columns(columns: list[Column]) -> list[str]:
"""
Extract the physical columns that a list of columns are referencing. Ignore
adhoc columns
@@ -1693,7 +1669,7 @@ def get_column_names_from_columns(columns: List[Column]) -> List[str]:
return [col for col in map(get_column_name_from_column, columns) if col]
-def get_column_name_from_metric(metric: Metric) -> Optional[str]:
+def get_column_name_from_metric(metric: Metric) -> str | None:
"""
Extract the column that a metric is referencing. If the metric isn't
a simple metric, always returns `None`.
@@ -1704,11 +1680,11 @@ def get_column_name_from_metric(metric: Metric) -> Optional[str]:
if is_adhoc_metric(metric):
metric = cast(AdhocMetric, metric)
if metric["expressionType"] == AdhocMetricExpressionType.SIMPLE:
- return cast(Dict[str, Any], metric["column"])["column_name"]
+ return cast(dict[str, Any], metric["column"])["column_name"]
return None
-def get_column_names_from_metrics(metrics: List[Metric]) -> List[str]:
+def get_column_names_from_metrics(metrics: list[Metric]) -> list[str]:
"""
Extract the columns that a list of metrics are referencing. Excludes all
SQL metrics.
@@ -1721,12 +1697,12 @@ def get_column_names_from_metrics(metrics: List[Metric]) -> List[str]:
def extract_dataframe_dtypes(
df: pd.DataFrame,
- datasource: Optional[Union[BaseDatasource, Query]] = None,
-) -> List[GenericDataType]:
+ datasource: BaseDatasource | Query | None = None,
+) -> list[GenericDataType]:
"""Serialize pandas/numpy dtypes to generic types"""
# omitting string types as those will be the default type
- inferred_type_map: Dict[str, GenericDataType] = {
+ inferred_type_map: dict[str, GenericDataType] = {
"floating": GenericDataType.NUMERIC,
"integer": GenericDataType.NUMERIC,
"mixed-integer-float": GenericDataType.NUMERIC,
@@ -1737,7 +1713,7 @@ def extract_dataframe_dtypes(
"date": GenericDataType.TEMPORAL,
}
- columns_by_name: Dict[str, Any] = {}
+ columns_by_name: dict[str, Any] = {}
if datasource:
for column in datasource.columns:
if isinstance(column, dict):
@@ -1745,7 +1721,7 @@ def extract_dataframe_dtypes(
else:
columns_by_name[column.column_name] = column
- generic_types: List[GenericDataType] = []
+ generic_types: list[GenericDataType] = []
for column in df.columns:
column_object = columns_by_name.get(column)
series = df[column]
@@ -1767,7 +1743,7 @@ def extract_dataframe_dtypes(
return generic_types
-def extract_column_dtype(col: "BaseColumn") -> GenericDataType:
+def extract_column_dtype(col: BaseColumn) -> GenericDataType:
if col.is_temporal:
return GenericDataType.TEMPORAL
if col.is_numeric:
@@ -1776,11 +1752,9 @@ def extract_column_dtype(col: "BaseColumn") -> GenericDataType:
return GenericDataType.STRING
-def indexed(
- items: List[Any], key: Union[str, Callable[[Any], Any]]
-) -> Dict[Any, List[Any]]:
+def indexed(items: list[Any], key: str | Callable[[Any], Any]) -> dict[Any, list[Any]]:
"""Build an index for a list of objects"""
- idx: Dict[Any, Any] = {}
+ idx: dict[Any, Any] = {}
for item in items:
key_ = getattr(item, key) if isinstance(key, str) else key(item)
idx.setdefault(key_, []).append(item)
@@ -1792,14 +1766,14 @@ def is_test() -> bool:
def get_time_filter_status(
- datasource: "BaseDatasource",
- applied_time_extras: Dict[str, str],
-) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]:
- temporal_columns: Set[Any] = {
+ datasource: BaseDatasource,
+ applied_time_extras: dict[str, str],
+) -> tuple[list[dict[str, str]], list[dict[str, str]]]:
+ temporal_columns: set[Any] = {
col.column_name for col in datasource.columns if col.is_dttm
}
- applied: List[Dict[str, str]] = []
- rejected: List[Dict[str, str]] = []
+ applied: list[dict[str, str]] = []
+ rejected: list[dict[str, str]] = []
if time_column := applied_time_extras.get(ExtraFiltersTimeColumnType.TIME_COL):
if time_column in temporal_columns:
applied.append({"column": ExtraFiltersTimeColumnType.TIME_COL})
@@ -1844,14 +1818,14 @@ def format_list(items: Sequence[str], sep: str = ", ", quote: str = '"') -> str:
return sep.join(f"{quote}{x.replace(quote, quote_escaped)}{quote}" for x in items)
-def find_duplicates(items: Iterable[InputType]) -> List[InputType]:
+def find_duplicates(items: Iterable[InputType]) -> list[InputType]:
"""Find duplicate items in an iterable."""
return [item for item, count in collections.Counter(items).items() if count > 1]
def remove_duplicates(
- items: Iterable[InputType], key: Optional[Callable[[InputType], Any]] = None
-) -> List[InputType]:
+ items: Iterable[InputType], key: Callable[[InputType], Any] | None = None
+) -> list[InputType]:
"""Remove duplicate items in an iterable."""
if not key:
return list(dict.fromkeys(items).keys())
@@ -1868,9 +1842,9 @@ def remove_duplicates(
@dataclass
class DateColumn:
col_label: str
- timestamp_format: Optional[str] = None
- offset: Optional[int] = None
- time_shift: Optional[str] = None
+ timestamp_format: str | None = None
+ offset: int | None = None
+ time_shift: str | None = None
def __hash__(self) -> int:
return hash(self.col_label)
@@ -1881,9 +1855,9 @@ class DateColumn:
@classmethod
def get_legacy_time_column(
cls,
- timestamp_format: Optional[str],
- offset: Optional[int],
- time_shift: Optional[str],
+ timestamp_format: str | None,
+ offset: int | None,
+ time_shift: str | None,
) -> DateColumn:
return cls(
timestamp_format=timestamp_format,
@@ -1895,7 +1869,7 @@ class DateColumn:
def normalize_dttm_col(
df: pd.DataFrame,
- dttm_cols: Tuple[DateColumn, ...] = tuple(),
+ dttm_cols: tuple[DateColumn, ...] = tuple(),
) -> None:
for _col in dttm_cols:
if _col.col_label not in df.columns:
@@ -1925,7 +1899,7 @@ def normalize_dttm_col(
df[_col.col_label] += parse_human_timedelta(_col.time_shift)
-def parse_boolean_string(bool_str: Optional[str]) -> bool:
+def parse_boolean_string(bool_str: str | None) -> bool:
"""
Convert a string representation of a true/false value into a boolean
@@ -1956,7 +1930,7 @@ def parse_boolean_string(bool_str: Optional[str]) -> bool:
def apply_max_row_limit(
limit: int,
- max_limit: Optional[int] = None,
+ max_limit: int | None = None,
) -> int:
"""
Override row limit if max global limit is defined
@@ -1979,7 +1953,7 @@ def apply_max_row_limit(
return max_limit
-def create_zip(files: Dict[str, Any]) -> BytesIO:
+def create_zip(files: dict[str, Any]) -> BytesIO:
buf = BytesIO()
with ZipFile(buf, "w") as bundle:
for filename, contents in files.items():
@@ -1989,7 +1963,7 @@ def create_zip(files: Dict[str, Any]) -> BytesIO:
return buf
-def remove_extra_adhoc_filters(form_data: Dict[str, Any]) -> None:
+def remove_extra_adhoc_filters(form_data: dict[str, Any]) -> None:
"""
Remove filters from slice data that originate from a filter box or native filter
"""
diff --git a/superset/utils/csv.py b/superset/utils/csv.py
index a6c834b834..bab14058f2 100644
--- a/superset/utils/csv.py
+++ b/superset/utils/csv.py
@@ -17,7 +17,7 @@
import logging
import re
import urllib.request
-from typing import Any, Dict, Optional
+from typing import Any, Optional
from urllib.error import URLError
import numpy as np
@@ -81,7 +81,7 @@ def df_to_escaped_csv(df: pd.DataFrame, **kwargs: Any) -> Any:
def get_chart_csv_data(
- chart_url: str, auth_cookies: Optional[Dict[str, str]] = None
+ chart_url: str, auth_cookies: Optional[dict[str, str]] = None
) -> Optional[bytes]:
content = None
if auth_cookies:
@@ -98,7 +98,7 @@ def get_chart_csv_data(
def get_chart_dataframe(
- chart_url: str, auth_cookies: Optional[Dict[str, str]] = None
+ chart_url: str, auth_cookies: Optional[dict[str, str]] = None
) -> Optional[pd.DataFrame]:
# Disable all the unnecessary-lambda violations in this function
# pylint: disable=unnecessary-lambda
diff --git a/superset/utils/dashboard_filter_scopes_converter.py b/superset/utils/dashboard_filter_scopes_converter.py
index c0ee64370d..ce89b2a255 100644
--- a/superset/utils/dashboard_filter_scopes_converter.py
+++ b/superset/utils/dashboard_filter_scopes_converter.py
@@ -17,7 +17,7 @@
import json
import logging
from collections import defaultdict
-from typing import Any, Dict, List
+from typing import Any
from shortid import ShortId
@@ -27,11 +27,11 @@ logger = logging.getLogger(__name__)
def convert_filter_scopes(
- json_metadata: Dict[Any, Any], filter_boxes: List[Slice]
-) -> Dict[int, Dict[str, Dict[str, Any]]]:
+ json_metadata: dict[Any, Any], filter_boxes: list[Slice]
+) -> dict[int, dict[str, dict[str, Any]]]:
filter_scopes = {}
- immuned_by_id: List[int] = json_metadata.get("filter_immune_slices") or []
- immuned_by_column: Dict[str, List[int]] = defaultdict(list)
+ immuned_by_id: list[int] = json_metadata.get("filter_immune_slices") or []
+ immuned_by_column: dict[str, list[int]] = defaultdict(list)
for slice_id, columns in json_metadata.get(
"filter_immune_slice_fields", {}
).items():
@@ -39,7 +39,7 @@ def convert_filter_scopes(
immuned_by_column[column].append(int(slice_id))
def add_filter_scope(
- filter_fields: Dict[str, Dict[str, Any]], filter_field: str, filter_id: int
+ filter_fields: dict[str, dict[str, Any]], filter_field: str, filter_id: int
) -> None:
# in case filter field is invalid
if isinstance(filter_field, str):
@@ -54,7 +54,7 @@ def convert_filter_scopes(
logging.info("slice [%i] has invalid field: %s", filter_id, filter_field)
for filter_box in filter_boxes:
- filter_fields: Dict[str, Dict[str, Any]] = {}
+ filter_fields: dict[str, dict[str, Any]] = {}
filter_id = filter_box.id
slice_params = json.loads(filter_box.params or "{}")
configs = slice_params.get("filter_configs") or []
@@ -75,10 +75,10 @@ def convert_filter_scopes(
def copy_filter_scopes(
- old_to_new_slc_id_dict: Dict[int, int],
- old_filter_scopes: Dict[int, Dict[str, Dict[str, Any]]],
-) -> Dict[str, Dict[Any, Any]]:
- new_filter_scopes: Dict[str, Dict[Any, Any]] = {}
+ old_to_new_slc_id_dict: dict[int, int],
+ old_filter_scopes: dict[int, dict[str, dict[str, Any]]],
+) -> dict[str, dict[Any, Any]]:
+ new_filter_scopes: dict[str, dict[Any, Any]] = {}
for filter_id, scopes in old_filter_scopes.items():
new_filter_key = old_to_new_slc_id_dict.get(int(filter_id))
if new_filter_key:
@@ -93,10 +93,10 @@ def copy_filter_scopes(
def convert_filter_scopes_to_native_filters( # pylint: disable=invalid-name,too-many-branches,too-many-locals,too-many-nested-blocks,too-many-statements
- json_metadata: Dict[str, Any],
- position_json: Dict[str, Any],
- filter_boxes: List[Slice],
-) -> List[Dict[str, Any]]:
+ json_metadata: dict[str, Any],
+ position_json: dict[str, Any],
+ filter_boxes: list[Slice],
+) -> list[dict[str, Any]]:
"""
Convert the legacy filter scopes et al. to the native filter configuration.
@@ -121,11 +121,11 @@ def convert_filter_scopes_to_native_filters( # pylint: disable=invalid-name,too
filter_scopes = json_metadata.get("filter_scopes", {})
filter_box_ids = {filter_box.id for filter_box in filter_boxes}
- filter_scope_by_key_and_field: Dict[str, Dict[str, Dict[str, Any]]] = defaultdict(
+ filter_scope_by_key_and_field: dict[str, dict[str, dict[str, Any]]] = defaultdict(
dict
)
- filter_by_key_and_field: Dict[str, Dict[str, Dict[str, Any]]] = defaultdict(dict)
+ filter_by_key_and_field: dict[str, dict[str, dict[str, Any]]] = defaultdict(dict)
# Dense representation of filter scopes, falling back to chart level filter configs
# if the respective filter scope is not defined at the dashboard level.
@@ -150,7 +150,7 @@ def convert_filter_scopes_to_native_filters( # pylint: disable=invalid-name,too
for field, filter_scope in filter_scope_by_key_and_field[key].items():
default = default_filters.get(key, {}).get(field)
- fltr: Dict[str, Any] = {
+ fltr: dict[str, Any] = {
"cascadeParentIds": [],
"id": f"NATIVE_FILTER-{shortid.generate()}",
"scope": {
diff --git a/superset/utils/database.py b/superset/utils/database.py
index 750d873d1c..70730554f3 100644
--- a/superset/utils/database.py
+++ b/superset/utils/database.py
@@ -17,7 +17,7 @@
from __future__ import annotations
import logging
-from typing import Optional, TYPE_CHECKING
+from typing import TYPE_CHECKING
from flask import current_app
@@ -32,7 +32,7 @@ logger = logging.getLogger(__name__)
# TODO: duplicate code with DatabaseDao, below function should be moved or use dao
def get_or_create_db(
- database_name: str, sqlalchemy_uri: str, always_create: Optional[bool] = True
+ database_name: str, sqlalchemy_uri: str, always_create: bool | None = True
) -> Database:
# pylint: disable=import-outside-toplevel
from superset import db
diff --git a/superset/utils/date_parser.py b/superset/utils/date_parser.py
index 7cdc23784a..438e379a96 100644
--- a/superset/utils/date_parser.py
+++ b/superset/utils/date_parser.py
@@ -20,7 +20,7 @@ import re
from datetime import datetime, timedelta
from functools import lru_cache
from time import struct_time
-from typing import Dict, List, Optional, Tuple
+from typing import Optional
import pandas as pd
import parsedatetime
@@ -75,7 +75,7 @@ def parse_human_datetime(human_readable: str) -> datetime:
return dttm
-def normalize_time_delta(human_readable: str) -> Dict[str, int]:
+def normalize_time_delta(human_readable: str) -> dict[str, int]:
x_unit = r"^\s*([0-9]+)\s+(second|minute|hour|day|week|month|quarter|year)s?\s+(ago|later)*$" # pylint: disable=line-too-long,useless-suppression
matched = re.match(x_unit, human_readable, re.IGNORECASE)
if not matched:
@@ -149,7 +149,7 @@ def get_since_until( # pylint: disable=too-many-arguments,too-many-locals,too-m
time_shift: Optional[str] = None,
relative_start: Optional[str] = None,
relative_end: Optional[str] = None,
-) -> Tuple[Optional[datetime], Optional[datetime]]:
+) -> tuple[Optional[datetime], Optional[datetime]]:
"""Return `since` and `until` date time tuple from string representations of
time_range, since, until and time_shift.
@@ -227,7 +227,7 @@ def get_since_until( # pylint: disable=too-many-arguments,too-many-locals,too-m
]
since_and_until_partition = [_.strip() for _ in time_range.split(separator, 1)]
- since_and_until: List[Optional[str]] = []
+ since_and_until: list[Optional[str]] = []
for part in since_and_until_partition:
if not part:
# if since or until is "", set as None
diff --git a/superset/utils/decorators.py b/superset/utils/decorators.py
index e77a559905..4ecd2eca98 100644
--- a/superset/utils/decorators.py
+++ b/superset/utils/decorators.py
@@ -17,9 +17,10 @@
from __future__ import annotations
import time
+from collections.abc import Iterator
from contextlib import contextmanager
from functools import wraps
-from typing import Any, Callable, Dict, Iterator, Optional, TYPE_CHECKING, Union
+from typing import Any, Callable, TYPE_CHECKING
from flask import current_app, Response
@@ -32,7 +33,7 @@ if TYPE_CHECKING:
from superset.stats_logger import BaseStatsLogger
-def statsd_gauge(metric_prefix: Optional[str] = None) -> Callable[..., Any]:
+def statsd_gauge(metric_prefix: str | None = None) -> Callable[..., Any]:
def decorate(f: Callable[..., Any]) -> Callable[..., Any]:
"""
Handle sending statsd gauge metric from any method or function
@@ -83,13 +84,13 @@ def arghash(args: Any, kwargs: Any) -> int:
return hash(sorted_args)
-def debounce(duration: Union[float, int] = 0.1) -> Callable[..., Any]:
+def debounce(duration: float | int = 0.1) -> Callable[..., Any]:
"""Ensure a function called with the same arguments executes only once
per `duration` (default: 100ms).
"""
def decorate(f: Callable[..., Any]) -> Callable[..., Any]:
- last: Dict[str, Any] = {"t": None, "input": None, "output": None}
+ last: dict[str, Any] = {"t": None, "input": None, "output": None}
def wrapped(*args: Any, **kwargs: Any) -> Any:
now = time.time()
diff --git a/superset/utils/dict_import_export.py b/superset/utils/dict_import_export.py
index 93070732e7..f3fb1bbd6c 100644
--- a/superset/utils/dict_import_export.py
+++ b/superset/utils/dict_import_export.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import Any, Dict
+from typing import Any
from sqlalchemy.orm import Session
@@ -26,7 +26,7 @@ DATABASES_KEY = "databases"
logger = logging.getLogger(__name__)
-def export_schema_to_dict(back_references: bool) -> Dict[str, Any]:
+def export_schema_to_dict(back_references: bool) -> dict[str, Any]:
"""Exports the supported import/export schema to a dictionary"""
databases = [
Database.export_schema(recursive=True, include_parent_ref=back_references)
@@ -39,7 +39,7 @@ def export_schema_to_dict(back_references: bool) -> Dict[str, Any]:
def export_to_dict(
session: Session, recursive: bool, back_references: bool, include_defaults: bool
-) -> Dict[str, Any]:
+) -> dict[str, Any]:
"""Exports databases and druid clusters to a dictionary"""
logger.info("Starting export")
dbs = session.query(Database)
diff --git a/superset/utils/encrypt.py b/superset/utils/encrypt.py
index 52b784bb23..c812581ac4 100644
--- a/superset/utils/encrypt.py
+++ b/superset/utils/encrypt.py
@@ -16,7 +16,7 @@
# under the License.
import logging
from abc import ABC, abstractmethod
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from flask import Flask
from flask_babel import lazy_gettext as _
@@ -31,9 +31,9 @@ class AbstractEncryptedFieldAdapter(ABC): # pylint: disable=too-few-public-meth
@abstractmethod
def create(
self,
- app_config: Optional[Dict[str, Any]],
- *args: List[Any],
- **kwargs: Optional[Dict[str, Any]],
+ app_config: Optional[dict[str, Any]],
+ *args: list[Any],
+ **kwargs: Optional[dict[str, Any]],
) -> TypeDecorator:
pass
@@ -43,9 +43,9 @@ class SQLAlchemyUtilsAdapter( # pylint: disable=too-few-public-methods
):
def create(
self,
- app_config: Optional[Dict[str, Any]],
- *args: List[Any],
- **kwargs: Optional[Dict[str, Any]],
+ app_config: Optional[dict[str, Any]],
+ *args: list[Any],
+ **kwargs: Optional[dict[str, Any]],
) -> TypeDecorator:
if app_config:
return EncryptedType(*args, app_config["SECRET_KEY"], **kwargs)
@@ -56,7 +56,7 @@ class SQLAlchemyUtilsAdapter( # pylint: disable=too-few-public-methods
class EncryptedFieldFactory:
def __init__(self) -> None:
self._concrete_type_adapter: Optional[AbstractEncryptedFieldAdapter] = None
- self._config: Optional[Dict[str, Any]] = None
+ self._config: Optional[dict[str, Any]] = None
def init_app(self, app: Flask) -> None:
self._config = app.config
@@ -65,7 +65,7 @@ class EncryptedFieldFactory:
]()
def create(
- self, *args: List[Any], **kwargs: Optional[Dict[str, Any]]
+ self, *args: list[Any], **kwargs: Optional[dict[str, Any]]
) -> TypeDecorator:
if self._concrete_type_adapter:
return self._concrete_type_adapter.create(self._config, *args, **kwargs)
@@ -81,14 +81,14 @@ class SecretsMigrator:
self._previous_secret_key = previous_secret_key
self._dialect: Dialect = db.engine.url.get_dialect()
- def discover_encrypted_fields(self) -> Dict[str, Dict[str, EncryptedType]]:
+ def discover_encrypted_fields(self) -> dict[str, dict[str, EncryptedType]]:
"""
Iterates over SqlAlchemy's metadata, looking for EncryptedType
columns along the way. Builds up a dict of
table_name -> dict of col_name: enc type instance
:return:
"""
- meta_info: Dict[str, Any] = {}
+ meta_info: dict[str, Any] = {}
for table_name, table in self._db.metadata.tables.items():
for col_name, col in table.columns.items():
@@ -120,7 +120,7 @@ class SecretsMigrator:
@staticmethod
def _select_columns_from_table(
- conn: Connection, column_names: List[str], table_name: str
+ conn: Connection, column_names: list[str], table_name: str
) -> Row:
return conn.execute(f"SELECT id, {','.join(column_names)} FROM {table_name}")
@@ -129,7 +129,7 @@ class SecretsMigrator:
conn: Connection,
row: Row,
table_name: str,
- columns: Dict[str, EncryptedType],
+ columns: dict[str, EncryptedType],
) -> None:
"""
Re encrypts all columns in a Row
diff --git a/superset/utils/feature_flag_manager.py b/superset/utils/feature_flag_manager.py
index 9874656722..ea295c776c 100644
--- a/superset/utils/feature_flag_manager.py
+++ b/superset/utils/feature_flag_manager.py
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
from copy import deepcopy
-from typing import Dict
from flask import Flask
@@ -25,7 +24,7 @@ class FeatureFlagManager:
super().__init__()
self._get_feature_flags_func = None
self._is_feature_enabled_func = None
- self._feature_flags: Dict[str, bool] = {}
+ self._feature_flags: dict[str, bool] = {}
def init_app(self, app: Flask) -> None:
self._get_feature_flags_func = app.config["GET_FEATURE_FLAGS_FUNC"]
@@ -33,7 +32,7 @@ class FeatureFlagManager:
self._feature_flags = app.config["DEFAULT_FEATURE_FLAGS"]
self._feature_flags.update(app.config["FEATURE_FLAGS"])
- def get_feature_flags(self) -> Dict[str, bool]:
+ def get_feature_flags(self) -> dict[str, bool]:
if self._get_feature_flags_func:
return self._get_feature_flags_func(deepcopy(self._feature_flags))
if callable(self._is_feature_enabled_func):
diff --git a/superset/utils/filters.py b/superset/utils/filters.py
index 4772f49ba0..88154a40b3 100644
--- a/superset/utils/filters.py
+++ b/superset/utils/filters.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Type
+from typing import Any
from flask_appbuilder import Model
from sqlalchemy import or_
@@ -22,7 +22,7 @@ from sqlalchemy.sql.elements import BooleanClauseList
def get_dataset_access_filters(
- base_model: Type[Model],
+ base_model: type[Model],
*args: Any,
) -> BooleanClauseList:
# pylint: disable=import-outside-toplevel
diff --git a/superset/utils/hashing.py b/superset/utils/hashing.py
index 66983582ca..fff654263e 100644
--- a/superset/utils/hashing.py
+++ b/superset/utils/hashing.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import hashlib
-from typing import Any, Callable, Dict, Optional
+from typing import Any, Callable, Optional
import simplejson as json
@@ -25,7 +25,7 @@ def md5_sha_from_str(val: str) -> str:
def md5_sha_from_dict(
- obj: Dict[Any, Any],
+ obj: dict[Any, Any],
ignore_nan: bool = False,
default: Optional[Callable[[Any], Any]] = None,
) -> str:
diff --git a/superset/utils/log.py b/superset/utils/log.py
index f2379fe11c..5430accb43 100644
--- a/superset/utils/log.py
+++ b/superset/utils/log.py
@@ -22,25 +22,14 @@ import json
import logging
import textwrap
from abc import ABC, abstractmethod
+from collections.abc import Iterator
from contextlib import contextmanager
from datetime import datetime, timedelta
-from typing import (
- Any,
- Callable,
- cast,
- Dict,
- Iterator,
- Optional,
- Tuple,
- Type,
- TYPE_CHECKING,
- Union,
-)
+from typing import Any, Callable, cast, Literal, TYPE_CHECKING
from flask import current_app, g, request
from flask_appbuilder.const import API_URI_RIS_KEY
from sqlalchemy.exc import SQLAlchemyError
-from typing_extensions import Literal
from superset.extensions import stats_logger_manager
from superset.utils.core import get_user_id, LoggerLevel
@@ -51,12 +40,12 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-def collect_request_payload() -> Dict[str, Any]:
+def collect_request_payload() -> dict[str, Any]:
"""Collect log payload identifiable from request context"""
if not request:
return {}
- payload: Dict[str, Any] = {
+ payload: dict[str, Any] = {
"path": request.path,
**request.form.to_dict(),
# url search params can overwrite POST body
@@ -81,7 +70,7 @@ def collect_request_payload() -> Dict[str, Any]:
def get_logger_from_status(
status: int,
-) -> Tuple[Callable[..., None], str]:
+) -> tuple[Callable[..., None], str]:
"""
Return logger method by status of exception.
Maps logger level to status code level
@@ -101,10 +90,10 @@ class AbstractEventLogger(ABC):
def __call__(
self,
action: str,
- object_ref: Optional[str] = None,
+ object_ref: str | None = None,
log_to_statsd: bool = True,
- duration: Optional[timedelta] = None,
- **payload_override: Dict[str, Any],
+ duration: timedelta | None = None,
+ **payload_override: dict[str, Any],
) -> object:
# pylint: disable=W0201
self.action = action
@@ -130,12 +119,12 @@ class AbstractEventLogger(ABC):
@abstractmethod
def log( # pylint: disable=too-many-arguments
self,
- user_id: Optional[int],
+ user_id: int | None,
action: str,
- dashboard_id: Optional[int],
- duration_ms: Optional[int],
- slice_id: Optional[int],
- referrer: Optional[str],
+ dashboard_id: int | None,
+ duration_ms: int | None,
+ slice_id: int | None,
+ referrer: str | None,
*args: Any,
**kwargs: Any,
) -> None:
@@ -144,10 +133,10 @@ class AbstractEventLogger(ABC):
def log_with_context( # pylint: disable=too-many-locals
self,
action: str,
- duration: Optional[timedelta] = None,
- object_ref: Optional[str] = None,
+ duration: timedelta | None = None,
+ object_ref: str | None = None,
log_to_statsd: bool = True,
- **payload_override: Optional[Dict[str, Any]],
+ **payload_override: dict[str, Any] | None,
) -> None:
# pylint: disable=import-outside-toplevel
from superset.views.core import get_form_data
@@ -176,7 +165,7 @@ class AbstractEventLogger(ABC):
if payload_override:
payload.update(payload_override)
- dashboard_id: Optional[int] = None
+ dashboard_id: int | None = None
try:
dashboard_id = int(payload.get("dashboard_id")) # type: ignore
except (TypeError, ValueError):
@@ -218,7 +207,7 @@ class AbstractEventLogger(ABC):
def log_context(
self,
action: str,
- object_ref: Optional[str] = None,
+ object_ref: str | None = None,
log_to_statsd: bool = True,
) -> Iterator[Callable[..., None]]:
"""
@@ -242,9 +231,9 @@ class AbstractEventLogger(ABC):
def _wrapper(
self,
f: Callable[..., Any],
- action: Optional[Union[str, Callable[..., str]]] = None,
- object_ref: Optional[Union[str, Callable[..., str], Literal[False]]] = None,
- allow_extra_payload: Optional[bool] = False,
+ action: str | Callable[..., str] | None = None,
+ object_ref: str | Callable[..., str] | Literal[False] | None = None,
+ allow_extra_payload: bool | None = False,
**wrapper_kwargs: Any,
) -> Callable[..., Any]:
@functools.wraps(f)
@@ -314,7 +303,7 @@ def get_event_logger_from_cfg_value(cfg_value: Any) -> AbstractEventLogger:
)
)
- event_logger_type = cast(Type[Any], cfg_value)
+ event_logger_type = cast(type[Any], cfg_value)
result = event_logger_type()
# Verify that we have a valid logger impl
@@ -333,12 +322,12 @@ class DBEventLogger(AbstractEventLogger):
def log( # pylint: disable=too-many-arguments,too-many-locals
self,
- user_id: Optional[int],
+ user_id: int | None,
action: str,
- dashboard_id: Optional[int],
- duration_ms: Optional[int],
- slice_id: Optional[int],
- referrer: Optional[str],
+ dashboard_id: int | None,
+ duration_ms: int | None,
+ slice_id: int | None,
+ referrer: str | None,
*args: Any,
**kwargs: Any,
) -> None:
@@ -348,7 +337,7 @@ class DBEventLogger(AbstractEventLogger):
records = kwargs.get("records", [])
logs = []
for record in records:
- json_string: Optional[str]
+ json_string: str | None
try:
json_string = json.dumps(record)
except Exception: # pylint: disable=broad-except
diff --git a/superset/utils/machine_auth.py b/superset/utils/machine_auth.py
index 02c04abe6a..7e45fc0f31 100644
--- a/superset/utils/machine_auth.py
+++ b/superset/utils/machine_auth.py
@@ -19,7 +19,7 @@ from __future__ import annotations
import importlib
import logging
-from typing import Callable, Dict, TYPE_CHECKING
+from typing import Callable, TYPE_CHECKING
from flask import current_app, Flask, request, Response, session
from flask_login import login_user
@@ -71,7 +71,7 @@ class MachineAuthProvider:
return driver
@staticmethod
- def get_auth_cookies(user: User) -> Dict[str, str]:
+ def get_auth_cookies(user: User) -> dict[str, str]:
# Login with the user specified to get the reports
with current_app.test_request_context("/login"):
login_user(user)
diff --git a/superset/utils/mock_data.py b/superset/utils/mock_data.py
index 4b156cc10c..7462d14540 100644
--- a/superset/utils/mock_data.py
+++ b/superset/utils/mock_data.py
@@ -21,8 +21,9 @@ import os
import random
import string
import sys
+from collections.abc import Iterator
from datetime import date, datetime, time, timedelta
-from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Type
+from typing import Any, Callable, cast, Optional
from uuid import uuid4
import sqlalchemy.sql.sqltypes
@@ -39,17 +40,14 @@ from superset import db
logger = logging.getLogger(__name__)
-ColumnInfo = TypedDict(
- "ColumnInfo",
- {
- "name": str,
- "type": VisitableType,
- "nullable": bool,
- "default": Optional[Any],
- "autoincrement": str,
- "primary_key": int,
- },
-)
+
+class ColumnInfo(TypedDict):
+ name: str
+ type: VisitableType
+ nullable: bool
+ default: Optional[Any]
+ autoincrement: str
+ primary_key: int
example_column = {
@@ -167,7 +165,7 @@ def get_type_generator( # pylint: disable=too-many-return-statements,too-many-b
def add_data(
- columns: Optional[List[ColumnInfo]],
+ columns: Optional[list[ColumnInfo]],
num_rows: int,
table_name: str,
append: bool = True,
@@ -212,16 +210,16 @@ def add_data(
engine.execute(table.insert(), data)
-def get_column_objects(columns: List[ColumnInfo]) -> List[Column]:
+def get_column_objects(columns: list[ColumnInfo]) -> list[Column]:
out = []
for column in columns:
- kwargs = cast(Dict[str, Any], column.copy())
+ kwargs = cast(dict[str, Any], column.copy())
kwargs["type_"] = kwargs.pop("type")
out.append(Column(**kwargs))
return out
-def generate_data(columns: List[ColumnInfo], num_rows: int) -> List[Dict[str, Any]]:
+def generate_data(columns: list[ColumnInfo], num_rows: int) -> list[dict[str, Any]]:
keys = [column["name"] for column in columns]
return [
dict(zip(keys, row))
@@ -229,13 +227,13 @@ def generate_data(columns: List[ColumnInfo], num_rows: int) -> List[Dict[str, An
]
-def generate_column_data(column: ColumnInfo, num_rows: int) -> List[Any]:
+def generate_column_data(column: ColumnInfo, num_rows: int) -> list[Any]:
gen = get_type_generator(column["type"])
return [gen() for _ in range(num_rows)]
def add_sample_rows(
- session: Session, model: Type[Model], count: int
+ session: Session, model: type[Model], count: int
) -> Iterator[Model]:
"""
Add entities of a given model.
diff --git a/superset/utils/network.py b/superset/utils/network.py
index 7a1aea5a71..fea3cfc6b2 100644
--- a/superset/utils/network.py
+++ b/superset/utils/network.py
@@ -32,10 +32,10 @@ def is_port_open(host: str, port: int) -> bool:
s = socket.socket(af, socket.SOCK_STREAM)
try:
s.settimeout(PORT_TIMEOUT)
- s.connect((sockaddr))
+ s.connect(sockaddr)
s.shutdown(socket.SHUT_RDWR)
return True
- except socket.error as _:
+ except OSError as _:
continue
finally:
s.close()
diff --git a/superset/utils/pandas_postprocessing/aggregate.py b/superset/utils/pandas_postprocessing/aggregate.py
index a863d260c5..1116e4ec70 100644
--- a/superset/utils/pandas_postprocessing/aggregate.py
+++ b/superset/utils/pandas_postprocessing/aggregate.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, List
+from typing import Any
from pandas import DataFrame
@@ -26,7 +26,7 @@ from superset.utils.pandas_postprocessing.utils import (
@validate_column_args("groupby")
def aggregate(
- df: DataFrame, groupby: List[str], aggregates: Dict[str, Dict[str, Any]]
+ df: DataFrame, groupby: list[str], aggregates: dict[str, dict[str, Any]]
) -> DataFrame:
"""
Apply aggregations to a DataFrame.
diff --git a/superset/utils/pandas_postprocessing/boxplot.py b/superset/utils/pandas_postprocessing/boxplot.py
index 399cf569fb..f9fed40e59 100644
--- a/superset/utils/pandas_postprocessing/boxplot.py
+++ b/superset/utils/pandas_postprocessing/boxplot.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
+from typing import Any, Callable, Optional, Union
import numpy as np
from flask_babel import gettext as _
@@ -27,11 +27,11 @@ from superset.utils.pandas_postprocessing.aggregate import aggregate
def boxplot(
df: DataFrame,
- groupby: List[str],
- metrics: List[str],
+ groupby: list[str],
+ metrics: list[str],
whisker_type: PostProcessingBoxplotWhiskerType,
percentiles: Optional[
- Union[List[Union[int, float]], Tuple[Union[int, float], Union[int, float]]]
+ Union[list[Union[int, float]], tuple[Union[int, float], Union[int, float]]]
] = None,
) -> DataFrame:
"""
@@ -102,12 +102,12 @@ def boxplot(
whisker_high = np.max
whisker_low = np.min
- def outliers(series: Series) -> Set[float]:
+ def outliers(series: Series) -> set[float]:
above = series[series > whisker_high(series)]
below = series[series < whisker_low(series)]
return above.tolist() + below.tolist()
- operators: Dict[str, Callable[[Any], Any]] = {
+ operators: dict[str, Callable[[Any], Any]] = {
"mean": np.mean,
"median": np.median,
"max": whisker_high,
@@ -117,7 +117,7 @@ def boxplot(
"count": np.ma.count,
"outliers": outliers,
}
- aggregates: Dict[str, Dict[str, Union[str, Callable[..., Any]]]] = {
+ aggregates: dict[str, dict[str, Union[str, Callable[..., Any]]]] = {
f"{metric}__{operator_name}": {"column": metric, "operator": operator}
for operator_name, operator in operators.items()
for metric in metrics
diff --git a/superset/utils/pandas_postprocessing/compare.py b/superset/utils/pandas_postprocessing/compare.py
index f7c8365508..b20682027f 100644
--- a/superset/utils/pandas_postprocessing/compare.py
+++ b/superset/utils/pandas_postprocessing/compare.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import List, Optional
+from typing import Optional
import pandas as pd
from flask_babel import gettext as _
@@ -29,8 +29,8 @@ from superset.utils.pandas_postprocessing.utils import validate_column_args
@validate_column_args("source_columns", "compare_columns")
def compare( # pylint: disable=too-many-arguments
df: DataFrame,
- source_columns: List[str],
- compare_columns: List[str],
+ source_columns: list[str],
+ compare_columns: list[str],
compare_type: PandasPostprocessingCompare,
drop_original_columns: Optional[bool] = False,
precision: Optional[int] = 4,
diff --git a/superset/utils/pandas_postprocessing/contribution.py b/superset/utils/pandas_postprocessing/contribution.py
index f8519f39a9..d383312f75 100644
--- a/superset/utils/pandas_postprocessing/contribution.py
+++ b/superset/utils/pandas_postprocessing/contribution.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from decimal import Decimal
-from typing import List, Optional
+from typing import Optional
from flask_babel import gettext as _
from pandas import DataFrame
@@ -31,8 +31,8 @@ def contribution(
orientation: Optional[
PostProcessingContributionOrientation
] = PostProcessingContributionOrientation.COLUMN,
- columns: Optional[List[str]] = None,
- rename_columns: Optional[List[str]] = None,
+ columns: Optional[list[str]] = None,
+ rename_columns: Optional[list[str]] = None,
) -> DataFrame:
"""
Calculate cell contribution to row/column total for numeric columns.
diff --git a/superset/utils/pandas_postprocessing/cum.py b/superset/utils/pandas_postprocessing/cum.py
index b94f048e5c..128fa970f5 100644
--- a/superset/utils/pandas_postprocessing/cum.py
+++ b/superset/utils/pandas_postprocessing/cum.py
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Dict
from flask_babel import gettext as _
from pandas import DataFrame
@@ -31,7 +30,7 @@ from superset.utils.pandas_postprocessing.utils import (
def cum(
df: DataFrame,
operator: str,
- columns: Dict[str, str],
+ columns: dict[str, str],
) -> DataFrame:
"""
Calculate cumulative sum/product/min/max for select columns.
diff --git a/superset/utils/pandas_postprocessing/diff.py b/superset/utils/pandas_postprocessing/diff.py
index 0cead2de8d..de68d39439 100644
--- a/superset/utils/pandas_postprocessing/diff.py
+++ b/superset/utils/pandas_postprocessing/diff.py
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Dict
from pandas import DataFrame
@@ -28,7 +27,7 @@ from superset.utils.pandas_postprocessing.utils import (
@validate_column_args("columns")
def diff(
df: DataFrame,
- columns: Dict[str, str],
+ columns: dict[str, str],
periods: int = 1,
axis: PandasAxis = PandasAxis.ROW,
) -> DataFrame:
diff --git a/superset/utils/pandas_postprocessing/flatten.py b/superset/utils/pandas_postprocessing/flatten.py
index da9954ef11..40db86db06 100644
--- a/superset/utils/pandas_postprocessing/flatten.py
+++ b/superset/utils/pandas_postprocessing/flatten.py
@@ -15,8 +15,8 @@
# specific language governing permissions and limitations
# under the License.
-from collections.abc import Iterable
-from typing import Any, Sequence, Union
+from collections.abc import Iterable, Sequence
+from typing import Any, Union
import pandas as pd
diff --git a/superset/utils/pandas_postprocessing/geography.py b/superset/utils/pandas_postprocessing/geography.py
index 33a27c2df4..79046cb71a 100644
--- a/superset/utils/pandas_postprocessing/geography.py
+++ b/superset/utils/pandas_postprocessing/geography.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Optional, Tuple
+from typing import Optional
import geohash as geohash_lib
from flask_babel import gettext as _
@@ -95,7 +95,7 @@ def geodetic_parse(
:return: DataFrame with decoded longitudes and latitudes
"""
- def _parse_location(location: str) -> Tuple[float, float, float]:
+ def _parse_location(location: str) -> tuple[float, float, float]:
"""
Parse a string containing a geodetic point and return latitude, longitude
and altitude
diff --git a/superset/utils/pandas_postprocessing/pivot.py b/superset/utils/pandas_postprocessing/pivot.py
index df5fa7e37c..28e8ff380f 100644
--- a/superset/utils/pandas_postprocessing/pivot.py
+++ b/superset/utils/pandas_postprocessing/pivot.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from flask_babel import gettext as _
from pandas import DataFrame
@@ -30,9 +30,9 @@ from superset.utils.pandas_postprocessing.utils import (
@validate_column_args("index", "columns")
def pivot( # pylint: disable=too-many-arguments
df: DataFrame,
- index: List[str],
- aggregates: Dict[str, Dict[str, Any]],
- columns: Optional[List[str]] = None,
+ index: list[str],
+ aggregates: dict[str, dict[str, Any]],
+ columns: Optional[list[str]] = None,
metric_fill_value: Optional[Any] = None,
column_fill_value: Optional[str] = NULL_STRING,
drop_missing_columns: Optional[bool] = True,
diff --git a/superset/utils/pandas_postprocessing/rename.py b/superset/utils/pandas_postprocessing/rename.py
index 0e35a651a8..4bcd19782c 100644
--- a/superset/utils/pandas_postprocessing/rename.py
+++ b/superset/utils/pandas_postprocessing/rename.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Dict, Optional, Union
+from typing import Optional, Union
import pandas as pd
from flask_babel import gettext as _
@@ -27,7 +27,7 @@ from superset.utils.pandas_postprocessing.utils import validate_column_args
@validate_column_args("columns")
def rename(
df: pd.DataFrame,
- columns: Dict[str, Union[str, None]],
+ columns: dict[str, Union[str, None]],
inplace: bool = False,
level: Optional[Level] = None,
) -> pd.DataFrame:
diff --git a/superset/utils/pandas_postprocessing/rolling.py b/superset/utils/pandas_postprocessing/rolling.py
index 885032eb17..f93a047be9 100644
--- a/superset/utils/pandas_postprocessing/rolling.py
+++ b/superset/utils/pandas_postprocessing/rolling.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, Optional, Union
+from typing import Any, Optional, Union
from flask_babel import gettext as _
from pandas import DataFrame
@@ -31,9 +31,9 @@ from superset.utils.pandas_postprocessing.utils import (
def rolling( # pylint: disable=too-many-arguments
df: DataFrame,
rolling_type: str,
- columns: Dict[str, str],
+ columns: dict[str, str],
window: Optional[int] = None,
- rolling_type_options: Optional[Dict[str, Any]] = None,
+ rolling_type_options: Optional[dict[str, Any]] = None,
center: bool = False,
win_type: Optional[str] = None,
min_periods: Optional[int] = None,
@@ -62,7 +62,7 @@ def rolling( # pylint: disable=too-many-arguments
rolling_type_options = rolling_type_options or {}
df_rolling = df.loc[:, columns.keys()]
- kwargs: Dict[str, Union[str, int]] = {}
+ kwargs: dict[str, Union[str, int]] = {}
if window is None:
raise InvalidPostProcessingError(_("Undefined window for rolling operation"))
if window == 0:
diff --git a/superset/utils/pandas_postprocessing/select.py b/superset/utils/pandas_postprocessing/select.py
index 59fe886d4d..c4e02508df 100644
--- a/superset/utils/pandas_postprocessing/select.py
+++ b/superset/utils/pandas_postprocessing/select.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Dict, List, Optional
+from typing import Optional
from pandas import DataFrame
@@ -24,9 +24,9 @@ from superset.utils.pandas_postprocessing.utils import validate_column_args
@validate_column_args("columns", "drop", "rename")
def select(
df: DataFrame,
- columns: Optional[List[str]] = None,
- exclude: Optional[List[str]] = None,
- rename: Optional[Dict[str, str]] = None,
+ columns: Optional[list[str]] = None,
+ exclude: Optional[list[str]] = None,
+ rename: Optional[dict[str, str]] = None,
) -> DataFrame:
"""
Only select a subset of columns in the original dataset. Can be useful for
diff --git a/superset/utils/pandas_postprocessing/sort.py b/superset/utils/pandas_postprocessing/sort.py
index 66041a7166..b6470c3546 100644
--- a/superset/utils/pandas_postprocessing/sort.py
+++ b/superset/utils/pandas_postprocessing/sort.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import List, Optional, Union
+from typing import Optional, Union
from pandas import DataFrame
@@ -26,8 +26,8 @@ from superset.utils.pandas_postprocessing.utils import validate_column_args
def sort(
df: DataFrame,
is_sort_index: bool = False,
- by: Optional[Union[List[str], str]] = None,
- ascending: Union[List[bool], bool] = True,
+ by: Optional[Union[list[str], str]] = None,
+ ascending: Union[list[bool], bool] = True,
) -> DataFrame:
"""
Sort a DataFrame.
diff --git a/superset/utils/pandas_postprocessing/utils.py b/superset/utils/pandas_postprocessing/utils.py
index 2b754fbbef..37d53697cb 100644
--- a/superset/utils/pandas_postprocessing/utils.py
+++ b/superset/utils/pandas_postprocessing/utils.py
@@ -14,8 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+from collections.abc import Sequence
from functools import partial
-from typing import Any, Callable, Dict, Sequence
+from typing import Any, Callable
import numpy as np
import pandas as pd
@@ -24,7 +25,7 @@ from pandas import DataFrame, NamedAgg
from superset.exceptions import InvalidPostProcessingError
-NUMPY_FUNCTIONS: Dict[str, Callable[..., Any]] = {
+NUMPY_FUNCTIONS: dict[str, Callable[..., Any]] = {
"average": np.average,
"argmin": np.argmin,
"argmax": np.argmax,
@@ -133,8 +134,8 @@ def validate_column_args(*argnames: str) -> Callable[..., Any]:
def _get_aggregate_funcs(
df: DataFrame,
- aggregates: Dict[str, Dict[str, Any]],
-) -> Dict[str, NamedAgg]:
+ aggregates: dict[str, dict[str, Any]],
+) -> dict[str, NamedAgg]:
"""
Converts a set of aggregate config objects into functions that pandas can use as
aggregators. Currently only numpy aggregators are supported.
@@ -143,7 +144,7 @@ def _get_aggregate_funcs(
:param aggregates: Mapping from column name to aggregate config.
:return: Mapping from metric name to function that takes a single input argument.
"""
- agg_funcs: Dict[str, NamedAgg] = {}
+ agg_funcs: dict[str, NamedAgg] = {}
for name, agg_obj in aggregates.items():
column = agg_obj.get("column", name)
if column not in df:
@@ -180,7 +181,7 @@ def _get_aggregate_funcs(
def _append_columns(
- base_df: DataFrame, append_df: DataFrame, columns: Dict[str, str]
+ base_df: DataFrame, append_df: DataFrame, columns: dict[str, str]
) -> DataFrame:
"""
Function for adding columns from one DataFrame to another DataFrame. Calls the
diff --git a/superset/utils/retries.py b/superset/utils/retries.py
index d1c2947146..8a1e6b95ea 100644
--- a/superset/utils/retries.py
+++ b/superset/utils/retries.py
@@ -15,7 +15,8 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Callable, Dict, Generator, List, Optional, Type
+from collections.abc import Generator
+from typing import Any, Callable, Optional
import backoff
@@ -24,9 +25,9 @@ def retry_call(
func: Callable[..., Any],
*args: Any,
strategy: Callable[..., Generator[int, None, None]] = backoff.constant,
- exception: Type[Exception] = Exception,
- fargs: Optional[List[Any]] = None,
- fkwargs: Optional[Dict[str, Any]] = None,
+ exception: type[Exception] = Exception,
+ fargs: Optional[list[Any]] = None,
+ fkwargs: Optional[dict[str, Any]] = None,
**kwargs: Any
) -> Any:
"""
diff --git a/superset/utils/screenshots.py b/superset/utils/screenshots.py
index 88b97901b2..5c699e9e19 100644
--- a/superset/utils/screenshots.py
+++ b/superset/utils/screenshots.py
@@ -18,7 +18,7 @@ from __future__ import annotations
import logging
from io import BytesIO
-from typing import Optional, TYPE_CHECKING, Union
+from typing import TYPE_CHECKING
from flask import current_app
@@ -53,16 +53,16 @@ class BaseScreenshot:
def __init__(self, url: str, digest: str):
self.digest: str = digest
self.url = url
- self.screenshot: Optional[bytes] = None
+ self.screenshot: bytes | None = None
- def driver(self, window_size: Optional[WindowSize] = None) -> WebDriverProxy:
+ def driver(self, window_size: WindowSize | None = None) -> WebDriverProxy:
window_size = window_size or self.window_size
return WebDriverProxy(self.driver_type, window_size)
def cache_key(
self,
- window_size: Optional[Union[bool, WindowSize]] = None,
- thumb_size: Optional[Union[bool, WindowSize]] = None,
+ window_size: bool | WindowSize | None = None,
+ thumb_size: bool | WindowSize | None = None,
) -> str:
window_size = window_size or self.window_size
thumb_size = thumb_size or self.thumb_size
@@ -76,8 +76,8 @@ class BaseScreenshot:
return md5_sha_from_dict(args)
def get_screenshot(
- self, user: User, window_size: Optional[WindowSize] = None
- ) -> Optional[bytes]:
+ self, user: User, window_size: WindowSize | None = None
+ ) -> bytes | None:
driver = self.driver(window_size)
self.screenshot = driver.get_screenshot(self.url, self.element, user)
return self.screenshot
@@ -86,8 +86,8 @@ class BaseScreenshot:
self,
user: User = None,
cache: Cache = None,
- thumb_size: Optional[WindowSize] = None,
- ) -> Optional[BytesIO]:
+ thumb_size: WindowSize | None = None,
+ ) -> BytesIO | None:
"""
Get thumbnail screenshot has BytesIO from cache or fetch
@@ -95,7 +95,7 @@ class BaseScreenshot:
:param cache: The cache to use
:param thumb_size: Override thumbnail site
"""
- payload: Optional[bytes] = None
+ payload: bytes | None = None
cache_key = self.cache_key(self.window_size, thumb_size)
if cache:
payload = cache.get(cache_key)
@@ -112,14 +112,14 @@ class BaseScreenshot:
def get_from_cache(
self,
cache: Cache,
- window_size: Optional[WindowSize] = None,
- thumb_size: Optional[WindowSize] = None,
- ) -> Optional[BytesIO]:
+ window_size: WindowSize | None = None,
+ thumb_size: WindowSize | None = None,
+ ) -> BytesIO | None:
cache_key = self.cache_key(window_size, thumb_size)
return self.get_from_cache_key(cache, cache_key)
@staticmethod
- def get_from_cache_key(cache: Cache, cache_key: str) -> Optional[BytesIO]:
+ def get_from_cache_key(cache: Cache, cache_key: str) -> BytesIO | None:
logger.info("Attempting to get from cache: %s", cache_key)
if payload := cache.get(cache_key):
return BytesIO(payload)
@@ -129,11 +129,11 @@ class BaseScreenshot:
def compute_and_cache( # pylint: disable=too-many-arguments
self,
user: User = None,
- window_size: Optional[WindowSize] = None,
- thumb_size: Optional[WindowSize] = None,
+ window_size: WindowSize | None = None,
+ thumb_size: WindowSize | None = None,
cache: Cache = None,
force: bool = True,
- ) -> Optional[bytes]:
+ ) -> bytes | None:
"""
Fetches the screenshot, computes the thumbnail and caches the result
@@ -178,7 +178,7 @@ class BaseScreenshot:
cls,
img_bytes: bytes,
output: str = "png",
- thumb_size: Optional[WindowSize] = None,
+ thumb_size: WindowSize | None = None,
crop: bool = True,
) -> bytes:
thumb_size = thumb_size or cls.thumb_size
@@ -207,8 +207,8 @@ class ChartScreenshot(BaseScreenshot):
self,
url: str,
digest: str,
- window_size: Optional[WindowSize] = None,
- thumb_size: Optional[WindowSize] = None,
+ window_size: WindowSize | None = None,
+ thumb_size: WindowSize | None = None,
):
# Chart reports are in standalone="true" mode
url = modify_url_query(
@@ -228,8 +228,8 @@ class DashboardScreenshot(BaseScreenshot):
self,
url: str,
digest: str,
- window_size: Optional[WindowSize] = None,
- thumb_size: Optional[WindowSize] = None,
+ window_size: WindowSize | None = None,
+ thumb_size: WindowSize | None = None,
):
# per the element above, dashboard screenshots
# should always capture in standalone
diff --git a/superset/utils/ssh_tunnel.py b/superset/utils/ssh_tunnel.py
index 48ada98dcc..8421350f8c 100644
--- a/superset/utils/ssh_tunnel.py
+++ b/superset/utils/ssh_tunnel.py
@@ -15,13 +15,13 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict
+from typing import Any
from superset.constants import PASSWORD_MASK
from superset.databases.ssh_tunnel.models import SSHTunnel
-def mask_password_info(ssh_tunnel: Dict[str, Any]) -> Dict[str, Any]:
+def mask_password_info(ssh_tunnel: dict[str, Any]) -> dict[str, Any]:
if ssh_tunnel.pop("password", None) is not None:
ssh_tunnel["password"] = PASSWORD_MASK
if ssh_tunnel.pop("private_key", None) is not None:
@@ -32,8 +32,8 @@ def mask_password_info(ssh_tunnel: Dict[str, Any]) -> Dict[str, Any]:
def unmask_password_info(
- ssh_tunnel: Dict[str, Any], model: SSHTunnel
-) -> Dict[str, Any]:
+ ssh_tunnel: dict[str, Any], model: SSHTunnel
+) -> dict[str, Any]:
if ssh_tunnel.get("password") == PASSWORD_MASK:
ssh_tunnel["password"] = model.password
if ssh_tunnel.get("private_key") == PASSWORD_MASK:
diff --git a/superset/utils/url_map_converters.py b/superset/utils/url_map_converters.py
index fbd9c800b0..11e40267b3 100644
--- a/superset/utils/url_map_converters.py
+++ b/superset/utils/url_map_converters.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, List
+from typing import Any
from werkzeug.routing import BaseConverter, Map
@@ -22,7 +22,7 @@ from superset.tags.models import ObjectTypes
class RegexConverter(BaseConverter):
- def __init__(self, url_map: Map, *items: List[str]) -> None:
+ def __init__(self, url_map: Map, *items: list[str]) -> None:
super().__init__(url_map)
self.regex = items[0]
diff --git a/superset/utils/webdriver.py b/superset/utils/webdriver.py
index 05dbee6740..c302ab8921 100644
--- a/superset/utils/webdriver.py
+++ b/superset/utils/webdriver.py
@@ -20,7 +20,7 @@ from __future__ import annotations
import logging
from enum import Enum
from time import sleep
-from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
+from typing import Any, TYPE_CHECKING
from flask import current_app
from selenium.common.exceptions import (
@@ -37,7 +37,7 @@ from selenium.webdriver.support.ui import WebDriverWait
from superset.extensions import machine_auth_provider_factory
from superset.utils.retries import retry_call
-WindowSize = Tuple[int, int]
+WindowSize = tuple[int, int]
logger = logging.getLogger(__name__)
@@ -56,7 +56,7 @@ class ChartStandaloneMode(Enum):
SHOW_NAV = 0
-def find_unexpected_errors(driver: WebDriver) -> List[str]:
+def find_unexpected_errors(driver: WebDriver) -> list[str]:
error_messages = []
try:
@@ -111,7 +111,7 @@ def find_unexpected_errors(driver: WebDriver) -> List[str]:
class WebDriverProxy:
- def __init__(self, driver_type: str, window: Optional[WindowSize] = None):
+ def __init__(self, driver_type: str, window: WindowSize | None = None):
self._driver_type = driver_type
self._window: WindowSize = window or (800, 600)
self._screenshot_locate_wait = current_app.config["SCREENSHOT_LOCATE_WAIT"]
@@ -124,7 +124,7 @@ class WebDriverProxy:
options = firefox.options.Options()
profile = FirefoxProfile()
profile.set_preference("layout.css.devPixelsPerPx", str(pixel_density))
- kwargs: Dict[Any, Any] = dict(options=options, firefox_profile=profile)
+ kwargs: dict[Any, Any] = dict(options=options, firefox_profile=profile)
elif self._driver_type == "chrome":
driver_class = chrome.webdriver.WebDriver
options = chrome.options.Options()
@@ -164,13 +164,11 @@ class WebDriverProxy:
except Exception: # pylint: disable=broad-except
pass
- def get_screenshot(
- self, url: str, element_name: str, user: User
- ) -> Optional[bytes]:
+ def get_screenshot(self, url: str, element_name: str, user: User) -> bytes | None:
driver = self.auth(user)
driver.set_window_size(*self._window)
driver.get(url)
- img: Optional[bytes] = None
+ img: bytes | None = None
selenium_headstart = current_app.config["SCREENSHOT_SELENIUM_HEADSTART"]
logger.debug("Sleeping for %i seconds", selenium_headstart)
sleep(selenium_headstart)
diff --git a/superset/views/__init__.py b/superset/views/__init__.py
index 5247f215c1..b5a21c77f0 100644
--- a/superset/views/__init__.py
+++ b/superset/views/__init__.py
@@ -21,8 +21,6 @@ from . import (
base,
core,
css_templates,
- dashboard,
- datasource,
dynamic_plugins,
health,
redirects,
diff --git a/superset/views/all_entities.py b/superset/views/all_entities.py
index 4031d81d21..3de53be461 100644
--- a/superset/views/all_entities.py
+++ b/superset/views/all_entities.py
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from __future__ import absolute_import, division, print_function, unicode_literals
import logging
diff --git a/superset/views/base.py b/superset/views/base.py
index 97d8da1d69..3a72096ac2 100644
--- a/superset/views/base.py
+++ b/superset/views/base.py
@@ -20,7 +20,7 @@ import logging
import os
import traceback
from datetime import datetime
-from typing import Any, Callable, cast, Dict, List, Optional, Union
+from typing import Any, Callable, cast, Optional, Union
import simplejson as json
import yaml
@@ -140,11 +140,11 @@ def get_error_msg() -> str:
def json_error_response(
msg: Optional[str] = None,
status: int = 500,
- payload: Optional[Dict[str, Any]] = None,
+ payload: Optional[dict[str, Any]] = None,
link: Optional[str] = None,
) -> FlaskResponse:
if not payload:
- payload = {"error": "{}".format(msg)}
+ payload = {"error": f"{msg}"}
if link:
payload["link"] = link
@@ -156,9 +156,9 @@ def json_error_response(
def json_errors_response(
- errors: List[SupersetError],
+ errors: list[SupersetError],
status: int = 500,
- payload: Optional[Dict[str, Any]] = None,
+ payload: Optional[dict[str, Any]] = None,
) -> FlaskResponse:
if not payload:
payload = {}
@@ -182,7 +182,7 @@ def data_payload_response(payload_json: str, has_error: bool = False) -> FlaskRe
def generate_download_headers(
extension: str, filename: Optional[str] = None
-) -> Dict[str, Any]:
+) -> dict[str, Any]:
filename = filename if filename else datetime.now().strftime("%Y%m%d_%H%M%S")
content_disp = f"attachment; filename={filename}.{extension}"
headers = {"Content-Disposition": content_disp}
@@ -332,7 +332,7 @@ class BaseSupersetView(BaseView):
)
-def menu_data(user: User) -> Dict[str, Any]:
+def menu_data(user: User) -> dict[str, Any]:
menu = appbuilder.menu.get_data()
languages = {}
@@ -396,7 +396,7 @@ def menu_data(user: User) -> Dict[str, Any]:
@cache_manager.cache.memoize(timeout=60)
-def cached_common_bootstrap_data(user: User) -> Dict[str, Any]:
+def cached_common_bootstrap_data(user: User) -> dict[str, Any]:
"""Common data always sent to the client
The function is memoized as the return value only changes when user permissions
@@ -439,7 +439,7 @@ def cached_common_bootstrap_data(user: User) -> Dict[str, Any]:
return bootstrap_data
-def common_bootstrap_payload(user: User) -> Dict[str, Any]:
+def common_bootstrap_payload(user: User) -> dict[str, Any]:
return {
**(cached_common_bootstrap_data(user)),
"flash_messages": get_flashed_messages(with_categories=True),
@@ -548,7 +548,7 @@ def show_unexpected_exception(ex: Exception) -> FlaskResponse:
@superset_app.context_processor
-def get_common_bootstrap_data() -> Dict[str, Any]:
+def get_common_bootstrap_data() -> dict[str, Any]:
def serialize_bootstrap_data() -> str:
return json.dumps(
{"common": common_bootstrap_payload(g.user)},
@@ -606,7 +606,7 @@ class YamlExportMixin: # pylint: disable=too-few-public-methods
@action("yaml_export", __("Export to YAML"), __("Export to YAML?"), "fa-download")
def yaml_export(
- self, items: Union[ImportExportMixin, List[ImportExportMixin]]
+ self, items: Union[ImportExportMixin, list[ImportExportMixin]]
) -> FlaskResponse:
if not isinstance(items, list):
items = [items]
@@ -663,7 +663,7 @@ class DeleteMixin: # pylint: disable=too-few-public-methods
@action(
"muldelete", __("Delete"), __("Delete all Really?"), "fa-trash", single=False
)
- def muldelete(self: BaseView, items: List[Model]) -> FlaskResponse:
+ def muldelete(self: BaseView, items: list[Model]) -> FlaskResponse:
if not items:
abort(404)
for item in items:
@@ -709,7 +709,7 @@ class XlsxResponse(Response):
def bind_field(
- _: Any, form: DynamicForm, unbound_field: UnboundField, options: Dict[Any, Any]
+ _: Any, form: DynamicForm, unbound_field: UnboundField, options: dict[Any, Any]
) -> Field:
"""
Customize how fields are bound by stripping all whitespace.
diff --git a/superset/views/base_api.py b/superset/views/base_api.py
index 30d25382f3..dca7a96b1d 100644
--- a/superset/views/base_api.py
+++ b/superset/views/base_api.py
@@ -18,7 +18,7 @@ from __future__ import annotations
import functools
import logging
-from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type, Union
+from typing import Any, Callable, cast
from flask import request, Response
from flask_appbuilder import Model, ModelRestApi
@@ -87,7 +87,7 @@ def requires_json(f: Callable[..., Any]) -> Callable[..., Any]:
Require JSON-like formatted request to the REST API
"""
- def wraps(self: "BaseSupersetModelRestApi", *args: Any, **kwargs: Any) -> Response:
+ def wraps(self: BaseSupersetModelRestApi, *args: Any, **kwargs: Any) -> Response:
if not request.is_json:
raise InvalidPayloadFormatError(message="Request is not JSON")
return f(self, *args, **kwargs)
@@ -135,7 +135,7 @@ def statsd_metrics(f: Callable[..., Any]) -> Callable[..., Any]:
class RelatedFieldFilter:
# data class to specify what filter to use on a /related endpoint
# pylint: disable=too-few-public-methods
- def __init__(self, field_name: str, filter_class: Type[BaseFilter]):
+ def __init__(self, field_name: str, filter_class: type[BaseFilter]):
self.field_name = field_name
self.filter_class = filter_class
@@ -150,7 +150,7 @@ class BaseFavoriteFilter(BaseFilter): # pylint: disable=too-few-public-methods
arg_name = ""
class_name = ""
""" The FavStar class_name to user """
- model: Type[Union[Dashboard, Slice, SqllabQuery]] = Dashboard
+ model: type[Dashboard | Slice | SqllabQuery] = Dashboard
""" The SQLAlchemy model """
def apply(self, query: Query, value: Any) -> Query:
@@ -178,7 +178,7 @@ class BaseTagFilter(BaseFilter): # pylint: disable=too-few-public-methods
arg_name = ""
class_name = ""
""" The Tag class_name to user """
- model: Type[Union[Dashboard, Slice, SqllabQuery, SqlaTable]] = Dashboard
+ model: type[Dashboard | Slice | SqllabQuery | SqlaTable] = Dashboard
""" The SQLAlchemy model """
def apply(self, query: Query, value: Any) -> Query:
@@ -229,7 +229,7 @@ class BaseSupersetApiMixin:
)
def send_stats_metrics(
- self, response: Response, key: str, time_delta: Optional[float] = None
+ self, response: Response, key: str, time_delta: float | None = None
) -> None:
"""
Helper function to handle sending statsd metrics
@@ -280,7 +280,7 @@ class BaseSupersetModelRestApi(ModelRestApi, BaseSupersetApiMixin):
"viz_types": "list",
}
- order_rel_fields: Dict[str, Tuple[str, str]] = {}
+ order_rel_fields: dict[str, tuple[str, str]] = {}
"""
Impose ordering on related fields query::
@@ -290,7 +290,7 @@ class BaseSupersetModelRestApi(ModelRestApi, BaseSupersetApiMixin):
}
"""
- base_related_field_filters: Dict[str, BaseFilter] = {}
+ base_related_field_filters: dict[str, BaseFilter] = {}
"""
This is used to specify a base filter for related fields
when they are accessed through the '/related/' endpoint.
@@ -302,7 +302,7 @@ class BaseSupersetModelRestApi(ModelRestApi, BaseSupersetApiMixin):
}
"""
- related_field_filters: Dict[str, Union[RelatedFieldFilter, str]] = {}
+ related_field_filters: dict[str, RelatedFieldFilter | str] = {}
"""
Specify a filter for related fields when they are accessed
through the '/related/' endpoint.
@@ -313,10 +313,10 @@ class BaseSupersetModelRestApi(ModelRestApi, BaseSupersetApiMixin):
"": )
}
"""
- allowed_rel_fields: Set[str] = set()
+ allowed_rel_fields: set[str] = set()
# Declare a set of allowed related fields that the `related` endpoint supports.
- text_field_rel_fields: Dict[str, str] = {}
+ text_field_rel_fields: dict[str, str] = {}
"""
Declare an alternative for the human readable representation of the Model object::
@@ -325,7 +325,7 @@ class BaseSupersetModelRestApi(ModelRestApi, BaseSupersetApiMixin):
}
"""
- extra_fields_rel_fields: Dict[str, List[str]] = {"owners": ["email", "active"]}
+ extra_fields_rel_fields: dict[str, list[str]] = {"owners": ["email", "active"]}
"""
Declare extra fields for the representation of the Model object::
@@ -334,12 +334,12 @@ class BaseSupersetModelRestApi(ModelRestApi, BaseSupersetApiMixin):
}
"""
- allowed_distinct_fields: Set[str] = set()
+ allowed_distinct_fields: set[str] = set()
- add_columns: List[str]
- edit_columns: List[str]
- list_columns: List[str]
- show_columns: List[str]
+ add_columns: list[str]
+ edit_columns: list[str]
+ list_columns: list[str]
+ show_columns: list[str]
def __init__(self) -> None:
super().__init__()
@@ -347,8 +347,8 @@ class BaseSupersetModelRestApi(ModelRestApi, BaseSupersetApiMixin):
if self.apispec_parameter_schemas is None: # type: ignore
self.apispec_parameter_schemas = {}
self.apispec_parameter_schemas["get_related_schema"] = get_related_schema
- self.openapi_spec_component_schemas: Tuple[
- Type[Schema], ...
+ self.openapi_spec_component_schemas: tuple[
+ type[Schema], ...
] = self.openapi_spec_component_schemas + (
RelatedResponseSchema,
DistincResponseSchema,
@@ -409,7 +409,7 @@ class BaseSupersetModelRestApi(ModelRestApi, BaseSupersetApiMixin):
def _get_extra_field_for_model(
self, model: Model, column_name: str
- ) -> Dict[str, str]:
+ ) -> dict[str, str]:
ret = {}
if column_name in self.extra_fields_rel_fields:
model_column_names = self.extra_fields_rel_fields.get(column_name)
@@ -419,8 +419,8 @@ class BaseSupersetModelRestApi(ModelRestApi, BaseSupersetApiMixin):
return ret
def _get_result_from_rows(
- self, datamodel: SQLAInterface, rows: List[Model], column_name: str
- ) -> List[Dict[str, Any]]:
+ self, datamodel: SQLAInterface, rows: list[Model], column_name: str
+ ) -> list[dict[str, Any]]:
return [
{
"value": datamodel.get_pk_value(row),
@@ -434,8 +434,8 @@ class BaseSupersetModelRestApi(ModelRestApi, BaseSupersetApiMixin):
self,
datamodel: SQLAInterface,
column_name: str,
- ids: List[int],
- result: List[Dict[str, Any]],
+ ids: list[int],
+ result: list[dict[str, Any]],
) -> None:
if ids:
# Filter out already present values on the result
diff --git a/superset/views/base_schemas.py b/superset/views/base_schemas.py
index 8f4ed7735c..2107558dc7 100644
--- a/superset/views/base_schemas.py
+++ b/superset/views/base_schemas.py
@@ -14,7 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Set, Union
+from collections.abc import Iterable, Mapping, Sequence
+from typing import Any, Optional, Union
from flask import current_app, g
from flask_appbuilder import Model
@@ -54,7 +55,7 @@ class BaseSupersetSchema(Schema):
self,
data: Union[Mapping[str, Any], Iterable[Mapping[str, Any]]],
many: Optional[bool] = None,
- partial: Union[bool, Sequence[str], Set[str], None] = None,
+ partial: Union[bool, Sequence[str], set[str], None] = None,
instance: Optional[Model] = None,
**kwargs: Any,
) -> Any:
@@ -67,7 +68,7 @@ class BaseSupersetSchema(Schema):
@post_load
def make_object(
- self, data: Dict[Any, Any], discard: Optional[List[str]] = None
+ self, data: dict[Any, Any], discard: Optional[list[str]] = None
) -> Model:
"""
Creates a Model object from POST or PUT requests. PUT will use self.instance
@@ -95,7 +96,7 @@ class BaseOwnedSchema(BaseSupersetSchema):
@post_load
def make_object(
- self, data: Dict[str, Any], discard: Optional[List[str]] = None
+ self, data: dict[str, Any], discard: Optional[list[str]] = None
) -> Model:
discard = discard or []
discard.append(self.owners_field_name)
@@ -107,13 +108,13 @@ class BaseOwnedSchema(BaseSupersetSchema):
return instance
@pre_load
- def pre_load(self, data: Dict[Any, Any]) -> None:
+ def pre_load(self, data: dict[Any, Any]) -> None:
# if PUT request don't set owners to empty list
if not self.instance:
data[self.owners_field_name] = data.get(self.owners_field_name, [])
@staticmethod
- def set_owners(instance: Model, owners: List[int]) -> None:
+ def set_owners(instance: Model, owners: list[int]) -> None:
owner_objs = []
user_id = get_user_id()
if user_id and user_id not in owners:
diff --git a/superset/views/core.py b/superset/views/core.py
index 24bc16c310..3b63eb74d8 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -21,7 +21,7 @@ import logging
import re
from contextlib import closing
from datetime import datetime
-from typing import Any, Callable, cast, Dict, List, Optional, Union
+from typing import Any, Callable, cast, Optional
from urllib import parse
import backoff
@@ -202,7 +202,7 @@ PARAMETER_MISSING_ERR = __(
"your query again."
)
-SqlResults = Dict[str, Any]
+SqlResults = dict[str, Any]
class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
@@ -300,13 +300,11 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
datasources.add(datasource)
has_access_ = all(
- (
- datasource and security_manager.can_access_datasource(datasource)
- for datasource in datasources
- )
+ datasource and security_manager.can_access_datasource(datasource)
+ for datasource in datasources
)
if has_access_:
- return redirect("/superset/dashboard/{}".format(dashboard_id))
+ return redirect(f"/superset/dashboard/{dashboard_id}")
if request.args.get("action") == "go":
for datasource in datasources:
@@ -483,7 +481,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
return data_payload_response(*viz_obj.payload_json_and_has_error(payload))
def generate_json(
- self, viz_obj: BaseViz, response_type: Optional[str] = None
+ self, viz_obj: BaseViz, response_type: str | None = None
) -> FlaskResponse:
if response_type == ChartDataResultFormat.CSV:
return CsvResponse(
@@ -618,7 +616,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
@check_resource_permissions(check_datasource_perms)
@deprecated(eol_version="3.0")
def explore_json(
- self, datasource_type: Optional[str] = None, datasource_id: Optional[int] = None
+ self, datasource_type: str | None = None, datasource_id: int | None = None
) -> FlaskResponse:
"""Serves all request that GET or POST form_data
@@ -631,7 +629,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
TODO: break into one endpoint for each return shape"""
response_type = ChartDataResultFormat.JSON.value
- responses: List[Union[ChartDataResultFormat, ChartDataResultType]] = list(
+ responses: list[ChartDataResultFormat | ChartDataResultType] = list(
ChartDataResultFormat
)
responses.extend(list(ChartDataResultType))
@@ -814,9 +812,9 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
# pylint: disable=too-many-locals,too-many-branches,too-many-statements
def explore(
self,
- datasource_type: Optional[str] = None,
- datasource_id: Optional[int] = None,
- key: Optional[str] = None,
+ datasource_type: str | None = None,
+ datasource_id: int | None = None,
+ key: str | None = None,
) -> FlaskResponse:
if request.method == "GET":
return redirect(Superset.get_redirect_url())
@@ -879,7 +877,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
# fallback unknown datasource to table type
datasource_type = SqlaTable.type
- datasource: Optional[BaseDatasource] = None
+ datasource: BaseDatasource | None = None
if datasource_id is not None:
try:
datasource = DatasourceDAO.get_datasource(
@@ -965,7 +963,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
)
standalone_mode = ReservedUrlParameters.is_standalone_mode()
force = request.args.get("force") in {"force", "1", "true"}
- dummy_datasource_data: Dict[str, Any] = {
+ dummy_datasource_data: dict[str, Any] = {
"type": datasource_type,
"name": datasource_name,
"columns": [],
@@ -1058,14 +1056,14 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
@staticmethod
def save_or_overwrite_slice(
# pylint: disable=too-many-arguments,too-many-locals
- slc: Optional[Slice],
+ slc: Slice | None,
slice_add_perm: bool,
slice_overwrite_perm: bool,
slice_download_perm: bool,
datasource_id: int,
datasource_type: str,
datasource_name: str,
- query_context: Optional[str] = None,
+ query_context: str | None = None,
) -> FlaskResponse:
"""Save or overwrite a slice"""
slice_name = request.args.get("slice_name")
@@ -1100,7 +1098,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
flash(msg, "success")
# Adding slice to a dashboard if requested
- dash: Optional[Dashboard] = None
+ dash: Dashboard | None = None
save_to_dashboard_id = request.args.get("save_to_dashboard_id")
new_dashboard_name = request.args.get("new_dashboard_name")
@@ -1293,7 +1291,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
dash.dashboard_title = data["dashboard_title"]
dash.css = data.get("css")
- old_to_new_slice_ids: Dict[int, int] = {}
+ old_to_new_slice_ids: dict[int, int] = {}
if data["duplicate_slices"]:
# Duplicating slices as well, mapping old ids to new ones
for slc in original_dash.slices:
@@ -1480,7 +1478,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
)
@staticmethod
- def get_user_activity_access_error(user_id: int) -> Optional[FlaskResponse]:
+ def get_user_activity_access_error(user_id: int) -> FlaskResponse | None:
try:
security_manager.raise_for_user_activity_access(user_id)
except SupersetSecurityException as ex:
@@ -1567,7 +1565,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
if o.Dashboard.created_by:
user = o.Dashboard.created_by
dash["creator"] = str(user)
- dash["creator_url"] = "/superset/profile/{}/".format(user.username)
+ dash["creator_url"] = f"/superset/profile/{user.username}/"
payload.append(dash)
return json_success(json.dumps(payload, default=utils.json_int_dttm_ser))
@@ -1607,7 +1605,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
@expose("/user_slices", methods=("GET",))
@expose("/user_slices//", methods=("GET",))
@deprecated(new_target="/api/v1/chart/")
- def user_slices(self, user_id: Optional[int] = None) -> FlaskResponse:
+ def user_slices(self, user_id: int | None = None) -> FlaskResponse:
"""List of slices a user owns, created, modified or faved"""
if not user_id:
user_id = cast(int, get_user_id())
@@ -1660,7 +1658,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
@expose("/created_slices", methods=("GET",))
@expose("/created_slices//", methods=("GET",))
@deprecated(new_target="api/v1/chart/")
- def created_slices(self, user_id: Optional[int] = None) -> FlaskResponse:
+ def created_slices(self, user_id: int | None = None) -> FlaskResponse:
"""List of slices created by this user"""
if not user_id:
user_id = cast(int, get_user_id())
@@ -1691,7 +1689,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
@expose("/fave_slices", methods=("GET",))
@expose("/fave_slices//", methods=("GET",))
@deprecated(new_target="api/v1/chart/")
- def fave_slices(self, user_id: Optional[int] = None) -> FlaskResponse:
+ def fave_slices(self, user_id: int | None = None) -> FlaskResponse:
"""Favorite slices for a user"""
if user_id is None:
user_id = cast(int, get_user_id())
@@ -1721,7 +1719,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
if o.Slice.created_by:
user = o.Slice.created_by
dash["creator"] = str(user)
- dash["creator_url"] = "/superset/profile/{}/".format(user.username)
+ dash["creator_url"] = f"/superset/profile/{user.username}/"
payload.append(dash)
return json_success(json.dumps(payload, default=utils.json_int_dttm_ser))
@@ -1745,7 +1743,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
table_name = request.args.get("table_name")
db_name = request.args.get("db_name")
extra_filters = request.args.get("extra_filters")
- slices: List[Slice] = []
+ slices: list[Slice] = []
if not slice_id and not (table_name and db_name):
return json_error_response(
@@ -1869,7 +1867,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
self,
dashboard_id_or_slug: str, # pylint: disable=unused-argument
add_extra_log_payload: Callable[..., None] = lambda **kwargs: None,
- dashboard: Optional[Dashboard] = None,
+ dashboard: Dashboard | None = None,
) -> FlaskResponse:
"""
Server side rendering for a dashboard
@@ -2112,7 +2110,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
@event_logger.log_this
@deprecated(new_target="api/v1/sqllab/estimate/")
def estimate_query_cost( # pylint: disable=no-self-use
- self, database_id: int, schema: Optional[str] = None
+ self, database_id: int, schema: str | None = None
) -> FlaskResponse:
mydb = db.session.query(Database).get(database_id)
@@ -2135,7 +2133,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
return json_error_response(utils.error_msg_from_exception(ex))
spec = mydb.db_engine_spec
- query_cost_formatters: Dict[str, Any] = app.config[
+ query_cost_formatters: dict[str, Any] = app.config[
"QUERY_COST_FORMATTERS_BY_ENGINE"
]
query_cost_formatter = query_cost_formatters.get(
@@ -2334,14 +2332,14 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
mydb = session.query(Database).filter_by(id=database_id).one_or_none()
if not mydb:
return json_error_response(
- "Database with id {} is missing.".format(database_id), status=400
+ f"Database with id {database_id} is missing.", status=400
)
spec = mydb.db_engine_spec
validators_by_engine = app.config["SQL_VALIDATORS_BY_ENGINE"]
if not validators_by_engine or spec.engine not in validators_by_engine:
return json_error_response(
- "no SQL validator is configured for {}".format(spec.engine), status=400
+ f"no SQL validator is configured for {spec.engine}", status=400
)
validator_name = validators_by_engine[spec.engine]
validator = get_validator_by_name(validator_name)
@@ -2403,7 +2401,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
@staticmethod
def _create_sql_json_command(
- execution_context: SqlJsonExecutionContext, log_params: Optional[Dict[str, Any]]
+ execution_context: SqlJsonExecutionContext, log_params: dict[str, Any] | None
) -> ExecuteSqlCommand:
query_dao = QueryDAO()
sql_json_executor = Superset._create_sql_json_executor(
@@ -2556,7 +2554,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
@expose("/queries/")
@expose("/queries/")
@deprecated(new_target="api/v1/query/updated_since")
- def queries(self, last_updated_ms: Union[float, int]) -> FlaskResponse:
+ def queries(self, last_updated_ms: float | int) -> FlaskResponse:
"""
Get the updated queries.
@@ -2566,7 +2564,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
return self.queries_exec(last_updated_ms)
@staticmethod
- def queries_exec(last_updated_ms: Union[float, int]) -> FlaskResponse:
+ def queries_exec(last_updated_ms: float | int) -> FlaskResponse:
stats_logger.incr("queries")
if not get_user_id():
return json_error_response(
@@ -2714,7 +2712,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
)
@staticmethod
- def _get_sqllab_tabs(user_id: Optional[int]) -> Dict[str, Any]:
+ def _get_sqllab_tabs(user_id: int | None) -> dict[str, Any]:
# send list of tab state ids
tabs_state = (
db.session.query(TabState.id, TabState.label)
@@ -2730,13 +2728,13 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
.first()
)
- databases: Dict[int, Any] = {}
+ databases: dict[int, Any] = {}
for database in DatabaseDAO.find_all():
databases[database.id] = {
k: v for k, v in database.to_json().items() if k in DATABASE_KEYS
}
databases[database.id]["backend"] = database.backend
- queries: Dict[str, Any] = {}
+ queries: dict[str, Any] = {}
# These are unnecessary if sqllab backend persistence is disabled
if is_feature_enabled("SQLLAB_BACKEND_PERSISTENCE"):
diff --git a/superset/views/dashboard/views.py b/superset/views/dashboard/views.py
index 4f12206771..71ef212f6d 100644
--- a/superset/views/dashboard/views.py
+++ b/superset/views/dashboard/views.py
@@ -14,9 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import builtins
import json
import re
-from typing import Callable, List, Union
+from typing import Callable, Union
from flask import g, redirect, request, Response
from flask_appbuilder import expose
@@ -64,12 +65,13 @@ class DashboardModelView(
@action("mulexport", __("Export"), __("Export dashboards?"), "fa-database")
def mulexport( # pylint: disable=no-self-use
- self, items: Union["DashboardModelView", List["DashboardModelView"]]
+ self,
+ items: Union["DashboardModelView", builtins.list["DashboardModelView"]],
) -> FlaskResponse:
if not isinstance(items, list):
items = [items]
- ids = "".join("&id={}".format(d.id) for d in items)
- return redirect("/dashboard/export_dashboards_form?{}".format(ids[1:]))
+ ids = "".join(f"&id={d.id}" for d in items)
+ return redirect(f"/dashboard/export_dashboards_form?{ids[1:]}")
@event_logger.log_this
@has_access
diff --git a/superset/views/database/forms.py b/superset/views/database/forms.py
index 5e2347528a..b906e5e70b 100644
--- a/superset/views/database/forms.py
+++ b/superset/views/database/forms.py
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
"""Contains the logic to create cohesive forms on the explore view"""
-from typing import List
from flask_appbuilder.fields import QuerySelectField
from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
@@ -44,7 +43,7 @@ config = app.config
class UploadToDatabaseForm(DynamicForm):
@staticmethod
- def file_allowed_dbs() -> List[Database]:
+ def file_allowed_dbs() -> list[Database]:
file_enabled_dbs = (
db.session.query(Database).filter_by(allow_file_upload=True).all()
)
diff --git a/superset/views/database/mixins.py b/superset/views/database/mixins.py
index efd0b6c6eb..deb1b88f1f 100644
--- a/superset/views/database/mixins.py
+++ b/superset/views/database/mixins.py
@@ -227,7 +227,7 @@ class DatabaseMixin:
Markup(
"Cannot delete a database that has tables attached. "
"Here's the list of associated tables: "
- + ", ".join("{}".format(table) for table in database.tables)
+ + ", ".join(f"{table}" for table in database.tables)
)
)
diff --git a/superset/views/database/validators.py b/superset/views/database/validators.py
index 29d80611a2..2ee49c8210 100644
--- a/superset/views/database/validators.py
+++ b/superset/views/database/validators.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Optional, Type
+from typing import Optional
from flask_babel import lazy_gettext as _
from marshmallow import ValidationError
@@ -27,7 +27,7 @@ from superset.models.core import Database
def sqlalchemy_uri_validator(
- uri: str, exception: Type[ValidationError] = ValidationError
+ uri: str, exception: type[ValidationError] = ValidationError
) -> None:
"""
Check if a user has submitted a valid SQLAlchemy URI
diff --git a/superset/views/datasource/schemas.py b/superset/views/datasource/schemas.py
index b71b3defa8..5b1700708a 100644
--- a/superset/views/datasource/schemas.py
+++ b/superset/views/datasource/schemas.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict
+from typing import Any
from marshmallow import fields, post_load, pre_load, Schema, validate
from typing_extensions import TypedDict
@@ -76,7 +76,7 @@ class SamplesPayloadSchema(Schema):
@pre_load
# pylint: disable=no-self-use, unused-argument
- def handle_none(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
+ def handle_none(self, data: dict[str, Any], **kwargs: Any) -> dict[str, Any]:
if data is None:
return {}
return data
diff --git a/superset/views/datasource/utils.py b/superset/views/datasource/utils.py
index 42cddf4167..a4cf0c5e90 100644
--- a/superset/views/datasource/utils.py
+++ b/superset/views/datasource/utils.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, Optional
+from typing import Any, Optional
from superset import app, db
from superset.common.chart_data import ChartDataResultType
@@ -27,7 +27,7 @@ from superset.utils.core import QueryStatus
from superset.views.datasource.schemas import SamplesPayloadSchema
-def get_limit_clause(page: Optional[int], per_page: Optional[int]) -> Dict[str, int]:
+def get_limit_clause(page: Optional[int], per_page: Optional[int]) -> dict[str, int]:
samples_row_limit = app.config.get("SAMPLES_ROW_LIMIT", 1000)
limit = samples_row_limit
offset = 0
@@ -50,7 +50,7 @@ def get_samples( # pylint: disable=too-many-arguments,too-many-locals
page: int = 1,
per_page: int = 1000,
payload: Optional[SamplesPayloadSchema] = None,
-) -> Dict[str, Any]:
+) -> dict[str, Any]:
datasource = DatasourceDAO.get_datasource(
session=db.session,
datasource_type=datasource_type,
diff --git a/superset/views/log/dao.py b/superset/views/log/dao.py
index 71d8a62348..87bc0817da 100644
--- a/superset/views/log/dao.py
+++ b/superset/views/log/dao.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime, timedelta
-from typing import Any, Dict, List
+from typing import Any
import humanize
from sqlalchemy import and_, or_
@@ -34,8 +34,8 @@ class LogDAO(BaseDAO):
@staticmethod
def get_recent_activity(
- user_id: int, actions: List[str], distinct: bool, page: int, page_size: int
- ) -> List[Dict[str, Any]]:
+ user_id: int, actions: list[str], distinct: bool, page: int, page_size: int
+ ) -> list[dict[str, Any]]:
has_subject_title = or_(
and_(
Dashboard.dashboard_title is not None,
diff --git a/superset/views/tags.py b/superset/views/tags.py
index bd4f43a0d9..4f9d55aed7 100644
--- a/superset/views/tags.py
+++ b/superset/views/tags.py
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from __future__ import absolute_import, division, print_function, unicode_literals
import logging
diff --git a/superset/views/users/__init__.py b/superset/views/users/__init__.py
index fd9417fe5c..13a83393a9 100644
--- a/superset/views/users/__init__.py
+++ b/superset/views/users/__init__.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
diff --git a/superset/views/utils.py b/superset/views/utils.py
index a366ac683c..9b515edc26 100644
--- a/superset/views/utils.py
+++ b/superset/views/utils.py
@@ -17,7 +17,7 @@
import logging
from collections import defaultdict
from functools import wraps
-from typing import Any, Callable, DefaultDict, Dict, List, Optional, Tuple, Union
+from typing import Any, Callable, DefaultDict, Optional, Union
from urllib import parse
import msgpack
@@ -55,12 +55,12 @@ from superset.viz import BaseViz
logger = logging.getLogger(__name__)
stats_logger = app.config["STATS_LOGGER"]
-REJECTED_FORM_DATA_KEYS: List[str] = []
+REJECTED_FORM_DATA_KEYS: list[str] = []
if not feature_flag_manager.is_feature_enabled("ENABLE_JAVASCRIPT_CONTROLS"):
REJECTED_FORM_DATA_KEYS = ["js_tooltip", "js_onclick_href", "js_data_mutator"]
-def sanitize_datasource_data(datasource_data: Dict[str, Any]) -> Dict[str, Any]:
+def sanitize_datasource_data(datasource_data: dict[str, Any]) -> dict[str, Any]:
if datasource_data:
datasource_database = datasource_data.get("database")
if datasource_database:
@@ -69,7 +69,7 @@ def sanitize_datasource_data(datasource_data: Dict[str, Any]) -> Dict[str, Any]:
return datasource_data
-def bootstrap_user_data(user: User, include_perms: bool = False) -> Dict[str, Any]:
+def bootstrap_user_data(user: User, include_perms: bool = False) -> dict[str, Any]:
if user.is_anonymous:
payload = {}
user.roles = (security_manager.find_role("Public"),)
@@ -103,7 +103,7 @@ def bootstrap_user_data(user: User, include_perms: bool = False) -> Dict[str, An
def get_permissions(
user: User,
-) -> Tuple[Dict[str, List[Tuple[str]]], DefaultDict[str, List[str]]]:
+) -> tuple[dict[str, list[tuple[str]]], DefaultDict[str, list[str]]]:
if not user.roles:
raise AttributeError("User object does not have roles")
@@ -138,7 +138,7 @@ def get_viz(
return viz_obj
-def loads_request_json(request_json_data: str) -> Dict[Any, Any]:
+def loads_request_json(request_json_data: str) -> dict[Any, Any]:
try:
return json.loads(request_json_data)
except (TypeError, json.JSONDecodeError):
@@ -148,9 +148,9 @@ def loads_request_json(request_json_data: str) -> Dict[Any, Any]:
def get_form_data( # pylint: disable=too-many-locals
slice_id: Optional[int] = None,
use_slice_data: bool = False,
- initial_form_data: Optional[Dict[str, Any]] = None,
-) -> Tuple[Dict[str, Any], Optional[Slice]]:
- form_data: Dict[str, Any] = initial_form_data or {}
+ initial_form_data: Optional[dict[str, Any]] = None,
+) -> tuple[dict[str, Any], Optional[Slice]]:
+ form_data: dict[str, Any] = initial_form_data or {}
if has_request_context():
# chart data API requests are JSON
@@ -222,7 +222,7 @@ def get_form_data( # pylint: disable=too-many-locals
return form_data, slc
-def add_sqllab_custom_filters(form_data: Dict[Any, Any]) -> Any:
+def add_sqllab_custom_filters(form_data: dict[Any, Any]) -> Any:
"""
SQLLab can include a "filters" attribute in the templateParams.
The filters attribute is a list of filters to include in the
@@ -244,7 +244,7 @@ def add_sqllab_custom_filters(form_data: Dict[Any, Any]) -> Any:
def get_datasource_info(
datasource_id: Optional[int], datasource_type: Optional[str], form_data: FormData
-) -> Tuple[int, Optional[str]]:
+) -> tuple[int, Optional[str]]:
"""
Compatibility layer for handling of datasource info
@@ -277,8 +277,8 @@ def get_datasource_info(
def apply_display_max_row_limit(
- sql_results: Dict[str, Any], rows: Optional[int] = None
-) -> Dict[str, Any]:
+ sql_results: dict[str, Any], rows: Optional[int] = None
+) -> dict[str, Any]:
"""
Given a `sql_results` nested structure, applies a limit to the number of rows
@@ -311,7 +311,7 @@ CONTAINER_TYPES = ["COLUMN", "GRID", "TABS", "TAB", "ROW"]
def get_dashboard_extra_filters(
slice_id: int, dashboard_id: int
-) -> List[Dict[str, Any]]:
+) -> list[dict[str, Any]]:
session = db.session()
dashboard = session.query(Dashboard).filter_by(id=dashboard_id).one_or_none()
@@ -348,11 +348,11 @@ def get_dashboard_extra_filters(
def build_extra_filters( # pylint: disable=too-many-locals,too-many-nested-blocks
- layout: Dict[str, Dict[str, Any]],
- filter_scopes: Dict[str, Dict[str, Any]],
- default_filters: Dict[str, Dict[str, List[Any]]],
+ layout: dict[str, dict[str, Any]],
+ filter_scopes: dict[str, dict[str, Any]],
+ default_filters: dict[str, dict[str, list[Any]]],
slice_id: int,
-) -> List[Dict[str, Any]]:
+) -> list[dict[str, Any]]:
extra_filters = []
# do not apply filters if chart is not in filter's scope or chart is immune to the
@@ -360,7 +360,7 @@ def build_extra_filters( # pylint: disable=too-many-locals,too-many-nested-bloc
for filter_id, columns in default_filters.items():
filter_slice = db.session.query(Slice).filter_by(id=filter_id).one_or_none()
- filter_configs: List[Dict[str, Any]] = []
+ filter_configs: list[dict[str, Any]] = []
if filter_slice:
filter_configs = (
json.loads(filter_slice.params or "{}").get("filter_configs") or []
@@ -403,7 +403,7 @@ def build_extra_filters( # pylint: disable=too-many-locals,too-many-nested-bloc
def is_slice_in_container(
- layout: Dict[str, Dict[str, Any]], container_id: str, slice_id: int
+ layout: dict[str, dict[str, Any]], container_id: str, slice_id: int
) -> bool:
if container_id == "ROOT_ID":
return True
@@ -551,7 +551,7 @@ def check_slice_perms(_self: Any, slice_id: int) -> None:
def _deserialize_results_payload(
payload: Union[bytes, str], query: Query, use_msgpack: Optional[bool] = False
-) -> Dict[str, Any]:
+) -> dict[str, Any]:
logger.debug("Deserializing from msgpack: %r", use_msgpack)
if use_msgpack:
with stats_timing(
diff --git a/superset/viz.py b/superset/viz.py
index a7b4a8952a..3bb6204524 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -30,19 +30,7 @@ import re
from collections import defaultdict, OrderedDict
from datetime import date, datetime, timedelta
from itertools import product
-from typing import (
- Any,
- Callable,
- cast,
- Dict,
- List,
- Optional,
- Set,
- Tuple,
- Type,
- TYPE_CHECKING,
- Union,
-)
+from typing import Any, Callable, cast, Optional, TYPE_CHECKING
import geohash
import numpy as np
@@ -124,7 +112,7 @@ class BaseViz: # pylint: disable=too-many-public-methods
"""All visualizations derive this base class"""
- viz_type: Optional[str] = None
+ viz_type: str | None = None
verbose_name = "Base Viz"
credits = ""
is_timeseries = False
@@ -134,8 +122,8 @@ class BaseViz: # pylint: disable=too-many-public-methods
@deprecated(deprecated_in="3.0")
def __init__(
self,
- datasource: "BaseDatasource",
- form_data: Dict[str, Any],
+ datasource: BaseDatasource,
+ form_data: dict[str, Any],
force: bool = False,
force_cached: bool = False,
) -> None:
@@ -150,25 +138,25 @@ class BaseViz: # pylint: disable=too-many-public-methods
self.query = ""
self.token = utils.get_form_data_token(form_data)
- self.groupby: List[Column] = self.form_data.get("groupby") or []
+ self.groupby: list[Column] = self.form_data.get("groupby") or []
self.time_shift = timedelta()
- self.status: Optional[str] = None
+ self.status: str | None = None
self.error_msg = ""
- self.results: Optional[QueryResult] = None
- self.applied_filter_columns: List[Column] = []
- self.rejected_filter_columns: List[Column] = []
- self.errors: List[Dict[str, Any]] = []
+ self.results: QueryResult | None = None
+ self.applied_filter_columns: list[Column] = []
+ self.rejected_filter_columns: list[Column] = []
+ self.errors: list[dict[str, Any]] = []
self.force = force
self._force_cached = force_cached
- self.from_dttm: Optional[datetime] = None
- self.to_dttm: Optional[datetime] = None
- self._extra_chart_data: List[Tuple[str, pd.DataFrame]] = []
+ self.from_dttm: datetime | None = None
+ self.to_dttm: datetime | None = None
+ self._extra_chart_data: list[tuple[str, pd.DataFrame]] = []
self.process_metrics()
- self.applied_filters: List[Dict[str, str]] = []
- self.rejected_filters: List[Dict[str, str]] = []
+ self.applied_filters: list[dict[str, str]] = []
+ self.rejected_filters: list[dict[str, str]] = []
@property
@deprecated(deprecated_in="3.0")
@@ -196,8 +184,8 @@ class BaseViz: # pylint: disable=too-many-public-methods
@staticmethod
@deprecated(deprecated_in="3.0")
def handle_js_int_overflow(
- data: Dict[str, List[Dict[str, Any]]]
- ) -> Dict[str, List[Dict[str, Any]]]:
+ data: dict[str, list[dict[str, Any]]]
+ ) -> dict[str, list[dict[str, Any]]]:
for record in data.get("records", {}):
for k, v in list(record.items()):
if isinstance(v, int):
@@ -259,7 +247,7 @@ class BaseViz: # pylint: disable=too-many-public-methods
return df
@deprecated(deprecated_in="3.0")
- def get_samples(self) -> Dict[str, Any]:
+ def get_samples(self) -> dict[str, Any]:
query_obj = self.query_obj()
query_obj.update(
{
@@ -281,7 +269,7 @@ class BaseViz: # pylint: disable=too-many-public-methods
}
@deprecated(deprecated_in="3.0")
- def get_df(self, query_obj: Optional[QueryObjectDict] = None) -> pd.DataFrame:
+ def get_df(self, query_obj: QueryObjectDict | None = None) -> pd.DataFrame:
"""Returns a pandas dataframe based on the query object"""
if not query_obj:
query_obj = self.query_obj()
@@ -346,10 +334,10 @@ class BaseViz: # pylint: disable=too-many-public-methods
@staticmethod
@deprecated(deprecated_in="3.0")
- def dedup_columns(*columns_args: Optional[List[Column]]) -> List[Column]:
+ def dedup_columns(*columns_args: list[Column] | None) -> list[Column]:
# dedup groupby and columns while preserving order
- labels: List[str] = []
- deduped_columns: List[Column] = []
+ labels: list[str] = []
+ deduped_columns: list[Column] = []
for columns in columns_args:
for column in columns or []:
label = get_column_name(column)
@@ -492,7 +480,7 @@ class BaseViz: # pylint: disable=too-many-public-methods
return md5_sha_from_str(json_data)
@deprecated(deprecated_in="3.0")
- def get_payload(self, query_obj: Optional[QueryObjectDict] = None) -> VizPayload:
+ def get_payload(self, query_obj: QueryObjectDict | None = None) -> VizPayload:
"""Returns a payload of metadata and data"""
try:
@@ -534,8 +522,8 @@ class BaseViz: # pylint: disable=too-many-public-methods
@deprecated(deprecated_in="3.0")
def get_df_payload( # pylint: disable=too-many-statements
- self, query_obj: Optional[QueryObjectDict] = None, **kwargs: Any
- ) -> Dict[str, Any]:
+ self, query_obj: QueryObjectDict | None = None, **kwargs: Any
+ ) -> dict[str, Any]:
"""Handles caching around the df payload retrieval"""
if not query_obj:
query_obj = self.query_obj()
@@ -587,7 +575,7 @@ class BaseViz: # pylint: disable=too-many-public-methods
)
+ get_column_names_from_columns(query_obj.get("groupby") or [])
+ utils.get_column_names_from_metrics(
- cast(List[Metric], query_obj.get("metrics") or [])
+ cast(list[Metric], query_obj.get("metrics") or [])
)
if col not in self.datasource.column_names
]
@@ -676,12 +664,12 @@ class BaseViz: # pylint: disable=too-many-public-methods
)
@deprecated(deprecated_in="3.0")
- def payload_json_and_has_error(self, payload: VizPayload) -> Tuple[str, bool]:
+ def payload_json_and_has_error(self, payload: VizPayload) -> tuple[str, bool]:
return self.json_dumps(payload), self.has_error(payload)
@property
@deprecated(deprecated_in="3.0")
- def data(self) -> Dict[str, Any]:
+ def data(self) -> dict[str, Any]:
"""This is the data object serialized to the js layer"""
content = {
"form_data": self.form_data,
@@ -692,7 +680,7 @@ class BaseViz: # pylint: disable=too-many-public-methods
return content
@deprecated(deprecated_in="3.0")
- def get_csv(self) -> Optional[str]:
+ def get_csv(self) -> str | None:
df = self.get_df_payload()["df"] # leverage caching logic
include_index = not isinstance(df.index, pd.RangeIndex)
return csv.df_to_escaped_csv(df, index=include_index, **config["CSV_EXPORT"])
@@ -766,8 +754,8 @@ class TableViz(BaseViz):
else QueryMode.AGGREGATE
)
- columns: List[str] # output columns sans time and percent_metric column
- percent_columns: List[str] = [] # percent columns that needs extra computation
+ columns: list[str] # output columns sans time and percent_metric column
+ percent_columns: list[str] = [] # percent columns that needs extra computation
if self.query_mode == QueryMode.RAW:
columns = get_metric_names(self.form_data.get("all_columns"))
@@ -906,7 +894,7 @@ class TimeTableViz(BaseViz):
return None
columns = None
- values: Union[List[str], str] = self.metric_labels
+ values: list[str] | str = self.metric_labels
if self.form_data.get("groupby"):
values = self.metric_labels[0]
columns = get_column_names(self.form_data.get("groupby"))
@@ -948,10 +936,8 @@ class PivotTableViz(BaseViz):
if transpose and not columns:
raise QueryObjectValidationError(
_(
- (
- "Please choose at least one 'Columns' field when "
- "select 'Transpose Pivot' option"
- )
+ "Please choose at least one 'Columns' field when "
+ "select 'Transpose Pivot' option"
)
)
if not metrics:
@@ -973,8 +959,8 @@ class PivotTableViz(BaseViz):
@staticmethod
@deprecated(deprecated_in="3.0")
def get_aggfunc(
- metric: str, df: pd.DataFrame, form_data: Dict[str, Any]
- ) -> Union[str, Callable[[Any], Any]]:
+ metric: str, df: pd.DataFrame, form_data: dict[str, Any]
+ ) -> str | Callable[[Any], Any]:
aggfunc = form_data.get("pandas_aggfunc") or "sum"
if pd.api.types.is_numeric_dtype(df[metric]):
# Ensure that Pandas's sum function mimics that of SQL.
@@ -985,7 +971,7 @@ class PivotTableViz(BaseViz):
@staticmethod
@deprecated(deprecated_in="3.0")
- def _format_datetime(value: Union[pd.Timestamp, datetime, date, str]) -> str:
+ def _format_datetime(value: pd.Timestamp | datetime | date | str) -> str:
"""
Format a timestamp in such a way that the viz will be able to apply
the correct formatting in the frontend.
@@ -994,7 +980,7 @@ class PivotTableViz(BaseViz):
:return: formatted timestamp if it is a valid timestamp, otherwise
the original value
"""
- tstamp: Optional[pd.Timestamp] = None
+ tstamp: pd.Timestamp | None = None
if isinstance(value, pd.Timestamp):
tstamp = value
if isinstance(value, (date, datetime)):
@@ -1018,7 +1004,7 @@ class PivotTableViz(BaseViz):
del df[DTTM_ALIAS]
metrics = [utils.get_metric_name(m) for m in self.form_data["metrics"]]
- aggfuncs: Dict[str, Union[str, Callable[[Any], Any]]] = {}
+ aggfuncs: dict[str, str | Callable[[Any], Any]] = {}
for metric in metrics:
aggfuncs[metric] = self.get_aggfunc(metric, df, self.form_data)
@@ -1088,7 +1074,7 @@ class TreemapViz(BaseViz):
return query_obj
@deprecated(deprecated_in="3.0")
- def _nest(self, metric: str, df: pd.DataFrame) -> List[Dict[str, Any]]:
+ def _nest(self, metric: str, df: pd.DataFrame) -> list[dict[str, Any]]:
nlevels = df.index.nlevels
if nlevels == 1:
result = [{"name": n, "value": v} for n, v in zip(df.index, df[metric])]
@@ -1200,7 +1186,7 @@ class NVD3Viz(BaseViz):
"""Base class for all nvd3 vizs"""
credits = 'NVD3.org'
- viz_type: Optional[str] = None
+ viz_type: str | None = None
verbose_name = "Base NVD3 Viz"
is_timeseries = False
@@ -1249,7 +1235,7 @@ class BubbleViz(NVD3Viz):
df["shape"] = "circle"
df["group"] = df[[get_column_name(self.series)]] # type: ignore
- series: Dict[Any, List[Any]] = defaultdict(list)
+ series: dict[Any, list[Any]] = defaultdict(list)
for row in df.to_dict(orient="records"):
series[row["group"]].append(row)
chart_data = []
@@ -1357,7 +1343,7 @@ class NVD3TimeSeriesViz(NVD3Viz):
verbose_name = _("Time Series - Line Chart")
sort_series = False
is_timeseries = True
- pivot_fill_value: Optional[int] = None
+ pivot_fill_value: int | None = None
@deprecated(deprecated_in="3.0")
def query_obj(self) -> QueryObjectDict:
@@ -1376,7 +1362,7 @@ class NVD3TimeSeriesViz(NVD3Viz):
@deprecated(deprecated_in="3.0")
def to_series( # pylint: disable=too-many-branches
self, df: pd.DataFrame, classed: str = "", title_suffix: str = ""
- ) -> List[Dict[str, Any]]:
+ ) -> list[dict[str, Any]]:
cols = []
for col in df.columns:
if col == "":
@@ -1393,7 +1379,7 @@ class NVD3TimeSeriesViz(NVD3Viz):
ys = series[name]
if df[name].dtype.kind not in "biufc":
continue
- series_title: Union[List[str], str, Tuple[str, ...]]
+ series_title: list[str] | str | tuple[str, ...]
if isinstance(name, list):
series_title = [str(title) for title in name]
elif isinstance(name, tuple):
@@ -1510,7 +1496,7 @@ class NVD3TimeSeriesViz(NVD3Viz):
dttm_series = df2[DTTM_ALIAS] + delta
df2 = df2.drop(DTTM_ALIAS, axis=1)
df2 = pd.concat([dttm_series, df2], axis=1)
- label = "{} offset".format(option)
+ label = f"{option} offset"
df2 = self.process_data(df2)
self._extra_chart_data.append((label, df2))
@@ -1524,9 +1510,7 @@ class NVD3TimeSeriesViz(NVD3Viz):
for i, (label, df2) in enumerate(self._extra_chart_data):
chart_data.extend(
- self.to_series(
- df2, classed="time-shift-{}".format(i), title_suffix=label
- )
+ self.to_series(df2, classed=f"time-shift-{i}", title_suffix=label)
)
else:
chart_data = []
@@ -1547,16 +1531,14 @@ class NVD3TimeSeriesViz(NVD3Viz):
diff = df / df2
else:
raise QueryObjectValidationError(
- "Invalid `comparison_type`: {0}".format(comparison_type)
+ f"Invalid `comparison_type`: {comparison_type}"
)
# remove leading/trailing NaNs from the time shift difference
diff = diff[diff.first_valid_index() : diff.last_valid_index()]
chart_data.extend(
- self.to_series(
- diff, classed="time-shift-{}".format(i), title_suffix=label
- )
+ self.to_series(diff, classed=f"time-shift-{i}", title_suffix=label)
)
if not self.sort_series:
@@ -1670,7 +1652,7 @@ class NVD3DualLineViz(NVD3Viz):
return query_obj
@deprecated(deprecated_in="3.0")
- def to_series(self, df: pd.DataFrame, classed: str = "") -> List[Dict[str, Any]]:
+ def to_series(self, df: pd.DataFrame, classed: str = "") -> list[dict[str, Any]]:
cols = []
for col in df.columns:
if col == "":
@@ -1823,7 +1805,7 @@ class HistogramViz(BaseViz):
return query_obj
@deprecated(deprecated_in="3.0")
- def labelify(self, keys: Union[List[str], str], column: str) -> str:
+ def labelify(self, keys: list[str] | str, column: str) -> str:
if isinstance(keys, str):
keys = [keys]
# removing undesirable characters
@@ -2033,17 +2015,17 @@ class SankeyViz(BaseViz):
df["target"] = df["target"].astype(str)
recs = df.to_dict(orient="records")
- hierarchy: Dict[str, Set[str]] = defaultdict(set)
+ hierarchy: dict[str, set[str]] = defaultdict(set)
for row in recs:
hierarchy[row["source"]].add(row["target"])
@deprecated(deprecated_in="3.0")
- def find_cycle(graph: Dict[str, Set[str]]) -> Optional[Tuple[str, str]]:
+ def find_cycle(graph: dict[str, set[str]]) -> tuple[str, str] | None:
"""Whether there's a cycle in a directed graph"""
path = set()
@deprecated(deprecated_in="3.0")
- def visit(vertex: str) -> Optional[Tuple[str, str]]:
+ def visit(vertex: str) -> tuple[str, str] | None:
path.add(vertex)
for neighbour in graph.get(vertex, ()):
if neighbour in path or visit(neighbour):
@@ -2214,7 +2196,7 @@ class FilterBoxViz(BaseViz):
"""A multi filter, multi-choice filter box to make dashboards interactive"""
- query_context_factory: Optional[QueryContextFactory] = None
+ query_context_factory: QueryContextFactory | None = None
viz_type = "filter_box"
verbose_name = _("Filters")
is_timeseries = False
@@ -2581,20 +2563,20 @@ class BaseDeckGLViz(BaseViz):
is_timeseries = False
credits = 'deck.gl'
- spatial_control_keys: List[str] = []
+ spatial_control_keys: list[str] = []
@deprecated(deprecated_in="3.0")
- def get_metrics(self) -> List[str]:
+ def get_metrics(self) -> list[str]:
# pylint: disable=attribute-defined-outside-init
self.metric = self.form_data.get("size")
return [self.metric] if self.metric else []
@deprecated(deprecated_in="3.0")
- def process_spatial_query_obj(self, key: str, group_by: List[str]) -> None:
+ def process_spatial_query_obj(self, key: str, group_by: list[str]) -> None:
group_by.extend(self.get_spatial_columns(key))
@deprecated(deprecated_in="3.0")
- def get_spatial_columns(self, key: str) -> List[str]:
+ def get_spatial_columns(self, key: str) -> list[str]:
spatial = self.form_data.get(key)
if spatial is None:
raise ValueError(_("Bad spatial key"))
@@ -2611,7 +2593,7 @@ class BaseDeckGLViz(BaseViz):
@staticmethod
@deprecated(deprecated_in="3.0")
- def parse_coordinates(latlog: Any) -> Optional[Tuple[float, float]]:
+ def parse_coordinates(latlog: Any) -> tuple[float, float] | None:
if not latlog:
return None
try:
@@ -2624,7 +2606,7 @@ class BaseDeckGLViz(BaseViz):
@staticmethod
@deprecated(deprecated_in="3.0")
- def reverse_geohash_decode(geohash_code: str) -> Tuple[str, str]:
+ def reverse_geohash_decode(geohash_code: str) -> tuple[str, str]:
lat, lng = geohash.decode(geohash_code)
return (lng, lat)
@@ -2692,7 +2674,7 @@ class BaseDeckGLViz(BaseViz):
self.add_null_filters()
query_obj = super().query_obj()
- group_by: List[str] = []
+ group_by: list[str] = []
for key in self.spatial_control_keys:
self.process_spatial_query_obj(key, group_by)
@@ -2720,7 +2702,7 @@ class BaseDeckGLViz(BaseViz):
return query_obj
@deprecated(deprecated_in="3.0")
- def get_js_columns(self, data: Dict[str, Any]) -> Dict[str, Any]:
+ def get_js_columns(self, data: dict[str, Any]) -> dict[str, Any]:
cols = self.form_data.get("js_columns") or []
return {col: data.get(col) for col in cols}
@@ -2748,7 +2730,7 @@ class BaseDeckGLViz(BaseViz):
}
@deprecated(deprecated_in="3.0")
- def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]:
+ def get_properties(self, data: dict[str, Any]) -> dict[str, Any]:
raise NotImplementedError()
@@ -2774,7 +2756,7 @@ class DeckScatterViz(BaseDeckGLViz):
return super().query_obj()
@deprecated(deprecated_in="3.0")
- def get_metrics(self) -> List[str]:
+ def get_metrics(self) -> list[str]:
# pylint: disable=attribute-defined-outside-init
self.metric = None
if self.point_radius_fixed.get("type") == "metric":
@@ -2783,7 +2765,7 @@ class DeckScatterViz(BaseDeckGLViz):
return []
@deprecated(deprecated_in="3.0")
- def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]:
+ def get_properties(self, data: dict[str, Any]) -> dict[str, Any]:
return {
"metric": data.get(self.metric_label) if self.metric_label else None,
"radius": self.fixed_value
@@ -2825,7 +2807,7 @@ class DeckScreengrid(BaseDeckGLViz):
return super().query_obj()
@deprecated(deprecated_in="3.0")
- def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]:
+ def get_properties(self, data: dict[str, Any]) -> dict[str, Any]:
return {
"position": data.get("spatial"),
"weight": (data.get(self.metric_label) if self.metric_label else None) or 1,
@@ -2849,7 +2831,7 @@ class DeckGrid(BaseDeckGLViz):
spatial_control_keys = ["spatial"]
@deprecated(deprecated_in="3.0")
- def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]:
+ def get_properties(self, data: dict[str, Any]) -> dict[str, Any]:
return {
"position": data.get("spatial"),
"weight": (data.get(self.metric_label) if self.metric_label else None) or 1,
@@ -2864,7 +2846,7 @@ class DeckGrid(BaseDeckGLViz):
@deprecated(deprecated_in="3.0")
-def geohash_to_json(geohash_code: str) -> List[List[float]]:
+def geohash_to_json(geohash_code: str) -> list[list[float]]:
bbox = geohash.bbox(geohash_code)
return [
[bbox.get("w"), bbox.get("n")],
@@ -2907,7 +2889,7 @@ class DeckPathViz(BaseDeckGLViz):
return query_obj
@deprecated(deprecated_in="3.0")
- def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]:
+ def get_properties(self, data: dict[str, Any]) -> dict[str, Any]:
line_type = self.form_data["line_type"]
deser = self.deser_map[line_type]
line_column = self.form_data["line_column"]
@@ -2946,14 +2928,14 @@ class DeckPolygon(DeckPathViz):
return super().query_obj()
@deprecated(deprecated_in="3.0")
- def get_metrics(self) -> List[str]:
+ def get_metrics(self) -> list[str]:
metrics = [self.form_data.get("metric")]
if self.elevation.get("type") == "metric":
metrics.append(self.elevation.get("value"))
return [metric for metric in metrics if metric]
@deprecated(deprecated_in="3.0")
- def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]:
+ def get_properties(self, data: dict[str, Any]) -> dict[str, Any]:
super().get_properties(data)
elevation = self.form_data["point_radius_fixed"]["value"]
type_ = self.form_data["point_radius_fixed"]["type"]
@@ -2974,7 +2956,7 @@ class DeckHex(BaseDeckGLViz):
spatial_control_keys = ["spatial"]
@deprecated(deprecated_in="3.0")
- def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]:
+ def get_properties(self, data: dict[str, Any]) -> dict[str, Any]:
return {
"position": data.get("spatial"),
"weight": (data.get(self.metric_label) if self.metric_label else None) or 1,
@@ -2996,7 +2978,7 @@ class DeckHeatmap(BaseDeckGLViz):
verbose_name = _("Deck.gl - Heatmap")
spatial_control_keys = ["spatial"]
- def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]:
+ def get_properties(self, data: dict[str, Any]) -> dict[str, Any]:
return {
"position": data.get("spatial"),
"weight": (data.get(self.metric_label) if self.metric_label else None) or 1,
@@ -3025,7 +3007,7 @@ class DeckGeoJson(BaseDeckGLViz):
return query_obj
@deprecated(deprecated_in="3.0")
- def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]:
+ def get_properties(self, data: dict[str, Any]) -> dict[str, Any]:
geojson = data[get_column_name(self.form_data["geojson"])]
return json.loads(geojson)
@@ -3047,7 +3029,7 @@ class DeckArc(BaseDeckGLViz):
return super().query_obj()
@deprecated(deprecated_in="3.0")
- def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]:
+ def get_properties(self, data: dict[str, Any]) -> dict[str, Any]:
dim = self.form_data.get("dimension")
return {
"sourcePosition": data.get("start_spatial"),
@@ -3153,7 +3135,7 @@ class PairedTTestViz(BaseViz):
else:
cols.append(col)
df.columns = cols
- data: Dict[str, List[Dict[str, Any]]] = {}
+ data: dict[str, list[dict[str, Any]]] = {}
series = df.to_dict("series")
for name_set in df.columns:
# If no groups are defined, nameSet will be the metric name
@@ -3188,7 +3170,7 @@ class RoseViz(NVD3TimeSeriesViz):
return None
data = super().get_data(df)
- result: Dict[str, List[Dict[str, str]]] = {}
+ result: dict[str, list[dict[str, str]]] = {}
for datum in data:
key = datum["key"]
for val in datum["values"]:
@@ -3227,8 +3209,8 @@ class PartitionViz(NVD3TimeSeriesViz):
@staticmethod
@deprecated(deprecated_in="3.0")
def levels_for(
- time_op: str, groups: List[str], df: pd.DataFrame
- ) -> Dict[int, pd.Series]:
+ time_op: str, groups: list[str], df: pd.DataFrame
+ ) -> dict[int, pd.Series]:
"""
Compute the partition at each `level` from the dataframe.
"""
@@ -3245,8 +3227,8 @@ class PartitionViz(NVD3TimeSeriesViz):
@staticmethod
@deprecated(deprecated_in="3.0")
def levels_for_diff(
- time_op: str, groups: List[str], df: pd.DataFrame
- ) -> Dict[int, pd.DataFrame]:
+ time_op: str, groups: list[str], df: pd.DataFrame
+ ) -> dict[int, pd.DataFrame]:
# Obtain a unique list of the time grains
times = list(set(df[DTTM_ALIAS]))
times.sort()
@@ -3282,8 +3264,8 @@ class PartitionViz(NVD3TimeSeriesViz):
@deprecated(deprecated_in="3.0")
def levels_for_time(
- self, groups: List[str], df: pd.DataFrame
- ) -> Dict[int, VizData]:
+ self, groups: list[str], df: pd.DataFrame
+ ) -> dict[int, VizData]:
procs = {}
for i in range(0, len(groups) + 1):
self.form_data["groupby"] = groups[:i]
@@ -3295,11 +3277,11 @@ class PartitionViz(NVD3TimeSeriesViz):
@deprecated(deprecated_in="3.0")
def nest_values(
self,
- levels: Dict[int, pd.DataFrame],
+ levels: dict[int, pd.DataFrame],
level: int = 0,
- metric: Optional[str] = None,
- dims: Optional[List[str]] = None,
- ) -> List[Dict[str, Any]]:
+ metric: str | None = None,
+ dims: list[str] | None = None,
+ ) -> list[dict[str, Any]]:
"""
Nest values at each level on the back-end with
access and setting, instead of summing from the bottom.
@@ -3340,11 +3322,11 @@ class PartitionViz(NVD3TimeSeriesViz):
@deprecated(deprecated_in="3.0")
def nest_procs(
self,
- procs: Dict[int, pd.DataFrame],
+ procs: dict[int, pd.DataFrame],
level: int = -1,
- dims: Optional[Tuple[str, ...]] = None,
+ dims: tuple[str, ...] | None = None,
time: Any = None,
- ) -> List[Dict[str, Any]]:
+ ) -> list[dict[str, Any]]:
if dims is None:
dims = ()
if level == -1:
@@ -3395,7 +3377,7 @@ class PartitionViz(NVD3TimeSeriesViz):
@deprecated(deprecated_in="3.0")
-def get_subclasses(cls: Type[BaseViz]) -> Set[Type[BaseViz]]:
+def get_subclasses(cls: type[BaseViz]) -> set[type[BaseViz]]:
return set(cls.__subclasses__()).union(
[sc for c in cls.__subclasses__() for sc in get_subclasses(c)]
)
diff --git a/tests/common/logger_utils.py b/tests/common/logger_utils.py
index 98471342b7..8cb443cac8 100644
--- a/tests/common/logger_utils.py
+++ b/tests/common/logger_utils.py
@@ -29,7 +29,7 @@ from inspect import (
Signature,
)
from logging import Logger
-from typing import Any, Callable, cast, Optional, Type, Union
+from typing import Any, Callable, cast, Union
_DEFAULT_ENTER_MSG_PREFIX = "enter to "
_DEFAULT_ENTER_MSG_SUFFIX = ""
@@ -48,11 +48,11 @@ empty_and_none = {Signature.empty, "None"}
Function = Callable[..., Any]
-Decorated = Union[Type[Any], Function]
+Decorated = Union[type[Any], Function]
def log(
- decorated: Optional[Decorated] = None,
+ decorated: Decorated | None = None,
*,
prefix_enter_msg: str = _DEFAULT_ENTER_MSG_PREFIX,
suffix_enter_msg: str = _DEFAULT_ENTER_MSG_SUFFIX,
@@ -85,11 +85,11 @@ def _make_decorator(
def decorator(decorated: Decorated):
decorated_logger = _get_logger(decorated)
- def decorator_class(clazz: Type[Any]) -> Type[Any]:
+ def decorator_class(clazz: type[Any]) -> type[Any]:
_decorate_class_members_with_logs(clazz)
return clazz
- def _decorate_class_members_with_logs(clazz: Type[Any]) -> None:
+ def _decorate_class_members_with_logs(clazz: type[Any]) -> None:
members = getmembers(
clazz, predicate=lambda val: ismethod(val) or isfunction(val)
)
@@ -160,7 +160,7 @@ def _make_decorator(
return _wrapper_func
if isclass(decorated):
- return decorator_class(cast(Type[Any], decorated))
+ return decorator_class(cast(type[Any], decorated))
return decorator_func(cast(Function, decorated))
return decorator
diff --git a/tests/common/query_context_generator.py b/tests/common/query_context_generator.py
index 15b013dc84..32b4063974 100644
--- a/tests/common/query_context_generator.py
+++ b/tests/common/query_context_generator.py
@@ -16,7 +16,7 @@
# under the License.
import copy
import dataclasses
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from superset.common.chart_data import ChartDataResultType
from superset.utils.core import AnnotationType, DTTM_ALIAS
@@ -42,7 +42,7 @@ query_birth_names = {
"where": "",
}
-QUERY_OBJECTS: Dict[str, Dict[str, object]] = {
+QUERY_OBJECTS: dict[str, dict[str, object]] = {
"birth_names": query_birth_names,
# `:suffix` are overrides only
"birth_names:include_time": {
@@ -205,7 +205,7 @@ def get_query_object(
query_name: str,
add_postprocessing_operations: bool,
add_time_offsets: bool,
-) -> Dict[str, Any]:
+) -> dict[str, Any]:
if query_name not in QUERY_OBJECTS:
raise Exception(f"QueryObject fixture not defined for datasource: {query_name}")
obj = QUERY_OBJECTS[query_name]
@@ -227,7 +227,7 @@ def get_query_object(
return query_object
-def _get_postprocessing_operation(query_name: str) -> List[Dict[str, Any]]:
+def _get_postprocessing_operation(query_name: str) -> list[dict[str, Any]]:
if query_name not in QUERY_OBJECTS:
raise Exception(
f"Post-processing fixture not defined for datasource: {query_name}"
@@ -250,8 +250,8 @@ class QueryContextGenerator:
add_time_offsets: bool = False,
table_id=1,
table_type="table",
- form_data: Optional[Dict[str, Any]] = None,
- ) -> Dict[str, Any]:
+ form_data: Optional[dict[str, Any]] = None,
+ ) -> dict[str, Any]:
form_data = form_data or {}
table_name = query_name.split(":")[0]
table = self.get_table(table_name, table_id, table_type)
diff --git a/tests/example_data/data_generator/base_generator.py b/tests/example_data/data_generator/base_generator.py
index 023b929091..38ab2e5413 100644
--- a/tests/example_data/data_generator/base_generator.py
+++ b/tests/example_data/data_generator/base_generator.py
@@ -15,10 +15,11 @@
# specific language governing permissions and limitations
# under the License.
from abc import ABC, abstractmethod
-from typing import Any, Dict, Iterable
+from collections.abc import Iterable
+from typing import Any
class ExampleDataGenerator(ABC):
@abstractmethod
- def generate(self) -> Iterable[Dict[Any, Any]]:
+ def generate(self) -> Iterable[dict[Any, Any]]:
...
diff --git a/tests/example_data/data_generator/birth_names/birth_names_generator.py b/tests/example_data/data_generator/birth_names/birth_names_generator.py
index 2b68abbd4f..a8e8c45e28 100644
--- a/tests/example_data/data_generator/birth_names/birth_names_generator.py
+++ b/tests/example_data/data_generator/birth_names/birth_names_generator.py
@@ -16,9 +16,10 @@
# under the License.
from __future__ import annotations
+from collections.abc import Iterable
from datetime import datetime
from random import choice, randint
-from typing import Any, Dict, Iterable, TYPE_CHECKING
+from typing import Any, TYPE_CHECKING
from tests.consts.birth_names import (
BOY,
@@ -58,7 +59,7 @@ class BirthNamesGenerator(ExampleDataGenerator):
self._until_not_include_year = start_year + years_amount
self._rows_per_year = rows_per_year
- def generate(self) -> Iterable[Dict[Any, Any]]:
+ def generate(self) -> Iterable[dict[Any, Any]]:
for year in range(self._start_year, self._until_not_include_year):
ds = self._make_year(year)
for _ in range(self._rows_per_year):
@@ -67,7 +68,7 @@ class BirthNamesGenerator(ExampleDataGenerator):
def _make_year(self, year: int):
return datetime(year, 1, 1, 0, 0, 0)
- def generate_row(self, dt: datetime) -> Dict[Any, Any]:
+ def generate_row(self, dt: datetime) -> dict[Any, Any]:
gender = choice([BOY, GIRL])
num = randint(1, 100000)
return {
diff --git a/tests/example_data/data_loading/data_definitions/types.py b/tests/example_data/data_loading/data_definitions/types.py
index e393019e01..a1ed104348 100644
--- a/tests/example_data/data_loading/data_definitions/types.py
+++ b/tests/example_data/data_loading/data_definitions/types.py
@@ -24,8 +24,9 @@
# specific language governing permissions and limitations
# under the License.
from abc import ABC, abstractmethod
+from collections.abc import Iterable
from dataclasses import dataclass
-from typing import Any, Dict, Iterable, Optional
+from typing import Any, Optional
from sqlalchemy.types import TypeEngine
@@ -33,14 +34,14 @@ from sqlalchemy.types import TypeEngine
@dataclass
class TableMetaData:
table_name: str
- types: Optional[Dict[str, TypeEngine]]
+ types: Optional[dict[str, TypeEngine]]
@dataclass
class Table:
table_name: str
table_metadata: TableMetaData
- data: Iterable[Dict[Any, Any]]
+ data: Iterable[dict[Any, Any]]
class TableMetaDataFactory(ABC):
@@ -48,6 +49,6 @@ class TableMetaDataFactory(ABC):
def make(self) -> TableMetaData:
...
- def make_table(self, data: Iterable[Dict[Any, Any]]) -> Table:
+ def make_table(self, data: Iterable[dict[Any, Any]]) -> Table:
metadata = self.make()
return Table(metadata.table_name, metadata, data)
diff --git a/tests/example_data/data_loading/pandas/pandas_data_loader.py b/tests/example_data/data_loading/pandas/pandas_data_loader.py
index 7f41602054..49dcf3b2db 100644
--- a/tests/example_data/data_loading/pandas/pandas_data_loader.py
+++ b/tests/example_data/data_loading/pandas/pandas_data_loader.py
@@ -17,7 +17,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
-from typing import Dict, Optional, TYPE_CHECKING
+from typing import TYPE_CHECKING
from pandas import DataFrame
from sqlalchemy.inspection import inspect
@@ -63,10 +63,10 @@ class PandasDataLoader(DataLoader):
schema=self._detect_schema_name(),
)
- def _detect_schema_name(self) -> Optional[str]:
+ def _detect_schema_name(self) -> str | None:
return inspect(self._db_engine).default_schema_name
- def _take_data_types(self, table: Table) -> Optional[Dict[str, str]]:
+ def _take_data_types(self, table: Table) -> dict[str, str] | None:
if metadata_table := table.table_metadata:
types = metadata_table.types
if types:
diff --git a/tests/example_data/data_loading/pandas/pands_data_loading_conf.py b/tests/example_data/data_loading/pandas/pands_data_loading_conf.py
index 1c43adc931..8de12b39ef 100644
--- a/tests/example_data/data_loading/pandas/pands_data_loading_conf.py
+++ b/tests/example_data/data_loading/pandas/pands_data_loading_conf.py
@@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations
-from typing import Any, Dict
+from typing import Any
default_pandas_data_loader_config = {
"if_exists": "replace",
@@ -54,7 +54,7 @@ class PandasLoaderConfigurations:
self.support_datetime_type = support_datetime_type
@classmethod
- def make_from_dict(cls, _dict: Dict[str, Any]) -> PandasLoaderConfigurations:
+ def make_from_dict(cls, _dict: dict[str, Any]) -> PandasLoaderConfigurations:
copy_dict = default_pandas_data_loader_config.copy()
copy_dict.update(_dict)
return PandasLoaderConfigurations(**copy_dict) # type: ignore
diff --git a/tests/example_data/data_loading/pandas/table_df_convertor.py b/tests/example_data/data_loading/pandas/table_df_convertor.py
index e801c8464e..aad1077ce5 100644
--- a/tests/example_data/data_loading/pandas/table_df_convertor.py
+++ b/tests/example_data/data_loading/pandas/table_df_convertor.py
@@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations
-from typing import Optional, TYPE_CHECKING
+from typing import TYPE_CHECKING
from pandas import DataFrame
@@ -30,10 +30,10 @@ if TYPE_CHECKING:
@log
class TableToDfConvertorImpl(TableToDfConvertor):
convert_datetime_to_str: bool
- _time_format: Optional[str]
+ _time_format: str | None
def __init__(
- self, convert_ds_to_datetime: bool, time_format: Optional[str] = None
+ self, convert_ds_to_datetime: bool, time_format: str | None = None
) -> None:
self.convert_datetime_to_str = convert_ds_to_datetime
self._time_format = time_format
diff --git a/tests/integration_tests/access_tests.py b/tests/integration_tests/access_tests.py
index 38fd105240..79fdff6346 100644
--- a/tests/integration_tests/access_tests.py
+++ b/tests/integration_tests/access_tests.py
@@ -516,7 +516,7 @@ class TestRequestAccess(SupersetTestCase):
)
self.assertEqual(
access_request3.roles_with_datasource,
- "".format(approve_link_3),
+ f"",
)
# cleanup
diff --git a/tests/integration_tests/advanced_data_type/api_tests.py b/tests/integration_tests/advanced_data_type/api_tests.py
index 5bfe308e16..e865069462 100644
--- a/tests/integration_tests/advanced_data_type/api_tests.py
+++ b/tests/integration_tests/advanced_data_type/api_tests.py
@@ -24,7 +24,7 @@ from superset.utils.core import get_example_default_schema
from tests.integration_tests.utils.get_dashboards import get_dashboards_ids
from unittest import mock
from sqlalchemy import Column
-from typing import Any, List
+from typing import Any
from superset.advanced_data_type.types import (
AdvancedDataType,
AdvancedDataTypeRequest,
@@ -52,7 +52,7 @@ def translation_func(req: AdvancedDataTypeRequest) -> AdvancedDataTypeResponse:
return target_resp
-def translate_filter_func(col: Column, op: FilterOperator, values: List[Any]):
+def translate_filter_func(col: Column, op: FilterOperator, values: list[Any]):
pass
diff --git a/tests/integration_tests/base_tests.py b/tests/integration_tests/base_tests.py
index f70f0f63bd..fec66f88d2 100644
--- a/tests/integration_tests/base_tests.py
+++ b/tests/integration_tests/base_tests.py
@@ -20,7 +20,7 @@ from datetime import datetime
import imp
import json
from contextlib import contextmanager
-from typing import Any, Dict, Union, List, Optional
+from typing import Any, Union, Optional
from unittest.mock import Mock, patch, MagicMock
import pandas as pd
@@ -67,12 +67,12 @@ def get_resp(
else:
resp = client.get(url, follow_redirects=follow_redirects)
if raise_on_error and resp.status_code > 400:
- raise Exception("http request failed with code {}".format(resp.status_code))
+ raise Exception(f"http request failed with code {resp.status_code}")
return resp.data.decode("utf-8")
def post_assert_metric(
- client: Any, uri: str, data: Dict[str, Any], func_name: str
+ client: Any, uri: str, data: dict[str, Any], func_name: str
) -> Response:
"""
Simple client post with an extra assertion for statsd metrics
@@ -121,7 +121,7 @@ class SupersetTestCase(TestCase):
@staticmethod
def create_user_with_roles(
- username: str, roles: List[str], should_create_roles: bool = False
+ username: str, roles: list[str], should_create_roles: bool = False
):
user_to_create = security_manager.find_user(username)
if not user_to_create:
@@ -485,12 +485,12 @@ class SupersetTestCase(TestCase):
return rv
def post_assert_metric(
- self, uri: str, data: Dict[str, Any], func_name: str
+ self, uri: str, data: dict[str, Any], func_name: str
) -> Response:
return post_assert_metric(self.client, uri, data, func_name)
def put_assert_metric(
- self, uri: str, data: Dict[str, Any], func_name: str
+ self, uri: str, data: dict[str, Any], func_name: str
) -> Response:
"""
Simple client put with an extra assertion for statsd metrics
diff --git a/tests/integration_tests/cachekeys/api_tests.py b/tests/integration_tests/cachekeys/api_tests.py
index d3552bfc8d..c867ce7f51 100644
--- a/tests/integration_tests/cachekeys/api_tests.py
+++ b/tests/integration_tests/cachekeys/api_tests.py
@@ -16,7 +16,7 @@
# under the License.
# isort:skip_file
"""Unit tests for Superset"""
-from typing import Dict, Any
+from typing import Any
import pytest
@@ -31,7 +31,7 @@ from tests.integration_tests.base_tests import (
@pytest.fixture
def invalidate(test_client, login_as_admin):
- def _invalidate(params: Dict[str, Any]):
+ def _invalidate(params: dict[str, Any]):
return post_assert_metric(
test_client, "api/v1/cachekey/invalidate", params, "invalidate"
)
diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py
index 99b3275281..fa09e56675 100644
--- a/tests/integration_tests/charts/api_tests.py
+++ b/tests/integration_tests/charts/api_tests.py
@@ -17,7 +17,6 @@
# isort:skip_file
"""Unit tests for Superset"""
import json
-import logging
from io import BytesIO
from zipfile import is_zipfile, ZipFile
diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py
index b02ccb5b96..f9e6b5e3b1 100644
--- a/tests/integration_tests/charts/data/api_tests.py
+++ b/tests/integration_tests/charts/data/api_tests.py
@@ -21,7 +21,7 @@ import unittest
import copy
from datetime import datetime
from io import BytesIO
-from typing import Any, Dict, Optional, List
+from typing import Any, Optional
from unittest import mock
from zipfile import ZipFile
@@ -740,11 +740,11 @@ class TestPostChartDataApi(BaseTestChartDataApi):
data = rv.json["result"][0]["data"]
- unique_names = set(row["name"] for row in data)
+ unique_names = {row["name"] for row in data}
self.maxDiff = None
self.assertEqual(len(unique_names), SERIES_LIMIT)
self.assertEqual(
- set(column for column in data[0].keys()), {"state", "name", "sum__num"}
+ {column for column in data[0].keys()}, {"state", "name", "sum__num"}
)
@pytest.mark.usefixtures(
@@ -1124,7 +1124,7 @@ class TestGetChartDataApi(BaseTestChartDataApi):
@pytest.fixture()
-def physical_query_context(physical_dataset) -> Dict[str, Any]:
+def physical_query_context(physical_dataset) -> dict[str, Any]:
return {
"datasource": {
"type": physical_dataset.type,
@@ -1218,7 +1218,7 @@ def test_data_cache_default_timeout(
def test_chart_cache_timeout(
- load_energy_table_with_slice: List[Slice],
+ load_energy_table_with_slice: list[Slice],
test_client,
login_as_admin,
physical_query_context,
diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py
index 0ea5bb5106..28da7b7913 100644
--- a/tests/integration_tests/conftest.py
+++ b/tests/integration_tests/conftest.py
@@ -19,7 +19,7 @@ from __future__ import annotations
import contextlib
import functools
import os
-from typing import Any, Callable, Dict, Optional, TYPE_CHECKING
+from typing import Any, Callable, TYPE_CHECKING
from unittest.mock import patch
import pytest
@@ -55,7 +55,7 @@ def test_client(app_context: AppContext):
@pytest.fixture
-def login_as(test_client: "FlaskClient[Any]"):
+def login_as(test_client: FlaskClient[Any]):
"""Fixture with app context and logged in admin user."""
def _login_as(username: str, password: str = "general"):
@@ -160,7 +160,7 @@ def drop_from_schema(engine: Engine, schema_name: str):
@pytest.fixture(scope="session")
def example_db_provider() -> Callable[[], Database]: # type: ignore
class _example_db_provider:
- _db: Optional[Database] = None
+ _db: Database | None = None
def __call__(self) -> Database:
with app.app_context():
@@ -257,7 +257,7 @@ def with_feature_flags(**mock_feature_flags):
return decorate
-def with_config(override_config: Dict[str, Any]):
+def with_config(override_config: dict[str, Any]):
"""
Use this decorator to mock specific config keys.
diff --git a/tests/integration_tests/core_tests.py b/tests/integration_tests/core_tests.py
index 2e9e287620..f0c72b0680 100644
--- a/tests/integration_tests/core_tests.py
+++ b/tests/integration_tests/core_tests.py
@@ -23,7 +23,6 @@ import html
import io
import json
import logging
-from typing import Dict, List
from urllib.parse import quote
import superset.utils.database
@@ -37,7 +36,6 @@ from sqlalchemy import Table
import pytest
import pytz
import random
-import re
import unittest
from unittest import mock
@@ -496,7 +494,7 @@ class TestCore(SupersetTestCase):
assert response.headers["Content-Type"] == "application/json"
response_body = json.loads(response.data.decode("utf-8"))
expected_body = {"error": "Could not load database driver: broken"}
- assert response_body == expected_body, "%s != %s" % (
+ assert response_body == expected_body, "{} != {}".format(
response_body,
expected_body,
)
@@ -515,7 +513,7 @@ class TestCore(SupersetTestCase):
assert response.headers["Content-Type"] == "application/json"
response_body = json.loads(response.data.decode("utf-8"))
expected_body = {"error": "Could not load database driver: mssql+pymssql"}
- assert response_body == expected_body, "%s != %s" % (
+ assert response_body == expected_body, "{} != {}".format(
response_body,
expected_body,
)
@@ -563,7 +561,7 @@ class TestCore(SupersetTestCase):
self.login(username=username)
database = superset.utils.database.get_example_database()
sqlalchemy_uri_decrypted = database.sqlalchemy_uri_decrypted
- url = "databaseview/edit/{}".format(database.id)
+ url = f"databaseview/edit/{database.id}"
data = {k: database.__getattribute__(k) for k in DatabaseView.add_columns}
data["sqlalchemy_uri"] = database.safe_sqlalchemy_uri()
self.client.post(url, data=data)
@@ -582,7 +580,7 @@ class TestCore(SupersetTestCase):
def test_warm_up_cache(self):
self.login()
slc = self.get_slice("Girls", db.session)
- data = self.get_json_resp("/superset/warm_up_cache?slice_id={}".format(slc.id))
+ data = self.get_json_resp(f"/superset/warm_up_cache?slice_id={slc.id}")
self.assertEqual(
data, [{"slice_id": slc.id, "viz_error": None, "viz_status": "success"}]
)
@@ -609,7 +607,7 @@ class TestCore(SupersetTestCase):
store_cache_keys = app.config["STORE_CACHE_KEYS_IN_METADATA_DB"]
app.config["STORE_CACHE_KEYS_IN_METADATA_DB"] = True
girls_slice = self.get_slice("Girls", db.session)
- self.get_json_resp("/superset/warm_up_cache?slice_id={}".format(girls_slice.id))
+ self.get_json_resp(f"/superset/warm_up_cache?slice_id={girls_slice.id}")
ck = db.session.query(CacheKey).order_by(CacheKey.id.desc()).first()
assert ck.datasource_uid == f"{girls_slice.table.id}__table"
app.config["STORE_CACHE_KEYS_IN_METADATA_DB"] = store_cache_keys
@@ -650,7 +648,7 @@ class TestCore(SupersetTestCase):
kv_value = kv.value
self.assertEqual(json.loads(value), json.loads(kv_value))
- resp = self.client.get("/kv/{}/".format(kv.id))
+ resp = self.client.get(f"/kv/{kv.id}/")
self.assertEqual(resp.status_code, 200)
self.assertEqual(json.loads(value), json.loads(resp.data.decode("utf-8")))
@@ -662,7 +660,7 @@ class TestCore(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_csv_endpoint(self):
self.login()
- client_id = "{}".format(random.getrandbits(64))[:10]
+ client_id = f"{random.getrandbits(64)}"[:10]
get_name_sql = """
SELECT name
FROM birth_names
@@ -676,17 +674,17 @@ class TestCore(SupersetTestCase):
WHERE name = '{name}'
LIMIT 1
"""
- client_id = "{}".format(random.getrandbits(64))[:10]
+ client_id = f"{random.getrandbits(64)}"[:10]
self.run_sql(sql, client_id, raise_on_error=True)
- resp = self.get_resp("/superset/csv/{}".format(client_id))
+ resp = self.get_resp(f"/superset/csv/{client_id}")
data = csv.reader(io.StringIO(resp))
expected_data = csv.reader(io.StringIO(f"name\n{name}\n"))
- client_id = "{}".format(random.getrandbits(64))[:10]
+ client_id = f"{random.getrandbits(64)}"[:10]
self.run_sql(sql, client_id, raise_on_error=True)
- resp = self.get_resp("/superset/csv/{}".format(client_id))
+ resp = self.get_resp(f"/superset/csv/{client_id}")
data = csv.reader(io.StringIO(resp))
expected_data = csv.reader(io.StringIO(f"name\n{name}\n"))
@@ -704,7 +702,7 @@ class TestCore(SupersetTestCase):
def test_required_params_in_sql_json(self):
self.login()
- client_id = "{}".format(random.getrandbits(64))[:10]
+ client_id = f"{random.getrandbits(64)}"[:10]
data = {"client_id": client_id}
rv = self.client.post(
@@ -876,12 +874,12 @@ class TestCore(SupersetTestCase):
self.get_resp(slc.slice_url)
self.assertEqual(1, qry.count())
- def create_sample_csvfile(self, filename: str, content: List[str]) -> None:
+ def create_sample_csvfile(self, filename: str, content: list[str]) -> None:
with open(filename, "w+") as test_file:
for l in content:
test_file.write(f"{l}\n")
- def create_sample_excelfile(self, filename: str, content: Dict[str, str]) -> None:
+ def create_sample_excelfile(self, filename: str, content: dict[str, str]) -> None:
pd.DataFrame(content).to_excel(filename)
def enable_csv_upload(self, database: models.Database) -> None:
diff --git a/tests/integration_tests/csv_upload_tests.py b/tests/integration_tests/csv_upload_tests.py
index 91a76f97cf..9bc204ff06 100644
--- a/tests/integration_tests/csv_upload_tests.py
+++ b/tests/integration_tests/csv_upload_tests.py
@@ -20,7 +20,7 @@ import json
import logging
import os
import shutil
-from typing import Dict, Optional, Union
+from typing import Optional, Union
from unittest import mock
@@ -132,7 +132,7 @@ def get_upload_db():
def upload_csv(
filename: str,
table_name: str,
- extra: Optional[Dict[str, str]] = None,
+ extra: Optional[dict[str, str]] = None,
dtype: Union[str, None] = None,
):
csv_upload_db_id = get_upload_db().id
@@ -155,7 +155,7 @@ def upload_csv(
def upload_excel(
- filename: str, table_name: str, extra: Optional[Dict[str, str]] = None
+ filename: str, table_name: str, extra: Optional[dict[str, str]] = None
):
excel_upload_db_id = get_upload_db().id
form_data = {
@@ -175,7 +175,7 @@ def upload_excel(
def upload_columnar(
- filename: str, table_name: str, extra: Optional[Dict[str, str]] = None
+ filename: str, table_name: str, extra: Optional[dict[str, str]] = None
):
columnar_upload_db_id = get_upload_db().id
form_data = {
@@ -218,7 +218,7 @@ def mock_upload_to_s3(filename: str, upload_prefix: str, table: Table) -> str:
def escaped_double_quotes(text):
- return f"\"{text}\""
+ return rf"\"{text}\""
def escaped_parquet(text):
diff --git a/tests/integration_tests/dashboard_tests.py b/tests/integration_tests/dashboard_tests.py
index d54151db83..669bc93693 100644
--- a/tests/integration_tests/dashboard_tests.py
+++ b/tests/integration_tests/dashboard_tests.py
@@ -115,7 +115,7 @@ class TestDashboard(SupersetTestCase):
def get_mock_positions(self, dash):
positions = {"DASHBOARD_VERSION_KEY": "v2"}
for i, slc in enumerate(dash.slices):
- id = "DASHBOARD_CHART_TYPE-{}".format(i)
+ id = f"DASHBOARD_CHART_TYPE-{i}"
d = {
"type": "CHART",
"id": id,
@@ -167,7 +167,7 @@ class TestDashboard(SupersetTestCase):
# set a further modified_time for unit test
"last_modified_time": datetime.now().timestamp() + 1000,
}
- url = "/superset/save_dash/{}/".format(dash.id)
+ url = f"/superset/save_dash/{dash.id}/"
resp = self.get_resp(url, data=dict(data=json.dumps(data)))
self.assertIn("SUCCESS", resp)
@@ -189,7 +189,7 @@ class TestDashboard(SupersetTestCase):
"last_modified_time": datetime.now().timestamp() + 1000,
}
- url = "/superset/save_dash/{}/".format(dash.id)
+ url = f"/superset/save_dash/{dash.id}/"
resp = self.get_resp(url, data=dict(data=json.dumps(data)))
self.assertIn("SUCCESS", resp)
@@ -217,7 +217,7 @@ class TestDashboard(SupersetTestCase):
"last_modified_time": datetime.now().timestamp() + 1000,
}
- url = "/superset/save_dash/{}/".format(dash.id)
+ url = f"/superset/save_dash/{dash.id}/"
resp = self.get_resp(url, data=dict(data=json.dumps(data)))
self.assertIn("SUCCESS", resp)
@@ -239,7 +239,7 @@ class TestDashboard(SupersetTestCase):
# set a further modified_time for unit test
"last_modified_time": datetime.now().timestamp() + 1000,
}
- url = "/superset/save_dash/{}/".format(dash.id)
+ url = f"/superset/save_dash/{dash.id}/"
self.get_resp(url, data=dict(data=json.dumps(data)))
updatedDash = db.session.query(Dashboard).filter_by(slug="births").first()
self.assertEqual(updatedDash.dashboard_title, "new title")
@@ -264,7 +264,7 @@ class TestDashboard(SupersetTestCase):
# set a further modified_time for unit test
"last_modified_time": datetime.now().timestamp() + 1000,
}
- url = "/superset/save_dash/{}/".format(dash.id)
+ url = f"/superset/save_dash/{dash.id}/"
self.get_resp(url, data=dict(data=json.dumps(data)))
updatedDash = db.session.query(Dashboard).filter_by(slug="births").first()
self.assertIn("color_namespace", updatedDash.json_metadata)
@@ -301,13 +301,13 @@ class TestDashboard(SupersetTestCase):
# Save changes to Births dashboard and retrieve updated dash
dash_id = dash.id
- url = "/superset/save_dash/{}/".format(dash_id)
+ url = f"/superset/save_dash/{dash_id}/"
self.client.post(url, data=dict(data=json.dumps(data)))
dash = db.session.query(Dashboard).filter_by(id=dash_id).first()
orig_json_data = dash.data
# Verify that copy matches original
- url = "/superset/copy_dash/{}/".format(dash_id)
+ url = f"/superset/copy_dash/{dash_id}/"
resp = self.get_json_resp(url, data=dict(data=json.dumps(data)))
self.assertEqual(resp["dashboard_title"], "Copy Of Births")
self.assertEqual(resp["position_json"], orig_json_data["position_json"])
@@ -334,7 +334,7 @@ class TestDashboard(SupersetTestCase):
data = {
"slice_ids": [new_slice.data["slice_id"], existing_slice.data["slice_id"]]
}
- url = "/superset/add_slices/{}/".format(dash.id)
+ url = f"/superset/add_slices/{dash.id}/"
resp = self.client.post(url, data=dict(data=json.dumps(data)))
assert "SLICES ADDED" in resp.data.decode("utf-8")
@@ -375,7 +375,7 @@ class TestDashboard(SupersetTestCase):
# save dash
dash_id = dash.id
- url = "/superset/save_dash/{}/".format(dash_id)
+ url = f"/superset/save_dash/{dash_id}/"
self.client.post(url, data=dict(data=json.dumps(data)))
dash = db.session.query(Dashboard).filter_by(id=dash_id).first()
diff --git a/tests/integration_tests/dashboard_utils.py b/tests/integration_tests/dashboard_utils.py
index bea724dafc..6c3d000051 100644
--- a/tests/integration_tests/dashboard_utils.py
+++ b/tests/integration_tests/dashboard_utils.py
@@ -17,7 +17,7 @@
"""Utils to provide dashboards for tests"""
import json
-from typing import Any, Dict, List, Optional
+from typing import Optional
from pandas import DataFrame
@@ -65,7 +65,7 @@ def create_table_metadata(
def create_slice(
- title: str, viz_type: str, table: SqlaTable, slices_dict: Dict[str, str]
+ title: str, viz_type: str, table: SqlaTable, slices_dict: dict[str, str]
) -> Slice:
return Slice(
slice_name=title,
@@ -77,7 +77,7 @@ def create_slice(
def create_dashboard(
- slug: str, title: str, position: str, slices: List[Slice]
+ slug: str, title: str, position: str, slices: list[Slice]
) -> Dashboard:
dash = db.session.query(Dashboard).filter_by(slug=slug).one_or_none()
if dash:
diff --git a/tests/integration_tests/dashboards/api_tests.py b/tests/integration_tests/dashboards/api_tests.py
index f57afd95a6..49a6bbecbc 100644
--- a/tests/integration_tests/dashboards/api_tests.py
+++ b/tests/integration_tests/dashboards/api_tests.py
@@ -19,7 +19,7 @@
import json
from io import BytesIO
from time import sleep
-from typing import List, Optional
+from typing import Optional
from unittest.mock import ANY, patch
from zipfile import is_zipfile, ZipFile
@@ -66,7 +66,7 @@ DASHBOARDS_FIXTURE_COUNT = 10
class TestDashboardApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin):
resource_name = "dashboard"
- dashboards: List[Dashboard] = []
+ dashboards: list[Dashboard] = []
dashboard_data = {
"dashboard_title": "title1_changed",
"slug": "slug1_changed",
@@ -80,10 +80,10 @@ class TestDashboardApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixi
self,
dashboard_title: str,
slug: Optional[str],
- owners: List[int],
- roles: List[int] = [],
+ owners: list[int],
+ roles: list[int] = [],
created_by=None,
- slices: Optional[List[Slice]] = None,
+ slices: Optional[list[Slice]] = None,
position_json: str = "",
css: str = "",
json_metadata: str = "",
@@ -211,9 +211,9 @@ class TestDashboardApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixi
self.assertEqual(response.status_code, 200)
data = json.loads(response.data.decode("utf-8"))
dashboard = Dashboard.get("world_health")
- expected_dataset_ids = set([s.datasource_id for s in dashboard.slices])
+ expected_dataset_ids = {s.datasource_id for s in dashboard.slices}
result = data["result"]
- actual_dataset_ids = set([dataset["id"] for dataset in result])
+ actual_dataset_ids = {dataset["id"] for dataset in result}
self.assertEqual(actual_dataset_ids, expected_dataset_ids)
expected_values = [0, 1] if backend() == "presto" else [0, 1, 2]
self.assertEqual(result[0]["column_types"], expected_values)
@@ -927,7 +927,7 @@ class TestDashboardApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixi
buf = BytesIO()
with ZipFile(buf, "w") as bundle:
with bundle.open("sql/dump.sql", "w") as fp:
- fp.write("CREATE TABLE foo (bar INT)".encode())
+ fp.write(b"CREATE TABLE foo (bar INT)")
buf.seek(0)
return buf
diff --git a/tests/integration_tests/dashboards/base_case.py b/tests/integration_tests/dashboards/base_case.py
index a0a1ff630f..db85cd6409 100644
--- a/tests/integration_tests/dashboards/base_case.py
+++ b/tests/integration_tests/dashboards/base_case.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import json
-from typing import Any, Dict, Union
+from typing import Any, Union
import prison
from flask import Response
@@ -49,13 +49,13 @@ class DashboardTestCase(SupersetTestCase):
return self.client.get(DASHBOARDS_API_URL)
def save_dashboard_via_view(
- self, dashboard_id: Union[str, int], dashboard_data: Dict[str, Any]
+ self, dashboard_id: Union[str, int], dashboard_data: dict[str, Any]
) -> Response:
save_dash_url = SAVE_DASHBOARD_URL_FORMAT.format(dashboard_id)
return self.get_resp(save_dash_url, data=dict(data=json.dumps(dashboard_data)))
def save_dashboard(
- self, dashboard_id: Union[str, int], dashboard_data: Dict[str, Any]
+ self, dashboard_id: Union[str, int], dashboard_data: dict[str, Any]
) -> Response:
return self.save_dashboard_via_view(dashboard_id, dashboard_data)
diff --git a/tests/integration_tests/dashboards/dashboard_test_utils.py b/tests/integration_tests/dashboards/dashboard_test_utils.py
index df2687fba9..ee8001cdba 100644
--- a/tests/integration_tests/dashboards/dashboard_test_utils.py
+++ b/tests/integration_tests/dashboards/dashboard_test_utils.py
@@ -17,7 +17,7 @@
import logging
import random
import string
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Optional
from sqlalchemy import func
@@ -32,10 +32,10 @@ logger = logging.getLogger(__name__)
session = appbuilder.get_session
-def get_mock_positions(dashboard: Dashboard) -> Dict[str, Any]:
+def get_mock_positions(dashboard: Dashboard) -> dict[str, Any]:
positions = {"DASHBOARD_VERSION_KEY": "v2"}
for i, slc in enumerate(dashboard.slices):
- id_ = "DASHBOARD_CHART_TYPE-{}".format(i)
+ id_ = f"DASHBOARD_CHART_TYPE-{i}"
position_data: Any = {
"type": "CHART",
"id": id_,
@@ -48,7 +48,7 @@ def get_mock_positions(dashboard: Dashboard) -> Dict[str, Any]:
def build_save_dash_parts(
dashboard_slug: Optional[str] = None, dashboard_to_edit: Optional[Dashboard] = None
-) -> Tuple[Dashboard, Dict[str, Any], Dict[str, Any]]:
+) -> tuple[Dashboard, dict[str, Any], dict[str, Any]]:
if not dashboard_to_edit:
dashboard_slug = (
dashboard_slug if dashboard_slug else DEFAULT_DASHBOARD_SLUG_TO_TEST
@@ -68,7 +68,7 @@ def build_save_dash_parts(
return dashboard_to_edit, data_before_change, data_after_change
-def get_all_dashboards() -> List[Dashboard]:
+def get_all_dashboards() -> list[Dashboard]:
return db.session.query(Dashboard).all()
diff --git a/tests/integration_tests/dashboards/filter_sets/conftest.py b/tests/integration_tests/dashboards/filter_sets/conftest.py
index b7a28273b0..b19e929f9d 100644
--- a/tests/integration_tests/dashboards/filter_sets/conftest.py
+++ b/tests/integration_tests/dashboards/filter_sets/conftest.py
@@ -17,7 +17,8 @@
from __future__ import annotations
import json
-from typing import Any, Dict, Generator, List, TYPE_CHECKING
+from collections.abc import Generator
+from typing import Any, TYPE_CHECKING
import pytest
@@ -67,7 +68,7 @@ security_manager: BaseSecurityManager = sm
@pytest.fixture(autouse=True, scope="module")
-def test_users() -> Generator[Dict[str, int], None, None]:
+def test_users() -> Generator[dict[str, int], None, None]:
usernames = [
ADMIN_USERNAME_FOR_TEST,
DASHBOARD_OWNER_USERNAME,
@@ -82,16 +83,16 @@ def test_users() -> Generator[Dict[str, int], None, None]:
delete_users(usernames_to_ids)
-def delete_users(usernames_to_ids: Dict[str, int]) -> None:
+def delete_users(usernames_to_ids: dict[str, int]) -> None:
for username in usernames_to_ids.keys():
db.session.delete(security_manager.find_user(username))
db.session.commit()
def create_test_users(
- admin_role: Role, filter_set_role: Role, usernames: List[str]
-) -> Dict[str, int]:
- users: List[User] = []
+ admin_role: Role, filter_set_role: Role, usernames: list[str]
+) -> dict[str, int]:
+ users: list[User] = []
for username in usernames:
user = build_user(username, filter_set_role, admin_role)
users.append(user)
@@ -108,7 +109,7 @@ def build_user(username: str, filter_set_role: Role, admin_role: Role) -> User:
if not user:
user = security_manager.find_user(username)
if user is None:
- raise Exception("Failed to build the user {}".format(username))
+ raise Exception(f"Failed to build the user {username}")
return user
@@ -118,7 +119,7 @@ def build_filter_set_role() -> Role:
all_datasource_view_name: ViewMenu = security_manager.find_view_menu(
"all_datasource_access"
)
- pvms: List[PermissionView] = security_manager.find_permissions_view_menu(
+ pvms: list[PermissionView] = security_manager.find_permissions_view_menu(
filterset_view_name
) + security_manager.find_permissions_view_menu(all_datasource_view_name)
for pvm in pvms:
@@ -167,8 +168,8 @@ def dashboard_id(dashboard: Dashboard) -> Generator[int, None, None]:
@pytest.fixture
def filtersets(
- dashboard_id: int, test_users: Dict[str, int], dumped_valid_json_metadata: str
-) -> Generator[Dict[str, List[FilterSet]], None, None]:
+ dashboard_id: int, test_users: dict[str, int], dumped_valid_json_metadata: str
+) -> Generator[dict[str, list[FilterSet]], None, None]:
first_filter_set = FilterSet(
name="filter_set_1_of_" + str(dashboard_id),
dashboard_id=dashboard_id,
@@ -216,17 +217,17 @@ def filtersets(
@pytest.fixture
-def filterset_id(filtersets: Dict[str, List[FilterSet]]) -> int:
+def filterset_id(filtersets: dict[str, list[FilterSet]]) -> int:
return filtersets["Dashboard"][0].id
@pytest.fixture
-def valid_json_metadata() -> Dict[str, Any]:
+def valid_json_metadata() -> dict[str, Any]:
return {"nativeFilters": {}}
@pytest.fixture
-def dumped_valid_json_metadata(valid_json_metadata: Dict[str, Any]) -> str:
+def dumped_valid_json_metadata(valid_json_metadata: dict[str, Any]) -> str:
return json.dumps(valid_json_metadata)
@@ -238,7 +239,7 @@ def exists_user_id() -> int:
@pytest.fixture
def valid_filter_set_data_for_create(
dashboard_id: int, dumped_valid_json_metadata: str, exists_user_id: int
-) -> Dict[str, Any]:
+) -> dict[str, Any]:
name = "test_filter_set_of_dashboard_" + str(dashboard_id)
return {
NAME_FIELD: name,
@@ -252,7 +253,7 @@ def valid_filter_set_data_for_create(
@pytest.fixture
def valid_filter_set_data_for_update(
dashboard_id: int, dumped_valid_json_metadata: str, exists_user_id: int
-) -> Dict[str, Any]:
+) -> dict[str, Any]:
name = "name_changed_test_filter_set_of_dashboard_" + str(dashboard_id)
return {
NAME_FIELD: name,
@@ -273,13 +274,13 @@ def not_exists_user_id() -> int:
@pytest.fixture()
def dashboard_based_filter_set_dict(
- filtersets: Dict[str, List[FilterSet]]
-) -> Dict[str, Any]:
+ filtersets: dict[str, list[FilterSet]]
+) -> dict[str, Any]:
return filtersets["Dashboard"][0].to_dict()
@pytest.fixture()
def user_based_filter_set_dict(
- filtersets: Dict[str, List[FilterSet]]
-) -> Dict[str, Any]:
+ filtersets: dict[str, list[FilterSet]]
+) -> dict[str, Any]:
return filtersets[FILTER_SET_OWNER_USERNAME][0].to_dict()
diff --git a/tests/integration_tests/dashboards/filter_sets/create_api_tests.py b/tests/integration_tests/dashboards/filter_sets/create_api_tests.py
index b5d1919dd4..9891266101 100644
--- a/tests/integration_tests/dashboards/filter_sets/create_api_tests.py
+++ b/tests/integration_tests/dashboards/filter_sets/create_api_tests.py
@@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations
-from typing import Any, Dict
+from typing import Any
from flask.testing import FlaskClient
@@ -42,11 +42,11 @@ from tests.integration_tests.dashboards.filter_sets.utils import (
from tests.integration_tests.test_app import login
-def assert_filterset_was_not_created(filter_set_data: Dict[str, Any]) -> None:
+def assert_filterset_was_not_created(filter_set_data: dict[str, Any]) -> None:
assert get_filter_set_by_name(str(filter_set_data["name"])) is None
-def assert_filterset_was_created(filter_set_data: Dict[str, Any]) -> None:
+def assert_filterset_was_created(filter_set_data: dict[str, Any]) -> None:
assert get_filter_set_by_name(filter_set_data["name"]) is not None
@@ -54,7 +54,7 @@ class TestCreateFilterSetsApi:
def test_with_extra_field__400(
self,
dashboard_id: int,
- valid_filter_set_data_for_create: Dict[str, Any],
+ valid_filter_set_data_for_create: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -74,7 +74,7 @@ class TestCreateFilterSetsApi:
def test_with_id_field__400(
self,
dashboard_id: int,
- valid_filter_set_data_for_create: Dict[str, Any],
+ valid_filter_set_data_for_create: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -94,7 +94,7 @@ class TestCreateFilterSetsApi:
def test_with_dashboard_not_exists__404(
self,
not_exists_dashboard_id: int,
- valid_filter_set_data_for_create: Dict[str, Any],
+ valid_filter_set_data_for_create: dict[str, Any],
client: FlaskClient[Any],
):
# act
@@ -110,7 +110,7 @@ class TestCreateFilterSetsApi:
def test_without_name__400(
self,
dashboard_id: int,
- valid_filter_set_data_for_create: Dict[str, Any],
+ valid_filter_set_data_for_create: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -129,7 +129,7 @@ class TestCreateFilterSetsApi:
def test_with_none_name__400(
self,
dashboard_id: int,
- valid_filter_set_data_for_create: Dict[str, Any],
+ valid_filter_set_data_for_create: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -148,7 +148,7 @@ class TestCreateFilterSetsApi:
def test_with_int_as_name__400(
self,
dashboard_id: int,
- valid_filter_set_data_for_create: Dict[str, Any],
+ valid_filter_set_data_for_create: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -167,7 +167,7 @@ class TestCreateFilterSetsApi:
def test_without_description__201(
self,
dashboard_id: int,
- valid_filter_set_data_for_create: Dict[str, Any],
+ valid_filter_set_data_for_create: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -186,7 +186,7 @@ class TestCreateFilterSetsApi:
def test_with_none_description__201(
self,
dashboard_id: int,
- valid_filter_set_data_for_create: Dict[str, Any],
+ valid_filter_set_data_for_create: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -205,7 +205,7 @@ class TestCreateFilterSetsApi:
def test_with_int_as_description__400(
self,
dashboard_id: int,
- valid_filter_set_data_for_create: Dict[str, Any],
+ valid_filter_set_data_for_create: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -224,7 +224,7 @@ class TestCreateFilterSetsApi:
def test_without_json_metadata__400(
self,
dashboard_id: int,
- valid_filter_set_data_for_create: Dict[str, Any],
+ valid_filter_set_data_for_create: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -243,7 +243,7 @@ class TestCreateFilterSetsApi:
def test_with_invalid_json_metadata__400(
self,
dashboard_id: int,
- valid_filter_set_data_for_create: Dict[str, Any],
+ valid_filter_set_data_for_create: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -262,7 +262,7 @@ class TestCreateFilterSetsApi:
def test_without_owner_type__400(
self,
dashboard_id: int,
- valid_filter_set_data_for_create: Dict[str, Any],
+ valid_filter_set_data_for_create: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -281,7 +281,7 @@ class TestCreateFilterSetsApi:
def test_with_invalid_owner_type__400(
self,
dashboard_id: int,
- valid_filter_set_data_for_create: Dict[str, Any],
+ valid_filter_set_data_for_create: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -300,7 +300,7 @@ class TestCreateFilterSetsApi:
def test_without_owner_id_when_owner_type_is_user__400(
self,
dashboard_id: int,
- valid_filter_set_data_for_create: Dict[str, Any],
+ valid_filter_set_data_for_create: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -320,7 +320,7 @@ class TestCreateFilterSetsApi:
def test_without_owner_id_when_owner_type_is_dashboard__201(
self,
dashboard_id: int,
- valid_filter_set_data_for_create: Dict[str, Any],
+ valid_filter_set_data_for_create: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -340,7 +340,7 @@ class TestCreateFilterSetsApi:
def test_with_not_exists_owner__400(
self,
dashboard_id: int,
- valid_filter_set_data_for_create: Dict[str, Any],
+ valid_filter_set_data_for_create: dict[str, Any],
not_exists_user_id: int,
client: FlaskClient[Any],
):
@@ -361,8 +361,8 @@ class TestCreateFilterSetsApi:
def test_when_caller_is_admin_and_owner_is_admin__201(
self,
dashboard_id: int,
- test_users: Dict[str, int],
- valid_filter_set_data_for_create: Dict[str, Any],
+ test_users: dict[str, int],
+ valid_filter_set_data_for_create: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -384,8 +384,8 @@ class TestCreateFilterSetsApi:
def test_when_caller_is_admin_and_owner_is_dashboard_owner__201(
self,
dashboard_id: int,
- test_users: Dict[str, int],
- valid_filter_set_data_for_create: Dict[str, Any],
+ test_users: dict[str, int],
+ valid_filter_set_data_for_create: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -407,8 +407,8 @@ class TestCreateFilterSetsApi:
def test_when_caller_is_admin_and_owner_is_regular_user__201(
self,
dashboard_id: int,
- test_users: Dict[str, int],
- valid_filter_set_data_for_create: Dict[str, Any],
+ test_users: dict[str, int],
+ valid_filter_set_data_for_create: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -430,8 +430,8 @@ class TestCreateFilterSetsApi:
def test_when_caller_is_admin_and_owner_type_is_dashboard__201(
self,
dashboard_id: int,
- test_users: Dict[str, int],
- valid_filter_set_data_for_create: Dict[str, Any],
+ test_users: dict[str, int],
+ valid_filter_set_data_for_create: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -451,8 +451,8 @@ class TestCreateFilterSetsApi:
def test_when_caller_is_dashboard_owner_and_owner_is_admin__201(
self,
dashboard_id: int,
- test_users: Dict[str, int],
- valid_filter_set_data_for_create: Dict[str, Any],
+ test_users: dict[str, int],
+ valid_filter_set_data_for_create: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -474,8 +474,8 @@ class TestCreateFilterSetsApi:
def test_when_caller_is_dashboard_owner_and_owner_is_dashboard_owner__201(
self,
dashboard_id: int,
- test_users: Dict[str, int],
- valid_filter_set_data_for_create: Dict[str, Any],
+ test_users: dict[str, int],
+ valid_filter_set_data_for_create: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -497,8 +497,8 @@ class TestCreateFilterSetsApi:
def test_when_caller_is_dashboard_owner_and_owner_is_regular_user__201(
self,
dashboard_id: int,
- test_users: Dict[str, int],
- valid_filter_set_data_for_create: Dict[str, Any],
+ test_users: dict[str, int],
+ valid_filter_set_data_for_create: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -520,8 +520,8 @@ class TestCreateFilterSetsApi:
def test_when_caller_is_dashboard_owner_and_owner_type_is_dashboard__201(
self,
dashboard_id: int,
- test_users: Dict[str, int],
- valid_filter_set_data_for_create: Dict[str, Any],
+ test_users: dict[str, int],
+ valid_filter_set_data_for_create: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -541,8 +541,8 @@ class TestCreateFilterSetsApi:
def test_when_caller_is_regular_user_and_owner_is_admin__201(
self,
dashboard_id: int,
- test_users: Dict[str, int],
- valid_filter_set_data_for_create: Dict[str, Any],
+ test_users: dict[str, int],
+ valid_filter_set_data_for_create: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -564,8 +564,8 @@ class TestCreateFilterSetsApi:
def test_when_caller_is_regular_user_and_owner_is_dashboard_owner__201(
self,
dashboard_id: int,
- test_users: Dict[str, int],
- valid_filter_set_data_for_create: Dict[str, Any],
+ test_users: dict[str, int],
+ valid_filter_set_data_for_create: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -587,8 +587,8 @@ class TestCreateFilterSetsApi:
def test_when_caller_is_regular_user_and_owner_is_regular_user__201(
self,
dashboard_id: int,
- test_users: Dict[str, int],
- valid_filter_set_data_for_create: Dict[str, Any],
+ test_users: dict[str, int],
+ valid_filter_set_data_for_create: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -610,8 +610,8 @@ class TestCreateFilterSetsApi:
def test_when_caller_is_regular_user_and_owner_type_is_dashboard__403(
self,
dashboard_id: int,
- test_users: Dict[str, int],
- valid_filter_set_data_for_create: Dict[str, Any],
+ test_users: dict[str, int],
+ valid_filter_set_data_for_create: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
diff --git a/tests/integration_tests/dashboards/filter_sets/delete_api_tests.py b/tests/integration_tests/dashboards/filter_sets/delete_api_tests.py
index 7011cb5781..41d7ea59f7 100644
--- a/tests/integration_tests/dashboards/filter_sets/delete_api_tests.py
+++ b/tests/integration_tests/dashboards/filter_sets/delete_api_tests.py
@@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations
-from typing import Any, Dict, List, TYPE_CHECKING
+from typing import Any, TYPE_CHECKING
from tests.integration_tests.dashboards.filter_sets.consts import (
DASHBOARD_OWNER_USERNAME,
@@ -36,11 +36,11 @@ if TYPE_CHECKING:
from superset.models.filter_set import FilterSet
-def assert_filterset_was_not_deleted(filter_set_dict: Dict[str, Any]) -> None:
+def assert_filterset_was_not_deleted(filter_set_dict: dict[str, Any]) -> None:
assert get_filter_set_by_name(filter_set_dict["name"]) is not None
-def assert_filterset_deleted(filter_set_dict: Dict[str, Any]) -> None:
+def assert_filterset_deleted(filter_set_dict: dict[str, Any]) -> None:
assert get_filter_set_by_name(filter_set_dict["name"]) is None
@@ -48,7 +48,7 @@ class TestDeleteFilterSet:
def test_with_dashboard_exists_filterset_not_exists__200(
self,
dashboard_id: int,
- filtersets: Dict[str, List[FilterSet]],
+ filtersets: dict[str, list[FilterSet]],
client: FlaskClient[Any],
):
# arrange
@@ -62,7 +62,7 @@ class TestDeleteFilterSet:
def test_with_dashboard_not_exists_filterset_not_exists__404(
self,
not_exists_dashboard_id: int,
- filtersets: Dict[str, List[FilterSet]],
+ filtersets: dict[str, list[FilterSet]],
client: FlaskClient[Any],
):
# arrange
@@ -78,7 +78,7 @@ class TestDeleteFilterSet:
def test_with_dashboard_not_exists_filterset_exists__404(
self,
not_exists_dashboard_id: int,
- dashboard_based_filter_set_dict: Dict[str, Any],
+ dashboard_based_filter_set_dict: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -94,9 +94,9 @@ class TestDeleteFilterSet:
def test_when_caller_is_admin_and_owner_type_is_user__200(
self,
- test_users: Dict[str, int],
- user_based_filter_set_dict: Dict[str, Any],
- valid_filter_set_data_for_update: Dict[str, Any],
+ test_users: dict[str, int],
+ user_based_filter_set_dict: dict[str, Any],
+ valid_filter_set_data_for_update: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -110,9 +110,9 @@ class TestDeleteFilterSet:
def test_when_caller_is_admin_and_owner_type_is_dashboard__200(
self,
- test_users: Dict[str, int],
- dashboard_based_filter_set_dict: Dict[str, Any],
- valid_filter_set_data_for_update: Dict[str, Any],
+ test_users: dict[str, int],
+ dashboard_based_filter_set_dict: dict[str, Any],
+ valid_filter_set_data_for_update: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -126,9 +126,9 @@ class TestDeleteFilterSet:
def test_when_caller_is_dashboard_owner_and_owner_is_other_user_403(
self,
- test_users: Dict[str, int],
- user_based_filter_set_dict: Dict[str, Any],
- valid_filter_set_data_for_update: Dict[str, Any],
+ test_users: dict[str, int],
+ user_based_filter_set_dict: dict[str, Any],
+ valid_filter_set_data_for_update: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -143,9 +143,9 @@ class TestDeleteFilterSet:
def test_when_caller_is_dashboard_owner_and_owner_type_is_dashboard__200(
self,
- test_users: Dict[str, int],
- dashboard_based_filter_set_dict: Dict[str, Any],
- valid_filter_set_data_for_update: Dict[str, Any],
+ test_users: dict[str, int],
+ dashboard_based_filter_set_dict: dict[str, Any],
+ valid_filter_set_data_for_update: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -160,9 +160,9 @@ class TestDeleteFilterSet:
def test_when_caller_is_filterset_owner__200(
self,
- test_users: Dict[str, int],
- user_based_filter_set_dict: Dict[str, Any],
- valid_filter_set_data_for_update: Dict[str, Any],
+ test_users: dict[str, int],
+ user_based_filter_set_dict: dict[str, Any],
+ valid_filter_set_data_for_update: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -177,9 +177,9 @@ class TestDeleteFilterSet:
def test_when_caller_is_regular_user_and_owner_type_is_user__403(
self,
- test_users: Dict[str, int],
- user_based_filter_set_dict: Dict[str, Any],
- valid_filter_set_data_for_update: Dict[str, Any],
+ test_users: dict[str, int],
+ user_based_filter_set_dict: dict[str, Any],
+ valid_filter_set_data_for_update: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -194,9 +194,9 @@ class TestDeleteFilterSet:
def test_when_caller_is_regular_user_and_owner_type_is_dashboard__403(
self,
- test_users: Dict[str, int],
- dashboard_based_filter_set_dict: Dict[str, Any],
- valid_filter_set_data_for_update: Dict[str, Any],
+ test_users: dict[str, int],
+ dashboard_based_filter_set_dict: dict[str, Any],
+ valid_filter_set_data_for_update: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
diff --git a/tests/integration_tests/dashboards/filter_sets/get_api_tests.py b/tests/integration_tests/dashboards/filter_sets/get_api_tests.py
index ad40d0e33c..71c985310d 100644
--- a/tests/integration_tests/dashboards/filter_sets/get_api_tests.py
+++ b/tests/integration_tests/dashboards/filter_sets/get_api_tests.py
@@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations
-from typing import Any, Dict, List, Set, TYPE_CHECKING
+from typing import Any, TYPE_CHECKING
from tests.integration_tests.dashboards.filter_sets.consts import (
DASHBOARD_OWNER_USERNAME,
@@ -66,12 +66,12 @@ class TestGetFilterSetsApi:
def test_when_caller_admin__200(
self,
dashboard_id: int,
- filtersets: Dict[str, List[FilterSet]],
+ filtersets: dict[str, list[FilterSet]],
client: FlaskClient[Any],
):
# arrange
login(client, "admin")
- expected_ids: Set[int] = collect_all_ids(filtersets)
+ expected_ids: set[int] = collect_all_ids(filtersets)
# act
response = call_get_filter_sets(client, dashboard_id)
@@ -83,7 +83,7 @@ class TestGetFilterSetsApi:
def test_when_caller_dashboard_owner__200(
self,
dashboard_id: int,
- filtersets: Dict[str, List[FilterSet]],
+ filtersets: dict[str, list[FilterSet]],
client: FlaskClient[Any],
):
# arrange
@@ -100,7 +100,7 @@ class TestGetFilterSetsApi:
def test_when_caller_filterset_owner__200(
self,
dashboard_id: int,
- filtersets: Dict[str, List[FilterSet]],
+ filtersets: dict[str, list[FilterSet]],
client: FlaskClient[Any],
):
# arrange
@@ -117,12 +117,12 @@ class TestGetFilterSetsApi:
def test_when_caller_regular_user__200(
self,
dashboard_id: int,
- filtersets: Dict[str, List[int]],
+ filtersets: dict[str, list[int]],
client: FlaskClient[Any],
):
# arrange
login(client, REGULAR_USER)
- expected_ids: Set[int] = set()
+ expected_ids: set[int] = set()
# act
response = call_get_filter_sets(client, dashboard_id)
diff --git a/tests/integration_tests/dashboards/filter_sets/update_api_tests.py b/tests/integration_tests/dashboards/filter_sets/update_api_tests.py
index 07db98f617..a6e895a460 100644
--- a/tests/integration_tests/dashboards/filter_sets/update_api_tests.py
+++ b/tests/integration_tests/dashboards/filter_sets/update_api_tests.py
@@ -17,7 +17,7 @@
from __future__ import annotations
import json
-from typing import Any, Dict, List, TYPE_CHECKING
+from typing import Any, TYPE_CHECKING
from superset.dashboards.filter_sets.consts import (
DESCRIPTION_FIELD,
@@ -45,8 +45,8 @@ if TYPE_CHECKING:
def merge_two_filter_set_dict(
- first: Dict[Any, Any], second: Dict[Any, Any]
-) -> Dict[Any, Any]:
+ first: dict[Any, Any], second: dict[Any, Any]
+) -> dict[Any, Any]:
for d in [first, second]:
if JSON_METADATA_FIELD in d:
if PARAMS_PROPERTY not in d:
@@ -55,12 +55,12 @@ def merge_two_filter_set_dict(
return {**first, **second}
-def assert_filterset_was_not_updated(filter_set_dict: Dict[str, Any]) -> None:
+def assert_filterset_was_not_updated(filter_set_dict: dict[str, Any]) -> None:
assert filter_set_dict == get_filter_set_by_name(filter_set_dict["name"]).to_dict()
def assert_filterset_updated(
- filter_set_dict_before: Dict[str, Any], data_updated: Dict[str, Any]
+ filter_set_dict_before: dict[str, Any], data_updated: dict[str, Any]
) -> None:
expected_data = merge_two_filter_set_dict(filter_set_dict_before, data_updated)
assert expected_data == get_filter_set_by_name(expected_data["name"]).to_dict()
@@ -70,7 +70,7 @@ class TestUpdateFilterSet:
def test_with_dashboard_exists_filterset_not_exists__404(
self,
dashboard_id: int,
- filtersets: Dict[str, List[FilterSet]],
+ filtersets: dict[str, list[FilterSet]],
client: FlaskClient[Any],
):
# arrange
@@ -86,7 +86,7 @@ class TestUpdateFilterSet:
def test_with_dashboard_not_exists_filterset_not_exists__404(
self,
not_exists_dashboard_id: int,
- filtersets: Dict[str, List[FilterSet]],
+ filtersets: dict[str, list[FilterSet]],
client: FlaskClient[Any],
):
# arrange
@@ -102,7 +102,7 @@ class TestUpdateFilterSet:
def test_with_dashboard_not_exists_filterset_exists__404(
self,
not_exists_dashboard_id: int,
- dashboard_based_filter_set_dict: Dict[str, Any],
+ dashboard_based_filter_set_dict: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -118,8 +118,8 @@ class TestUpdateFilterSet:
def test_with_extra_field__400(
self,
- dashboard_based_filter_set_dict: Dict[str, Any],
- valid_filter_set_data_for_update: Dict[str, Any],
+ dashboard_based_filter_set_dict: dict[str, Any],
+ valid_filter_set_data_for_update: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -138,8 +138,8 @@ class TestUpdateFilterSet:
def test_with_id_field__400(
self,
- dashboard_based_filter_set_dict: Dict[str, Any],
- valid_filter_set_data_for_update: Dict[str, Any],
+ dashboard_based_filter_set_dict: dict[str, Any],
+ valid_filter_set_data_for_update: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -158,8 +158,8 @@ class TestUpdateFilterSet:
def test_with_none_name__400(
self,
- dashboard_based_filter_set_dict: Dict[str, Any],
- valid_filter_set_data_for_update: Dict[str, Any],
+ dashboard_based_filter_set_dict: dict[str, Any],
+ valid_filter_set_data_for_update: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -177,8 +177,8 @@ class TestUpdateFilterSet:
def test_with_int_as_name__400(
self,
- dashboard_based_filter_set_dict: Dict[str, Any],
- valid_filter_set_data_for_update: Dict[str, Any],
+ dashboard_based_filter_set_dict: dict[str, Any],
+ valid_filter_set_data_for_update: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -196,8 +196,8 @@ class TestUpdateFilterSet:
def test_without_name__200(
self,
- dashboard_based_filter_set_dict: Dict[str, Any],
- valid_filter_set_data_for_update: Dict[str, Any],
+ dashboard_based_filter_set_dict: dict[str, Any],
+ valid_filter_set_data_for_update: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -217,8 +217,8 @@ class TestUpdateFilterSet:
def test_with_none_description__400(
self,
- dashboard_based_filter_set_dict: Dict[str, Any],
- valid_filter_set_data_for_update: Dict[str, Any],
+ dashboard_based_filter_set_dict: dict[str, Any],
+ valid_filter_set_data_for_update: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -236,8 +236,8 @@ class TestUpdateFilterSet:
def test_with_int_as_description__400(
self,
- dashboard_based_filter_set_dict: Dict[str, Any],
- valid_filter_set_data_for_update: Dict[str, Any],
+ dashboard_based_filter_set_dict: dict[str, Any],
+ valid_filter_set_data_for_update: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -255,8 +255,8 @@ class TestUpdateFilterSet:
def test_without_description__200(
self,
- dashboard_based_filter_set_dict: Dict[str, Any],
- valid_filter_set_data_for_update: Dict[str, Any],
+ dashboard_based_filter_set_dict: dict[str, Any],
+ valid_filter_set_data_for_update: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -276,8 +276,8 @@ class TestUpdateFilterSet:
def test_with_invalid_json_metadata__400(
self,
- dashboard_based_filter_set_dict: Dict[str, Any],
- valid_filter_set_data_for_update: Dict[str, Any],
+ dashboard_based_filter_set_dict: dict[str, Any],
+ valid_filter_set_data_for_update: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -295,9 +295,9 @@ class TestUpdateFilterSet:
def test_with_json_metadata__200(
self,
- dashboard_based_filter_set_dict: Dict[str, Any],
- valid_filter_set_data_for_update: Dict[str, Any],
- valid_json_metadata: Dict[Any, Any],
+ dashboard_based_filter_set_dict: dict[str, Any],
+ valid_filter_set_data_for_update: dict[str, Any],
+ valid_json_metadata: dict[Any, Any],
client: FlaskClient[Any],
):
# arrange
@@ -320,8 +320,8 @@ class TestUpdateFilterSet:
def test_with_invalid_owner_type__400(
self,
- dashboard_based_filter_set_dict: Dict[str, Any],
- valid_filter_set_data_for_update: Dict[str, Any],
+ dashboard_based_filter_set_dict: dict[str, Any],
+ valid_filter_set_data_for_update: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -339,8 +339,8 @@ class TestUpdateFilterSet:
def test_with_user_owner_type__400(
self,
- dashboard_based_filter_set_dict: Dict[str, Any],
- valid_filter_set_data_for_update: Dict[str, Any],
+ dashboard_based_filter_set_dict: dict[str, Any],
+ valid_filter_set_data_for_update: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -358,8 +358,8 @@ class TestUpdateFilterSet:
def test_with_dashboard_owner_type__200(
self,
- user_based_filter_set_dict: Dict[str, Any],
- valid_filter_set_data_for_update: Dict[str, Any],
+ user_based_filter_set_dict: dict[str, Any],
+ valid_filter_set_data_for_update: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -382,9 +382,9 @@ class TestUpdateFilterSet:
def test_when_caller_is_admin_and_owner_type_is_user__200(
self,
- test_users: Dict[str, int],
- user_based_filter_set_dict: Dict[str, Any],
- valid_filter_set_data_for_update: Dict[str, Any],
+ test_users: dict[str, int],
+ user_based_filter_set_dict: dict[str, Any],
+ valid_filter_set_data_for_update: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -402,9 +402,9 @@ class TestUpdateFilterSet:
def test_when_caller_is_admin_and_owner_type_is_dashboard__200(
self,
- test_users: Dict[str, int],
- dashboard_based_filter_set_dict: Dict[str, Any],
- valid_filter_set_data_for_update: Dict[str, Any],
+ test_users: dict[str, int],
+ dashboard_based_filter_set_dict: dict[str, Any],
+ valid_filter_set_data_for_update: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -422,9 +422,9 @@ class TestUpdateFilterSet:
def test_when_caller_is_dashboard_owner_and_owner_is_other_user_403(
self,
- test_users: Dict[str, int],
- user_based_filter_set_dict: Dict[str, Any],
- valid_filter_set_data_for_update: Dict[str, Any],
+ test_users: dict[str, int],
+ user_based_filter_set_dict: dict[str, Any],
+ valid_filter_set_data_for_update: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -441,9 +441,9 @@ class TestUpdateFilterSet:
def test_when_caller_is_dashboard_owner_and_owner_type_is_dashboard__200(
self,
- test_users: Dict[str, int],
- dashboard_based_filter_set_dict: Dict[str, Any],
- valid_filter_set_data_for_update: Dict[str, Any],
+ test_users: dict[str, int],
+ dashboard_based_filter_set_dict: dict[str, Any],
+ valid_filter_set_data_for_update: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -462,9 +462,9 @@ class TestUpdateFilterSet:
def test_when_caller_is_filterset_owner__200(
self,
- test_users: Dict[str, int],
- user_based_filter_set_dict: Dict[str, Any],
- valid_filter_set_data_for_update: Dict[str, Any],
+ test_users: dict[str, int],
+ user_based_filter_set_dict: dict[str, Any],
+ valid_filter_set_data_for_update: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -483,9 +483,9 @@ class TestUpdateFilterSet:
def test_when_caller_is_regular_user_and_owner_type_is_user__403(
self,
- test_users: Dict[str, int],
- user_based_filter_set_dict: Dict[str, Any],
- valid_filter_set_data_for_update: Dict[str, Any],
+ test_users: dict[str, int],
+ user_based_filter_set_dict: dict[str, Any],
+ valid_filter_set_data_for_update: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
@@ -502,9 +502,9 @@ class TestUpdateFilterSet:
def test_when_caller_is_regular_user_and_owner_type_is_dashboard__403(
self,
- test_users: Dict[str, int],
- dashboard_based_filter_set_dict: Dict[str, Any],
- valid_filter_set_data_for_update: Dict[str, Any],
+ test_users: dict[str, int],
+ dashboard_based_filter_set_dict: dict[str, Any],
+ valid_filter_set_data_for_update: dict[str, Any],
client: FlaskClient[Any],
):
# arrange
diff --git a/tests/integration_tests/dashboards/filter_sets/utils.py b/tests/integration_tests/dashboards/filter_sets/utils.py
index a63e4164d8..d728bf6fc3 100644
--- a/tests/integration_tests/dashboards/filter_sets/utils.py
+++ b/tests/integration_tests/dashboards/filter_sets/utils.py
@@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations
-from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Union
+from typing import Any, TYPE_CHECKING
from superset.models.filter_set import FilterSet
from tests.integration_tests.dashboards.filter_sets.consts import FILTER_SET_URI
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
def call_create_filter_set(
- client: FlaskClient[Any], dashboard_id: int, data: Dict[str, Any]
+ client: FlaskClient[Any], dashboard_id: int, data: dict[str, Any]
) -> Response:
uri = FILTER_SET_URI.format(dashboard_id=dashboard_id)
return client.post(uri, json=data)
@@ -41,8 +41,8 @@ def call_get_filter_sets(client: FlaskClient[Any], dashboard_id: int) -> Respons
def call_delete_filter_set(
client: FlaskClient[Any],
- filter_set_dict_to_update: Dict[str, Any],
- dashboard_id: Optional[int] = None,
+ filter_set_dict_to_update: dict[str, Any],
+ dashboard_id: int | None = None,
) -> Response:
dashboard_id = (
dashboard_id
@@ -58,9 +58,9 @@ def call_delete_filter_set(
def call_update_filter_set(
client: FlaskClient[Any],
- filter_set_dict_to_update: Dict[str, Any],
- data: Dict[str, Any],
- dashboard_id: Optional[int] = None,
+ filter_set_dict_to_update: dict[str, Any],
+ data: dict[str, Any],
+ dashboard_id: int | None = None,
) -> Response:
dashboard_id = (
dashboard_id
@@ -90,12 +90,12 @@ def get_filter_set_by_dashboard_id(dashboard_id: int) -> FilterSet:
def collect_all_ids(
- filtersets: Union[Dict[str, List[FilterSet]], List[FilterSet]]
-) -> Set[int]:
+ filtersets: dict[str, list[FilterSet]] | list[FilterSet]
+) -> set[int]:
if isinstance(filtersets, dict):
- filtersets_lists: List[List[FilterSet]] = list(filtersets.values())
- ids: Set[int] = set()
- lst: List[FilterSet]
+ filtersets_lists: list[list[FilterSet]] = list(filtersets.values())
+ ids: set[int] = set()
+ lst: list[FilterSet]
for lst in filtersets_lists:
ids.update(set(map(lambda fs: fs.id, lst)))
return ids
diff --git a/tests/integration_tests/dashboards/permalink/api_tests.py b/tests/integration_tests/dashboards/permalink/api_tests.py
index b20112334d..3c560a4469 100644
--- a/tests/integration_tests/dashboards/permalink/api_tests.py
+++ b/tests/integration_tests/dashboards/permalink/api_tests.py
@@ -14,8 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import json
-from typing import Iterator
+from collections.abc import Iterator
from unittest.mock import patch
from uuid import uuid3
diff --git a/tests/integration_tests/dashboards/security/base_case.py b/tests/integration_tests/dashboards/security/base_case.py
index bbb5fad831..e60fa96d44 100644
--- a/tests/integration_tests/dashboards/security/base_case.py
+++ b/tests/integration_tests/dashboards/security/base_case.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import List, Optional
+from typing import Optional
import pytest
from flask import escape, Response
@@ -37,8 +37,8 @@ class BaseTestDashboardSecurity(DashboardTestCase):
self,
response: Response,
expected_counts: int,
- expected_dashboards: Optional[List[Dashboard]] = None,
- not_expected_dashboards: Optional[List[Dashboard]] = None,
+ expected_dashboards: Optional[list[Dashboard]] = None,
+ not_expected_dashboards: Optional[list[Dashboard]] = None,
) -> None:
self.assert200(response)
response_data = response.json
diff --git a/tests/integration_tests/dashboards/superset_factory_util.py b/tests/integration_tests/dashboards/superset_factory_util.py
index b160a56a33..88495b03b4 100644
--- a/tests/integration_tests/dashboards/superset_factory_util.py
+++ b/tests/integration_tests/dashboards/superset_factory_util.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
-from typing import List, Optional
+from typing import Optional
from flask_appbuilder import Model
from flask_appbuilder.security.sqla.models import User
@@ -50,8 +50,8 @@ def create_dashboard_to_db(
dashboard_title: Optional[str] = None,
slug: Optional[str] = None,
published: bool = False,
- owners: Optional[List[User]] = None,
- slices: Optional[List[Slice]] = None,
+ owners: Optional[list[User]] = None,
+ slices: Optional[list[Slice]] = None,
css: str = "",
json_metadata: str = "",
position_json: str = "",
@@ -76,8 +76,8 @@ def create_dashboard(
dashboard_title: Optional[str] = None,
slug: Optional[str] = None,
published: bool = False,
- owners: Optional[List[User]] = None,
- slices: Optional[List[Slice]] = None,
+ owners: Optional[list[User]] = None,
+ slices: Optional[list[Slice]] = None,
css: str = "",
json_metadata: str = "",
position_json: str = "",
@@ -107,7 +107,7 @@ def insert_model(dashboard: Model) -> None:
def create_slice_to_db(
name: Optional[str] = None,
datasource_id: Optional[int] = None,
- owners: Optional[List[User]] = None,
+ owners: Optional[list[User]] = None,
) -> Slice:
slice_ = create_slice(datasource_id, name=name, owners=owners)
insert_model(slice_)
@@ -119,7 +119,7 @@ def create_slice(
datasource_id: Optional[int] = None,
datasource: Optional[SqlaTable] = None,
name: Optional[str] = None,
- owners: Optional[List[User]] = None,
+ owners: Optional[list[User]] = None,
) -> Slice:
name = name if name is not None else random_str()
owners = owners if owners is not None else []
@@ -149,7 +149,7 @@ def create_slice(
def create_datasource_table_to_db(
name: Optional[str] = None,
db_id: Optional[int] = None,
- owners: Optional[List[User]] = None,
+ owners: Optional[list[User]] = None,
) -> SqlaTable:
sqltable = create_datasource_table(name, db_id, owners=owners)
insert_model(sqltable)
@@ -161,7 +161,7 @@ def create_datasource_table(
name: Optional[str] = None,
db_id: Optional[int] = None,
database: Optional[Database] = None,
- owners: Optional[List[User]] = None,
+ owners: Optional[list[User]] = None,
) -> SqlaTable:
name = name if name is not None else random_str()
owners = owners if owners is not None else []
@@ -192,7 +192,7 @@ def delete_all_inserted_objects() -> None:
def delete_all_inserted_dashboards():
try:
- dashboards_to_delete: List[Dashboard] = (
+ dashboards_to_delete: list[Dashboard] = (
session.query(Dashboard)
.filter(Dashboard.id.in_(inserted_dashboards_ids))
.all()
@@ -241,7 +241,7 @@ def delete_dashboard_slices_associations(dashboard: Dashboard) -> None:
def delete_all_inserted_slices():
try:
- slices_to_delete: List[Slice] = (
+ slices_to_delete: list[Slice] = (
session.query(Slice).filter(Slice.id.in_(inserted_slices_ids)).all()
)
for slice in slices_to_delete:
@@ -272,7 +272,7 @@ def delete_slice_users_associations(slice_: Slice) -> None:
def delete_all_inserted_tables():
try:
- tables_to_delete: List[SqlaTable] = (
+ tables_to_delete: list[SqlaTable] = (
session.query(SqlaTable)
.filter(SqlaTable.id.in_(inserted_sqltables_ids))
.all()
@@ -307,7 +307,7 @@ def delete_table_users_associations(table: SqlaTable) -> None:
def delete_all_inserted_dbs():
try:
- dbs_to_delete: List[Database] = (
+ dbs_to_delete: list[Database] = (
session.query(Database)
.filter(Database.id.in_(inserted_databases_ids))
.all()
diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py
index a30a951884..6fa1288067 100644
--- a/tests/integration_tests/databases/api_tests.py
+++ b/tests/integration_tests/databases/api_tests.py
@@ -23,7 +23,6 @@ from io import BytesIO
from unittest import mock
from unittest.mock import patch, MagicMock
from zipfile import is_zipfile, ZipFile
-from operator import itemgetter
import prison
import pytest
@@ -52,7 +51,6 @@ from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
load_birth_names_data,
)
-from tests.integration_tests.fixtures.certificates import ssl_certificate
from tests.integration_tests.fixtures.energy_dashboard import (
load_energy_table_with_slice,
load_energy_table_data,
@@ -1805,10 +1803,10 @@ class TestDatabaseApi(SupersetTestCase):
schemas = [
s[0] for s in database.get_all_table_names_in_schema(schema_name)
]
- self.assertEquals(response["count"], len(schemas))
+ self.assertEqual(response["count"], len(schemas))
for option in response["result"]:
- self.assertEquals(option["extra"], None)
- self.assertEquals(option["type"], "table")
+ self.assertEqual(option["extra"], None)
+ self.assertEqual(option["type"], "table")
self.assertTrue(option["value"] in schemas)
def test_database_tables_not_found(self):
diff --git a/tests/integration_tests/databases/commands_tests.py b/tests/integration_tests/databases/commands_tests.py
index 553fae4fbf..b47d3d89fe 100644
--- a/tests/integration_tests/databases/commands_tests.py
+++ b/tests/integration_tests/databases/commands_tests.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from unittest import mock, skip
+from unittest import skip
from unittest.mock import patch
import pytest
diff --git a/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py b/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py
index 86c280b9bb..64bc0d8572 100644
--- a/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py
+++ b/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from unittest import mock, skip
+from unittest import mock
from unittest.mock import patch
import pytest
diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py
index 2c358d7114..6c99efd358 100644
--- a/tests/integration_tests/datasets/api_tests.py
+++ b/tests/integration_tests/datasets/api_tests.py
@@ -18,7 +18,7 @@
import json
import unittest
from io import BytesIO
-from typing import List, Optional
+from typing import Optional
from unittest.mock import ANY, patch
from zipfile import is_zipfile, ZipFile
@@ -68,7 +68,7 @@ class TestDatasetApi(SupersetTestCase):
@staticmethod
def insert_dataset(
table_name: str,
- owners: List[int],
+ owners: list[int],
database: Database,
sql: Optional[str] = None,
schema: Optional[str] = None,
@@ -94,7 +94,7 @@ class TestDatasetApi(SupersetTestCase):
"ab_permission", [self.get_user("admin").id], get_main_database()
)
- def get_fixture_datasets(self) -> List[SqlaTable]:
+ def get_fixture_datasets(self) -> list[SqlaTable]:
return (
db.session.query(SqlaTable)
.options(joinedload(SqlaTable.database))
@@ -102,7 +102,7 @@ class TestDatasetApi(SupersetTestCase):
.all()
)
- def get_fixture_virtual_datasets(self) -> List[SqlaTable]:
+ def get_fixture_virtual_datasets(self) -> list[SqlaTable]:
return (
db.session.query(SqlaTable)
.filter(SqlaTable.table_name.in_(self.fixture_virtual_table_names))
@@ -410,13 +410,11 @@ class TestDatasetApi(SupersetTestCase):
)
all_datasets = db.session.query(SqlaTable).all()
schema_values = sorted(
- set(
- [
- dataset.schema
- for dataset in all_datasets
- if dataset.schema is not None
- ]
- )
+ {
+ dataset.schema
+ for dataset in all_datasets
+ if dataset.schema is not None
+ }
)
expected_response = {
"count": len(schema_values),
diff --git a/tests/integration_tests/datasets/commands_tests.py b/tests/integration_tests/datasets/commands_tests.py
index 8753b1d273..953c34059f 100644
--- a/tests/integration_tests/datasets/commands_tests.py
+++ b/tests/integration_tests/datasets/commands_tests.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from operator import itemgetter
-from typing import Any, List
+from typing import Any
from unittest.mock import patch
import pytest
@@ -312,7 +312,7 @@ class TestImportDatasetsCommand(SupersetTestCase):
assert len(dataset.metrics) == 2
assert dataset.main_dttm_col == "ds"
assert dataset.filter_select_enabled
- assert set(col.column_name for col in dataset.columns) == {
+ assert {col.column_name for col in dataset.columns} == {
"num_california",
"ds",
"state",
@@ -526,7 +526,7 @@ class TestImportDatasetsCommand(SupersetTestCase):
db.session.commit()
-def _get_table_from_list_by_name(name: str, tables: List[Any]):
+def _get_table_from_list_by_name(name: str, tables: list[Any]):
for table in tables:
if table.table_name == name:
return table
diff --git a/tests/integration_tests/db_engine_specs/base_tests.py b/tests/integration_tests/db_engine_specs/base_tests.py
index e20ea35ae4..2d4f72c4f4 100644
--- a/tests/integration_tests/db_engine_specs/base_tests.py
+++ b/tests/integration_tests/db_engine_specs/base_tests.py
@@ -15,8 +15,6 @@
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
-from datetime import datetime
-from typing import Tuple, Type
from tests.integration_tests.test_app import app
from tests.integration_tests.base_tests import SupersetTestCase
diff --git a/tests/integration_tests/db_engine_specs/bigquery_tests.py b/tests/integration_tests/db_engine_specs/bigquery_tests.py
index 2f4f1c70cc..c4f04584fa 100644
--- a/tests/integration_tests/db_engine_specs/bigquery_tests.py
+++ b/tests/integration_tests/db_engine_specs/bigquery_tests.py
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import sys
import unittest.mock as mock
import pytest
@@ -95,7 +94,7 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec):
"""
# Mock a google.cloud.bigquery.table.Row
- class Row(object):
+ class Row:
def __init__(self, value):
self._value = value
diff --git a/tests/integration_tests/db_engine_specs/hive_tests.py b/tests/integration_tests/db_engine_specs/hive_tests.py
index b63f64ab03..341b494927 100644
--- a/tests/integration_tests/db_engine_specs/hive_tests.py
+++ b/tests/integration_tests/db_engine_specs/hive_tests.py
@@ -15,9 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
-from datetime import datetime
from unittest import mock
-from typing import List
import pytest
import pandas as pd
@@ -377,7 +375,7 @@ def test_where_latest_partition_no_columns_no_values(mock_method):
def test__latest_partition_from_df():
- def is_correct_result(data: List, result: List) -> bool:
+ def is_correct_result(data: list, result: list) -> bool:
df = pd.DataFrame({"partition": data})
return HiveEngineSpec._latest_partition_from_df(df) == result
diff --git a/tests/integration_tests/dict_import_export_tests.py b/tests/integration_tests/dict_import_export_tests.py
index de0aa83262..6018e59a92 100644
--- a/tests/integration_tests/dict_import_export_tests.py
+++ b/tests/integration_tests/dict_import_export_tests.py
@@ -61,7 +61,7 @@ class TestDictImportExport(SupersetTestCase):
self, name, schema=None, id=0, cols_names=[], cols_uuids=None, metric_names=[]
):
database_name = "main"
- name = "{0}{1}".format(NAME_PREFIX, name)
+ name = f"{NAME_PREFIX}{name}"
params = {DBREF: id, "database_name": database_name}
if cols_uuids is None:
@@ -100,12 +100,12 @@ class TestDictImportExport(SupersetTestCase):
self.assertEqual(len(expected_ds.metrics), len(actual_ds.metrics))
self.assertEqual(len(expected_ds.columns), len(actual_ds.columns))
self.assertEqual(
- set([c.column_name for c in expected_ds.columns]),
- set([c.column_name for c in actual_ds.columns]),
+ {c.column_name for c in expected_ds.columns},
+ {c.column_name for c in actual_ds.columns},
)
self.assertEqual(
- set([m.metric_name for m in expected_ds.metrics]),
- set([m.metric_name for m in actual_ds.metrics]),
+ {m.metric_name for m in expected_ds.metrics},
+ {m.metric_name for m in actual_ds.metrics},
)
def assert_datasource_equals(self, expected_ds, actual_ds):
@@ -114,12 +114,12 @@ class TestDictImportExport(SupersetTestCase):
self.assertEqual(len(expected_ds.metrics), len(actual_ds.metrics))
self.assertEqual(len(expected_ds.columns), len(actual_ds.columns))
self.assertEqual(
- set([c.column_name for c in expected_ds.columns]),
- set([c.column_name for c in actual_ds.columns]),
+ {c.column_name for c in expected_ds.columns},
+ {c.column_name for c in actual_ds.columns},
)
self.assertEqual(
- set([m.metric_name for m in expected_ds.metrics]),
- set([m.metric_name for m in actual_ds.metrics]),
+ {m.metric_name for m in expected_ds.metrics},
+ {m.metric_name for m in actual_ds.metrics},
)
def test_import_table_no_metadata(self):
diff --git a/tests/integration_tests/email_tests.py b/tests/integration_tests/email_tests.py
index 381b8cda1b..7c7cc16830 100644
--- a/tests/integration_tests/email_tests.py
+++ b/tests/integration_tests/email_tests.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
diff --git a/tests/integration_tests/event_logger_tests.py b/tests/integration_tests/event_logger_tests.py
index fa965ebd7d..3b20f6a918 100644
--- a/tests/integration_tests/event_logger_tests.py
+++ b/tests/integration_tests/event_logger_tests.py
@@ -17,8 +17,8 @@
import logging
import time
import unittest
-from datetime import datetime, timedelta
-from typing import Any, Callable, cast, Dict, Iterator, Optional, Type, Union
+from datetime import timedelta
+from typing import Any, Optional
from unittest.mock import patch
from flask import current_app
diff --git a/tests/integration_tests/explore/permalink/api_tests.py b/tests/integration_tests/explore/permalink/api_tests.py
index b9b1bfd0fb..81be2f0de8 100644
--- a/tests/integration_tests/explore/permalink/api_tests.py
+++ b/tests/integration_tests/explore/permalink/api_tests.py
@@ -15,7 +15,8 @@
# specific language governing permissions and limitations
# under the License.
import json
-from typing import Any, Dict, Iterator
+from collections.abc import Iterator
+from typing import Any
from uuid import uuid3
import pytest
@@ -43,7 +44,7 @@ def chart(app_context, load_world_bank_dashboard_with_slices) -> Slice:
@pytest.fixture
-def form_data(chart) -> Dict[str, Any]:
+def form_data(chart) -> dict[str, Any]:
datasource = f"{chart.datasource.id}__{chart.datasource.type}"
return {
"chart_id": chart.id,
@@ -68,7 +69,7 @@ def permalink_salt() -> Iterator[str]:
def test_post(
- form_data: Dict[str, Any], permalink_salt: str, test_client, login_as_admin
+ form_data: dict[str, Any], permalink_salt: str, test_client, login_as_admin
):
resp = test_client.post(f"api/v1/explore/permalink", json={"formData": form_data})
assert resp.status_code == 201
@@ -125,7 +126,7 @@ def test_post_invalid_schema(test_client, login_as_admin) -> None:
def test_get(
- form_data: Dict[str, Any], permalink_salt: str, test_client, login_as_admin
+ form_data: dict[str, Any], permalink_salt: str, test_client, login_as_admin
) -> None:
resp = test_client.post(f"api/v1/explore/permalink", json={"formData": form_data})
data = json.loads(resp.data.decode("utf-8"))
diff --git a/tests/integration_tests/explore/permalink/commands_tests.py b/tests/integration_tests/explore/permalink/commands_tests.py
index 63ed02cd7b..eace978d78 100644
--- a/tests/integration_tests/explore/permalink/commands_tests.py
+++ b/tests/integration_tests/explore/permalink/commands_tests.py
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-import json
from unittest.mock import patch
import pytest
diff --git a/tests/integration_tests/fixtures/birth_names_dashboard.py b/tests/integration_tests/fixtures/birth_names_dashboard.py
index be680a720d..d9a4a5d9e0 100644
--- a/tests/integration_tests/fixtures/birth_names_dashboard.py
+++ b/tests/integration_tests/fixtures/birth_names_dashboard.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Callable, List, Optional
+from typing import Callable, Optional
import pytest
@@ -93,7 +93,7 @@ def _create_table(
return table
-def _cleanup(dash_id: int, slice_ids: List[int]) -> None:
+def _cleanup(dash_id: int, slice_ids: list[int]) -> None:
schema = get_example_default_schema()
for datasource in db.session.query(SqlaTable).filter_by(
table_name="birth_names", schema=schema
diff --git a/tests/integration_tests/fixtures/datasource.py b/tests/integration_tests/fixtures/datasource.py
index f394d68a0e..279b67eda0 100644
--- a/tests/integration_tests/fixtures/datasource.py
+++ b/tests/integration_tests/fixtures/datasource.py
@@ -15,7 +15,8 @@
# specific language governing permissions and limitations
# under the License.
"""Fixtures for test_datasource.py"""
-from typing import Any, Dict, Generator
+from collections.abc import Generator
+from typing import Any
import pytest
from sqlalchemy import Column, create_engine, Date, Integer, MetaData, String, Table
@@ -31,7 +32,7 @@ from superset.utils.database import get_example_database
from tests.integration_tests.test_app import app
-def get_datasource_post() -> Dict[str, Any]:
+def get_datasource_post() -> dict[str, Any]:
schema = get_example_default_schema()
return {
diff --git a/tests/integration_tests/fixtures/energy_dashboard.py b/tests/integration_tests/fixtures/energy_dashboard.py
index effe59a755..8b597bf3be 100644
--- a/tests/integration_tests/fixtures/energy_dashboard.py
+++ b/tests/integration_tests/fixtures/energy_dashboard.py
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
import random
-from typing import Dict, List, Set
import pandas as pd
import pytest
@@ -29,7 +28,7 @@ from superset.utils.database import get_example_database
from tests.integration_tests.dashboard_utils import create_slice, create_table_metadata
from tests.integration_tests.test_app import app
-misc_dash_slices: Set[str] = set()
+misc_dash_slices: set[str] = set()
ENERGY_USAGE_TBL_NAME = "energy_usage"
@@ -70,7 +69,7 @@ def _get_dataframe():
return pd.DataFrame.from_dict(data)
-def _create_energy_table() -> List[Slice]:
+def _create_energy_table() -> list[Slice]:
table = create_table_metadata(
table_name=ENERGY_USAGE_TBL_NAME,
database=get_example_database(),
@@ -100,7 +99,7 @@ def _create_energy_table() -> List[Slice]:
def _create_and_commit_energy_slice(
- table: SqlaTable, title: str, viz_type: str, param: Dict[str, str]
+ table: SqlaTable, title: str, viz_type: str, param: dict[str, str]
):
slice = create_slice(title, viz_type, table, param)
existing_slice = (
diff --git a/tests/integration_tests/fixtures/importexport.py b/tests/integration_tests/fixtures/importexport.py
index d5c898eba2..5fddb071e2 100644
--- a/tests/integration_tests/fixtures/importexport.py
+++ b/tests/integration_tests/fixtures/importexport.py
@@ -14,10 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, List
+from typing import Any
# example V0 import/export format
-dataset_ui_export: List[Dict[str, Any]] = [
+dataset_ui_export: list[dict[str, Any]] = [
{
"columns": [
{
@@ -48,7 +48,7 @@ dataset_ui_export: List[Dict[str, Any]] = [
}
]
-dataset_cli_export: Dict[str, Any] = {
+dataset_cli_export: dict[str, Any] = {
"databases": [
{
"allow_run_async": True,
@@ -59,7 +59,7 @@ dataset_cli_export: Dict[str, Any] = {
]
}
-dashboard_export: Dict[str, Any] = {
+dashboard_export: dict[str, Any] = {
"dashboards": [
{
"__Dashboard__": {
@@ -318,35 +318,35 @@ dashboard_export: Dict[str, Any] = {
}
# example V1 import/export format
-database_metadata_config: Dict[str, Any] = {
+database_metadata_config: dict[str, Any] = {
"version": "1.0.0",
"type": "Database",
"timestamp": "2020-11-04T21:27:44.423819+00:00",
}
-dataset_metadata_config: Dict[str, Any] = {
+dataset_metadata_config: dict[str, Any] = {
"version": "1.0.0",
"type": "SqlaTable",
"timestamp": "2020-11-04T21:27:44.423819+00:00",
}
-chart_metadata_config: Dict[str, Any] = {
+chart_metadata_config: dict[str, Any] = {
"version": "1.0.0",
"type": "Slice",
"timestamp": "2020-11-04T21:27:44.423819+00:00",
}
-dashboard_metadata_config: Dict[str, Any] = {
+dashboard_metadata_config: dict[str, Any] = {
"version": "1.0.0",
"type": "Dashboard",
"timestamp": "2020-11-04T21:27:44.423819+00:00",
}
-saved_queries_metadata_config: Dict[str, Any] = {
+saved_queries_metadata_config: dict[str, Any] = {
"version": "1.0.0",
"type": "SavedQuery",
"timestamp": "2021-03-30T20:37:54.791187+00:00",
}
-database_config: Dict[str, Any] = {
+database_config: dict[str, Any] = {
"allow_csv_upload": True,
"allow_ctas": True,
"allow_cvas": True,
@@ -361,7 +361,7 @@ database_config: Dict[str, Any] = {
"version": "1.0.0",
}
-database_with_ssh_tunnel_config_private_key: Dict[str, Any] = {
+database_with_ssh_tunnel_config_private_key: dict[str, Any] = {
"allow_csv_upload": True,
"allow_ctas": True,
"allow_cvas": True,
@@ -383,7 +383,7 @@ database_with_ssh_tunnel_config_private_key: Dict[str, Any] = {
"version": "1.0.0",
}
-database_with_ssh_tunnel_config_password: Dict[str, Any] = {
+database_with_ssh_tunnel_config_password: dict[str, Any] = {
"allow_csv_upload": True,
"allow_ctas": True,
"allow_cvas": True,
@@ -404,7 +404,7 @@ database_with_ssh_tunnel_config_password: Dict[str, Any] = {
"version": "1.0.0",
}
-database_with_ssh_tunnel_config_no_credentials: Dict[str, Any] = {
+database_with_ssh_tunnel_config_no_credentials: dict[str, Any] = {
"allow_csv_upload": True,
"allow_ctas": True,
"allow_cvas": True,
@@ -424,7 +424,7 @@ database_with_ssh_tunnel_config_no_credentials: Dict[str, Any] = {
"version": "1.0.0",
}
-database_with_ssh_tunnel_config_mix_credentials: Dict[str, Any] = {
+database_with_ssh_tunnel_config_mix_credentials: dict[str, Any] = {
"allow_csv_upload": True,
"allow_ctas": True,
"allow_cvas": True,
@@ -446,7 +446,7 @@ database_with_ssh_tunnel_config_mix_credentials: Dict[str, Any] = {
"version": "1.0.0",
}
-database_with_ssh_tunnel_config_private_pass_only: Dict[str, Any] = {
+database_with_ssh_tunnel_config_private_pass_only: dict[str, Any] = {
"allow_csv_upload": True,
"allow_ctas": True,
"allow_cvas": True,
@@ -468,7 +468,7 @@ database_with_ssh_tunnel_config_private_pass_only: Dict[str, Any] = {
}
-dataset_config: Dict[str, Any] = {
+dataset_config: dict[str, Any] = {
"table_name": "imported_dataset",
"main_dttm_col": None,
"description": "This is a dataset that was exported",
@@ -513,7 +513,7 @@ dataset_config: Dict[str, Any] = {
"database_uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89",
}
-chart_config: Dict[str, Any] = {
+chart_config: dict[str, Any] = {
"slice_name": "Deck Path",
"viz_type": "deck_path",
"params": {
@@ -557,7 +557,7 @@ chart_config: Dict[str, Any] = {
"dataset_uuid": "10808100-158b-42c4-842e-f32b99d88dfb",
}
-dashboard_config: Dict[str, Any] = {
+dashboard_config: dict[str, Any] = {
"dashboard_title": "Test dash",
"description": None,
"css": "",
diff --git a/tests/integration_tests/fixtures/query_context.py b/tests/integration_tests/fixtures/query_context.py
index 00a3036e01..9efa589ba8 100644
--- a/tests/integration_tests/fixtures/query_context.py
+++ b/tests/integration_tests/fixtures/query_context.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, Optional
+from typing import Any, Optional
from tests.common.query_context_generator import QueryContextGenerator
from tests.integration_tests.base_tests import SupersetTestCase
@@ -29,8 +29,8 @@ def get_query_context(
query_name: str,
add_postprocessing_operations: bool = False,
add_time_offsets: bool = False,
- form_data: Optional[Dict[str, Any]] = None,
-) -> Dict[str, Any]:
+ form_data: Optional[dict[str, Any]] = None,
+) -> dict[str, Any]:
"""
Create a request payload for retrieving a QueryContext object via the
`api/v1/chart/data` endpoint. By default returns a payload corresponding to one
diff --git a/tests/integration_tests/fixtures/world_bank_dashboard.py b/tests/integration_tests/fixtures/world_bank_dashboard.py
index 561bbe10b2..18ceba9af2 100644
--- a/tests/integration_tests/fixtures/world_bank_dashboard.py
+++ b/tests/integration_tests/fixtures/world_bank_dashboard.py
@@ -17,7 +17,7 @@
import json
import string
from random import choice, randint, random, uniform
-from typing import Any, Dict, List
+from typing import Any
import pandas as pd
import pytest
@@ -94,7 +94,7 @@ def create_dashboard_for_loaded_data():
return dash_id_to_delete, slices_ids_to_delete
-def _create_world_bank_slices(table: SqlaTable) -> List[Slice]:
+def _create_world_bank_slices(table: SqlaTable) -> list[Slice]:
from superset.examples.world_bank import create_slices
slices = create_slices(table)
@@ -102,7 +102,7 @@ def _create_world_bank_slices(table: SqlaTable) -> List[Slice]:
return slices
-def _commit_slices(slices: List[Slice]):
+def _commit_slices(slices: list[Slice]):
for slice in slices:
o = db.session.query(Slice).filter_by(slice_name=slice.slice_name).one_or_none()
if o:
@@ -128,7 +128,7 @@ def _create_world_bank_dashboard(table: SqlaTable) -> Dashboard:
return dash
-def _cleanup(dash_id: int, slices_ids: List[int]) -> None:
+def _cleanup(dash_id: int, slices_ids: list[int]) -> None:
dash = db.session.query(Dashboard).filter_by(id=dash_id).first()
db.session.delete(dash)
for slice_id in slices_ids:
@@ -148,7 +148,7 @@ def _get_dataframe(database: Database) -> DataFrame:
return df
-def _get_world_bank_data() -> List[Dict[Any, Any]]:
+def _get_world_bank_data() -> list[dict[Any, Any]]:
data = []
for _ in range(100):
data.append(
diff --git a/tests/integration_tests/import_export_tests.py b/tests/integration_tests/import_export_tests.py
index 5bbc985a36..d44745377f 100644
--- a/tests/integration_tests/import_export_tests.py
+++ b/tests/integration_tests/import_export_tests.py
@@ -115,7 +115,7 @@ class TestImportExport(SupersetTestCase):
dashboard_title=title,
slices=slcs,
position_json='{"size_y": 2, "size_x": 2}',
- slug="{}_imported".format(title.lower()),
+ slug=f"{title.lower()}_imported",
json_metadata=json.dumps(json_metadata),
)
@@ -160,12 +160,12 @@ class TestImportExport(SupersetTestCase):
self.assertEqual(len(expected_ds.metrics), len(actual_ds.metrics))
self.assertEqual(len(expected_ds.columns), len(actual_ds.columns))
self.assertEqual(
- set([c.column_name for c in expected_ds.columns]),
- set([c.column_name for c in actual_ds.columns]),
+ {c.column_name for c in expected_ds.columns},
+ {c.column_name for c in actual_ds.columns},
)
self.assertEqual(
- set([m.metric_name for m in expected_ds.metrics]),
- set([m.metric_name for m in actual_ds.metrics]),
+ {m.metric_name for m in expected_ds.metrics},
+ {m.metric_name for m in actual_ds.metrics},
)
def assert_datasource_equals(self, expected_ds, actual_ds):
@@ -174,12 +174,12 @@ class TestImportExport(SupersetTestCase):
self.assertEqual(len(expected_ds.metrics), len(actual_ds.metrics))
self.assertEqual(len(expected_ds.columns), len(actual_ds.columns))
self.assertEqual(
- set([c.column_name for c in expected_ds.columns]),
- set([c.column_name for c in actual_ds.columns]),
+ {c.column_name for c in expected_ds.columns},
+ {c.column_name for c in actual_ds.columns},
)
self.assertEqual(
- set([m.metric_name for m in expected_ds.metrics]),
- set([m.metric_name for m in actual_ds.metrics]),
+ {m.metric_name for m in expected_ds.metrics},
+ {m.metric_name for m in actual_ds.metrics},
)
def assert_slice_equals(self, expected_slc, actual_slc):
@@ -404,8 +404,8 @@ class TestImportExport(SupersetTestCase):
{
"remote_id": 10003,
"expanded_slices": {
- "{}".format(e_slc.id): True,
- "{}".format(b_slc.id): False,
+ f"{e_slc.id}": True,
+ f"{b_slc.id}": False,
},
# mocked filter_scope metadata
"filter_scopes": {
@@ -437,8 +437,8 @@ class TestImportExport(SupersetTestCase):
}
},
"expanded_slices": {
- "{}".format(i_e_slc.id): True,
- "{}".format(i_b_slc.id): False,
+ f"{i_e_slc.id}": True,
+ f"{i_b_slc.id}": False,
},
}
self.assertEqual(
diff --git a/tests/integration_tests/insert_chart_mixin.py b/tests/integration_tests/insert_chart_mixin.py
index da05d0c49d..722e387a54 100644
--- a/tests/integration_tests/insert_chart_mixin.py
+++ b/tests/integration_tests/insert_chart_mixin.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import List, Optional
+from typing import Optional
from superset import db, security_manager
from superset.connectors.sqla.models import SqlaTable
@@ -29,7 +29,7 @@ class InsertChartMixin:
def insert_chart(
self,
slice_name: str,
- owners: List[int],
+ owners: list[int],
datasource_id: int,
created_by=None,
datasource_type: str = "table",
diff --git a/tests/integration_tests/key_value/commands/fixtures.py b/tests/integration_tests/key_value/commands/fixtures.py
index 66aea8a4ed..ac33d003e0 100644
--- a/tests/integration_tests/key_value/commands/fixtures.py
+++ b/tests/integration_tests/key_value/commands/fixtures.py
@@ -18,7 +18,8 @@
from __future__ import annotations
import json
-from typing import Generator, TYPE_CHECKING
+from collections.abc import Generator
+from typing import TYPE_CHECKING
from uuid import UUID
import pytest
diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py
index d5684b1b62..c4bc7aa89b 100644
--- a/tests/integration_tests/model_tests.py
+++ b/tests/integration_tests/model_tests.py
@@ -417,14 +417,14 @@ class TestSqlaTableModel(SupersetTestCase):
assert str(sqla_literal.compile()) == "ds"
sqla_literal = ds_col.get_timestamp_expression("P1D")
- compiled = "{}".format(sqla_literal.compile())
+ compiled = f"{sqla_literal.compile()}"
if tbl.database.backend == "mysql":
assert compiled == "DATE(ds)"
prev_ds_expr = ds_col.expression
ds_col.expression = "DATE_ADD(ds, 1)"
sqla_literal = ds_col.get_timestamp_expression("P1D")
- compiled = "{}".format(sqla_literal.compile())
+ compiled = f"{sqla_literal.compile()}"
if tbl.database.backend == "mysql":
assert compiled == "DATE(DATE_ADD(ds, 1))"
ds_col.expression = prev_ds_expr
@@ -437,20 +437,20 @@ class TestSqlaTableModel(SupersetTestCase):
ds_col.expression = None
ds_col.python_date_format = "epoch_s"
sqla_literal = ds_col.get_timestamp_expression(None)
- compiled = "{}".format(sqla_literal.compile())
+ compiled = f"{sqla_literal.compile()}"
if tbl.database.backend == "mysql":
self.assertEqual(compiled, "from_unixtime(ds)")
ds_col.python_date_format = "epoch_s"
sqla_literal = ds_col.get_timestamp_expression("P1D")
- compiled = "{}".format(sqla_literal.compile())
+ compiled = f"{sqla_literal.compile()}"
if tbl.database.backend == "mysql":
self.assertEqual(compiled, "DATE(from_unixtime(ds))")
prev_ds_expr = ds_col.expression
ds_col.expression = "DATE_ADD(ds, 1)"
sqla_literal = ds_col.get_timestamp_expression("P1D")
- compiled = "{}".format(sqla_literal.compile())
+ compiled = f"{sqla_literal.compile()}"
if tbl.database.backend == "mysql":
self.assertEqual(compiled, "DATE(from_unixtime(DATE_ADD(ds, 1)))")
ds_col.expression = prev_ds_expr
diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py
index 5e5beae345..7a3d4e4a1e 100644
--- a/tests/integration_tests/query_context_tests.py
+++ b/tests/integration_tests/query_context_tests.py
@@ -16,7 +16,7 @@
# under the License.
import re
import time
-from typing import Any, Dict
+from typing import Any
import numpy as np
import pandas as pd
@@ -49,7 +49,7 @@ from tests.integration_tests.fixtures.birth_names_dashboard import (
from tests.integration_tests.fixtures.query_context import get_query_context
-def get_sql_text(payload: Dict[str, Any]) -> str:
+def get_sql_text(payload: dict[str, Any]) -> str:
payload["result_type"] = ChartDataResultType.QUERY.value
query_context = ChartDataQueryContextSchema().load(payload)
responses = query_context.get_payload()
diff --git a/tests/integration_tests/reports/alert_tests.py b/tests/integration_tests/reports/alert_tests.py
index 32cc2dcefb..4920a96283 100644
--- a/tests/integration_tests/reports/alert_tests.py
+++ b/tests/integration_tests/reports/alert_tests.py
@@ -16,7 +16,7 @@
# under the License.
# pylint: disable=invalid-name, unused-argument, import-outside-toplevel
from contextlib import nullcontext
-from typing import List, Optional, Tuple, Union
+from typing import Optional, Union
import pandas as pd
import pytest
@@ -56,10 +56,10 @@ from tests.integration_tests.test_app import app
],
)
def test_execute_query_as_report_executor(
- owner_names: List[str],
+ owner_names: list[str],
creator_name: Optional[str],
- config: List[ExecutorType],
- expected_result: Union[Tuple[ExecutorType, str], Exception],
+ config: list[ExecutorType],
+ expected_result: Union[tuple[ExecutorType, str], Exception],
mocker: MockFixture,
app_context: None,
get_user,
diff --git a/tests/integration_tests/reports/commands_tests.py b/tests/integration_tests/reports/commands_tests.py
index a81bc6fa66..db80079d77 100644
--- a/tests/integration_tests/reports/commands_tests.py
+++ b/tests/integration_tests/reports/commands_tests.py
@@ -17,7 +17,7 @@
import json
from contextlib import contextmanager
from datetime import datetime, timedelta, timezone
-from typing import List, Optional
+from typing import Optional
from unittest.mock import call, Mock, patch
from uuid import uuid4
@@ -100,7 +100,7 @@ pytestmark = pytest.mark.usefixtures(
)
-def get_target_from_report_schedule(report_schedule: ReportSchedule) -> List[str]:
+def get_target_from_report_schedule(report_schedule: ReportSchedule) -> list[str]:
return [
json.loads(recipient.recipient_config_json)["target"]
for recipient in report_schedule.recipients
@@ -1976,9 +1976,7 @@ def test__send_with_client_errors(notification_mock, logger_mock):
assert excinfo.errisinstance(SupersetException)
logger_mock.warning.assert_called_with(
- (
- "SupersetError(message='', error_type=, level=, extra=None)"
- )
+ "SupersetError(message='', error_type=, level=, extra=None)"
)
@@ -2021,7 +2019,5 @@ def test__send_with_server_errors(notification_mock, logger_mock):
assert excinfo.errisinstance(SupersetException)
# it logs the error
logger_mock.warning.assert_called_with(
- (
- "SupersetError(message='', error_type=, level=, extra=None)"
- )
+ "SupersetError(message='', error_type=, level=, extra=None)"
)
diff --git a/tests/integration_tests/reports/scheduler_tests.py b/tests/integration_tests/reports/scheduler_tests.py
index 4b8968592b..3284ee9772 100644
--- a/tests/integration_tests/reports/scheduler_tests.py
+++ b/tests/integration_tests/reports/scheduler_tests.py
@@ -16,7 +16,6 @@
# under the License.
from random import randint
-from typing import List
from unittest.mock import patch
import pytest
@@ -32,7 +31,7 @@ from tests.integration_tests.test_app import app
@pytest.fixture
-def owners(get_user) -> List[User]:
+def owners(get_user) -> list[User]:
return [get_user("admin")]
diff --git a/tests/integration_tests/reports/utils.py b/tests/integration_tests/reports/utils.py
index 3801beb1a3..7672c5c940 100644
--- a/tests/integration_tests/reports/utils.py
+++ b/tests/integration_tests/reports/utils.py
@@ -17,7 +17,7 @@
import json
from contextlib import contextmanager
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from uuid import uuid4
from flask_appbuilder.security.sqla.models import User
@@ -49,7 +49,7 @@ def insert_report_schedule(
type: str,
name: str,
crontab: str,
- owners: List[User],
+ owners: list[User],
timezone: Optional[str] = None,
sql: Optional[str] = None,
description: Optional[str] = None,
@@ -61,10 +61,10 @@ def insert_report_schedule(
log_retention: Optional[int] = None,
last_state: Optional[ReportState] = None,
grace_period: Optional[int] = None,
- recipients: Optional[List[ReportRecipients]] = None,
+ recipients: Optional[list[ReportRecipients]] = None,
report_format: Optional[ReportDataFormat] = None,
- logs: Optional[List[ReportExecutionLog]] = None,
- extra: Optional[Dict[Any, Any]] = None,
+ logs: Optional[list[ReportExecutionLog]] = None,
+ extra: Optional[dict[Any, Any]] = None,
force_screenshot: bool = False,
) -> ReportSchedule:
owners = owners or []
@@ -113,9 +113,9 @@ def create_report_notification(
grace_period: Optional[int] = None,
report_format: Optional[ReportDataFormat] = None,
name: Optional[str] = None,
- extra: Optional[Dict[str, Any]] = None,
+ extra: Optional[dict[str, Any]] = None,
force_screenshot: bool = False,
- owners: Optional[List[User]] = None,
+ owners: Optional[list[User]] = None,
) -> ReportSchedule:
if not owners:
owners = [
diff --git a/tests/integration_tests/security/migrate_roles_tests.py b/tests/integration_tests/security/migrate_roles_tests.py
index a541f00952..ae89fea068 100644
--- a/tests/integration_tests/security/migrate_roles_tests.py
+++ b/tests/integration_tests/security/migrate_roles_tests.py
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
"""Unit tests for alerting in Superset"""
-import json
import logging
from contextlib import contextmanager
from unittest.mock import patch
diff --git a/tests/integration_tests/security/row_level_security_tests.py b/tests/integration_tests/security/row_level_security_tests.py
index 51aa76ee27..2a28089c3e 100644
--- a/tests/integration_tests/security/row_level_security_tests.py
+++ b/tests/integration_tests/security/row_level_security_tests.py
@@ -16,7 +16,7 @@
# under the License.
# isort:skip_file
import re
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from unittest import mock
import pytest
@@ -55,7 +55,7 @@ class TestRowLevelSecurity(SupersetTestCase):
"""
rls_entry = None
- query_obj: Dict[str, Any] = dict(
+ query_obj: dict[str, Any] = dict(
groupby=[],
metrics=None,
filter=[],
@@ -542,8 +542,8 @@ class TestRowLevelSecurityWithRelatedAPI(SupersetTestCase):
db_tables = db.session.query(SqlaTable).all()
- db_table_names = set([t.name for t in db_tables])
- received_tables = set([table["text"] for table in result])
+ db_table_names = {t.name for t in db_tables}
+ received_tables = {table["text"] for table in result}
assert data["count"] == len(db_tables)
assert len(result) == len(db_tables)
@@ -558,8 +558,8 @@ class TestRowLevelSecurityWithRelatedAPI(SupersetTestCase):
data = json.loads(rv.data.decode("utf-8"))
result = data["result"]
- db_role_names = set([r.name for r in security_manager.get_all_roles()])
- received_roles = set([role["text"] for role in result])
+ db_role_names = {r.name for r in security_manager.get_all_roles()}
+ received_roles = {role["text"] for role in result}
assert data["count"] == len(db_role_names)
assert len(result) == len(db_role_names)
@@ -580,7 +580,7 @@ class TestRowLevelSecurityWithRelatedAPI(SupersetTestCase):
self.assertEqual(rv.status_code, 200)
data = json.loads(rv.data.decode("utf-8"))
result = data["result"]
- received_tables = set([table["text"].split(".")[-1] for table in result])
+ received_tables = {table["text"].split(".")[-1] for table in result}
assert data["count"] == 1
assert len(result) == 1
@@ -615,7 +615,7 @@ RLS_GENDER_REGEX = re.compile(r"AND \(gender = 'girl'\)")
EMBEDDED_SUPERSET=True,
)
class GuestTokenRowLevelSecurityTests(SupersetTestCase):
- query_obj: Dict[str, Any] = dict(
+ query_obj: dict[str, Any] = dict(
groupby=[],
metrics=None,
filter=[],
@@ -633,7 +633,7 @@ class GuestTokenRowLevelSecurityTests(SupersetTestCase):
"clause": "name = 'Alice'",
}
- def guest_user_with_rls(self, rules: Optional[List[Any]] = None) -> GuestUser:
+ def guest_user_with_rls(self, rules: Optional[list[Any]] = None) -> GuestUser:
if rules is None:
rules = [self.default_rls_rule()]
return security_manager.get_guest_user_from_token(
diff --git a/tests/integration_tests/sql_lab/api_tests.py b/tests/integration_tests/sql_lab/api_tests.py
index a57d24c3e4..89aefdfd09 100644
--- a/tests/integration_tests/sql_lab/api_tests.py
+++ b/tests/integration_tests/sql_lab/api_tests.py
@@ -112,7 +112,7 @@ class TestSqlLabApi(SupersetTestCase):
@mock.patch("superset.sqllab.commands.results.results_backend_use_msgpack", False)
def test_execute_required_params(self):
self.login()
- client_id = "{}".format(random.getrandbits(64))[:10]
+ client_id = f"{random.getrandbits(64)}"[:10]
data = {"client_id": client_id}
rv = self.client.post(
@@ -157,7 +157,7 @@ class TestSqlLabApi(SupersetTestCase):
core.results_backend.get.return_value = {}
self.login()
- client_id = "{}".format(random.getrandbits(64))[:10]
+ client_id = f"{random.getrandbits(64)}"[:10]
data = {"sql": "SELECT 1", "database_id": 1, "client_id": client_id}
rv = self.client.post(
diff --git a/tests/integration_tests/sql_lab/commands_tests.py b/tests/integration_tests/sql_lab/commands_tests.py
index 3d505ee2f5..d76924a8fb 100644
--- a/tests/integration_tests/sql_lab/commands_tests.py
+++ b/tests/integration_tests/sql_lab/commands_tests.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from unittest import mock, skip
+from unittest import mock
from unittest.mock import Mock, patch
import pandas as pd
diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py
index 4003913516..854a0c9be0 100644
--- a/tests/integration_tests/sqla_models_tests.py
+++ b/tests/integration_tests/sqla_models_tests.py
@@ -17,7 +17,8 @@
# isort:skip_file
import re
from datetime import datetime
-from typing import Any, Dict, List, NamedTuple, Optional, Pattern, Tuple, Union
+from typing import Any, NamedTuple, Optional, Union
+from re import Pattern
from unittest.mock import patch
import pytest
@@ -50,7 +51,7 @@ from tests.integration_tests.test_app import app
from .base_tests import SupersetTestCase
from .conftest import only_postgresql
-VIRTUAL_TABLE_INT_TYPES: Dict[str, Pattern[str]] = {
+VIRTUAL_TABLE_INT_TYPES: dict[str, Pattern[str]] = {
"hive": re.compile(r"^INT_TYPE$"),
"mysql": re.compile("^LONGLONG$"),
"postgresql": re.compile(r"^INTEGER$"),
@@ -58,7 +59,7 @@ VIRTUAL_TABLE_INT_TYPES: Dict[str, Pattern[str]] = {
"sqlite": re.compile(r"^INT$"),
}
-VIRTUAL_TABLE_STRING_TYPES: Dict[str, Pattern[str]] = {
+VIRTUAL_TABLE_STRING_TYPES: dict[str, Pattern[str]] = {
"hive": re.compile(r"^STRING_TYPE$"),
"mysql": re.compile(r"^VAR_STRING$"),
"postgresql": re.compile(r"^STRING$"),
@@ -70,8 +71,8 @@ VIRTUAL_TABLE_STRING_TYPES: Dict[str, Pattern[str]] = {
class FilterTestCase(NamedTuple):
column: str
operator: str
- value: Union[float, int, List[Any], str]
- expected: Union[str, List[str]]
+ value: Union[float, int, list[Any], str]
+ expected: Union[str, list[str]]
class TestDatabaseModel(SupersetTestCase):
@@ -101,7 +102,7 @@ class TestDatabaseModel(SupersetTestCase):
assert col.is_temporal is True
def test_db_column_types(self):
- test_cases: Dict[str, GenericDataType] = {
+ test_cases: dict[str, GenericDataType] = {
# string
"CHAR": GenericDataType.STRING,
"VARCHAR": GenericDataType.STRING,
@@ -291,7 +292,7 @@ class TestDatabaseModel(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_where_operators(self):
- filters: Tuple[FilterTestCase, ...] = (
+ filters: tuple[FilterTestCase, ...] = (
FilterTestCase("num", FilterOperator.IS_NULL, "", "IS NULL"),
FilterTestCase("num", FilterOperator.IS_NOT_NULL, "", "IS NOT NULL"),
# Some db backends translate true/false to 1/0
@@ -493,7 +494,7 @@ class TestDatabaseModel(SupersetTestCase):
"mycase",
"expr",
}
- cols: Dict[str, TableColumn] = {col.column_name: col for col in table.columns}
+ cols: dict[str, TableColumn] = {col.column_name: col for col in table.columns}
# assert that the type for intcol has been updated (asserting CI types)
backend = table.database.backend
assert VIRTUAL_TABLE_INT_TYPES[backend].match(cols["intcol"].type)
@@ -802,7 +803,7 @@ def test__normalize_prequery_result_type(
result: Any,
) -> None:
def _convert_dttm(
- target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
+ target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None
) -> Optional[str]:
if target_type.upper() == "TIMESTAMP":
return f"""TIME_PARSE('{dttm.isoformat(timespec="seconds")}')"""
diff --git a/tests/integration_tests/sqllab_tests.py b/tests/integration_tests/sqllab_tests.py
index 16cc16d264..e9892b1d36 100644
--- a/tests/integration_tests/sqllab_tests.py
+++ b/tests/integration_tests/sqllab_tests.py
@@ -359,7 +359,7 @@ class TestSqlLab(SupersetTestCase):
db.session.commit()
data = self.get_json_resp(
- "/superset/queries/{}".format(float(datetime_to_epoch(now)) - 1000)
+ f"/superset/queries/{float(datetime_to_epoch(now)) - 1000}"
)
self.assertEqual(1, len(data))
@@ -391,13 +391,13 @@ class TestSqlLab(SupersetTestCase):
# Test search queries on user Id
user_id = security_manager.find_user("admin").id
- data = self.get_json_resp("/superset/search_queries?user_id={}".format(user_id))
+ data = self.get_json_resp(f"/superset/search_queries?user_id={user_id}")
self.assertEqual(2, len(data))
user_ids = {k["userId"] for k in data}
- self.assertEqual(set([user_id]), user_ids)
+ self.assertEqual({user_id}, user_ids)
user_id = security_manager.find_user("gamma_sqllab").id
- resp = self.get_resp("/superset/search_queries?user_id={}".format(user_id))
+ resp = self.get_resp(f"/superset/search_queries?user_id={user_id}")
data = json.loads(resp)
self.assertEqual(1, len(data))
self.assertEqual(data[0]["userId"], user_id)
@@ -451,7 +451,7 @@ class TestSqlLab(SupersetTestCase):
self.assertEqual(1, len(data))
user_ids = {k["userId"] for k in data}
- self.assertEqual(set([user_id]), user_ids)
+ self.assertEqual({user_id}, user_ids)
def test_alias_duplicate(self):
self.run_sql(
@@ -593,7 +593,7 @@ class TestSqlLab(SupersetTestCase):
self.assertEqual(len(data["data"]), test_limit)
data = self.run_sql(
- "SELECT * FROM birth_names LIMIT {}".format(test_limit),
+ f"SELECT * FROM birth_names LIMIT {test_limit}",
client_id="sql_limit_3",
query_limit=test_limit + 1,
)
@@ -601,7 +601,7 @@ class TestSqlLab(SupersetTestCase):
self.assertEqual(data["query"]["limitingFactor"], LimitingFactor.QUERY)
data = self.run_sql(
- "SELECT * FROM birth_names LIMIT {}".format(test_limit + 1),
+ f"SELECT * FROM birth_names LIMIT {test_limit + 1}",
client_id="sql_limit_4",
query_limit=test_limit,
)
@@ -609,7 +609,7 @@ class TestSqlLab(SupersetTestCase):
self.assertEqual(data["query"]["limitingFactor"], LimitingFactor.DROPDOWN)
data = self.run_sql(
- "SELECT * FROM birth_names LIMIT {}".format(test_limit),
+ f"SELECT * FROM birth_names LIMIT {test_limit}",
client_id="sql_limit_5",
query_limit=test_limit,
)
diff --git a/tests/integration_tests/strategy_tests.py b/tests/integration_tests/strategy_tests.py
index e54ae865e3..f6d664c649 100644
--- a/tests/integration_tests/strategy_tests.py
+++ b/tests/integration_tests/strategy_tests.py
@@ -16,8 +16,6 @@
# under the License.
# isort:skip_file
"""Unit tests for Superset cache warmup"""
-import datetime
-import json
from unittest.mock import MagicMock
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
diff --git a/tests/integration_tests/superset_test_config.py b/tests/integration_tests/superset_test_config.py
index c3f9b350f8..77e007a2dd 100644
--- a/tests/integration_tests/superset_test_config.py
+++ b/tests/integration_tests/superset_test_config.py
@@ -130,7 +130,7 @@ ALERT_REPORTS_WORKING_TIME_OUT_KILL = True
ALERT_REPORTS_QUERY_EXECUTION_MAX_TRIES = 3
-class CeleryConfig(object):
+class CeleryConfig:
BROKER_URL = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_CELERY_DB}"
CELERY_IMPORTS = ("superset.sql_lab",)
CELERY_RESULT_BACKEND = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_RESULTS_DB}"
diff --git a/tests/integration_tests/superset_test_config_sqllab_backend_persist_off.py b/tests/integration_tests/superset_test_config_sqllab_backend_persist_off.py
index 9f6dd2ead1..31d14ef71b 100644
--- a/tests/integration_tests/superset_test_config_sqllab_backend_persist_off.py
+++ b/tests/integration_tests/superset_test_config_sqllab_backend_persist_off.py
@@ -16,8 +16,6 @@
# under the License.
# flake8: noqa
# type: ignore
-import os
-from copy import copy
from .superset_test_config import *
diff --git a/tests/integration_tests/superset_test_config_thumbnails.py b/tests/integration_tests/superset_test_config_thumbnails.py
index 9f621efabb..5bd02e7b0f 100644
--- a/tests/integration_tests/superset_test_config_thumbnails.py
+++ b/tests/integration_tests/superset_test_config_thumbnails.py
@@ -61,7 +61,7 @@ REDIS_CELERY_DB = os.environ.get("REDIS_CELERY_DB", 2)
REDIS_RESULTS_DB = os.environ.get("REDIS_RESULTS_DB", 3)
-class CeleryConfig(object):
+class CeleryConfig:
BROKER_URL = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_CELERY_DB}"
CELERY_IMPORTS = ("superset.sql_lab", "superset.tasks.thumbnails")
CELERY_ANNOTATIONS = {"sql_lab.add": {"rate_limit": "10/s"}}
diff --git a/tests/integration_tests/tagging_tests.py b/tests/integration_tests/tagging_tests.py
index 71fb7e4e4e..72ba577d9f 100644
--- a/tests/integration_tests/tagging_tests.py
+++ b/tests/integration_tests/tagging_tests.py
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-from unittest import mock
import pytest
diff --git a/tests/integration_tests/tags/api_tests.py b/tests/integration_tests/tags/api_tests.py
index 7bf21da4fc..b047388a68 100644
--- a/tests/integration_tests/tags/api_tests.py
+++ b/tests/integration_tests/tags/api_tests.py
@@ -16,10 +16,7 @@
# under the License.
# isort:skip_file
"""Unit tests for Superset"""
-from datetime import datetime, timedelta
import json
-import random
-import string
import pytest
import prison
diff --git a/tests/integration_tests/tags/commands_tests.py b/tests/integration_tests/tags/commands_tests.py
index 8f44d2ebda..cd5a024840 100644
--- a/tests/integration_tests/tags/commands_tests.py
+++ b/tests/integration_tests/tags/commands_tests.py
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
import itertools
-import json
from unittest.mock import MagicMock, patch
import pytest
diff --git a/tests/integration_tests/tags/dao_tests.py b/tests/integration_tests/tags/dao_tests.py
index f46abaa723..49b22d260b 100644
--- a/tests/integration_tests/tags/dao_tests.py
+++ b/tests/integration_tests/tags/dao_tests.py
@@ -15,10 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
-import copy
-import json
from operator import and_
-import time
from unittest.mock import patch
import pytest
from superset.dao.exceptions import DAOCreateFailedError, DAOException
diff --git a/tests/integration_tests/thumbnails_tests.py b/tests/integration_tests/thumbnails_tests.py
index 228da6de79..eb2be859ba 100644
--- a/tests/integration_tests/thumbnails_tests.py
+++ b/tests/integration_tests/thumbnails_tests.py
@@ -20,7 +20,6 @@
import json
import urllib.request
from io import BytesIO
-from typing import Tuple
from unittest import skipUnless
from unittest.mock import ANY, call, MagicMock, patch
@@ -203,7 +202,7 @@ class TestThumbnails(SupersetTestCase):
digest_return_value = "foo_bar"
digest_hash = "5c7d96a3dd7a87850a2ef34087565a6e"
- def _get_id_and_thumbnail_url(self, url: str) -> Tuple[int, str]:
+ def _get_id_and_thumbnail_url(self, url: str) -> tuple[int, str]:
rv = self.client.get(url)
resp = json.loads(rv.data.decode("utf-8"))
obj = resp["result"][0]
diff --git a/tests/integration_tests/users/__init__.py b/tests/integration_tests/users/__init__.py
index fd9417fe5c..13a83393a9 100644
--- a/tests/integration_tests/users/__init__.py
+++ b/tests/integration_tests/users/__init__.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
diff --git a/tests/integration_tests/utils/csv_tests.py b/tests/integration_tests/utils/csv_tests.py
index e514efb1d2..38c1dd51ac 100644
--- a/tests/integration_tests/utils/csv_tests.py
+++ b/tests/integration_tests/utils/csv_tests.py
@@ -43,13 +43,13 @@ def test_escape_value():
assert result == "'=value"
result = csv.escape_value("|value")
- assert result == "'\|value"
+ assert result == r"'\|value"
result = csv.escape_value("%value")
assert result == "'%value"
result = csv.escape_value("=cmd|' /C calc'!A0")
- assert result == "'=cmd\|' /C calc'!A0"
+ assert result == r"'=cmd\|' /C calc'!A0"
result = csv.escape_value('""=10+2')
assert result == '\'""=10+2'
@@ -74,7 +74,7 @@ def test_df_to_escaped_csv():
assert escaped_csv_rows == [
["col_a", "'=func()"],
- ["-10", "'=cmd\|' /C calc'!A0"],
+ ["-10", r"'=cmd\|' /C calc'!A0"],
["a", "'=b"], # pandas seems to be removing the leading ""
["' =a", "b"],
]
diff --git a/tests/integration_tests/utils/encrypt_tests.py b/tests/integration_tests/utils/encrypt_tests.py
index 2199783529..45fd291ee8 100644
--- a/tests/integration_tests/utils/encrypt_tests.py
+++ b/tests/integration_tests/utils/encrypt_tests.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, List, Optional
+from typing import Any, Optional
from sqlalchemy import String, TypeDecorator
from sqlalchemy_utils import EncryptedType
@@ -28,9 +28,9 @@ from tests.integration_tests.base_tests import SupersetTestCase
class CustomEncFieldAdapter(AbstractEncryptedFieldAdapter):
def create(
self,
- app_config: Optional[Dict[str, Any]],
- *args: List[Any],
- **kwargs: Optional[Dict[str, Any]]
+ app_config: Optional[dict[str, Any]],
+ *args: list[Any],
+ **kwargs: Optional[dict[str, Any]]
) -> TypeDecorator:
if app_config:
return StringEncryptedType(*args, app_config["SECRET_KEY"], **kwargs)
diff --git a/tests/integration_tests/utils/get_dashboards.py b/tests/integration_tests/utils/get_dashboards.py
index 03260fb94d..7012bf08a0 100644
--- a/tests/integration_tests/utils/get_dashboards.py
+++ b/tests/integration_tests/utils/get_dashboards.py
@@ -14,14 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import List
from flask_appbuilder import SQLA
from superset.models.dashboard import Dashboard
-def get_dashboards_ids(db: SQLA, dashboard_slugs: List[str]) -> List[int]:
+def get_dashboards_ids(db: SQLA, dashboard_slugs: list[str]) -> list[int]:
result = (
db.session.query(Dashboard.id).filter(Dashboard.slug.in_(dashboard_slugs)).all()
)
diff --git a/tests/integration_tests/utils/public_interfaces_test.py b/tests/integration_tests/utils/public_interfaces_test.py
index 7b5d671246..af67bb6ca3 100644
--- a/tests/integration_tests/utils/public_interfaces_test.py
+++ b/tests/integration_tests/utils/public_interfaces_test.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Callable, Dict
+from typing import Any, Callable
import pytest
@@ -23,7 +23,7 @@ from superset.utils.public_interfaces import compute_hash, get_warning_message
# These are public interfaces exposed by Superset. Make sure
# to only change the interfaces and update the hashes in new
# major versions of Superset.
-hashes: Dict[Callable[..., Any], str] = {}
+hashes: dict[Callable[..., Any], str] = {}
@pytest.mark.parametrize("interface,expected_hash", list(hashes.items()))
diff --git a/tests/integration_tests/utils_tests.py b/tests/integration_tests/utils_tests.py
index 2db008fdb7..b4d750c8d0 100644
--- a/tests/integration_tests/utils_tests.py
+++ b/tests/integration_tests/utils_tests.py
@@ -21,7 +21,7 @@ from decimal import Decimal
import json
import os
import re
-from typing import Any, Tuple, List, Optional
+from typing import Any, Optional
from unittest.mock import Mock, patch
from superset.databases.commands.exceptions import DatabaseInvalidError
@@ -121,12 +121,12 @@ class TestUtils(SupersetTestCase):
assert isinstance(base_json_conv(np.int64(1)), int)
assert isinstance(base_json_conv(np.array([1, 2, 3])), list)
assert base_json_conv(np.array(None)) is None
- assert isinstance(base_json_conv(set([1])), list)
+ assert isinstance(base_json_conv({1}), list)
assert isinstance(base_json_conv(Decimal("1.0")), float)
assert isinstance(base_json_conv(uuid.uuid4()), str)
assert isinstance(base_json_conv(time()), str)
assert isinstance(base_json_conv(timedelta(0)), str)
- assert isinstance(base_json_conv(bytes()), str)
+ assert isinstance(base_json_conv(b""), str)
assert base_json_conv(bytes("", encoding="utf-16")) == "[bytes]"
with pytest.raises(TypeError):
@@ -1054,7 +1054,7 @@ class TestUtils(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_extract_dataframe_dtypes(self):
slc = self.get_slice("Girls", db.session)
- cols: Tuple[Tuple[str, GenericDataType, List[Any]], ...] = (
+ cols: tuple[tuple[str, GenericDataType, list[Any]], ...] = (
("dt", GenericDataType.TEMPORAL, [date(2021, 2, 4), date(2021, 2, 4)]),
(
"dttm",
diff --git a/tests/integration_tests/viz_tests.py b/tests/integration_tests/viz_tests.py
index 137e2a474c..30d4d1e183 100644
--- a/tests/integration_tests/viz_tests.py
+++ b/tests/integration_tests/viz_tests.py
@@ -19,7 +19,6 @@ from datetime import date, datetime, timezone
import logging
from math import nan
from unittest.mock import Mock, patch
-from typing import Any, Dict, List, Set
import numpy as np
import pandas as pd
@@ -1009,7 +1008,7 @@ class TestTimeSeriesTableViz(SupersetTestCase):
test_viz = viz.TimeTableViz(datasource, form_data)
data = test_viz.get_data(df)
# Check method correctly transforms data
- self.assertEqual(set(["count", "sum__A"]), set(data["columns"]))
+ self.assertEqual({"count", "sum__A"}, set(data["columns"]))
time_format = "%Y-%m-%d %H:%M:%S"
expected = {
t1.strftime(time_format): {"sum__A": 15, "count": 6},
@@ -1030,7 +1029,7 @@ class TestTimeSeriesTableViz(SupersetTestCase):
test_viz = viz.TimeTableViz(datasource, form_data)
data = test_viz.get_data(df)
# Check method correctly transforms data
- self.assertEqual(set(["a1", "a2", "a3"]), set(data["columns"]))
+ self.assertEqual({"a1", "a2", "a3"}, set(data["columns"]))
time_format = "%Y-%m-%d %H:%M:%S"
expected = {
t1.strftime(time_format): {"a1": 15, "a2": 20, "a3": 25},
diff --git a/tests/unit_tests/charts/dao/dao_tests.py b/tests/unit_tests/charts/dao/dao_tests.py
index 72ae9dbba7..b1d5cc6488 100644
--- a/tests/unit_tests/charts/dao/dao_tests.py
+++ b/tests/unit_tests/charts/dao/dao_tests.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Iterator
+from collections.abc import Iterator
import pytest
from sqlalchemy.orm.session import Session
diff --git a/tests/unit_tests/charts/test_post_processing.py b/tests/unit_tests/charts/test_post_processing.py
index 84496bf1cf..b7cdda6e68 100644
--- a/tests/unit_tests/charts/test_post_processing.py
+++ b/tests/unit_tests/charts/test_post_processing.py
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-import json
import pandas as pd
import pytest
diff --git a/tests/unit_tests/common/test_query_object_factory.py b/tests/unit_tests/common/test_query_object_factory.py
index 4fd906f648..02304828dc 100644
--- a/tests/unit_tests/common/test_query_object_factory.py
+++ b/tests/unit_tests/common/test_query_object_factory.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, Optional
+from typing import Any, Optional
from unittest.mock import Mock, patch
from pytest import fixture, mark
@@ -23,7 +23,7 @@ from superset.common.query_object_factory import QueryObjectFactory
from tests.common.query_context_generator import QueryContextGenerator
-def create_app_config() -> Dict[str, Any]:
+def create_app_config() -> dict[str, Any]:
return {
"ROW_LIMIT": 5000,
"DEFAULT_RELATIVE_START_TIME": "today",
@@ -34,7 +34,7 @@ def create_app_config() -> Dict[str, Any]:
@fixture
-def app_config() -> Dict[str, Any]:
+def app_config() -> dict[str, Any]:
return create_app_config().copy()
@@ -58,7 +58,7 @@ def apply_max_row_limit(limit: int, max_limit: Optional[int] = None) -> int:
@fixture
def query_object_factory(
- app_config: Dict[str, Any], connector_registry: Mock, session_factory: Mock
+ app_config: dict[str, Any], connector_registry: Mock, session_factory: Mock
) -> QueryObjectFactory:
import superset.common.query_object_factory as mod
@@ -67,7 +67,7 @@ def query_object_factory(
@fixture
-def raw_query_context() -> Dict[str, Any]:
+def raw_query_context() -> dict[str, Any]:
return QueryContextGenerator().generate("birth_names")
@@ -75,7 +75,7 @@ class TestQueryObjectFactory:
def test_query_context_limit_and_offset_defaults(
self,
query_object_factory: QueryObjectFactory,
- raw_query_context: Dict[str, Any],
+ raw_query_context: dict[str, Any],
):
raw_query_object = raw_query_context["queries"][0]
raw_query_object.pop("row_limit", None)
@@ -89,7 +89,7 @@ class TestQueryObjectFactory:
def test_query_context_limit(
self,
query_object_factory: QueryObjectFactory,
- raw_query_context: Dict[str, Any],
+ raw_query_context: dict[str, Any],
):
raw_query_object = raw_query_context["queries"][0]
raw_query_object["row_limit"] = 100
@@ -104,7 +104,7 @@ class TestQueryObjectFactory:
def test_query_context_null_post_processing_op(
self,
query_object_factory: QueryObjectFactory,
- raw_query_context: Dict[str, Any],
+ raw_query_context: dict[str, Any],
):
raw_query_object = raw_query_context["queries"][0]
raw_query_object["post_processing"] = [None]
diff --git a/tests/unit_tests/config_test.py b/tests/unit_tests/config_test.py
index 021193a6cd..4a62f26e6f 100644
--- a/tests/unit_tests/config_test.py
+++ b/tests/unit_tests/config_test.py
@@ -17,7 +17,7 @@
# pylint: disable=import-outside-toplevel, unused-argument, redefined-outer-name, invalid-name
from functools import partial
-from typing import Any, Dict, TYPE_CHECKING
+from typing import Any, TYPE_CHECKING
import pytest
from pytest_mock import MockerFixture
@@ -44,7 +44,7 @@ FULL_DTTM_DEFAULTS_EXAMPLE = {
}
-def apply_dttm_defaults(table: "SqlaTable", dttm_defaults: Dict[str, Any]) -> None:
+def apply_dttm_defaults(table: "SqlaTable", dttm_defaults: dict[str, Any]) -> None:
"""Applies dttm defaults to the table, mutates in place."""
for dbcol in table.columns:
# Set is_dttm is column is listed in dttm_columns.
diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py
index 6740a8b6e2..6a4f1e550c 100644
--- a/tests/unit_tests/conftest.py
+++ b/tests/unit_tests/conftest.py
@@ -19,7 +19,8 @@
import importlib
import os
import unittest.mock
-from typing import Any, Callable, Iterator
+from collections.abc import Iterator
+from typing import Any, Callable
import pytest
from _pytest.fixtures import SubRequest
diff --git a/tests/unit_tests/dao/queries_test.py b/tests/unit_tests/dao/queries_test.py
index 62eeff3106..d0ab3ec8a5 100644
--- a/tests/unit_tests/dao/queries_test.py
+++ b/tests/unit_tests/dao/queries_test.py
@@ -14,9 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-import json
from datetime import datetime, timedelta
-from typing import Any, Iterator
+from typing import Any
import pytest
from pytest_mock import MockFixture
diff --git a/tests/unit_tests/dashboards/commands/importers/v1/utils_test.py b/tests/unit_tests/dashboards/commands/importers/v1/utils_test.py
index 0392acb315..60a659159a 100644
--- a/tests/unit_tests/dashboards/commands/importers/v1/utils_test.py
+++ b/tests/unit_tests/dashboards/commands/importers/v1/utils_test.py
@@ -16,7 +16,7 @@
# under the License.
# pylint: disable=import-outside-toplevel, unused-argument
-from typing import Any, Dict
+from typing import Any
def test_update_id_refs_immune_missing( # pylint: disable=invalid-name
@@ -59,7 +59,7 @@ def test_update_id_refs_immune_missing( # pylint: disable=invalid-name
},
}
chart_ids = {"uuid1": 1, "uuid2": 2}
- dataset_info: Dict[str, Dict[str, Any]] = {} # not used
+ dataset_info: dict[str, dict[str, Any]] = {} # not used
fixed = update_id_refs(config, chart_ids, dataset_info)
assert fixed == {
@@ -103,7 +103,7 @@ def test_update_native_filter_config_scope_excluded():
},
}
chart_ids = {"uuid1": 1, "uuid2": 2}
- dataset_info: Dict[str, Dict[str, Any]] = {} # not used
+ dataset_info: dict[str, dict[str, Any]] = {} # not used
fixed = update_id_refs(config, chart_ids, dataset_info)
assert fixed == {
diff --git a/tests/unit_tests/dashboards/dao_tests.py b/tests/unit_tests/dashboards/dao_tests.py
index a8f93e7513..c94d2ab157 100644
--- a/tests/unit_tests/dashboards/dao_tests.py
+++ b/tests/unit_tests/dashboards/dao_tests.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Iterator
+from collections.abc import Iterator
import pytest
from sqlalchemy.orm.session import Session
diff --git a/tests/unit_tests/databases/dao/dao_tests.py b/tests/unit_tests/databases/dao/dao_tests.py
index 47db402670..f085cb53c7 100644
--- a/tests/unit_tests/databases/dao/dao_tests.py
+++ b/tests/unit_tests/databases/dao/dao_tests.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Iterator
+from collections.abc import Iterator
import pytest
from sqlalchemy.orm.session import Session
diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py
index 2a5738ebd3..fbad104c1d 100644
--- a/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py
+++ b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Iterator
import pytest
from sqlalchemy.orm.session import Session
diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py
index b5adf765fa..de0b70db9c 100644
--- a/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py
+++ b/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Iterator
+from collections.abc import Iterator
import pytest
from pytest_mock import MockFixture
diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py
index 58f90054cc..544cf3434a 100644
--- a/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py
+++ b/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Iterator
+from collections.abc import Iterator
import pytest
from sqlalchemy.orm.session import Session
diff --git a/tests/unit_tests/databases/ssh_tunnel/dao_tests.py b/tests/unit_tests/databases/ssh_tunnel/dao_tests.py
index ae5b6e9bd3..27f9c3b8ad 100644
--- a/tests/unit_tests/databases/ssh_tunnel/dao_tests.py
+++ b/tests/unit_tests/databases/ssh_tunnel/dao_tests.py
@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Iterator
import pytest
from sqlalchemy.orm.session import Session
diff --git a/tests/unit_tests/datasets/commands/importers/v1/import_test.py b/tests/unit_tests/datasets/commands/importers/v1/import_test.py
index 839374425b..9e8690e6e3 100644
--- a/tests/unit_tests/datasets/commands/importers/v1/import_test.py
+++ b/tests/unit_tests/datasets/commands/importers/v1/import_test.py
@@ -20,7 +20,7 @@ import copy
import json
import re
import uuid
-from typing import Any, Dict
+from typing import Any
from unittest.mock import Mock, patch
import pytest
@@ -296,7 +296,7 @@ def test_import_column_extra_is_string(mocker: MockFixture, session: Session) ->
session.flush()
dataset_uuid = uuid.uuid4()
- yaml_config: Dict[str, Any] = {
+ yaml_config: dict[str, Any] = {
"version": "1.0.0",
"table_name": "my_table",
"main_dttm_col": "ds",
@@ -388,7 +388,7 @@ def test_import_column_allowed_data_url(
session.flush()
dataset_uuid = uuid.uuid4()
- yaml_config: Dict[str, Any] = {
+ yaml_config: dict[str, Any] = {
"version": "1.0.0",
"table_name": "my_table",
"main_dttm_col": "ds",
diff --git a/tests/unit_tests/datasets/conftest.py b/tests/unit_tests/datasets/conftest.py
index 8d217ae27a..8bef6945a6 100644
--- a/tests/unit_tests/datasets/conftest.py
+++ b/tests/unit_tests/datasets/conftest.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, TYPE_CHECKING
+from typing import Any, TYPE_CHECKING
import pytest
@@ -23,7 +23,7 @@ if TYPE_CHECKING:
@pytest.fixture
-def columns_default() -> Dict[str, Any]:
+def columns_default() -> dict[str, Any]:
"""Default props for new columns"""
return {
"changed_by": 1,
@@ -49,7 +49,7 @@ def columns_default() -> Dict[str, Any]:
@pytest.fixture
-def sample_columns() -> Dict["TableColumn", Dict[str, Any]]:
+def sample_columns() -> dict["TableColumn", dict[str, Any]]:
from superset.connectors.sqla.models import TableColumn
return {
@@ -93,7 +93,7 @@ def sample_columns() -> Dict["TableColumn", Dict[str, Any]]:
@pytest.fixture
-def sample_metrics() -> Dict["SqlMetric", Dict[str, Any]]:
+def sample_metrics() -> dict["SqlMetric", dict[str, Any]]:
from superset.connectors.sqla.models import SqlMetric
return {
diff --git a/tests/unit_tests/datasets/dao/dao_tests.py b/tests/unit_tests/datasets/dao/dao_tests.py
index 350425d08e..4eb43cd9de 100644
--- a/tests/unit_tests/datasets/dao/dao_tests.py
+++ b/tests/unit_tests/datasets/dao/dao_tests.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Iterator
+from collections.abc import Iterator
import pytest
from sqlalchemy.orm.session import Session
diff --git a/tests/unit_tests/datasource/dao_tests.py b/tests/unit_tests/datasource/dao_tests.py
index 16334066d7..99a4850301 100644
--- a/tests/unit_tests/datasource/dao_tests.py
+++ b/tests/unit_tests/datasource/dao_tests.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Iterator
+from collections.abc import Iterator
import pytest
from sqlalchemy.orm.session import Session
diff --git a/tests/unit_tests/db_engine_specs/test_athena.py b/tests/unit_tests/db_engine_specs/test_athena.py
index 51ec6656aa..f0811a3e14 100644
--- a/tests/unit_tests/db_engine_specs/test_athena.py
+++ b/tests/unit_tests/db_engine_specs/test_athena.py
@@ -81,7 +81,7 @@ def test_get_text_clause_with_colon() -> None:
from superset.db_engine_specs.athena import AthenaEngineSpec
query = (
- "SELECT foo FROM tbl WHERE " "abc >= TIMESTAMP '2021-11-26T00\:00\:00.000000'"
+ "SELECT foo FROM tbl WHERE " r"abc >= TIMESTAMP '2021-11-26T00\:00\:00.000000'"
)
text_clause = AthenaEngineSpec.get_text_clause(query)
assert text_clause.text == query
diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py
index 868a6bbdc3..33083f0399 100644
--- a/tests/unit_tests/db_engine_specs/test_base.py
+++ b/tests/unit_tests/db_engine_specs/test_base.py
@@ -17,7 +17,7 @@
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
from textwrap import dedent
-from typing import Any, Dict, Optional, Type
+from typing import Any, Optional
import pytest
from sqlalchemy import types
@@ -130,8 +130,8 @@ def test_cte_query_parsing(original: types.TypeEngine, expected: str) -> None:
)
def test_get_column_spec(
native_type: str,
- sqla_type: Type[types.TypeEngine],
- attrs: Optional[Dict[str, Any]],
+ sqla_type: type[types.TypeEngine],
+ attrs: Optional[dict[str, Any]],
generic_type: GenericDataType,
is_dttm: bool,
) -> None:
diff --git a/tests/unit_tests/db_engine_specs/test_clickhouse.py b/tests/unit_tests/db_engine_specs/test_clickhouse.py
index 0c437bc009..6dfeddaf37 100644
--- a/tests/unit_tests/db_engine_specs/test_clickhouse.py
+++ b/tests/unit_tests/db_engine_specs/test_clickhouse.py
@@ -16,7 +16,7 @@
# under the License.
from datetime import datetime
-from typing import Any, Dict, Optional, Type
+from typing import Any, Optional
from unittest.mock import Mock
import pytest
@@ -189,8 +189,8 @@ def test_connect_convert_dttm(
)
def test_connect_get_column_spec(
native_type: str,
- sqla_type: Type[TypeEngine],
- attrs: Optional[Dict[str, Any]],
+ sqla_type: type[TypeEngine],
+ attrs: Optional[dict[str, Any]],
generic_type: GenericDataType,
is_dttm: bool,
) -> None:
diff --git a/tests/unit_tests/db_engine_specs/test_elasticsearch.py b/tests/unit_tests/db_engine_specs/test_elasticsearch.py
index de55c63426..0c15977669 100644
--- a/tests/unit_tests/db_engine_specs/test_elasticsearch.py
+++ b/tests/unit_tests/db_engine_specs/test_elasticsearch.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
-from typing import Any, Dict, Optional
+from typing import Any, Optional
from unittest.mock import MagicMock
import pytest
@@ -49,7 +49,7 @@ from tests.unit_tests.fixtures.common import dttm
)
def test_elasticsearch_convert_dttm(
target_type: str,
- db_extra: Optional[Dict[str, Any]],
+ db_extra: Optional[dict[str, Any]],
expected_result: Optional[str],
dttm: datetime,
) -> None:
diff --git a/tests/unit_tests/db_engine_specs/test_mssql.py b/tests/unit_tests/db_engine_specs/test_mssql.py
index acd35a4ecf..673f4817be 100644
--- a/tests/unit_tests/db_engine_specs/test_mssql.py
+++ b/tests/unit_tests/db_engine_specs/test_mssql.py
@@ -17,7 +17,7 @@
import unittest.mock as mock
from datetime import datetime
from textwrap import dedent
-from typing import Any, Dict, Optional, Type
+from typing import Any, Optional
import pytest
from sqlalchemy import column, table
@@ -50,8 +50,8 @@ from tests.unit_tests.fixtures.common import dttm
)
def test_get_column_spec(
native_type: str,
- sqla_type: Type[TypeEngine],
- attrs: Optional[Dict[str, Any]],
+ sqla_type: type[TypeEngine],
+ attrs: Optional[dict[str, Any]],
generic_type: GenericDataType,
is_dttm: bool,
) -> None:
diff --git a/tests/unit_tests/db_engine_specs/test_mysql.py b/tests/unit_tests/db_engine_specs/test_mysql.py
index 07ce6838fc..89abf2321d 100644
--- a/tests/unit_tests/db_engine_specs/test_mysql.py
+++ b/tests/unit_tests/db_engine_specs/test_mysql.py
@@ -16,7 +16,7 @@
# under the License.
from datetime import datetime
-from typing import Any, Dict, Optional, Tuple, Type
+from typing import Any, Optional
from unittest.mock import Mock, patch
import pytest
@@ -71,8 +71,8 @@ from tests.unit_tests.fixtures.common import dttm
)
def test_get_column_spec(
native_type: str,
- sqla_type: Type[types.TypeEngine],
- attrs: Optional[Dict[str, Any]],
+ sqla_type: type[types.TypeEngine],
+ attrs: Optional[dict[str, Any]],
generic_type: GenericDataType,
is_dttm: bool,
) -> None:
@@ -166,7 +166,7 @@ def test_validate_database_uri(sqlalchemy_uri: str, error: bool) -> None:
],
)
def test_adjust_engine_params(
- sqlalchemy_uri: str, connect_args: Dict[str, Any], returns: Dict[str, Any]
+ sqlalchemy_uri: str, connect_args: dict[str, Any], returns: dict[str, Any]
) -> None:
from superset.db_engine_specs.mysql import MySQLEngineSpec
diff --git a/tests/unit_tests/db_engine_specs/test_ocient.py b/tests/unit_tests/db_engine_specs/test_ocient.py
index af9fd2ad16..a58f31d242 100644
--- a/tests/unit_tests/db_engine_specs/test_ocient.py
+++ b/tests/unit_tests/db_engine_specs/test_ocient.py
@@ -17,7 +17,7 @@
# pylint: disable=import-outside-toplevel
-from typing import Any, Callable, Dict, List, Optional, Tuple
+from typing import Any, Callable
import pytest
@@ -33,7 +33,7 @@ def ocient_is_installed() -> bool:
# (msg,expected)
-MARSHALED_OCIENT_ERRORS: List[Tuple[str, SupersetError]] = [
+MARSHALED_OCIENT_ERRORS: list[tuple[str, SupersetError]] = [
(
"The referenced user does not exist (User 'mj' not found)",
SupersetError(
@@ -224,7 +224,7 @@ def test_connection_errors(msg: str, expected: SupersetError) -> None:
def _generate_gis_type_sanitization_test_cases() -> (
- List[Tuple[str, int, Any, Dict[str, Any]]]
+ list[tuple[str, int, Any, dict[str, Any]]]
):
if not ocient_is_installed():
return []
diff --git a/tests/unit_tests/db_engine_specs/test_postgres.py b/tests/unit_tests/db_engine_specs/test_postgres.py
index fef8647959..145d398898 100644
--- a/tests/unit_tests/db_engine_specs/test_postgres.py
+++ b/tests/unit_tests/db_engine_specs/test_postgres.py
@@ -16,7 +16,7 @@
# under the License.
from datetime import datetime
-from typing import Any, Dict, Optional, Type
+from typing import Any, Optional
import pytest
from sqlalchemy import types
@@ -82,8 +82,8 @@ def test_convert_dttm(
)
def test_get_column_spec(
native_type: str,
- sqla_type: Type[types.TypeEngine],
- attrs: Optional[Dict[str, Any]],
+ sqla_type: type[types.TypeEngine],
+ attrs: Optional[dict[str, Any]],
generic_type: GenericDataType,
is_dttm: bool,
) -> None:
diff --git a/tests/unit_tests/db_engine_specs/test_presto.py b/tests/unit_tests/db_engine_specs/test_presto.py
index df2ed58c37..7739361cf3 100644
--- a/tests/unit_tests/db_engine_specs/test_presto.py
+++ b/tests/unit_tests/db_engine_specs/test_presto.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from datetime import datetime
-from typing import Any, Dict, Optional, Type
+from typing import Any, Optional
from unittest import mock
import pytest
@@ -77,8 +77,8 @@ def test_convert_dttm(
)
def test_get_column_spec(
native_type: str,
- sqla_type: Type[types.TypeEngine],
- attrs: Optional[Dict[str, Any]],
+ sqla_type: type[types.TypeEngine],
+ attrs: Optional[dict[str, Any]],
generic_type: GenericDataType,
is_dttm: bool,
) -> None:
diff --git a/tests/unit_tests/db_engine_specs/test_starrocks.py b/tests/unit_tests/db_engine_specs/test_starrocks.py
index 7812a16830..ac246e3d5b 100644
--- a/tests/unit_tests/db_engine_specs/test_starrocks.py
+++ b/tests/unit_tests/db_engine_specs/test_starrocks.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict, Optional, Type
+from typing import Any, Optional
import pytest
from sqlalchemy import types
@@ -45,8 +45,8 @@ from tests.unit_tests.db_engine_specs.utils import assert_column_spec
)
def test_get_column_spec(
native_type: str,
- sqla_type: Type[types.TypeEngine],
- attrs: Optional[Dict[str, Any]],
+ sqla_type: type[types.TypeEngine],
+ attrs: Optional[dict[str, Any]],
generic_type: GenericDataType,
is_dttm: bool,
) -> None:
@@ -74,9 +74,9 @@ def test_get_column_spec(
)
def test_adjust_engine_params(
sqlalchemy_uri: str,
- connect_args: Dict[str, Any],
+ connect_args: dict[str, Any],
return_schema: str,
- return_connect_args: Dict[str, Any],
+ return_connect_args: dict[str, Any],
) -> None:
from superset.db_engine_specs.starrocks import StarRocksEngineSpec
diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py
index 0ea296a075..963953d18b 100644
--- a/tests/unit_tests/db_engine_specs/test_trino.py
+++ b/tests/unit_tests/db_engine_specs/test_trino.py
@@ -17,7 +17,7 @@
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
import json
from datetime import datetime
-from typing import Any, Dict, Optional, Type
+from typing import Any, Optional
from unittest.mock import Mock, patch
import pandas as pd
@@ -57,7 +57,7 @@ from tests.unit_tests.fixtures.common import dttm
),
],
)
-def test_get_extra_params(extra: Dict[str, Any], expected: Dict[str, Any]) -> None:
+def test_get_extra_params(extra: dict[str, Any], expected: dict[str, Any]) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
database = Mock()
@@ -95,7 +95,7 @@ def test_auth_basic(mock_auth: Mock) -> None:
{"auth_method": "basic", "auth_params": auth_params}
)
- params: Dict[str, Any] = {}
+ params: dict[str, Any] = {}
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
connect_args = params.setdefault("connect_args", {})
assert connect_args.get("http_scheme") == "https"
@@ -117,7 +117,7 @@ def test_auth_kerberos(mock_auth: Mock) -> None:
{"auth_method": "kerberos", "auth_params": auth_params}
)
- params: Dict[str, Any] = {}
+ params: dict[str, Any] = {}
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
connect_args = params.setdefault("connect_args", {})
assert connect_args.get("http_scheme") == "https"
@@ -134,7 +134,7 @@ def test_auth_certificate(mock_auth: Mock) -> None:
{"auth_method": "certificate", "auth_params": auth_params}
)
- params: Dict[str, Any] = {}
+ params: dict[str, Any] = {}
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
connect_args = params.setdefault("connect_args", {})
assert connect_args.get("http_scheme") == "https"
@@ -152,7 +152,7 @@ def test_auth_jwt(mock_auth: Mock) -> None:
{"auth_method": "jwt", "auth_params": auth_params}
)
- params: Dict[str, Any] = {}
+ params: dict[str, Any] = {}
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
connect_args = params.setdefault("connect_args", {})
assert connect_args.get("http_scheme") == "https"
@@ -176,7 +176,7 @@ def test_auth_custom_auth() -> None:
{"trino": {"custom_auth": auth_class}},
clear=True,
):
- params: Dict[str, Any] = {}
+ params: dict[str, Any] = {}
TrinoEngineSpec.update_params_from_encrypted_extra(database, params)
connect_args = params.setdefault("connect_args", {})
@@ -243,8 +243,8 @@ def test_auth_custom_auth_denied() -> None:
)
def test_get_column_spec(
native_type: str,
- sqla_type: Type[types.TypeEngine],
- attrs: Optional[Dict[str, Any]],
+ sqla_type: type[types.TypeEngine],
+ attrs: Optional[dict[str, Any]],
generic_type: GenericDataType,
is_dttm: bool,
) -> None:
@@ -324,8 +324,8 @@ def test_cancel_query_failed(engine_mock: Mock) -> None:
],
)
def test_prepare_cancel_query(
- initial_extra: Dict[str, Any],
- final_extra: Dict[str, Any],
+ initial_extra: dict[str, Any],
+ final_extra: dict[str, Any],
mocker: MockerFixture,
) -> None:
from superset.db_engine_specs.trino import TrinoEngineSpec
diff --git a/tests/unit_tests/db_engine_specs/utils.py b/tests/unit_tests/db_engine_specs/utils.py
index 13ae7a34d2..774ca3eaf2 100644
--- a/tests/unit_tests/db_engine_specs/utils.py
+++ b/tests/unit_tests/db_engine_specs/utils.py
@@ -17,7 +17,7 @@
from __future__ import annotations
from datetime import datetime
-from typing import Any, Dict, Optional, Type, TYPE_CHECKING
+from typing import Any, TYPE_CHECKING
from sqlalchemy import types
@@ -28,11 +28,11 @@ if TYPE_CHECKING:
def assert_convert_dttm(
- db_engine_spec: Type[BaseEngineSpec],
+ db_engine_spec: type[BaseEngineSpec],
target_type: str,
- expected_result: Optional[str],
+ expected_result: str | None,
dttm: datetime,
- db_extra: Optional[Dict[str, Any]] = None,
+ db_extra: dict[str, Any] | None = None,
) -> None:
for target in (
target_type,
@@ -50,10 +50,10 @@ def assert_convert_dttm(
def assert_column_spec(
- db_engine_spec: Type[BaseEngineSpec],
+ db_engine_spec: type[BaseEngineSpec],
native_type: str,
- sqla_type: Type[types.TypeEngine],
- attrs: Optional[Dict[str, Any]],
+ sqla_type: type[types.TypeEngine],
+ attrs: dict[str, Any] | None,
generic_type: GenericDataType,
is_dttm: bool,
) -> None:
diff --git a/tests/unit_tests/extensions/ssh_test.py b/tests/unit_tests/extensions/ssh_test.py
index 0e997729d9..4538d71969 100644
--- a/tests/unit_tests/extensions/ssh_test.py
+++ b/tests/unit_tests/extensions/ssh_test.py
@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any
from unittest.mock import Mock, patch
import pytest
diff --git a/tests/unit_tests/fixtures/assets_configs.py b/tests/unit_tests/fixtures/assets_configs.py
index 73bc5921ec..bda84c1335 100644
--- a/tests/unit_tests/fixtures/assets_configs.py
+++ b/tests/unit_tests/fixtures/assets_configs.py
@@ -14,9 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict
+from typing import Any
-databases_config: Dict[str, Any] = {
+databases_config: dict[str, Any] = {
"databases/examples.yaml": {
"database_name": "examples",
"sqlalchemy_uri": "sqlite:///test.db",
@@ -32,7 +32,7 @@ databases_config: Dict[str, Any] = {
"allow_csv_upload": False,
},
}
-datasets_config: Dict[str, Any] = {
+datasets_config: dict[str, Any] = {
"datasets/examples/video_game_sales.yaml": {
"table_name": "video_game_sales",
"main_dttm_col": None,
@@ -80,7 +80,7 @@ datasets_config: Dict[str, Any] = {
"database_uuid": "a2dc77af-e654-49bb-b321-40f6b559a1ee",
},
}
-charts_config_1: Dict[str, Any] = {
+charts_config_1: dict[str, Any] = {
"charts/Games_per_Genre_over_time_95.yaml": {
"slice_name": "Games per Genre over time",
"viz_type": "line",
@@ -100,7 +100,7 @@ charts_config_1: Dict[str, Any] = {
"dataset_uuid": "53d47c0c-c03d-47f0-b9ac-81225f808283",
},
}
-dashboards_config_1: Dict[str, Any] = {
+dashboards_config_1: dict[str, Any] = {
"dashboards/Video_Game_Sales_11.yaml": {
"dashboard_title": "Video Game Sales",
"description": None,
@@ -182,7 +182,7 @@ dashboards_config_1: Dict[str, Any] = {
},
}
-charts_config_2: Dict[str, Any] = {
+charts_config_2: dict[str, Any] = {
"charts/Games_per_Genre_131.yaml": {
"slice_name": "Games per Genre",
"viz_type": "treemap",
@@ -193,7 +193,7 @@ charts_config_2: Dict[str, Any] = {
"dataset_uuid": "53d47c0c-c03d-47f0-b9ac-81225f808283",
},
}
-dashboards_config_2: Dict[str, Any] = {
+dashboards_config_2: dict[str, Any] = {
"dashboards/Video_Game_Sales_11.yaml": {
"dashboard_title": "Video Game Sales",
"description": None,
diff --git a/tests/unit_tests/fixtures/datasets.py b/tests/unit_tests/fixtures/datasets.py
index 5d5466a5e8..7bddae0b81 100644
--- a/tests/unit_tests/fixtures/datasets.py
+++ b/tests/unit_tests/fixtures/datasets.py
@@ -14,11 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Any, Dict
+from typing import Any
from unittest.mock import Mock
-def get_column_mock(params: Dict[str, Any]) -> Mock:
+def get_column_mock(params: dict[str, Any]) -> Mock:
mock = Mock()
mock.id = params["id"]
mock.column_name = params["column_name"]
@@ -32,7 +32,7 @@ def get_column_mock(params: Dict[str, Any]) -> Mock:
return mock
-def get_metric_mock(params: Dict[str, Any]) -> Mock:
+def get_metric_mock(params: dict[str, Any]) -> Mock:
mock = Mock()
mock.id = params["id"]
mock.metric_name = params["metric_name"]
diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py
index bf8f589913..d37296447a 100644
--- a/tests/unit_tests/models/core_test.py
+++ b/tests/unit_tests/models/core_test.py
@@ -18,7 +18,7 @@
# pylint: disable=import-outside-toplevel
import json
from datetime import datetime
-from typing import List, Optional
+from typing import Optional
import pytest
from pytest_mock import MockFixture
@@ -54,7 +54,7 @@ def test_get_metrics(mocker: MockFixture) -> None:
inspector: Inspector,
table_name: str,
schema: Optional[str],
- ) -> List[MetricType]:
+ ) -> list[MetricType]:
return [
{
"expression": "COUNT(DISTINCT user_id)",
diff --git a/tests/unit_tests/pandas_postprocessing/utils.py b/tests/unit_tests/pandas_postprocessing/utils.py
index 07366b1577..fa9fa30d36 100644
--- a/tests/unit_tests/pandas_postprocessing/utils.py
+++ b/tests/unit_tests/pandas_postprocessing/utils.py
@@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import math
-from typing import Any, List, Optional
+from typing import Any, Optional
from pandas import Series
@@ -26,7 +26,7 @@ AGGREGATES_MULTIPLE = {
}
-def series_to_list(series: Series) -> List[Any]:
+def series_to_list(series: Series) -> list[Any]:
"""
Converts a `Series` to a regular list, and replaces non-numeric values to
Nones.
@@ -43,8 +43,8 @@ def series_to_list(series: Series) -> List[Any]:
def round_floats(
- floats: List[Optional[float]], precision: int
-) -> List[Optional[float]]:
+ floats: list[Optional[float]], precision: int
+) -> list[Optional[float]]:
"""
Round list of floats to certain precision
diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py
index cfe6e213b2..e00dc3166e 100644
--- a/tests/unit_tests/sql_parse_tests.py
+++ b/tests/unit_tests/sql_parse_tests.py
@@ -16,8 +16,7 @@
# under the License.
# pylint: disable=invalid-name, redefined-outer-name, unused-argument, protected-access, too-many-lines
-import unittest
-from typing import Optional, Set
+from typing import Optional
import pytest
import sqlparse
@@ -40,7 +39,7 @@ from superset.sql_parse import (
)
-def extract_tables(query: str) -> Set[Table]:
+def extract_tables(query: str) -> set[Table]:
"""
Helper function to extract tables referenced in a query.
"""
diff --git a/tests/unit_tests/tasks/test_cron_util.py b/tests/unit_tests/tasks/test_cron_util.py
index 282dc99860..5bc22273f5 100644
--- a/tests/unit_tests/tasks/test_cron_util.py
+++ b/tests/unit_tests/tasks/test_cron_util.py
@@ -14,8 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from datetime import datetime
-from typing import List
import pytest
import pytz
@@ -49,7 +47,7 @@ from superset.tasks.cron_util import cron_schedule_window
],
)
def test_cron_schedule_window_los_angeles(
- current_dttm: str, cron: str, expected: List[FakeDatetime]
+ current_dttm: str, cron: str, expected: list[FakeDatetime]
) -> None:
"""
Reports scheduler: Test cron schedule window for "America/Los_Angeles"
@@ -86,7 +84,7 @@ def test_cron_schedule_window_los_angeles(
],
)
def test_cron_schedule_window_invalid_timezone(
- current_dttm: str, cron: str, expected: List[FakeDatetime]
+ current_dttm: str, cron: str, expected: list[FakeDatetime]
) -> None:
"""
Reports scheduler: Test cron schedule window for "invalid timezone"
@@ -124,7 +122,7 @@ def test_cron_schedule_window_invalid_timezone(
],
)
def test_cron_schedule_window_new_york(
- current_dttm: str, cron: str, expected: List[FakeDatetime]
+ current_dttm: str, cron: str, expected: list[FakeDatetime]
) -> None:
"""
Reports scheduler: Test cron schedule window for "America/New_York"
@@ -161,7 +159,7 @@ def test_cron_schedule_window_new_york(
],
)
def test_cron_schedule_window_chicago(
- current_dttm: str, cron: str, expected: List[FakeDatetime]
+ current_dttm: str, cron: str, expected: list[FakeDatetime]
) -> None:
"""
Reports scheduler: Test cron schedule window for "America/Chicago"
@@ -198,7 +196,7 @@ def test_cron_schedule_window_chicago(
],
)
def test_cron_schedule_window_chicago_daylight(
- current_dttm: str, cron: str, expected: List[FakeDatetime]
+ current_dttm: str, cron: str, expected: list[FakeDatetime]
) -> None:
"""
Reports scheduler: Test cron schedule window for "America/Chicago"
diff --git a/tests/unit_tests/tasks/test_utils.py b/tests/unit_tests/tasks/test_utils.py
index 7854717201..b3fbfca8a2 100644
--- a/tests/unit_tests/tasks/test_utils.py
+++ b/tests/unit_tests/tasks/test_utils.py
@@ -18,7 +18,7 @@
from contextlib import nullcontext
from dataclasses import dataclass
from enum import Enum
-from typing import Any, Dict, List, Optional, Tuple, Type, Union
+from typing import Any, Optional, Union
import pytest
from flask_appbuilder.security.sqla.models import User
@@ -31,8 +31,8 @@ SELENIUM_USERNAME = "admin"
def _get_users(
- params: Optional[Union[int, List[int]]]
-) -> Optional[Union[User, List[User]]]:
+ params: Optional[Union[int, list[int]]]
+) -> Optional[Union[User, list[User]]]:
if params is None:
return None
if isinstance(params, int):
@@ -42,7 +42,7 @@ def _get_users(
@dataclass
class ModelConfig:
- owners: List[int]
+ owners: list[int]
creator: Optional[int] = None
modifier: Optional[int] = None
@@ -268,18 +268,18 @@ class ModelType(int, Enum):
)
def test_get_executor(
model_type: ModelType,
- executor_types: List[ExecutorType],
+ executor_types: list[ExecutorType],
model_config: ModelConfig,
current_user: Optional[int],
- expected_result: Tuple[int, ExecutorNotFoundError],
+ expected_result: tuple[int, ExecutorNotFoundError],
) -> None:
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.reports.models import ReportSchedule
from superset.tasks.utils import get_executor
- model: Type[Union[Dashboard, ReportSchedule, Slice]]
- model_kwargs: Dict[str, Any] = {}
+ model: type[Union[Dashboard, ReportSchedule, Slice]]
+ model_kwargs: dict[str, Any] = {}
if model_type == ModelType.REPORT_SCHEDULE:
model = ReportSchedule
model_kwargs = {
diff --git a/tests/unit_tests/thumbnails/test_digest.py b/tests/unit_tests/thumbnails/test_digest.py
index 04f244e629..68bd7a58f7 100644
--- a/tests/unit_tests/thumbnails/test_digest.py
+++ b/tests/unit_tests/thumbnails/test_digest.py
@@ -17,7 +17,7 @@
from __future__ import annotations
from contextlib import nullcontext
-from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union
+from typing import Any, TYPE_CHECKING
from unittest.mock import patch
import pytest
@@ -31,7 +31,7 @@ if TYPE_CHECKING:
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
-_DEFAULT_DASHBOARD_KWARGS: Dict[str, Any] = {
+_DEFAULT_DASHBOARD_KWARGS: dict[str, Any] = {
"id": 1,
"dashboard_title": "My Title",
"slices": [{"id": 1, "slice_name": "My Chart"}],
@@ -150,11 +150,11 @@ def CUSTOM_CHART_FUNC(
],
)
def test_dashboard_digest(
- dashboard_overrides: Optional[Dict[str, Any]],
- execute_as: List[ExecutorType],
+ dashboard_overrides: dict[str, Any] | None,
+ execute_as: list[ExecutorType],
has_current_user: bool,
use_custom_digest: bool,
- expected_result: Union[str, Exception],
+ expected_result: str | Exception,
) -> None:
from superset import app
from superset.models.dashboard import Dashboard
@@ -167,7 +167,7 @@ def test_dashboard_digest(
}
slices = [Slice(**slice_kwargs) for slice_kwargs in kwargs.pop("slices")]
dashboard = Dashboard(**kwargs, slices=slices)
- user: Optional[User] = None
+ user: User | None = None
if has_current_user:
user = User(id=1, username="1")
func = CUSTOM_DASHBOARD_FUNC if use_custom_digest else None
@@ -222,11 +222,11 @@ def test_dashboard_digest(
],
)
def test_chart_digest(
- chart_overrides: Optional[Dict[str, Any]],
- execute_as: List[ExecutorType],
+ chart_overrides: dict[str, Any] | None,
+ execute_as: list[ExecutorType],
has_current_user: bool,
use_custom_digest: bool,
- expected_result: Union[str, Exception],
+ expected_result: str | Exception,
) -> None:
from superset import app
from superset.models.slice import Slice
@@ -237,7 +237,7 @@ def test_chart_digest(
**(chart_overrides or {}),
}
chart = Slice(**kwargs)
- user: Optional[User] = None
+ user: User | None = None
if has_current_user:
user = User(id=1, username="1")
func = CUSTOM_CHART_FUNC if use_custom_digest else None
diff --git a/tests/unit_tests/utils/cache_test.py b/tests/unit_tests/utils/cache_test.py
index 53650e1d20..bd6179957e 100644
--- a/tests/unit_tests/utils/cache_test.py
+++ b/tests/unit_tests/utils/cache_test.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
diff --git a/tests/unit_tests/utils/date_parser_tests.py b/tests/unit_tests/utils/date_parser_tests.py
index fb0fe07d29..a2ec20901a 100644
--- a/tests/unit_tests/utils/date_parser_tests.py
+++ b/tests/unit_tests/utils/date_parser_tests.py
@@ -16,7 +16,7 @@
# under the License.
import re
from datetime import date, datetime, timedelta
-from typing import Optional, Tuple
+from typing import Optional
from unittest.mock import Mock, patch
import pytest
@@ -74,8 +74,8 @@ def mock_parse_human_datetime(s: str) -> Optional[datetime]:
@patch("superset.utils.date_parser.parse_human_datetime", mock_parse_human_datetime)
def test_get_since_until() -> None:
- result: Tuple[Optional[datetime], Optional[datetime]]
- expected: Tuple[Optional[datetime], Optional[datetime]]
+ result: tuple[Optional[datetime], Optional[datetime]]
+ expected: tuple[Optional[datetime], Optional[datetime]]
result = get_since_until()
expected = None, datetime(2016, 11, 7)
diff --git a/tests/unit_tests/utils/test_core.py b/tests/unit_tests/utils/test_core.py
index 3636983156..996bd1948f 100644
--- a/tests/unit_tests/utils/test_core.py
+++ b/tests/unit_tests/utils/test_core.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
@@ -16,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import os
-from typing import Any, Dict, Optional
+from typing import Any, Optional
import pytest
@@ -86,7 +85,7 @@ EXTRA_FILTER: QueryObjectFilterClause = {
],
)
def test_remove_extra_adhoc_filters(
- original: Dict[str, Any], expected: Dict[str, Any]
+ original: dict[str, Any], expected: dict[str, Any]
) -> None:
remove_extra_adhoc_filters(original)
assert expected == original
diff --git a/tests/unit_tests/utils/test_file.py b/tests/unit_tests/utils/test_file.py
index de20402e5c..a2168a7d92 100644
--- a/tests/unit_tests/utils/test_file.py
+++ b/tests/unit_tests/utils/test_file.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
diff --git a/tests/unit_tests/utils/urls_tests.py b/tests/unit_tests/utils/urls_tests.py
index 208d6caea4..287f346c3d 100644
--- a/tests/unit_tests/utils/urls_tests.py
+++ b/tests/unit_tests/utils/urls_tests.py
@@ -1,4 +1,3 @@
-# -*- coding: utf-8 -*-
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information