diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3f524b3658..07544d66d2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,14 +15,28 @@ # limitations under the License. # repos: + - repo: https://github.com/MarcoGorelli/auto-walrus + rev: v0.2.2 + hooks: + - id: auto-walrus + - repo: https://github.com/asottile/pyupgrade + rev: v3.4.0 + hooks: + - id: pyupgrade + args: + - --py39-plus + - repo: https://github.com/hadialqattan/pycln + rev: v2.1.2 + hooks: + - id: pycln + args: + - --disable-all-dunder-policy + - --exclude=superset/config.py + - --extend-exclude=tests/integration_tests/superset_test_config.*.py - repo: https://github.com/PyCQA/isort rev: 5.12.0 hooks: - id: isort - - repo: https://github.com/MarcoGorelli/auto-walrus - rev: v0.2.2 - hooks: - - id: auto-walrus - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.3.0 hooks: diff --git a/RELEASING/changelog.py b/RELEASING/changelog.py index 68a54e10be..d1ba06a620 100644 --- a/RELEASING/changelog.py +++ b/RELEASING/changelog.py @@ -17,8 +17,9 @@ import csv as lib_csv import os import re import sys +from collections.abc import Iterator from dataclasses import dataclass -from typing import Any, Dict, Iterator, List, Optional, Union +from typing import Any, Optional, Union import click from click.core import Context @@ -67,15 +68,15 @@ class GitChangeLog: def __init__( self, version: str, - logs: List[GitLog], + logs: list[GitLog], access_token: Optional[str] = None, risk: Optional[bool] = False, ) -> None: self._version = version self._logs = logs - self._pr_logs_with_details: Dict[int, Dict[str, Any]] = {} - self._github_login_cache: Dict[str, Optional[str]] = {} - self._github_prs: Dict[int, Any] = {} + self._pr_logs_with_details: dict[int, dict[str, Any]] = {} + self._github_login_cache: dict[str, Optional[str]] = {} + self._github_prs: dict[int, Any] = {} self._wait = 10 github_token = access_token or os.environ.get("GITHUB_TOKEN") self._github = Github(github_token) @@ -126,7 +127,7 @@ class GitChangeLog: "superset/migrations/versions/" in file.filename for file in commit.files ) - def _get_pull_request_details(self, git_log: GitLog) -> Dict[str, Any]: + def _get_pull_request_details(self, git_log: GitLog) -> dict[str, Any]: pr_number = git_log.pr_number if pr_number: detail = self._pr_logs_with_details.get(pr_number) @@ -156,7 +157,7 @@ class GitChangeLog: return detail - def _is_risk_pull_request(self, labels: List[Any]) -> bool: + def _is_risk_pull_request(self, labels: list[Any]) -> bool: for label in labels: risk_label = re.match(SUPERSET_RISKY_LABELS, label.name) if risk_label is not None: @@ -174,8 +175,8 @@ class GitChangeLog: def _parse_change_log( self, - changelog: Dict[str, str], - pr_info: Dict[str, str], + changelog: dict[str, str], + pr_info: dict[str, str], github_login: str, ) -> None: formatted_pr = ( @@ -227,7 +228,7 @@ class GitChangeLog: result += f"**{key}** {changelog[key]}\n" return result - def __iter__(self) -> Iterator[Dict[str, Any]]: + def __iter__(self) -> Iterator[dict[str, Any]]: for log in self._logs: yield { "pr_number": log.pr_number, @@ -250,20 +251,20 @@ class GitLogs: def __init__(self, git_ref: str) -> None: self._git_ref = git_ref - self._logs: List[GitLog] = [] + self._logs: list[GitLog] = [] @property def git_ref(self) -> str: return self._git_ref @property - def logs(self) -> List[GitLog]: + def logs(self) -> list[GitLog]: return self._logs def fetch(self) -> None: self._logs = list(map(self._parse_log, self._git_logs()))[::-1] - def diff(self, git_logs: "GitLogs") -> List[GitLog]: + def diff(self, git_logs: "GitLogs") -> list[GitLog]: return [log for log in git_logs.logs if log not in self._logs] def __repr__(self) -> str: @@ -284,7 +285,7 @@ class GitLogs: print(f"Could not checkout {git_ref}") sys.exit(1) - def _git_logs(self) -> List[str]: + def _git_logs(self) -> list[str]: # let's get current git ref so we can revert it back current_git_ref = self._git_get_current_head() self._git_checkout(self._git_ref) diff --git a/RELEASING/generate_email.py b/RELEASING/generate_email.py index 92536670cd..29142557c0 100755 --- a/RELEASING/generate_email.py +++ b/RELEASING/generate_email.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Any, Dict, List +from typing import Any from click.core import Context @@ -34,7 +34,7 @@ PROJECT_MODULE = "superset" PROJECT_DESCRIPTION = "Apache Superset is a modern, enterprise-ready business intelligence web application" -def string_comma_to_list(message: str) -> List[str]: +def string_comma_to_list(message: str) -> list[str]: if not message: return [] return [element.strip() for element in message.split(",")] @@ -52,7 +52,7 @@ def render_template(template_file: str, **kwargs: Any) -> str: return template.render(kwargs) -class BaseParameters(object): +class BaseParameters: def __init__( self, version: str, @@ -60,7 +60,7 @@ class BaseParameters(object): ) -> None: self.version = version self.version_rc = version_rc - self.template_arguments: Dict[str, Any] = {} + self.template_arguments: dict[str, Any] = {} def __repr__(self) -> str: return f"Apache Credentials: {self.version}/{self.version_rc}" diff --git a/docker/pythonpath_dev/superset_config.py b/docker/pythonpath_dev/superset_config.py index 6ea9abf63c..199e79f66e 100644 --- a/docker/pythonpath_dev/superset_config.py +++ b/docker/pythonpath_dev/superset_config.py @@ -22,7 +22,6 @@ # import logging import os -from datetime import timedelta from typing import Optional from cachelib.file import FileSystemCache @@ -42,7 +41,7 @@ def get_env_variable(var_name: str, default: Optional[str] = None) -> str: error_msg = "The environment variable {} was missing, abort...".format( var_name ) - raise EnvironmentError(error_msg) + raise OSError(error_msg) DATABASE_DIALECT = get_env_variable("DATABASE_DIALECT") @@ -53,7 +52,7 @@ DATABASE_PORT = get_env_variable("DATABASE_PORT") DATABASE_DB = get_env_variable("DATABASE_DB") # The SQLAlchemy connection string. -SQLALCHEMY_DATABASE_URI = "%s://%s:%s@%s:%s/%s" % ( +SQLALCHEMY_DATABASE_URI = "{}://{}:{}@{}:{}/{}".format( DATABASE_DIALECT, DATABASE_USER, DATABASE_PASSWORD, @@ -80,7 +79,7 @@ CACHE_CONFIG = { DATA_CACHE_CONFIG = CACHE_CONFIG -class CeleryConfig(object): +class CeleryConfig: broker_url = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_CELERY_DB}" imports = ("superset.sql_lab",) result_backend = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_RESULTS_DB}" diff --git a/scripts/benchmark_migration.py b/scripts/benchmark_migration.py index 83c06456a1..466fab6f13 100644 --- a/scripts/benchmark_migration.py +++ b/scripts/benchmark_migration.py @@ -23,7 +23,7 @@ from graphlib import TopologicalSorter from inspect import getsource from pathlib import Path from types import ModuleType -from typing import Any, Dict, List, Set, Type +from typing import Any import click from flask import current_app @@ -48,12 +48,10 @@ def import_migration_script(filepath: Path) -> ModuleType: module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) # type: ignore return module - raise Exception( - "No module spec found in location: `{path}`".format(path=str(filepath)) - ) + raise Exception(f"No module spec found in location: `{str(filepath)}`") -def extract_modified_tables(module: ModuleType) -> Set[str]: +def extract_modified_tables(module: ModuleType) -> set[str]: """ Extract the tables being modified by a migration script. @@ -62,7 +60,7 @@ def extract_modified_tables(module: ModuleType) -> Set[str]: actually traversing the AST. """ - tables: Set[str] = set() + tables: set[str] = set() for function in {"upgrade", "downgrade"}: source = getsource(getattr(module, function)) tables.update(re.findall(r'alter_table\(\s*"(\w+?)"\s*\)', source, re.DOTALL)) @@ -72,11 +70,11 @@ def extract_modified_tables(module: ModuleType) -> Set[str]: return tables -def find_models(module: ModuleType) -> List[Type[Model]]: +def find_models(module: ModuleType) -> list[type[Model]]: """ Find all models in a migration script. """ - models: List[Type[Model]] = [] + models: list[type[Model]] = [] tables = extract_modified_tables(module) # add models defined explicitly in the migration script @@ -123,7 +121,7 @@ def find_models(module: ModuleType) -> List[Type[Model]]: sorter: TopologicalSorter[Any] = TopologicalSorter() for model in models: inspector = inspect(model) - dependent_tables: List[str] = [] + dependent_tables: list[str] = [] for column in inspector.columns.values(): for foreign_key in column.foreign_keys: if foreign_key.column.table.name != model.__tablename__: @@ -174,7 +172,7 @@ def main( print("\nIdentifying models used in the migration:") models = find_models(module) - model_rows: Dict[Type[Model], int] = {} + model_rows: dict[type[Model], int] = {} for model in models: rows = session.query(model).count() print(f"- {model.__name__} ({rows} rows in table {model.__tablename__})") @@ -182,7 +180,7 @@ def main( session.close() print("Benchmarking migration") - results: Dict[str, float] = {} + results: dict[str, float] = {} start = time.time() upgrade(revision=revision) duration = time.time() - start @@ -190,14 +188,14 @@ def main( print(f"Migration on current DB took: {duration:.2f} seconds") min_entities = 10 - new_models: Dict[Type[Model], List[Model]] = defaultdict(list) + new_models: dict[type[Model], list[Model]] = defaultdict(list) while min_entities <= limit: downgrade(revision=down_revision) print(f"Running with at least {min_entities} entities of each model") for model in models: missing = min_entities - model_rows[model] if missing > 0: - entities: List[Model] = [] + entities: list[Model] = [] print(f"- Adding {missing} entities to the {model.__name__} model") bar = ChargingBar("Processing", max=missing) try: diff --git a/scripts/cancel_github_workflows.py b/scripts/cancel_github_workflows.py index 4d30d34adf..70744c2954 100755 --- a/scripts/cancel_github_workflows.py +++ b/scripts/cancel_github_workflows.py @@ -33,13 +33,13 @@ Example: ./cancel_github_workflows.py 1024 --include-last """ import os -from typing import Any, Dict, Iterable, Iterator, List, Optional, Union +from collections.abc import Iterable, Iterator +from typing import Any, Literal, Optional, Union import click import requests from click.exceptions import ClickException from dateutil import parser -from typing_extensions import Literal github_token = os.environ.get("GITHUB_TOKEN") github_repo = os.environ.get("GITHUB_REPOSITORY", "apache/superset") @@ -47,7 +47,7 @@ github_repo = os.environ.get("GITHUB_REPOSITORY", "apache/superset") def request( method: Literal["GET", "POST", "DELETE", "PUT"], endpoint: str, **kwargs: Any -) -> Dict[str, Any]: +) -> dict[str, Any]: resp = requests.request( method, f"https://api.github.com/{endpoint.lstrip('/')}", @@ -61,8 +61,8 @@ def request( def list_runs( repo: str, - params: Optional[Dict[str, str]] = None, -) -> Iterator[Dict[str, Any]]: + params: Optional[dict[str, str]] = None, +) -> Iterator[dict[str, Any]]: """List all github workflow runs. Returns: An iterator that will iterate through all pages of matching runs.""" @@ -77,16 +77,15 @@ def list_runs( params={**params, "per_page": 100, "page": page}, ) total_count = result["total_count"] - for item in result["workflow_runs"]: - yield item + yield from result["workflow_runs"] page += 1 -def cancel_run(repo: str, run_id: Union[str, int]) -> Dict[str, Any]: +def cancel_run(repo: str, run_id: Union[str, int]) -> dict[str, Any]: return request("POST", f"/repos/{repo}/actions/runs/{run_id}/cancel") -def get_pull_request(repo: str, pull_number: Union[str, int]) -> Dict[str, Any]: +def get_pull_request(repo: str, pull_number: Union[str, int]) -> dict[str, Any]: return request("GET", f"/repos/{repo}/pulls/{pull_number}") @@ -96,7 +95,7 @@ def get_runs( user: Optional[str] = None, statuses: Iterable[str] = ("queued", "in_progress"), events: Iterable[str] = ("pull_request", "push"), -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: """Get workflow runs associated with the given branch""" return [ item @@ -108,7 +107,7 @@ def get_runs( ] -def print_commit(commit: Dict[str, Any], branch: str) -> None: +def print_commit(commit: dict[str, Any], branch: str) -> None: """Print out commit message for verification""" indented_message = " \n".join(commit["message"].split("\n")) date_str = ( @@ -155,7 +154,7 @@ Date: {date_str} def cancel_github_workflows( branch_or_pull: Optional[str], repo: str, - event: List[str], + event: list[str], include_last: bool, include_running: bool, ) -> None: diff --git a/scripts/permissions_cleanup.py b/scripts/permissions_cleanup.py index 5ca75e394c..0416f55806 100644 --- a/scripts/permissions_cleanup.py +++ b/scripts/permissions_cleanup.py @@ -24,7 +24,7 @@ def cleanup_permissions() -> None: pvms = security_manager.get_session.query( security_manager.permissionview_model ).all() - print("# of permission view menus is: {}".format(len(pvms))) + print(f"# of permission view menus is: {len(pvms)}") pvms_dict = defaultdict(list) for pvm in pvms: pvms_dict[(pvm.permission, pvm.view_menu)].append(pvm) @@ -43,7 +43,7 @@ def cleanup_permissions() -> None: pvms = security_manager.get_session.query( security_manager.permissionview_model ).all() - print("Stage 1: # of permission view menus is: {}".format(len(pvms))) + print(f"Stage 1: # of permission view menus is: {len(pvms)}") # 2. Clean up None permissions or view menus pvms = security_manager.get_session.query( @@ -57,7 +57,7 @@ def cleanup_permissions() -> None: pvms = security_manager.get_session.query( security_manager.permissionview_model ).all() - print("Stage 2: # of permission view menus is: {}".format(len(pvms))) + print(f"Stage 2: # of permission view menus is: {len(pvms)}") # 3. Delete empty permission view menus from roles roles = security_manager.get_session.query(security_manager.role_model).all() diff --git a/setup.py b/setup.py index 41f7e11e38..d8adea3285 100644 --- a/setup.py +++ b/setup.py @@ -14,21 +14,19 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import io import json import os import subprocess -import sys from setuptools import find_packages, setup BASE_DIR = os.path.abspath(os.path.dirname(__file__)) PACKAGE_JSON = os.path.join(BASE_DIR, "superset-frontend", "package.json") -with open(PACKAGE_JSON, "r") as package_file: +with open(PACKAGE_JSON) as package_file: version_string = json.load(package_file)["version"] -with io.open("README.md", "r", encoding="utf-8") as f: +with open("README.md", encoding="utf-8") as f: long_description = f.read() diff --git a/superset/advanced_data_type/plugins/internet_address.py b/superset/advanced_data_type/plugins/internet_address.py index 08a0925846..8ab20fe2d0 100644 --- a/superset/advanced_data_type/plugins/internet_address.py +++ b/superset/advanced_data_type/plugins/internet_address.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import ipaddress -from typing import Any, List +from typing import Any from sqlalchemy import Column @@ -77,7 +77,7 @@ def cidr_func(req: AdvancedDataTypeRequest) -> AdvancedDataTypeResponse: # Make this return a single clause def cidr_translate_filter_func( - col: Column, operator: FilterOperator, values: List[Any] + col: Column, operator: FilterOperator, values: list[Any] ) -> Any: """ Convert a passed in column, FilterOperator and diff --git a/superset/advanced_data_type/plugins/internet_port.py b/superset/advanced_data_type/plugins/internet_port.py index 60a594bfd9..8983e41422 100644 --- a/superset/advanced_data_type/plugins/internet_port.py +++ b/superset/advanced_data_type/plugins/internet_port.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import itertools -from typing import Any, Dict, List +from typing import Any from sqlalchemy import Column @@ -26,7 +26,7 @@ from superset.advanced_data_type.types import ( ) from superset.utils.core import FilterOperator, FilterStringOperators -port_conversion_dict: Dict[str, List[int]] = { +port_conversion_dict: dict[str, list[int]] = { "http": [80], "ssh": [22], "https": [443], @@ -100,7 +100,7 @@ def port_translation_func(req: AdvancedDataTypeRequest) -> AdvancedDataTypeRespo def port_translate_filter_func( - col: Column, operator: FilterOperator, values: List[Any] + col: Column, operator: FilterOperator, values: list[Any] ) -> Any: """ Convert a passed in column, FilterOperator diff --git a/superset/advanced_data_type/types.py b/superset/advanced_data_type/types.py index 316922f339..e8d5de9143 100644 --- a/superset/advanced_data_type/types.py +++ b/superset/advanced_data_type/types.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from dataclasses import dataclass -from typing import Any, Callable, List, Optional, TypedDict, Union +from typing import Any, Callable, Optional, TypedDict, Union from sqlalchemy import Column from sqlalchemy.sql.expression import BinaryExpression @@ -30,7 +30,7 @@ class AdvancedDataTypeRequest(TypedDict): """ advanced_data_type: str - values: List[ + values: list[ Union[FilterValues, None] ] # unparsed value (usually text when passed from text box) @@ -41,9 +41,9 @@ class AdvancedDataTypeResponse(TypedDict, total=False): """ error_message: Optional[str] - values: List[Any] # parsed value (can be any value) + values: list[Any] # parsed value (can be any value) display_value: str # The string representation of the parsed values - valid_filter_operators: List[FilterStringOperators] + valid_filter_operators: list[FilterStringOperators] @dataclass @@ -54,6 +54,6 @@ class AdvancedDataType: verbose_name: str description: str - valid_data_types: List[str] + valid_data_types: list[str] translate_type: Callable[[AdvancedDataTypeRequest], AdvancedDataTypeResponse] translate_filter: Callable[[Column, FilterOperator, Any], BinaryExpression] diff --git a/superset/annotation_layers/annotations/api.py b/superset/annotation_layers/annotations/api.py index 0a6a2767f0..70e0a1ad02 100644 --- a/superset/annotation_layers/annotations/api.py +++ b/superset/annotation_layers/annotations/api.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict +from typing import Any from flask import request, Response from flask_appbuilder.api import expose, permission_name, protect, rison, safe @@ -127,7 +127,7 @@ class AnnotationRestApi(BaseSupersetModelRestApi): @staticmethod def _apply_layered_relation_to_rison( # pylint: disable=invalid-name - layer_id: int, rison_parameters: Dict[str, Any] + layer_id: int, rison_parameters: dict[str, Any] ) -> None: if "filters" not in rison_parameters: rison_parameters["filters"] = [] diff --git a/superset/annotation_layers/annotations/commands/bulk_delete.py b/superset/annotation_layers/annotations/commands/bulk_delete.py index 113725050f..dd47047788 100644 --- a/superset/annotation_layers/annotations/commands/bulk_delete.py +++ b/superset/annotation_layers/annotations/commands/bulk_delete.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import List, Optional +from typing import Optional from superset.annotation_layers.annotations.commands.exceptions import ( AnnotationBulkDeleteFailedError, @@ -30,9 +30,9 @@ logger = logging.getLogger(__name__) class BulkDeleteAnnotationCommand(BaseCommand): - def __init__(self, model_ids: List[int]): + def __init__(self, model_ids: list[int]): self._model_ids = model_ids - self._models: Optional[List[Annotation]] = None + self._models: Optional[list[Annotation]] = None def run(self) -> None: self.validate() diff --git a/superset/annotation_layers/annotations/commands/create.py b/superset/annotation_layers/annotations/commands/create.py index 0974624561..986b564291 100644 --- a/superset/annotation_layers/annotations/commands/create.py +++ b/superset/annotation_layers/annotations/commands/create.py @@ -16,7 +16,7 @@ # under the License. import logging from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask_appbuilder.models.sqla import Model from marshmallow import ValidationError @@ -37,7 +37,7 @@ logger = logging.getLogger(__name__) class CreateAnnotationCommand(BaseCommand): - def __init__(self, data: Dict[str, Any]): + def __init__(self, data: dict[str, Any]): self._properties = data.copy() def run(self) -> Model: @@ -50,7 +50,7 @@ class CreateAnnotationCommand(BaseCommand): return annotation def validate(self) -> None: - exceptions: List[ValidationError] = [] + exceptions: list[ValidationError] = [] layer_id: Optional[int] = self._properties.get("layer") start_dttm: Optional[datetime] = self._properties.get("start_dttm") end_dttm: Optional[datetime] = self._properties.get("end_dttm") diff --git a/superset/annotation_layers/annotations/commands/update.py b/superset/annotation_layers/annotations/commands/update.py index b644ddc362..99ab209165 100644 --- a/superset/annotation_layers/annotations/commands/update.py +++ b/superset/annotation_layers/annotations/commands/update.py @@ -16,7 +16,7 @@ # under the License. import logging from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask_appbuilder.models.sqla import Model from marshmallow import ValidationError @@ -39,7 +39,7 @@ logger = logging.getLogger(__name__) class UpdateAnnotationCommand(BaseCommand): - def __init__(self, model_id: int, data: Dict[str, Any]): + def __init__(self, model_id: int, data: dict[str, Any]): self._model_id = model_id self._properties = data.copy() self._model: Optional[Annotation] = None @@ -54,7 +54,7 @@ class UpdateAnnotationCommand(BaseCommand): return annotation def validate(self) -> None: - exceptions: List[ValidationError] = [] + exceptions: list[ValidationError] = [] layer_id: Optional[int] = self._properties.get("layer") short_descr: str = self._properties.get("short_descr", "") diff --git a/superset/annotation_layers/annotations/dao.py b/superset/annotation_layers/annotations/dao.py index 0c8a9e47c5..da69e576e5 100644 --- a/superset/annotation_layers/annotations/dao.py +++ b/superset/annotation_layers/annotations/dao.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import List, Optional +from typing import Optional from sqlalchemy.exc import SQLAlchemyError @@ -31,7 +31,7 @@ class AnnotationDAO(BaseDAO): model_cls = Annotation @staticmethod - def bulk_delete(models: Optional[List[Annotation]], commit: bool = True) -> None: + def bulk_delete(models: Optional[list[Annotation]], commit: bool = True) -> None: item_ids = [model.id for model in models] if models else [] try: db.session.query(Annotation).filter(Annotation.id.in_(item_ids)).delete( diff --git a/superset/annotation_layers/commands/bulk_delete.py b/superset/annotation_layers/commands/bulk_delete.py index b9bc17e82f..4910dc4275 100644 --- a/superset/annotation_layers/commands/bulk_delete.py +++ b/superset/annotation_layers/commands/bulk_delete.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import List, Optional +from typing import Optional from superset.annotation_layers.commands.exceptions import ( AnnotationLayerBulkDeleteFailedError, @@ -31,9 +31,9 @@ logger = logging.getLogger(__name__) class BulkDeleteAnnotationLayerCommand(BaseCommand): - def __init__(self, model_ids: List[int]): + def __init__(self, model_ids: list[int]): self._model_ids = model_ids - self._models: Optional[List[AnnotationLayer]] = None + self._models: Optional[list[AnnotationLayer]] = None def run(self) -> None: self.validate() diff --git a/superset/annotation_layers/commands/create.py b/superset/annotation_layers/commands/create.py index 97431568a9..86b0cb3b85 100644 --- a/superset/annotation_layers/commands/create.py +++ b/superset/annotation_layers/commands/create.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, List +from typing import Any from flask_appbuilder.models.sqla import Model from marshmallow import ValidationError @@ -33,7 +33,7 @@ logger = logging.getLogger(__name__) class CreateAnnotationLayerCommand(BaseCommand): - def __init__(self, data: Dict[str, Any]): + def __init__(self, data: dict[str, Any]): self._properties = data.copy() def run(self) -> Model: @@ -46,7 +46,7 @@ class CreateAnnotationLayerCommand(BaseCommand): return annotation_layer def validate(self) -> None: - exceptions: List[ValidationError] = [] + exceptions: list[ValidationError] = [] name = self._properties.get("name", "") diff --git a/superset/annotation_layers/commands/update.py b/superset/annotation_layers/commands/update.py index 4a9cc31be5..67d869c005 100644 --- a/superset/annotation_layers/commands/update.py +++ b/superset/annotation_layers/commands/update.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask_appbuilder.models.sqla import Model from marshmallow import ValidationError @@ -35,7 +35,7 @@ logger = logging.getLogger(__name__) class UpdateAnnotationLayerCommand(BaseCommand): - def __init__(self, model_id: int, data: Dict[str, Any]): + def __init__(self, model_id: int, data: dict[str, Any]): self._model_id = model_id self._properties = data.copy() self._model: Optional[AnnotationLayer] = None @@ -50,7 +50,7 @@ class UpdateAnnotationLayerCommand(BaseCommand): return annotation_layer def validate(self) -> None: - exceptions: List[ValidationError] = [] + exceptions: list[ValidationError] = [] name = self._properties.get("name", "") self._model = AnnotationLayerDAO.find_by_id(self._model_id) diff --git a/superset/annotation_layers/dao.py b/superset/annotation_layers/dao.py index d9db4b582d..67efc19f88 100644 --- a/superset/annotation_layers/dao.py +++ b/superset/annotation_layers/dao.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import List, Optional, Union +from typing import Optional, Union from sqlalchemy.exc import SQLAlchemyError @@ -32,7 +32,7 @@ class AnnotationLayerDAO(BaseDAO): @staticmethod def bulk_delete( - models: Optional[List[AnnotationLayer]], commit: bool = True + models: Optional[list[AnnotationLayer]], commit: bool = True ) -> None: item_ids = [model.id for model in models] if models else [] try: @@ -46,7 +46,7 @@ class AnnotationLayerDAO(BaseDAO): raise DAODeleteFailedError() from ex @staticmethod - def has_annotations(model_id: Union[int, List[int]]) -> bool: + def has_annotations(model_id: Union[int, list[int]]) -> bool: if isinstance(model_id, list): return ( db.session.query(AnnotationLayer) diff --git a/superset/charts/commands/bulk_delete.py b/superset/charts/commands/bulk_delete.py index c252f0be4c..ac801b7421 100644 --- a/superset/charts/commands/bulk_delete.py +++ b/superset/charts/commands/bulk_delete.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import List, Optional +from typing import Optional from flask_babel import lazy_gettext as _ @@ -37,9 +37,9 @@ logger = logging.getLogger(__name__) class BulkDeleteChartCommand(BaseCommand): - def __init__(self, model_ids: List[int]): + def __init__(self, model_ids: list[int]): self._model_ids = model_ids - self._models: Optional[List[Slice]] = None + self._models: Optional[list[Slice]] = None def run(self) -> None: self.validate() diff --git a/superset/charts/commands/create.py b/superset/charts/commands/create.py index 38076fb9cd..78706b3a66 100644 --- a/superset/charts/commands/create.py +++ b/superset/charts/commands/create.py @@ -16,7 +16,7 @@ # under the License. import logging from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask import g from flask_appbuilder.models.sqla import Model @@ -37,7 +37,7 @@ logger = logging.getLogger(__name__) class CreateChartCommand(CreateMixin, BaseCommand): - def __init__(self, data: Dict[str, Any]): + def __init__(self, data: dict[str, Any]): self._properties = data.copy() def run(self) -> Model: @@ -56,7 +56,7 @@ class CreateChartCommand(CreateMixin, BaseCommand): datasource_type = self._properties["datasource_type"] datasource_id = self._properties["datasource_id"] dashboard_ids = self._properties.get("dashboards", []) - owner_ids: Optional[List[int]] = self._properties.get("owners") + owner_ids: Optional[list[int]] = self._properties.get("owners") # Validate/Populate datasource try: diff --git a/superset/charts/commands/export.py b/superset/charts/commands/export.py index 9d445cb54e..22310ade99 100644 --- a/superset/charts/commands/export.py +++ b/superset/charts/commands/export.py @@ -18,7 +18,7 @@ import json import logging -from typing import Iterator, Tuple +from collections.abc import Iterator import yaml @@ -42,7 +42,7 @@ class ExportChartsCommand(ExportModelsCommand): not_found = ChartNotFoundError @staticmethod - def _export(model: Slice, export_related: bool = True) -> Iterator[Tuple[str, str]]: + def _export(model: Slice, export_related: bool = True) -> Iterator[tuple[str, str]]: file_name = get_filename(model.slice_name, model.id) file_path = f"charts/{file_name}.yaml" diff --git a/superset/charts/commands/importers/dispatcher.py b/superset/charts/commands/importers/dispatcher.py index afeb9c2820..fb5007a50c 100644 --- a/superset/charts/commands/importers/dispatcher.py +++ b/superset/charts/commands/importers/dispatcher.py @@ -16,7 +16,7 @@ # under the License. import logging -from typing import Any, Dict +from typing import Any from marshmallow.exceptions import ValidationError @@ -40,7 +40,7 @@ class ImportChartsCommand(BaseCommand): until it finds one that matches. """ - def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any): + def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any): self.contents = contents self.args = args self.kwargs = kwargs diff --git a/superset/charts/commands/importers/v1/__init__.py b/superset/charts/commands/importers/v1/__init__.py index ab88038aaa..132df21b08 100644 --- a/superset/charts/commands/importers/v1/__init__.py +++ b/superset/charts/commands/importers/v1/__init__.py @@ -15,8 +15,7 @@ # specific language governing permissions and limitations # under the License. -import json -from typing import Any, Dict, Set +from typing import Any from marshmallow import Schema from sqlalchemy.orm import Session @@ -40,7 +39,7 @@ class ImportChartsCommand(ImportModelsCommand): dao = ChartDAO model_name = "chart" prefix = "charts/" - schemas: Dict[str, Schema] = { + schemas: dict[str, Schema] = { "charts/": ImportV1ChartSchema(), "datasets/": ImportV1DatasetSchema(), "databases/": ImportV1DatabaseSchema(), @@ -49,29 +48,29 @@ class ImportChartsCommand(ImportModelsCommand): @staticmethod def _import( - session: Session, configs: Dict[str, Any], overwrite: bool = False + session: Session, configs: dict[str, Any], overwrite: bool = False ) -> None: # discover datasets associated with charts - dataset_uuids: Set[str] = set() + dataset_uuids: set[str] = set() for file_name, config in configs.items(): if file_name.startswith("charts/"): dataset_uuids.add(config["dataset_uuid"]) # discover databases associated with datasets - database_uuids: Set[str] = set() + database_uuids: set[str] = set() for file_name, config in configs.items(): if file_name.startswith("datasets/") and config["uuid"] in dataset_uuids: database_uuids.add(config["database_uuid"]) # import related databases - database_ids: Dict[str, int] = {} + database_ids: dict[str, int] = {} for file_name, config in configs.items(): if file_name.startswith("databases/") and config["uuid"] in database_uuids: database = import_database(session, config, overwrite=False) database_ids[str(database.uuid)] = database.id # import datasets with the correct parent ref - datasets: Dict[str, SqlaTable] = {} + datasets: dict[str, SqlaTable] = {} for file_name, config in configs.items(): if ( file_name.startswith("datasets/") diff --git a/superset/charts/commands/importers/v1/utils.py b/superset/charts/commands/importers/v1/utils.py index d4aeb17a1e..399e6c2243 100644 --- a/superset/charts/commands/importers/v1/utils.py +++ b/superset/charts/commands/importers/v1/utils.py @@ -16,7 +16,7 @@ # under the License. import json -from typing import Any, Dict +from typing import Any from flask import g from sqlalchemy.orm import Session @@ -28,7 +28,7 @@ from superset.models.slice import Slice def import_chart( session: Session, - config: Dict[str, Any], + config: dict[str, Any], overwrite: bool = False, ignore_permissions: bool = False, ) -> Slice: diff --git a/superset/charts/commands/update.py b/superset/charts/commands/update.py index f5fc2616a5..a4265d0835 100644 --- a/superset/charts/commands/update.py +++ b/superset/charts/commands/update.py @@ -16,7 +16,7 @@ # under the License. import logging from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask import g from flask_appbuilder.models.sqla import Model @@ -42,14 +42,14 @@ from superset.models.slice import Slice logger = logging.getLogger(__name__) -def is_query_context_update(properties: Dict[str, Any]) -> bool: +def is_query_context_update(properties: dict[str, Any]) -> bool: return set(properties) == {"query_context", "query_context_generation"} and bool( properties.get("query_context_generation") ) class UpdateChartCommand(UpdateMixin, BaseCommand): - def __init__(self, model_id: int, data: Dict[str, Any]): + def __init__(self, model_id: int, data: dict[str, Any]): self._model_id = model_id self._properties = data.copy() self._model: Optional[Slice] = None @@ -67,9 +67,9 @@ class UpdateChartCommand(UpdateMixin, BaseCommand): return chart def validate(self) -> None: - exceptions: List[ValidationError] = [] + exceptions: list[ValidationError] = [] dashboard_ids = self._properties.get("dashboards") - owner_ids: Optional[List[int]] = self._properties.get("owners") + owner_ids: Optional[list[int]] = self._properties.get("owners") # Validate if datasource_id is provided datasource_type is required datasource_id = self._properties.get("datasource_id") diff --git a/superset/charts/dao.py b/superset/charts/dao.py index 7102e6ad23..9c6b2c26ef 100644 --- a/superset/charts/dao.py +++ b/superset/charts/dao.py @@ -17,7 +17,7 @@ # pylint: disable=arguments-renamed import logging from datetime import datetime -from typing import List, Optional, TYPE_CHECKING +from typing import Optional, TYPE_CHECKING from sqlalchemy.exc import SQLAlchemyError @@ -39,7 +39,7 @@ class ChartDAO(BaseDAO): base_filter = ChartFilter @staticmethod - def bulk_delete(models: Optional[List[Slice]], commit: bool = True) -> None: + def bulk_delete(models: Optional[list[Slice]], commit: bool = True) -> None: item_ids = [model.id for model in models] if models else [] # bulk delete, first delete related data if models: @@ -71,7 +71,7 @@ class ChartDAO(BaseDAO): db.session.commit() @staticmethod - def favorited_ids(charts: List[Slice]) -> List[FavStar]: + def favorited_ids(charts: list[Slice]) -> list[FavStar]: ids = [chart.id for chart in charts] return [ star.obj_id diff --git a/superset/charts/data/api.py b/superset/charts/data/api.py index 9c620dcf5d..552044ebfa 100644 --- a/superset/charts/data/api.py +++ b/superset/charts/data/api.py @@ -18,7 +18,7 @@ from __future__ import annotations import json import logging -from typing import Any, Dict, Optional, TYPE_CHECKING, Union +from typing import Any, TYPE_CHECKING import simplejson from flask import current_app, g, make_response, request, Response @@ -315,7 +315,7 @@ class ChartDataRestApi(ChartRestApi): return self._get_data_response(command, True) def _run_async( - self, form_data: Dict[str, Any], command: ChartDataCommand + self, form_data: dict[str, Any], command: ChartDataCommand ) -> Response: """ Execute command as an async query. @@ -344,9 +344,9 @@ class ChartDataRestApi(ChartRestApi): def _send_chart_response( self, - result: Dict[Any, Any], - form_data: Optional[Dict[str, Any]] = None, - datasource: Optional[Union[BaseDatasource, Query]] = None, + result: dict[Any, Any], + form_data: dict[str, Any] | None = None, + datasource: BaseDatasource | Query | None = None, ) -> Response: result_type = result["query_context"].result_type result_format = result["query_context"].result_format @@ -408,8 +408,8 @@ class ChartDataRestApi(ChartRestApi): self, command: ChartDataCommand, force_cached: bool = False, - form_data: Optional[Dict[str, Any]] = None, - datasource: Optional[Union[BaseDatasource, Query]] = None, + form_data: dict[str, Any] | None = None, + datasource: BaseDatasource | Query | None = None, ) -> Response: try: result = command.run(force_cached=force_cached) @@ -421,12 +421,12 @@ class ChartDataRestApi(ChartRestApi): return self._send_chart_response(result, form_data, datasource) # pylint: disable=invalid-name, no-self-use - def _load_query_context_form_from_cache(self, cache_key: str) -> Dict[str, Any]: + def _load_query_context_form_from_cache(self, cache_key: str) -> dict[str, Any]: return QueryContextCacheLoader.load(cache_key) # pylint: disable=no-self-use def _create_query_context_from_form( - self, form_data: Dict[str, Any] + self, form_data: dict[str, Any] ) -> QueryContext: try: return ChartDataQueryContextSchema().load(form_data) diff --git a/superset/charts/data/commands/create_async_job_command.py b/superset/charts/data/commands/create_async_job_command.py index c4e25f742b..fb6e3f3dbf 100644 --- a/superset/charts/data/commands/create_async_job_command.py +++ b/superset/charts/data/commands/create_async_job_command.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, Optional +from typing import Any, Optional from flask import Request @@ -32,7 +32,7 @@ class CreateAsyncChartDataJobCommand: jwt_data = async_query_manager.parse_jwt_from_request(request) self._async_channel_id = jwt_data["channel"] - def run(self, form_data: Dict[str, Any], user_id: Optional[int]) -> Dict[str, Any]: + def run(self, form_data: dict[str, Any], user_id: Optional[int]) -> dict[str, Any]: job_metadata = async_query_manager.init_job(self._async_channel_id, user_id) load_chart_data_into_cache.delay(job_metadata, form_data) return job_metadata diff --git a/superset/charts/data/commands/get_data_command.py b/superset/charts/data/commands/get_data_command.py index 819693607b..a84870a1dd 100644 --- a/superset/charts/data/commands/get_data_command.py +++ b/superset/charts/data/commands/get_data_command.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict +from typing import Any from flask_babel import lazy_gettext as _ @@ -36,7 +36,7 @@ class ChartDataCommand(BaseCommand): def __init__(self, query_context: QueryContext): self._query_context = query_context - def run(self, **kwargs: Any) -> Dict[str, Any]: + def run(self, **kwargs: Any) -> dict[str, Any]: # caching is handled in query_context.get_df_payload # (also evals `force` property) cache_query_context = kwargs.get("cache", False) diff --git a/superset/charts/data/query_context_cache_loader.py b/superset/charts/data/query_context_cache_loader.py index b5ff3bdae8..97fa733a3e 100644 --- a/superset/charts/data/query_context_cache_loader.py +++ b/superset/charts/data/query_context_cache_loader.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict +from typing import Any from superset import cache from superset.charts.commands.exceptions import ChartDataCacheLoadError @@ -22,7 +22,7 @@ from superset.charts.commands.exceptions import ChartDataCacheLoadError class QueryContextCacheLoader: # pylint: disable=too-few-public-methods @staticmethod - def load(cache_key: str) -> Dict[str, Any]: + def load(cache_key: str) -> dict[str, Any]: cache_value = cache.get(cache_key) if not cache_value: raise ChartDataCacheLoadError("Cached data not found") diff --git a/superset/charts/post_processing.py b/superset/charts/post_processing.py index 1165769fc8..a6b64c08d6 100644 --- a/superset/charts/post_processing.py +++ b/superset/charts/post_processing.py @@ -27,7 +27,7 @@ for these chart types. """ from io import StringIO -from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, Optional, TYPE_CHECKING, Union import pandas as pd from flask_babel import gettext as __ @@ -45,14 +45,14 @@ if TYPE_CHECKING: from superset.models.sql_lab import Query -def get_column_key(label: Tuple[str, ...], metrics: List[str]) -> Tuple[Any, ...]: +def get_column_key(label: tuple[str, ...], metrics: list[str]) -> tuple[Any, ...]: """ Sort columns when combining metrics. MultiIndex labels have the metric name as the last element in the tuple. We want to sort these according to the list of passed metrics. """ - parts: List[Any] = list(label) + parts: list[Any] = list(label) metric = parts[-1] parts[-1] = metrics.index(metric) return tuple(parts) @@ -60,9 +60,9 @@ def get_column_key(label: Tuple[str, ...], metrics: List[str]) -> Tuple[Any, ... def pivot_df( # pylint: disable=too-many-locals, too-many-arguments, too-many-statements, too-many-branches df: pd.DataFrame, - rows: List[str], - columns: List[str], - metrics: List[str], + rows: list[str], + columns: list[str], + metrics: list[str], aggfunc: str = "Sum", transpose_pivot: bool = False, combine_metrics: bool = False, @@ -194,7 +194,7 @@ def list_unique_values(series: pd.Series) -> str: """ List unique values in a series. """ - return ", ".join(set(str(v) for v in pd.Series.unique(series))) + return ", ".join({str(v) for v in pd.Series.unique(series)}) pivot_v2_aggfunc_map = { @@ -223,7 +223,7 @@ pivot_v2_aggfunc_map = { def pivot_table_v2( df: pd.DataFrame, - form_data: Dict[str, Any], + form_data: dict[str, Any], datasource: Optional[Union["BaseDatasource", "Query"]] = None, ) -> pd.DataFrame: """ @@ -249,7 +249,7 @@ def pivot_table_v2( def pivot_table( df: pd.DataFrame, - form_data: Dict[str, Any], + form_data: dict[str, Any], datasource: Optional[Union["BaseDatasource", "Query"]] = None, ) -> pd.DataFrame: """ @@ -285,7 +285,7 @@ def pivot_table( def table( df: pd.DataFrame, - form_data: Dict[str, Any], + form_data: dict[str, Any], datasource: Optional[ # pylint: disable=unused-argument Union["BaseDatasource", "Query"] ] = None, @@ -315,10 +315,10 @@ post_processors = { def apply_post_process( - result: Dict[Any, Any], - form_data: Optional[Dict[str, Any]] = None, + result: dict[Any, Any], + form_data: Optional[dict[str, Any]] = None, datasource: Optional[Union["BaseDatasource", "Query"]] = None, -) -> Dict[Any, Any]: +) -> dict[Any, Any]: form_data = form_data or {} viz_type = form_data.get("viz_type") diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 44252ef06f..373600cd08 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -18,7 +18,7 @@ from __future__ import annotations import inspect -from typing import Any, Dict, Optional, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from flask_babel import gettext as _ from marshmallow import EXCLUDE, fields, post_load, Schema, validate @@ -1383,7 +1383,7 @@ class ChartDataQueryObjectSchema(Schema): class ChartDataQueryContextSchema(Schema): - query_context_factory: Optional[QueryContextFactory] = None + query_context_factory: QueryContextFactory | None = None datasource = fields.Nested(ChartDataDatasourceSchema) queries = fields.List(fields.Nested(ChartDataQueryObjectSchema)) custom_cache_timeout = fields.Integer( @@ -1407,7 +1407,7 @@ class ChartDataQueryContextSchema(Schema): # pylint: disable=unused-argument @post_load - def make_query_context(self, data: Dict[str, Any], **kwargs: Any) -> QueryContext: + def make_query_context(self, data: dict[str, Any], **kwargs: Any) -> QueryContext: query_context = self.get_query_context_factory().create(**data) return query_context diff --git a/superset/cli/importexport.py b/superset/cli/importexport.py index c7689569c2..86f6fe9b67 100755 --- a/superset/cli/importexport.py +++ b/superset/cli/importexport.py @@ -18,7 +18,7 @@ import logging import sys from datetime import datetime from pathlib import Path -from typing import List, Optional +from typing import Optional from zipfile import is_zipfile, ZipFile import click @@ -309,7 +309,7 @@ else: from superset.dashboards.commands.importers.v0 import ImportDashboardsCommand path_object = Path(path) - files: List[Path] = [] + files: list[Path] = [] if path_object.is_file(): files.append(path_object) elif path_object.exists() and not recursive: @@ -363,7 +363,7 @@ else: sync_metrics = "metrics" in sync_array path_object = Path(path) - files: List[Path] = [] + files: list[Path] = [] if path_object.is_file(): files.append(path_object) elif path_object.exists() and not recursive: diff --git a/superset/cli/main.py b/superset/cli/main.py index 006f8eb5c9..536617cadd 100755 --- a/superset/cli/main.py +++ b/superset/cli/main.py @@ -18,7 +18,7 @@ import importlib import logging import pkgutil -from typing import Any, Dict +from typing import Any import click from colorama import Fore, Style @@ -40,7 +40,7 @@ def superset() -> None: """This is a management script for the Superset application.""" @app.shell_context_processor - def make_shell_context() -> Dict[str, Any]: + def make_shell_context() -> dict[str, Any]: return dict(app=app, db=db) @@ -79,5 +79,5 @@ def version(verbose: bool) -> None: ) print(Fore.BLUE + "-=" * 15) if verbose: - print("[DB] : " + "{}".format(db.engine)) + print("[DB] : " + f"{db.engine}") print(Style.RESET_ALL) diff --git a/superset/cli/native_filters.py b/superset/cli/native_filters.py index 63cc185e8e..a25724d38d 100644 --- a/superset/cli/native_filters.py +++ b/superset/cli/native_filters.py @@ -17,7 +17,6 @@ import json from copy import deepcopy from textwrap import dedent -from typing import Set, Tuple import click from click_option_group import optgroup, RequiredMutuallyExclusiveOptionGroup @@ -102,7 +101,7 @@ def native_filters() -> None: ) def upgrade( all_: bool, # pylint: disable=unused-argument - dashboard_ids: Tuple[int, ...], + dashboard_ids: tuple[int, ...], ) -> None: """ Upgrade legacy filter-box charts to native dashboard filters. @@ -251,7 +250,7 @@ def upgrade( ) def downgrade( all_: bool, # pylint: disable=unused-argument - dashboard_ids: Tuple[int, ...], + dashboard_ids: tuple[int, ...], ) -> None: """ Downgrade native dashboard filters to legacy filter-box charts (where applicable). @@ -347,7 +346,7 @@ def downgrade( ) def cleanup( all_: bool, # pylint: disable=unused-argument - dashboard_ids: Tuple[int, ...], + dashboard_ids: tuple[int, ...], ) -> None: """ Cleanup obsolete legacy filter-box charts and interim metadata. @@ -355,7 +354,7 @@ def cleanup( Note this operation is irreversible. """ - slice_ids: Set[int] = set() + slice_ids: set[int] = set() # Cleanup the dashboard which contains legacy fields used for downgrading. for dashboard in ( diff --git a/superset/cli/thumbnails.py b/superset/cli/thumbnails.py index 276d9981c1..325fab6853 100755 --- a/superset/cli/thumbnails.py +++ b/superset/cli/thumbnails.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Type, Union +from typing import Union import click from celery.utils.abstract import CallableTask @@ -75,7 +75,7 @@ def compute_thumbnails( def compute_generic_thumbnail( friendly_type: str, - model_cls: Union[Type[Dashboard], Type[Slice]], + model_cls: Union[type[Dashboard], type[Slice]], model_id: int, compute_func: CallableTask, ) -> None: diff --git a/superset/commands/base.py b/superset/commands/base.py index 42d5956312..caca50755d 100644 --- a/superset/commands/base.py +++ b/superset/commands/base.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from abc import ABC, abstractmethod -from typing import Any, List, Optional +from typing import Any, Optional from flask_appbuilder.security.sqla.models import User @@ -45,7 +45,7 @@ class BaseCommand(ABC): class CreateMixin: # pylint: disable=too-few-public-methods @staticmethod - def populate_owners(owner_ids: Optional[List[int]] = None) -> List[User]: + def populate_owners(owner_ids: Optional[list[int]] = None) -> list[User]: """ Populate list of owners, defaulting to the current user if `owner_ids` is undefined or empty. If current user is missing in `owner_ids`, current user @@ -60,7 +60,7 @@ class CreateMixin: # pylint: disable=too-few-public-methods class UpdateMixin: # pylint: disable=too-few-public-methods @staticmethod - def populate_owners(owner_ids: Optional[List[int]] = None) -> List[User]: + def populate_owners(owner_ids: Optional[list[int]] = None) -> list[User]: """ Populate list of owners. If current user is missing in `owner_ids`, current user is added unless belonging to the Admin role. diff --git a/superset/commands/exceptions.py b/superset/commands/exceptions.py index db9d1b6c63..4398d740c5 100644 --- a/superset/commands/exceptions.py +++ b/superset/commands/exceptions.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask_babel import lazy_gettext as _ from marshmallow import ValidationError @@ -59,7 +59,7 @@ class CommandInvalidError(CommandException): def __init__( self, message: str = "", - exceptions: Optional[List[ValidationError]] = None, + exceptions: Optional[list[ValidationError]] = None, ) -> None: self._exceptions = exceptions or [] super().__init__(message) @@ -67,14 +67,14 @@ class CommandInvalidError(CommandException): def append(self, exception: ValidationError) -> None: self._exceptions.append(exception) - def extend(self, exceptions: List[ValidationError]) -> None: + def extend(self, exceptions: list[ValidationError]) -> None: self._exceptions.extend(exceptions) - def get_list_classnames(self) -> List[str]: + def get_list_classnames(self) -> list[str]: return list(sorted({ex.__class__.__name__ for ex in self._exceptions})) - def normalized_messages(self) -> Dict[Any, Any]: - errors: Dict[Any, Any] = {} + def normalized_messages(self) -> dict[Any, Any]: + errors: dict[Any, Any] = {} for exception in self._exceptions: errors.update(exception.normalized_messages()) return errors diff --git a/superset/commands/export/assets.py b/superset/commands/export/assets.py index 9f088af428..1bd2cf6d61 100644 --- a/superset/commands/export/assets.py +++ b/superset/commands/export/assets.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. +from collections.abc import Iterator from datetime import datetime, timezone -from typing import Iterator, Tuple import yaml @@ -36,7 +36,7 @@ class ExportAssetsCommand(BaseCommand): Command that exports all databases, datasets, charts, dashboards and saved queries. """ - def run(self) -> Iterator[Tuple[str, str]]: + def run(self) -> Iterator[tuple[str, str]]: metadata = { "version": EXPORT_VERSION, "type": "assets", diff --git a/superset/commands/export/models.py b/superset/commands/export/models.py index 4edafaa746..3f21f29281 100644 --- a/superset/commands/export/models.py +++ b/superset/commands/export/models.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. +from collections.abc import Iterator from datetime import datetime, timezone -from typing import Iterator, List, Tuple, Type import yaml from flask_appbuilder import Model @@ -30,21 +30,21 @@ METADATA_FILE_NAME = "metadata.yaml" class ExportModelsCommand(BaseCommand): - dao: Type[BaseDAO] = BaseDAO - not_found: Type[CommandException] = CommandException + dao: type[BaseDAO] = BaseDAO + not_found: type[CommandException] = CommandException - def __init__(self, model_ids: List[int], export_related: bool = True): + def __init__(self, model_ids: list[int], export_related: bool = True): self.model_ids = model_ids self.export_related = export_related # this will be set when calling validate() - self._models: List[Model] = [] + self._models: list[Model] = [] @staticmethod - def _export(model: Model, export_related: bool = True) -> Iterator[Tuple[str, str]]: + def _export(model: Model, export_related: bool = True) -> Iterator[tuple[str, str]]: raise NotImplementedError("Subclasses MUST implement _export") - def run(self) -> Iterator[Tuple[str, str]]: + def run(self) -> Iterator[tuple[str, str]]: self.validate() metadata = { diff --git a/superset/commands/importers/v1/__init__.py b/superset/commands/importers/v1/__init__.py index a67828bdb2..09830bf3cf 100644 --- a/superset/commands/importers/v1/__init__.py +++ b/superset/commands/importers/v1/__init__.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Optional, Set +from typing import Any, Optional from marshmallow import Schema, validate from marshmallow.exceptions import ValidationError @@ -40,33 +40,33 @@ class ImportModelsCommand(BaseCommand): dao = BaseDAO model_name = "model" prefix = "" - schemas: Dict[str, Schema] = {} + schemas: dict[str, Schema] = {} import_error = CommandException # pylint: disable=unused-argument - def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any): + def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any): self.contents = contents - self.passwords: Dict[str, str] = kwargs.get("passwords") or {} - self.ssh_tunnel_passwords: Dict[str, str] = ( + self.passwords: dict[str, str] = kwargs.get("passwords") or {} + self.ssh_tunnel_passwords: dict[str, str] = ( kwargs.get("ssh_tunnel_passwords") or {} ) - self.ssh_tunnel_private_keys: Dict[str, str] = ( + self.ssh_tunnel_private_keys: dict[str, str] = ( kwargs.get("ssh_tunnel_private_keys") or {} ) - self.ssh_tunnel_priv_key_passwords: Dict[str, str] = ( + self.ssh_tunnel_priv_key_passwords: dict[str, str] = ( kwargs.get("ssh_tunnel_priv_key_passwords") or {} ) self.overwrite: bool = kwargs.get("overwrite", False) - self._configs: Dict[str, Any] = {} + self._configs: dict[str, Any] = {} @staticmethod def _import( - session: Session, configs: Dict[str, Any], overwrite: bool = False + session: Session, configs: dict[str, Any], overwrite: bool = False ) -> None: raise NotImplementedError("Subclasses MUST implement _import") @classmethod - def _get_uuids(cls) -> Set[str]: + def _get_uuids(cls) -> set[str]: return {str(model.uuid) for model in db.session.query(cls.dao.model_cls).all()} def run(self) -> None: @@ -84,11 +84,11 @@ class ImportModelsCommand(BaseCommand): raise self.import_error() from ex def validate(self) -> None: - exceptions: List[ValidationError] = [] + exceptions: list[ValidationError] = [] # verify that the metadata file is present and valid try: - metadata: Optional[Dict[str, str]] = load_metadata(self.contents) + metadata: Optional[dict[str, str]] = load_metadata(self.contents) except ValidationError as exc: exceptions.append(exc) metadata = None @@ -114,7 +114,7 @@ class ImportModelsCommand(BaseCommand): ) def _prevent_overwrite_existing_model( # pylint: disable=invalid-name - self, exceptions: List[ValidationError] + self, exceptions: list[ValidationError] ) -> None: """check if the object exists and shouldn't be overwritten""" if not self.overwrite: diff --git a/superset/commands/importers/v1/assets.py b/superset/commands/importers/v1/assets.py index ce8b46c2a0..1ab2e486cf 100644 --- a/superset/commands/importers/v1/assets.py +++ b/superset/commands/importers/v1/assets.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Optional +from typing import Any, Optional from marshmallow import Schema from marshmallow.exceptions import ValidationError @@ -56,7 +56,7 @@ class ImportAssetsCommand(BaseCommand): and will overwrite everything. """ - schemas: Dict[str, Schema] = { + schemas: dict[str, Schema] = { "charts/": ImportV1ChartSchema(), "dashboards/": ImportV1DashboardSchema(), "datasets/": ImportV1DatasetSchema(), @@ -65,24 +65,24 @@ class ImportAssetsCommand(BaseCommand): } # pylint: disable=unused-argument - def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any): + def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any): self.contents = contents - self.passwords: Dict[str, str] = kwargs.get("passwords") or {} - self.ssh_tunnel_passwords: Dict[str, str] = ( + self.passwords: dict[str, str] = kwargs.get("passwords") or {} + self.ssh_tunnel_passwords: dict[str, str] = ( kwargs.get("ssh_tunnel_passwords") or {} ) - self.ssh_tunnel_private_keys: Dict[str, str] = ( + self.ssh_tunnel_private_keys: dict[str, str] = ( kwargs.get("ssh_tunnel_private_keys") or {} ) - self.ssh_tunnel_priv_key_passwords: Dict[str, str] = ( + self.ssh_tunnel_priv_key_passwords: dict[str, str] = ( kwargs.get("ssh_tunnel_priv_key_passwords") or {} ) - self._configs: Dict[str, Any] = {} + self._configs: dict[str, Any] = {} @staticmethod - def _import(session: Session, configs: Dict[str, Any]) -> None: + def _import(session: Session, configs: dict[str, Any]) -> None: # import databases first - database_ids: Dict[str, int] = {} + database_ids: dict[str, int] = {} for file_name, config in configs.items(): if file_name.startswith("databases/"): database = import_database(session, config, overwrite=True) @@ -95,7 +95,7 @@ class ImportAssetsCommand(BaseCommand): import_saved_query(session, config, overwrite=True) # import datasets - dataset_info: Dict[str, Dict[str, Any]] = {} + dataset_info: dict[str, dict[str, Any]] = {} for file_name, config in configs.items(): if file_name.startswith("datasets/"): config["database_id"] = database_ids[config["database_uuid"]] @@ -107,7 +107,7 @@ class ImportAssetsCommand(BaseCommand): } # import charts - chart_ids: Dict[str, int] = {} + chart_ids: dict[str, int] = {} for file_name, config in configs.items(): if file_name.startswith("charts/"): config.update(dataset_info[config["dataset_uuid"]]) @@ -121,7 +121,7 @@ class ImportAssetsCommand(BaseCommand): dashboard = import_dashboard(session, config, overwrite=True) # set ref in the dashboard_slices table - dashboard_chart_ids: List[Dict[str, int]] = [] + dashboard_chart_ids: list[dict[str, int]] = [] for uuid in find_chart_uuids(config["position"]): if uuid not in chart_ids: break @@ -151,11 +151,11 @@ class ImportAssetsCommand(BaseCommand): raise ImportFailedError() from ex def validate(self) -> None: - exceptions: List[ValidationError] = [] + exceptions: list[ValidationError] = [] # verify that the metadata file is present and valid try: - metadata: Optional[Dict[str, str]] = load_metadata(self.contents) + metadata: Optional[dict[str, str]] = load_metadata(self.contents) except ValidationError as exc: exceptions.append(exc) metadata = None diff --git a/superset/commands/importers/v1/examples.py b/superset/commands/importers/v1/examples.py index 35efdb1393..4c20e93ff7 100644 --- a/superset/commands/importers/v1/examples.py +++ b/superset/commands/importers/v1/examples.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Set, Tuple +from typing import Any from marshmallow import Schema from sqlalchemy.orm import Session @@ -52,7 +52,7 @@ class ImportExamplesCommand(ImportModelsCommand): dao = BaseDAO model_name = "model" - schemas: Dict[str, Schema] = { + schemas: dict[str, Schema] = { "charts/": ImportV1ChartSchema(), "dashboards/": ImportV1DashboardSchema(), "datasets/": ImportV1DatasetSchema(), @@ -60,7 +60,7 @@ class ImportExamplesCommand(ImportModelsCommand): } import_error = CommandException - def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any): + def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any): super().__init__(contents, *args, **kwargs) self.force_data = kwargs.get("force_data", False) @@ -81,7 +81,7 @@ class ImportExamplesCommand(ImportModelsCommand): raise self.import_error() from ex @classmethod - def _get_uuids(cls) -> Set[str]: + def _get_uuids(cls) -> set[str]: # pylint: disable=protected-access return ( ImportDatabasesCommand._get_uuids() @@ -93,12 +93,12 @@ class ImportExamplesCommand(ImportModelsCommand): @staticmethod def _import( # pylint: disable=arguments-differ, too-many-locals, too-many-branches session: Session, - configs: Dict[str, Any], + configs: dict[str, Any], overwrite: bool = False, force_data: bool = False, ) -> None: # import databases - database_ids: Dict[str, int] = {} + database_ids: dict[str, int] = {} for file_name, config in configs.items(): if file_name.startswith("databases/"): database = import_database( @@ -114,7 +114,7 @@ class ImportExamplesCommand(ImportModelsCommand): # database was created before its UUID was frozen, so it has a random UUID. # We need to determine its ID so we can point the dataset to it. examples_db = get_example_database() - dataset_info: Dict[str, Dict[str, Any]] = {} + dataset_info: dict[str, dict[str, Any]] = {} for file_name, config in configs.items(): if file_name.startswith("datasets/"): # find the ID of the corresponding database @@ -153,7 +153,7 @@ class ImportExamplesCommand(ImportModelsCommand): } # import charts - chart_ids: Dict[str, int] = {} + chart_ids: dict[str, int] = {} for file_name, config in configs.items(): if ( file_name.startswith("charts/") @@ -175,7 +175,7 @@ class ImportExamplesCommand(ImportModelsCommand): ).fetchall() # import dashboards - dashboard_chart_ids: List[Tuple[int, int]] = [] + dashboard_chart_ids: list[tuple[int, int]] = [] for file_name, config in configs.items(): if file_name.startswith("dashboards/"): try: diff --git a/superset/commands/importers/v1/utils.py b/superset/commands/importers/v1/utils.py index c8fb97c53d..8ca008b3e2 100644 --- a/superset/commands/importers/v1/utils.py +++ b/superset/commands/importers/v1/utils.py @@ -15,7 +15,7 @@ import logging from pathlib import Path, PurePosixPath -from typing import Any, Dict, List, Optional +from typing import Any, Optional from zipfile import ZipFile import yaml @@ -46,7 +46,7 @@ class MetadataSchema(Schema): timestamp = fields.DateTime() -def load_yaml(file_name: str, content: str) -> Dict[str, Any]: +def load_yaml(file_name: str, content: str) -> dict[str, Any]: """Try to load a YAML file""" try: return yaml.safe_load(content) @@ -55,7 +55,7 @@ def load_yaml(file_name: str, content: str) -> Dict[str, Any]: raise ValidationError({file_name: "Not a valid YAML file"}) from ex -def load_metadata(contents: Dict[str, str]) -> Dict[str, str]: +def load_metadata(contents: dict[str, str]) -> dict[str, str]: """Apply validation and load a metadata file""" if METADATA_FILE_NAME not in contents: # if the contents have no METADATA_FILE_NAME this is probably @@ -80,9 +80,9 @@ def load_metadata(contents: Dict[str, str]) -> Dict[str, str]: def validate_metadata_type( - metadata: Optional[Dict[str, str]], + metadata: Optional[dict[str, str]], type_: str, - exceptions: List[ValidationError], + exceptions: list[ValidationError], ) -> None: """Validate that the type declared in METADATA_FILE_NAME is correct""" if metadata and "type" in metadata: @@ -96,35 +96,35 @@ def validate_metadata_type( # pylint: disable=too-many-locals,too-many-arguments def load_configs( - contents: Dict[str, str], - schemas: Dict[str, Schema], - passwords: Dict[str, str], - exceptions: List[ValidationError], - ssh_tunnel_passwords: Dict[str, str], - ssh_tunnel_private_keys: Dict[str, str], - ssh_tunnel_priv_key_passwords: Dict[str, str], -) -> Dict[str, Any]: - configs: Dict[str, Any] = {} + contents: dict[str, str], + schemas: dict[str, Schema], + passwords: dict[str, str], + exceptions: list[ValidationError], + ssh_tunnel_passwords: dict[str, str], + ssh_tunnel_private_keys: dict[str, str], + ssh_tunnel_priv_key_passwords: dict[str, str], +) -> dict[str, Any]: + configs: dict[str, Any] = {} # load existing databases so we can apply the password validation - db_passwords: Dict[str, str] = { + db_passwords: dict[str, str] = { str(uuid): password for uuid, password in db.session.query(Database.uuid, Database.password).all() } # load existing ssh_tunnels so we can apply the password validation - db_ssh_tunnel_passwords: Dict[str, str] = { + db_ssh_tunnel_passwords: dict[str, str] = { str(uuid): password for uuid, password in db.session.query(SSHTunnel.uuid, SSHTunnel.password).all() } # load existing ssh_tunnels so we can apply the private_key validation - db_ssh_tunnel_private_keys: Dict[str, str] = { + db_ssh_tunnel_private_keys: dict[str, str] = { str(uuid): private_key for uuid, private_key in db.session.query( SSHTunnel.uuid, SSHTunnel.private_key ).all() } # load existing ssh_tunnels so we can apply the private_key_password validation - db_ssh_tunnel_priv_key_passws: Dict[str, str] = { + db_ssh_tunnel_priv_key_passws: dict[str, str] = { str(uuid): private_key_password for uuid, private_key_password in db.session.query( SSHTunnel.uuid, SSHTunnel.private_key_password @@ -206,7 +206,7 @@ def is_valid_config(file_name: str) -> bool: return True -def get_contents_from_bundle(bundle: ZipFile) -> Dict[str, str]: +def get_contents_from_bundle(bundle: ZipFile) -> dict[str, str]: return { remove_root(file_name): bundle.read(file_name).decode() for file_name in bundle.namelist() diff --git a/superset/commands/utils.py b/superset/commands/utils.py index ad58bb4050..7bb13984f8 100644 --- a/superset/commands/utils.py +++ b/superset/commands/utils.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from typing import List, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING from flask import g from flask_appbuilder.security.sqla.models import Role, User @@ -37,9 +37,9 @@ if TYPE_CHECKING: def populate_owners( - owner_ids: Optional[List[int]], + owner_ids: list[int] | None, default_to_user: bool, -) -> List[User]: +) -> list[User]: """ Helper function for commands, will fetch all users from owners id's @@ -63,13 +63,13 @@ def populate_owners( return owners -def populate_roles(role_ids: Optional[List[int]] = None) -> List[Role]: +def populate_roles(role_ids: list[int] | None = None) -> list[Role]: """ Helper function for commands, will fetch all roles from roles id's :raises RolesNotFoundValidationError: If a role in the input list is not found :param role_ids: A List of roles by id's """ - roles: List[Role] = [] + roles: list[Role] = [] if role_ids: roles = security_manager.find_roles_by_id(role_ids) if len(roles) != len(role_ids): diff --git a/superset/common/chart_data.py b/superset/common/chart_data.py index 659a640159..65c0c43c11 100644 --- a/superset/common/chart_data.py +++ b/superset/common/chart_data.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. from enum import Enum -from typing import Set class ChartDataResultFormat(str, Enum): @@ -28,7 +27,7 @@ class ChartDataResultFormat(str, Enum): XLSX = "xlsx" @classmethod - def table_like(cls) -> Set["ChartDataResultFormat"]: + def table_like(cls) -> set["ChartDataResultFormat"]: return {cls.CSV} | {cls.XLSX} diff --git a/superset/common/query_actions.py b/superset/common/query_actions.py index f6f5a5cd62..22c778b77b 100644 --- a/superset/common/query_actions.py +++ b/superset/common/query_actions.py @@ -17,7 +17,7 @@ from __future__ import annotations import copy -from typing import Any, Callable, Dict, Optional, TYPE_CHECKING +from typing import Any, Callable, TYPE_CHECKING from flask_babel import _ @@ -49,7 +49,7 @@ def _get_datasource( def _get_columns( query_context: QueryContext, query_obj: QueryObject, _: bool -) -> Dict[str, Any]: +) -> dict[str, Any]: datasource = _get_datasource(query_context, query_obj) return { "data": [ @@ -65,7 +65,7 @@ def _get_columns( def _get_timegrains( query_context: QueryContext, query_obj: QueryObject, _: bool -) -> Dict[str, Any]: +) -> dict[str, Any]: datasource = _get_datasource(query_context, query_obj) return { "data": [ @@ -83,7 +83,7 @@ def _get_query( query_context: QueryContext, query_obj: QueryObject, _: bool, -) -> Dict[str, Any]: +) -> dict[str, Any]: datasource = _get_datasource(query_context, query_obj) result = {"language": datasource.query_language} try: @@ -96,8 +96,8 @@ def _get_query( def _get_full( query_context: QueryContext, query_obj: QueryObject, - force_cached: Optional[bool] = False, -) -> Dict[str, Any]: + force_cached: bool | None = False, +) -> dict[str, Any]: datasource = _get_datasource(query_context, query_obj) result_type = query_obj.result_type or query_context.result_type payload = query_context.get_df_payload(query_obj, force_cached=force_cached) @@ -141,7 +141,7 @@ def _get_full( def _get_samples( query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False -) -> Dict[str, Any]: +) -> dict[str, Any]: datasource = _get_datasource(query_context, query_obj) query_obj = copy.copy(query_obj) query_obj.is_timeseries = False @@ -162,7 +162,7 @@ def _get_samples( def _get_drill_detail( query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False -) -> Dict[str, Any]: +) -> dict[str, Any]: # todo(yongjie): Remove this function, # when determining whether samples should be applied to the time filter. datasource = _get_datasource(query_context, query_obj) @@ -183,13 +183,13 @@ def _get_drill_detail( def _get_results( query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False -) -> Dict[str, Any]: +) -> dict[str, Any]: payload = _get_full(query_context, query_obj, force_cached) return payload -_result_type_functions: Dict[ - ChartDataResultType, Callable[[QueryContext, QueryObject, bool], Dict[str, Any]] +_result_type_functions: dict[ + ChartDataResultType, Callable[[QueryContext, QueryObject, bool], dict[str, Any]] ] = { ChartDataResultType.COLUMNS: _get_columns, ChartDataResultType.TIMEGRAINS: _get_timegrains, @@ -210,7 +210,7 @@ def get_query_results( query_context: QueryContext, query_obj: QueryObject, force_cached: bool, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Return result payload for a chart data request. diff --git a/superset/common/query_context.py b/superset/common/query_context.py index 78eb8800c4..1a8d3c518b 100644 --- a/superset/common/query_context.py +++ b/superset/common/query_context.py @@ -17,7 +17,7 @@ from __future__ import annotations import logging -from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Union +from typing import Any, ClassVar, TYPE_CHECKING import pandas as pd @@ -47,15 +47,15 @@ class QueryContext: enforce_numerical_metrics: ClassVar[bool] = True datasource: BaseDatasource - slice_: Optional[Slice] = None - queries: List[QueryObject] - form_data: Optional[Dict[str, Any]] + slice_: Slice | None = None + queries: list[QueryObject] + form_data: dict[str, Any] | None result_type: ChartDataResultType result_format: ChartDataResultFormat force: bool - custom_cache_timeout: Optional[int] + custom_cache_timeout: int | None - cache_values: Dict[str, Any] + cache_values: dict[str, Any] _processor: QueryContextProcessor @@ -65,14 +65,14 @@ class QueryContext: self, *, datasource: BaseDatasource, - queries: List[QueryObject], - slice_: Optional[Slice], - form_data: Optional[Dict[str, Any]], + queries: list[QueryObject], + slice_: Slice | None, + form_data: dict[str, Any] | None, result_type: ChartDataResultType, result_format: ChartDataResultFormat, force: bool = False, - custom_cache_timeout: Optional[int] = None, - cache_values: Dict[str, Any], + custom_cache_timeout: int | None = None, + cache_values: dict[str, Any], ) -> None: self.datasource = datasource self.slice_ = slice_ @@ -88,18 +88,18 @@ class QueryContext: def get_data( self, df: pd.DataFrame, - ) -> Union[str, List[Dict[str, Any]]]: + ) -> str | list[dict[str, Any]]: return self._processor.get_data(df) def get_payload( self, - cache_query_context: Optional[bool] = False, + cache_query_context: bool | None = False, force_cached: bool = False, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Returns the query results with both metadata and data""" return self._processor.get_payload(cache_query_context, force_cached) - def get_cache_timeout(self) -> Optional[int]: + def get_cache_timeout(self) -> int | None: if self.custom_cache_timeout is not None: return self.custom_cache_timeout if self.slice_ and self.slice_.cache_timeout is not None: @@ -110,14 +110,14 @@ class QueryContext: return self.datasource.database.cache_timeout return None - def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> Optional[str]: + def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> str | None: return self._processor.query_cache_key(query_obj, **kwargs) def get_df_payload( self, query_obj: QueryObject, - force_cached: Optional[bool] = False, - ) -> Dict[str, Any]: + force_cached: bool | None = False, + ) -> dict[str, Any]: return self._processor.get_df_payload( query_obj=query_obj, force_cached=force_cached, diff --git a/superset/common/query_context_factory.py b/superset/common/query_context_factory.py index 84c0415722..62018def8d 100644 --- a/superset/common/query_context_factory.py +++ b/superset/common/query_context_factory.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from typing import Any, Dict, List, Optional, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from superset import app, db from superset.charts.dao import ChartDAO @@ -48,12 +48,12 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods self, *, datasource: DatasourceDict, - queries: List[Dict[str, Any]], - form_data: Optional[Dict[str, Any]] = None, - result_type: Optional[ChartDataResultType] = None, - result_format: Optional[ChartDataResultFormat] = None, + queries: list[dict[str, Any]], + form_data: dict[str, Any] | None = None, + result_type: ChartDataResultType | None = None, + result_format: ChartDataResultFormat | None = None, force: bool = False, - custom_cache_timeout: Optional[int] = None, + custom_cache_timeout: int | None = None, ) -> QueryContext: datasource_model_instance = None if datasource: @@ -101,13 +101,13 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods datasource_id=int(datasource["id"]), ) - def _get_slice(self, slice_id: Any) -> Optional[Slice]: + def _get_slice(self, slice_id: Any) -> Slice | None: return ChartDAO.find_by_id(slice_id) def _process_query_object( self, datasource: BaseDatasource, - form_data: Optional[Dict[str, Any]], + form_data: dict[str, Any] | None, query_object: QueryObject, ) -> QueryObject: self._apply_granularity(query_object, form_data, datasource) @@ -117,7 +117,7 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods def _apply_granularity( self, query_object: QueryObject, - form_data: Optional[Dict[str, Any]], + form_data: dict[str, Any] | None, datasource: BaseDatasource, ) -> None: temporal_columns = { diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py index 85a2b5d97a..ecb8db4246 100644 --- a/superset/common/query_context_processor.py +++ b/superset/common/query_context_processor.py @@ -19,7 +19,7 @@ from __future__ import annotations import copy import logging import re -from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Union +from typing import Any, ClassVar, TYPE_CHECKING import numpy as np import pandas as pd @@ -77,8 +77,8 @@ logger = logging.getLogger(__name__) class CachedTimeOffset(TypedDict): df: pd.DataFrame - queries: List[str] - cache_keys: List[Optional[str]] + queries: list[str] + cache_keys: list[str | None] class QueryContextProcessor: @@ -102,8 +102,8 @@ class QueryContextProcessor: enforce_numerical_metrics: ClassVar[bool] = True def get_df_payload( - self, query_obj: QueryObject, force_cached: Optional[bool] = False - ) -> Dict[str, Any]: + self, query_obj: QueryObject, force_cached: bool | None = False + ) -> dict[str, Any]: """Handles caching around the df payload retrieval""" cache_key = self.query_cache_key(query_obj) timeout = self.get_cache_timeout() @@ -181,7 +181,7 @@ class QueryContextProcessor: "label_map": label_map, } - def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> Optional[str]: + def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> str | None: """ Returns a QueryObject cache key for objects in self.queries """ @@ -248,8 +248,8 @@ class QueryContextProcessor: def normalize_df(self, df: pd.DataFrame, query_object: QueryObject) -> pd.DataFrame: # todo: should support "python_date_format" and "get_column" in each datasource def _get_timestamp_format( - source: BaseDatasource, column: Optional[str] - ) -> Optional[str]: + source: BaseDatasource, column: str | None + ) -> str | None: column_obj = source.get_column(column) if ( column_obj @@ -315,9 +315,9 @@ class QueryContextProcessor: query_context = self._query_context # ensure query_object is immutable query_object_clone = copy.copy(query_object) - queries: List[str] = [] - cache_keys: List[Optional[str]] = [] - rv_dfs: List[pd.DataFrame] = [df] + queries: list[str] = [] + cache_keys: list[str | None] = [] + rv_dfs: list[pd.DataFrame] = [df] time_offsets = query_object.time_offsets outer_from_dttm, outer_to_dttm = get_since_until_from_query_object(query_object) @@ -449,7 +449,7 @@ class QueryContextProcessor: rv_df = pd.concat(rv_dfs, axis=1, copy=False) if time_offsets else df return CachedTimeOffset(df=rv_df, queries=queries, cache_keys=cache_keys) - def get_data(self, df: pd.DataFrame) -> Union[str, List[Dict[str, Any]]]: + def get_data(self, df: pd.DataFrame) -> str | list[dict[str, Any]]: if self._query_context.result_format in ChartDataResultFormat.table_like(): include_index = not isinstance(df.index, pd.RangeIndex) columns = list(df.columns) @@ -470,9 +470,9 @@ class QueryContextProcessor: def get_payload( self, - cache_query_context: Optional[bool] = False, + cache_query_context: bool | None = False, force_cached: bool = False, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Returns the query results with both metadata and data""" # Get all the payloads from the QueryObjects @@ -522,13 +522,13 @@ class QueryContextProcessor: return generate_cache_key(cache_dict, key_prefix) - def get_annotation_data(self, query_obj: QueryObject) -> Dict[str, Any]: + def get_annotation_data(self, query_obj: QueryObject) -> dict[str, Any]: """ :param query_context: :param query_obj: :return: """ - annotation_data: Dict[str, Any] = self.get_native_annotation_data(query_obj) + annotation_data: dict[str, Any] = self.get_native_annotation_data(query_obj) for annotation_layer in [ layer for layer in query_obj.annotation_layers @@ -541,7 +541,7 @@ class QueryContextProcessor: return annotation_data @staticmethod - def get_native_annotation_data(query_obj: QueryObject) -> Dict[str, Any]: + def get_native_annotation_data(query_obj: QueryObject) -> dict[str, Any]: annotation_data = {} annotation_layers = [ layer @@ -576,8 +576,8 @@ class QueryContextProcessor: @staticmethod def get_viz_annotation_data( - annotation_layer: Dict[str, Any], force: bool - ) -> Dict[str, Any]: + annotation_layer: dict[str, Any], force: bool + ) -> dict[str, Any]: chart = ChartDAO.find_by_id(annotation_layer["value"]) if not chart: raise QueryObjectValidationError(_("The chart does not exist")) diff --git a/superset/common/query_object.py b/superset/common/query_object.py index 802a1eed5b..dc02b774e5 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -21,7 +21,7 @@ import json import logging from datetime import datetime from pprint import pformat -from typing import Any, Dict, List, NamedTuple, Optional, TYPE_CHECKING +from typing import Any, NamedTuple, TYPE_CHECKING from flask import g from flask_babel import gettext as _ @@ -81,58 +81,58 @@ class QueryObject: # pylint: disable=too-many-instance-attributes and druid. The query objects are constructed on the client. """ - annotation_layers: List[Dict[str, Any]] - applied_time_extras: Dict[str, str] + annotation_layers: list[dict[str, Any]] + applied_time_extras: dict[str, str] apply_fetch_values_predicate: bool - columns: List[Column] - datasource: Optional[BaseDatasource] - extras: Dict[str, Any] - filter: List[QueryObjectFilterClause] - from_dttm: Optional[datetime] - granularity: Optional[str] - inner_from_dttm: Optional[datetime] - inner_to_dttm: Optional[datetime] + columns: list[Column] + datasource: BaseDatasource | None + extras: dict[str, Any] + filter: list[QueryObjectFilterClause] + from_dttm: datetime | None + granularity: str | None + inner_from_dttm: datetime | None + inner_to_dttm: datetime | None is_rowcount: bool is_timeseries: bool - metrics: Optional[List[Metric]] + metrics: list[Metric] | None order_desc: bool - orderby: List[OrderBy] - post_processing: List[Dict[str, Any]] - result_type: Optional[ChartDataResultType] - row_limit: Optional[int] + orderby: list[OrderBy] + post_processing: list[dict[str, Any]] + result_type: ChartDataResultType | None + row_limit: int | None row_offset: int - series_columns: List[Column] + series_columns: list[Column] series_limit: int - series_limit_metric: Optional[Metric] - time_offsets: List[str] - time_shift: Optional[str] - time_range: Optional[str] - to_dttm: Optional[datetime] + series_limit_metric: Metric | None + time_offsets: list[str] + time_shift: str | None + time_range: str | None + to_dttm: datetime | None def __init__( # pylint: disable=too-many-locals self, *, - annotation_layers: Optional[List[Dict[str, Any]]] = None, - applied_time_extras: Optional[Dict[str, str]] = None, + annotation_layers: list[dict[str, Any]] | None = None, + applied_time_extras: dict[str, str] | None = None, apply_fetch_values_predicate: bool = False, - columns: Optional[List[Column]] = None, - datasource: Optional[BaseDatasource] = None, - extras: Optional[Dict[str, Any]] = None, - filters: Optional[List[QueryObjectFilterClause]] = None, - granularity: Optional[str] = None, + columns: list[Column] | None = None, + datasource: BaseDatasource | None = None, + extras: dict[str, Any] | None = None, + filters: list[QueryObjectFilterClause] | None = None, + granularity: str | None = None, is_rowcount: bool = False, - is_timeseries: Optional[bool] = None, - metrics: Optional[List[Metric]] = None, + is_timeseries: bool | None = None, + metrics: list[Metric] | None = None, order_desc: bool = True, - orderby: Optional[List[OrderBy]] = None, - post_processing: Optional[List[Optional[Dict[str, Any]]]] = None, - row_limit: Optional[int], - row_offset: Optional[int] = None, - series_columns: Optional[List[Column]] = None, + orderby: list[OrderBy] | None = None, + post_processing: list[dict[str, Any] | None] | None = None, + row_limit: int | None, + row_offset: int | None = None, + series_columns: list[Column] | None = None, series_limit: int = 0, - series_limit_metric: Optional[Metric] = None, - time_range: Optional[str] = None, - time_shift: Optional[str] = None, + series_limit_metric: Metric | None = None, + time_range: str | None = None, + time_shift: str | None = None, **kwargs: Any, ): self._set_annotation_layers(annotation_layers) @@ -166,7 +166,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes self._move_deprecated_extra_fields(kwargs) def _set_annotation_layers( - self, annotation_layers: Optional[List[Dict[str, Any]]] + self, annotation_layers: list[dict[str, Any]] | None ) -> None: self.annotation_layers = [ layer @@ -175,14 +175,14 @@ class QueryObject: # pylint: disable=too-many-instance-attributes if layer["annotationType"] != "FORMULA" ] - def _set_is_timeseries(self, is_timeseries: Optional[bool]) -> None: + def _set_is_timeseries(self, is_timeseries: bool | None) -> None: # is_timeseries is True if time column is in either columns or groupby # (both are dimensions) self.is_timeseries = ( is_timeseries if is_timeseries is not None else DTTM_ALIAS in self.columns ) - def _set_metrics(self, metrics: Optional[List[Metric]] = None) -> None: + def _set_metrics(self, metrics: list[Metric] | None = None) -> None: # Support metric reference/definition in the format of # 1. 'metric_name' - name of predefined metric # 2. { label: 'label_name' } - legacy format for a predefined metric @@ -195,16 +195,16 @@ class QueryObject: # pylint: disable=too-many-instance-attributes ] def _set_post_processing( - self, post_processing: Optional[List[Optional[Dict[str, Any]]]] + self, post_processing: list[dict[str, Any] | None] | None ) -> None: post_processing = post_processing or [] self.post_processing = [post_proc for post_proc in post_processing if post_proc] def _init_series_columns( self, - series_columns: Optional[List[Column]], - metrics: Optional[List[Metric]], - is_timeseries: Optional[bool], + series_columns: list[Column] | None, + metrics: list[Metric] | None, + is_timeseries: bool | None, ) -> None: if series_columns: self.series_columns = series_columns @@ -213,7 +213,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes else: self.series_columns = [] - def _rename_deprecated_fields(self, kwargs: Dict[str, Any]) -> None: + def _rename_deprecated_fields(self, kwargs: dict[str, Any]) -> None: # rename deprecated fields for field in DEPRECATED_FIELDS: if field.old_name in kwargs: @@ -233,7 +233,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes ) setattr(self, field.new_name, value) - def _move_deprecated_extra_fields(self, kwargs: Dict[str, Any]) -> None: + def _move_deprecated_extra_fields(self, kwargs: dict[str, Any]) -> None: # move deprecated extras fields to extras for field in DEPRECATED_EXTRAS_FIELDS: if field.old_name in kwargs: @@ -256,19 +256,19 @@ class QueryObject: # pylint: disable=too-many-instance-attributes self.extras[field.new_name] = value @property - def metric_names(self) -> List[str]: + def metric_names(self) -> list[str]: """Return metrics names (labels), coerce adhoc metrics to strings.""" return get_metric_names(self.metrics or []) @property - def column_names(self) -> List[str]: + def column_names(self) -> list[str]: """Return column names (labels). Gives priority to groupbys if both groupbys and metrics are non-empty, otherwise returns column labels.""" return get_column_names(self.columns) def validate( - self, raise_exceptions: Optional[bool] = True - ) -> Optional[QueryObjectValidationError]: + self, raise_exceptions: bool | None = True + ) -> QueryObjectValidationError | None: """Validate query object""" try: self._validate_there_are_no_missing_series() @@ -314,7 +314,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes ) ) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: query_object_dict = { "apply_fetch_values_predicate": self.apply_fetch_values_predicate, "columns": self.columns, diff --git a/superset/common/query_object_factory.py b/superset/common/query_object_factory.py index 88cc7ca1b4..5676dc9eda 100644 --- a/superset/common/query_object_factory.py +++ b/superset/common/query_object_factory.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from typing import Any, Dict, Optional, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from superset.common.chart_data import ChartDataResultType from superset.common.query_object import QueryObject @@ -31,13 +31,13 @@ if TYPE_CHECKING: class QueryObjectFactory: # pylint: disable=too-few-public-methods - _config: Dict[str, Any] + _config: dict[str, Any] _datasource_dao: DatasourceDAO _session_maker: sessionmaker def __init__( self, - app_configurations: Dict[str, Any], + app_configurations: dict[str, Any], _datasource_dao: DatasourceDAO, session_maker: sessionmaker, ): @@ -48,11 +48,11 @@ class QueryObjectFactory: # pylint: disable=too-few-public-methods def create( # pylint: disable=too-many-arguments self, parent_result_type: ChartDataResultType, - datasource: Optional[DatasourceDict] = None, - extras: Optional[Dict[str, Any]] = None, - row_limit: Optional[int] = None, - time_range: Optional[str] = None, - time_shift: Optional[str] = None, + datasource: DatasourceDict | None = None, + extras: dict[str, Any] | None = None, + row_limit: int | None = None, + time_range: str | None = None, + time_shift: str | None = None, **kwargs: Any, ) -> QueryObject: datasource_model_instance = None @@ -84,13 +84,13 @@ class QueryObjectFactory: # pylint: disable=too-few-public-methods def _process_extras( # pylint: disable=no-self-use self, - extras: Optional[Dict[str, Any]], - ) -> Dict[str, Any]: + extras: dict[str, Any] | None, + ) -> dict[str, Any]: extras = extras or {} return extras def _process_row_limit( - self, row_limit: Optional[int], result_type: ChartDataResultType + self, row_limit: int | None, result_type: ChartDataResultType ) -> int: default_row_limit = ( self._config["SAMPLES_ROW_LIMIT"] diff --git a/superset/common/tags.py b/superset/common/tags.py index 706192913a..6066d0eec7 100644 --- a/superset/common/tags.py +++ b/superset/common/tags.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, List +from typing import Any from sqlalchemy import MetaData from sqlalchemy.exc import IntegrityError @@ -25,7 +25,7 @@ from superset.tags.models import ObjectTypes, TagTypes def add_types_to_charts( - metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] + metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str] ) -> None: slices = metadata.tables["slices"] @@ -57,7 +57,7 @@ def add_types_to_charts( def add_types_to_dashboards( - metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] + metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str] ) -> None: dashboard_table = metadata.tables["dashboards"] @@ -89,7 +89,7 @@ def add_types_to_dashboards( def add_types_to_saved_queries( - metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] + metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str] ) -> None: saved_query = metadata.tables["saved_query"] @@ -121,7 +121,7 @@ def add_types_to_saved_queries( def add_types_to_datasets( - metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] + metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str] ) -> None: tables = metadata.tables["tables"] @@ -237,7 +237,7 @@ def add_types(metadata: MetaData) -> None: def add_owners_to_charts( - metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] + metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str] ) -> None: slices = metadata.tables["slices"] @@ -273,7 +273,7 @@ def add_owners_to_charts( def add_owners_to_dashboards( - metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] + metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str] ) -> None: dashboard_table = metadata.tables["dashboards"] @@ -309,7 +309,7 @@ def add_owners_to_dashboards( def add_owners_to_saved_queries( - metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] + metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str] ) -> None: saved_query = metadata.tables["saved_query"] @@ -345,7 +345,7 @@ def add_owners_to_saved_queries( def add_owners_to_datasets( - metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] + metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str] ) -> None: tables = metadata.tables["tables"] diff --git a/superset/common/utils/dataframe_utils.py b/superset/common/utils/dataframe_utils.py index 4dd62e3b5d..a3421f6431 100644 --- a/superset/common/utils/dataframe_utils.py +++ b/superset/common/utils/dataframe_utils.py @@ -17,7 +17,7 @@ from __future__ import annotations import datetime -from typing import Any, List, TYPE_CHECKING +from typing import Any, TYPE_CHECKING import numpy as np import pandas as pd @@ -29,7 +29,7 @@ if TYPE_CHECKING: def left_join_df( left_df: pd.DataFrame, right_df: pd.DataFrame, - join_keys: List[str], + join_keys: list[str], ) -> pd.DataFrame: df = left_df.set_index(join_keys).join(right_df.set_index(join_keys)) df.reset_index(inplace=True) diff --git a/superset/common/utils/query_cache_manager.py b/superset/common/utils/query_cache_manager.py index 6c1b268f46..a0fb65b20d 100644 --- a/superset/common/utils/query_cache_manager.py +++ b/superset/common/utils/query_cache_manager.py @@ -17,7 +17,7 @@ from __future__ import annotations import logging -from typing import Any, Dict, List, Optional +from typing import Any from flask_caching import Cache from pandas import DataFrame @@ -37,7 +37,7 @@ config = app.config stats_logger: BaseStatsLogger = config["STATS_LOGGER"] logger = logging.getLogger(__name__) -_cache: Dict[CacheRegion, Cache] = { +_cache: dict[CacheRegion, Cache] = { CacheRegion.DEFAULT: cache_manager.cache, CacheRegion.DATA: cache_manager.data_cache, } @@ -53,17 +53,17 @@ class QueryCacheManager: self, df: DataFrame = DataFrame(), query: str = "", - annotation_data: Optional[Dict[str, Any]] = None, - applied_template_filters: Optional[List[str]] = None, - applied_filter_columns: Optional[List[Column]] = None, - rejected_filter_columns: Optional[List[Column]] = None, - status: Optional[str] = None, - error_message: Optional[str] = None, + annotation_data: dict[str, Any] | None = None, + applied_template_filters: list[str] | None = None, + applied_filter_columns: list[Column] | None = None, + rejected_filter_columns: list[Column] | None = None, + status: str | None = None, + error_message: str | None = None, is_loaded: bool = False, - stacktrace: Optional[str] = None, - is_cached: Optional[bool] = None, - cache_dttm: Optional[str] = None, - cache_value: Optional[Dict[str, Any]] = None, + stacktrace: str | None = None, + is_cached: bool | None = None, + cache_dttm: str | None = None, + cache_value: dict[str, Any] | None = None, ) -> None: self.df = df self.query = query @@ -85,10 +85,10 @@ class QueryCacheManager: self, key: str, query_result: QueryResult, - annotation_data: Optional[Dict[str, Any]] = None, - force_query: Optional[bool] = False, - timeout: Optional[int] = None, - datasource_uid: Optional[str] = None, + annotation_data: dict[str, Any] | None = None, + force_query: bool | None = False, + timeout: int | None = None, + datasource_uid: str | None = None, region: CacheRegion = CacheRegion.DEFAULT, ) -> None: """ @@ -136,11 +136,11 @@ class QueryCacheManager: @classmethod def get( cls, - key: Optional[str], + key: str | None, region: CacheRegion = CacheRegion.DEFAULT, - force_query: Optional[bool] = False, - force_cached: Optional[bool] = False, - ) -> "QueryCacheManager": + force_query: bool | None = False, + force_cached: bool | None = False, + ) -> QueryCacheManager: """ Initialize QueryCacheManager by query-cache key """ @@ -190,10 +190,10 @@ class QueryCacheManager: @staticmethod def set( - key: Optional[str], - value: Dict[str, Any], - timeout: Optional[int] = None, - datasource_uid: Optional[str] = None, + key: str | None, + value: dict[str, Any], + timeout: int | None = None, + datasource_uid: str | None = None, region: CacheRegion = CacheRegion.DEFAULT, ) -> None: """ @@ -204,7 +204,7 @@ class QueryCacheManager: @staticmethod def delete( - key: Optional[str], + key: str | None, region: CacheRegion = CacheRegion.DEFAULT, ) -> None: if key: @@ -212,7 +212,7 @@ class QueryCacheManager: @staticmethod def has( - key: Optional[str], + key: str | None, region: CacheRegion = CacheRegion.DEFAULT, ) -> bool: return bool(_cache[region].get(key)) if key else False diff --git a/superset/common/utils/time_range_utils.py b/superset/common/utils/time_range_utils.py index fa6a5244b2..5f9139c047 100644 --- a/superset/common/utils/time_range_utils.py +++ b/superset/common/utils/time_range_utils.py @@ -17,7 +17,7 @@ from __future__ import annotations from datetime import datetime -from typing import Any, cast, Dict, Optional, Tuple +from typing import Any, cast from superset import app from superset.common.query_object import QueryObject @@ -26,10 +26,10 @@ from superset.utils.date_parser import get_since_until def get_since_until_from_time_range( - time_range: Optional[str] = None, - time_shift: Optional[str] = None, - extras: Optional[Dict[str, Any]] = None, -) -> Tuple[Optional[datetime], Optional[datetime]]: + time_range: str | None = None, + time_shift: str | None = None, + extras: dict[str, Any] | None = None, +) -> tuple[datetime | None, datetime | None]: return get_since_until( relative_start=(extras or {}).get( "relative_start", app.config["DEFAULT_RELATIVE_START_TIME"] @@ -45,7 +45,7 @@ def get_since_until_from_time_range( # pylint: disable=invalid-name def get_since_until_from_query_object( query_object: QueryObject, -) -> Tuple[Optional[datetime], Optional[datetime]]: +) -> tuple[datetime | None, datetime | None]: """ this function will return since and until by tuple if 1) the time_range is in the query object. diff --git a/superset/config.py b/superset/config.py index 7d9359d14f..434456386d 100644 --- a/superset/config.py +++ b/superset/config.py @@ -33,20 +33,7 @@ import sys from collections import OrderedDict from datetime import timedelta from email.mime.multipart import MIMEMultipart -from typing import ( - Any, - Callable, - Dict, - List, - Literal, - Optional, - Set, - Tuple, - Type, - TYPE_CHECKING, - TypedDict, - Union, -) +from typing import Any, Callable, Literal, TYPE_CHECKING, TypedDict import pkg_resources from cachelib.base import BaseCache @@ -114,17 +101,17 @@ PACKAGE_JSON_FILE = pkg_resources.resource_filename( FAVICONS = [{"href": "/static/assets/images/favicon.png"}] -def _try_json_readversion(filepath: str) -> Optional[str]: +def _try_json_readversion(filepath: str) -> str | None: try: - with open(filepath, "r") as f: + with open(filepath) as f: return json.load(f).get("version") except Exception: # pylint: disable=broad-except return None -def _try_json_readsha(filepath: str, length: int) -> Optional[str]: +def _try_json_readsha(filepath: str, length: int) -> str | None: try: - with open(filepath, "r") as f: + with open(filepath) as f: return json.load(f).get("GIT_SHA")[:length] except Exception: # pylint: disable=broad-except return None @@ -275,7 +262,7 @@ ENABLE_PROXY_FIX = False PROXY_FIX_CONFIG = {"x_for": 1, "x_proto": 1, "x_host": 1, "x_port": 1, "x_prefix": 1} # Configuration for scheduling queries from SQL Lab. -SCHEDULED_QUERIES: Dict[str, Any] = {} +SCHEDULED_QUERIES: dict[str, Any] = {} # ------------------------------ # GLOBALS FOR APP Builder @@ -294,7 +281,7 @@ LOGO_TARGET_PATH = None LOGO_TOOLTIP = "" # Specify any text that should appear to the right of the logo -LOGO_RIGHT_TEXT: Union[Callable[[], str], str] = "" +LOGO_RIGHT_TEXT: Callable[[], str] | str = "" # Enables SWAGGER UI for superset openapi spec # ex: http://localhost:8080/swagger/v1 @@ -347,7 +334,7 @@ AUTH_TYPE = AUTH_DB # Grant public role the same set of permissions as for a selected builtin role. # This is useful if one wants to enable anonymous users to view # dashboards. Explicit grant on specific datasets is still required. -PUBLIC_ROLE_LIKE: Optional[str] = None +PUBLIC_ROLE_LIKE: str | None = None # --------------------------------------------------- # Babel config for translations @@ -390,8 +377,8 @@ LANGUAGES = {} class D3Format(TypedDict, total=False): decimal: str thousands: str - grouping: List[int] - currency: List[str] + grouping: list[int] + currency: list[str] D3_FORMAT: D3Format = {} @@ -404,7 +391,7 @@ D3_FORMAT: D3Format = {} # For example, DEFAULT_FEATURE_FLAGS = { 'FOO': True, 'BAR': False } here # and FEATURE_FLAGS = { 'BAR': True, 'BAZ': True } in superset_config.py # will result in combined feature flags of { 'FOO': True, 'BAR': True, 'BAZ': True } -DEFAULT_FEATURE_FLAGS: Dict[str, bool] = { +DEFAULT_FEATURE_FLAGS: dict[str, bool] = { # Experimental feature introducing a client (browser) cache "CLIENT_CACHE": False, # deprecated "DISABLE_DATASET_SOURCE_EDIT": False, # deprecated @@ -527,7 +514,7 @@ DEFAULT_FEATURE_FLAGS.update( ) # This is merely a default. -FEATURE_FLAGS: Dict[str, bool] = {} +FEATURE_FLAGS: dict[str, bool] = {} # A function that receives a dict of all feature flags # (DEFAULT_FEATURE_FLAGS merged with FEATURE_FLAGS) @@ -543,7 +530,7 @@ FEATURE_FLAGS: Dict[str, bool] = {} # if hasattr(g, "user") and g.user.is_active: # feature_flags_dict['some_feature'] = g.user and g.user.get_id() == 5 # return feature_flags_dict -GET_FEATURE_FLAGS_FUNC: Optional[Callable[[Dict[str, bool]], Dict[str, bool]]] = None +GET_FEATURE_FLAGS_FUNC: Callable[[dict[str, bool]], dict[str, bool]] | None = None # A function that receives a feature flag name and an optional default value. # Has a similar utility to GET_FEATURE_FLAGS_FUNC but it's useful to not force the # evaluation of all feature flags when just evaluating a single one. @@ -551,7 +538,7 @@ GET_FEATURE_FLAGS_FUNC: Optional[Callable[[Dict[str, bool]], Dict[str, bool]]] = # Note that the default `get_feature_flags` will evaluate each feature with this # callable when the config key is set, so don't use both GET_FEATURE_FLAGS_FUNC # and IS_FEATURE_ENABLED_FUNC in conjunction. -IS_FEATURE_ENABLED_FUNC: Optional[Callable[[str, Optional[bool]], bool]] = None +IS_FEATURE_ENABLED_FUNC: Callable[[str, bool | None], bool] | None = None # A function that expands/overrides the frontend `bootstrap_data.common` object. # Can be used to implement custom frontend functionality, # or dynamically change certain configs. @@ -563,7 +550,7 @@ IS_FEATURE_ENABLED_FUNC: Optional[Callable[[str, Optional[bool]], bool]] = None # Takes as a parameter the common bootstrap payload before transformations. # Returns a dict containing data that should be added or overridden to the payload. COMMON_BOOTSTRAP_OVERRIDES_FUNC: Callable[ - [Dict[str, Any]], Dict[str, Any] + [dict[str, Any]], dict[str, Any] ] = lambda data: {} # default: empty dict # EXTRA_CATEGORICAL_COLOR_SCHEMES is used for adding custom categorical color schemes @@ -580,7 +567,7 @@ COMMON_BOOTSTRAP_OVERRIDES_FUNC: Callable[ # }] # This is merely a default -EXTRA_CATEGORICAL_COLOR_SCHEMES: List[Dict[str, Any]] = [] +EXTRA_CATEGORICAL_COLOR_SCHEMES: list[dict[str, Any]] = [] # THEME_OVERRIDES is used for adding custom theme to superset # example code for "My theme" custom scheme @@ -599,7 +586,7 @@ EXTRA_CATEGORICAL_COLOR_SCHEMES: List[Dict[str, Any]] = [] # } # } -THEME_OVERRIDES: Dict[str, Any] = {} +THEME_OVERRIDES: dict[str, Any] = {} # EXTRA_SEQUENTIAL_COLOR_SCHEMES is used for adding custom sequential color schemes # EXTRA_SEQUENTIAL_COLOR_SCHEMES = [ @@ -615,7 +602,7 @@ THEME_OVERRIDES: Dict[str, Any] = {} # }] # This is merely a default -EXTRA_SEQUENTIAL_COLOR_SCHEMES: List[Dict[str, Any]] = [] +EXTRA_SEQUENTIAL_COLOR_SCHEMES: list[dict[str, Any]] = [] # --------------------------------------------------- # Thumbnail config (behind feature flag) @@ -626,7 +613,7 @@ EXTRA_SEQUENTIAL_COLOR_SCHEMES: List[Dict[str, Any]] = [] # `superset.tasks.types.ExecutorType` for a full list of executor options. # To always use a fixed user account, use the following configuration: # THUMBNAIL_EXECUTE_AS = [ExecutorType.SELENIUM] -THUMBNAIL_SELENIUM_USER: Optional[str] = "admin" +THUMBNAIL_SELENIUM_USER: str | None = "admin" THUMBNAIL_EXECUTE_AS = [ExecutorType.CURRENT_USER, ExecutorType.SELENIUM] # By default, thumbnail digests are calculated based on various parameters in the @@ -639,10 +626,10 @@ THUMBNAIL_EXECUTE_AS = [ExecutorType.CURRENT_USER, ExecutorType.SELENIUM] # `THUMBNAIL_EXECUTE_AS`; the executor is only equal to the currently logged in # user if the executor type is equal to `ExecutorType.CURRENT_USER`) # and return the final digest string: -THUMBNAIL_DASHBOARD_DIGEST_FUNC: Optional[ +THUMBNAIL_DASHBOARD_DIGEST_FUNC: None | ( Callable[[Dashboard, ExecutorType, str], str] -] = None -THUMBNAIL_CHART_DIGEST_FUNC: Optional[Callable[[Slice, ExecutorType, str], str]] = None +) = None +THUMBNAIL_CHART_DIGEST_FUNC: Callable[[Slice, ExecutorType, str], str] | None = None THUMBNAIL_CACHE_CONFIG: CacheConfig = { "CACHE_TYPE": "NullCache", @@ -714,7 +701,7 @@ STORE_CACHE_KEYS_IN_METADATA_DB = False # CORS Options ENABLE_CORS = False -CORS_OPTIONS: Dict[Any, Any] = {} +CORS_OPTIONS: dict[Any, Any] = {} # Sanitizes the HTML content used in markdowns to allow its rendering in a safe manner. # Disabling this option is not recommended for security reasons. If you wish to allow @@ -736,7 +723,7 @@ HTML_SANITIZATION = True # } # } # Be careful when extending the default schema to avoid XSS attacks. -HTML_SANITIZATION_SCHEMA_EXTENSIONS: Dict[str, Any] = {} +HTML_SANITIZATION_SCHEMA_EXTENSIONS: dict[str, Any] = {} # Chrome allows up to 6 open connections per domain at a time. When there are more # than 6 slices in dashboard, a lot of time fetch requests are queued up and wait for @@ -768,13 +755,13 @@ EXCEL_EXPORT = {"encoding": "utf-8"} # time grains in superset/db_engine_specs/base.py). # For example: to disable 1 second time grain: # TIME_GRAIN_DENYLIST = ['PT1S'] -TIME_GRAIN_DENYLIST: List[str] = [] +TIME_GRAIN_DENYLIST: list[str] = [] # Additional time grains to be supported using similar definitions as in # superset/db_engine_specs/base.py. # For example: To add a new 2 second time grain: # TIME_GRAIN_ADDONS = {'PT2S': '2 second'} -TIME_GRAIN_ADDONS: Dict[str, str] = {} +TIME_GRAIN_ADDONS: dict[str, str] = {} # Implementation of additional time grains per engine. # The column to be truncated is denoted `{col}` in the expression. @@ -784,7 +771,7 @@ TIME_GRAIN_ADDONS: Dict[str, str] = {} # 'PT2S': 'toDateTime(intDiv(toUInt32(toDateTime({col})), 2)*2)' # } # } -TIME_GRAIN_ADDON_EXPRESSIONS: Dict[str, Dict[str, str]] = {} +TIME_GRAIN_ADDON_EXPRESSIONS: dict[str, dict[str, str]] = {} # --------------------------------------------------- # List of viz_types not allowed in your environment @@ -792,7 +779,7 @@ TIME_GRAIN_ADDON_EXPRESSIONS: Dict[str, Dict[str, str]] = {} # VIZ_TYPE_DENYLIST = ['pivot_table', 'treemap'] # --------------------------------------------------- -VIZ_TYPE_DENYLIST: List[str] = [] +VIZ_TYPE_DENYLIST: list[str] = [] # -------------------------------------------------- # Modules, datasources and middleware to be registered @@ -802,8 +789,8 @@ DEFAULT_MODULE_DS_MAP = OrderedDict( ("superset.connectors.sqla.models", ["SqlaTable"]), ] ) -ADDITIONAL_MODULE_DS_MAP: Dict[str, List[str]] = {} -ADDITIONAL_MIDDLEWARE: List[Callable[..., Any]] = [] +ADDITIONAL_MODULE_DS_MAP: dict[str, list[str]] = {} +ADDITIONAL_MIDDLEWARE: list[Callable[..., Any]] = [] # 1) https://docs.python-guide.org/writing/logging/ # 2) https://docs.python.org/2/library/logging.config.html @@ -925,9 +912,9 @@ CELERY_CONFIG = CeleryConfig # pylint: disable=invalid-name # within the app # OVERRIDE_HTTP_HEADERS: sets override values for HTTP headers. These values will # override anything set within the app -DEFAULT_HTTP_HEADERS: Dict[str, Any] = {} -OVERRIDE_HTTP_HEADERS: Dict[str, Any] = {} -HTTP_HEADERS: Dict[str, Any] = {} +DEFAULT_HTTP_HEADERS: dict[str, Any] = {} +OVERRIDE_HTTP_HEADERS: dict[str, Any] = {} +HTTP_HEADERS: dict[str, Any] = {} # The db id here results in selecting this one as a default in SQL Lab DEFAULT_DB_ID = None @@ -974,8 +961,8 @@ SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT = int(timedelta(seconds=10).total_seconds()) # return out # # QUERY_COST_FORMATTERS_BY_ENGINE: {"postgresql": postgres_query_cost_formatter} -QUERY_COST_FORMATTERS_BY_ENGINE: Dict[ - str, Callable[[List[Dict[str, Any]]], List[Dict[str, Any]]] +QUERY_COST_FORMATTERS_BY_ENGINE: dict[ + str, Callable[[list[dict[str, Any]]], list[dict[str, Any]]] ] = {} # Flag that controls if limit should be enforced on the CTA (create table as queries). @@ -1000,13 +987,13 @@ SQLLAB_CTAS_NO_LIMIT = False # else: # return f'tmp_{schema}' # Function accepts database object, user object, schema name and sql that will be run. -SQLLAB_CTAS_SCHEMA_NAME_FUNC: Optional[ +SQLLAB_CTAS_SCHEMA_NAME_FUNC: None | ( Callable[[Database, models.User, str, str], str] -] = None +) = None # If enabled, it can be used to store the results of long-running queries # in SQL Lab by using the "Run Async" button/feature -RESULTS_BACKEND: Optional[BaseCache] = None +RESULTS_BACKEND: BaseCache | None = None # Use PyArrow and MessagePack for async query results serialization, # rather than JSON. This feature requires additional testing from the @@ -1028,7 +1015,7 @@ CSV_TO_HIVE_UPLOAD_DIRECTORY = "EXTERNAL_HIVE_TABLES/" def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name database: Database, user: models.User, # pylint: disable=unused-argument - schema: Optional[str], + schema: str | None, ) -> str: # Note the final empty path enforces a trailing slash. return os.path.join( @@ -1038,14 +1025,14 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name # The namespace within hive where the tables created from # uploading CSVs will be stored. -UPLOADED_CSV_HIVE_NAMESPACE: Optional[str] = None +UPLOADED_CSV_HIVE_NAMESPACE: str | None = None # Function that computes the allowed schemas for the CSV uploads. # Allowed schemas will be a union of schemas_allowed_for_file_upload # db configuration and a result of this function. # mypy doesn't catch that if case ensures list content being always str -ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[[Database, models.User], List[str]] = ( +ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[[Database, models.User], list[str]] = ( lambda database, user: [UPLOADED_CSV_HIVE_NAMESPACE] if UPLOADED_CSV_HIVE_NAMESPACE else [] @@ -1062,7 +1049,7 @@ CSV_DEFAULT_NA_NAMES = list(STR_NA_VALUES) # It's important to make sure that the objects exposed (as well as objects attached # to those objets) are harmless. We recommend only exposing simple/pure functions that # return native types. -JINJA_CONTEXT_ADDONS: Dict[str, Callable[..., Any]] = {} +JINJA_CONTEXT_ADDONS: dict[str, Callable[..., Any]] = {} # A dictionary of macro template processors (by engine) that gets merged into global # template processors. The existing template processors get updated with this @@ -1070,7 +1057,7 @@ JINJA_CONTEXT_ADDONS: Dict[str, Callable[..., Any]] = {} # dictionary. The customized addons don't necessarily need to use Jinja templating # language. This allows you to define custom logic to process templates on a per-engine # basis. Example value = `{"presto": CustomPrestoTemplateProcessor}` -CUSTOM_TEMPLATE_PROCESSORS: Dict[str, Type[BaseTemplateProcessor]] = {} +CUSTOM_TEMPLATE_PROCESSORS: dict[str, type[BaseTemplateProcessor]] = {} # Roles that are controlled by the API / Superset and should not be changes # by humans. @@ -1125,7 +1112,7 @@ PERMISSION_INSTRUCTIONS_LINK = "" # Integrate external Blueprints to the app by passing them to your # configuration. These blueprints will get integrated in the app -BLUEPRINTS: List[Blueprint] = [] +BLUEPRINTS: list[Blueprint] = [] # Provide a callable that receives a tracking_url and returns another # URL. This is used to translate internal Hadoop job tracker URL @@ -1142,7 +1129,7 @@ TRACKING_URL_TRANSFORMER = lambda url: url # customize the polling time of each engine -DB_POLL_INTERVAL_SECONDS: Dict[str, int] = {} +DB_POLL_INTERVAL_SECONDS: dict[str, int] = {} # Interval between consecutive polls when using Presto Engine # See here: https://github.com/dropbox/PyHive/blob/8eb0aeab8ca300f3024655419b93dad926c1a351/pyhive/presto.py#L93 # pylint: disable=line-too-long,useless-suppression @@ -1159,7 +1146,7 @@ PRESTO_POLL_INTERVAL = int(timedelta(seconds=1).total_seconds()) # "another_auth_method": auth_method, # }, # } -ALLOWED_EXTRA_AUTHENTICATIONS: Dict[str, Dict[str, Callable[..., Any]]] = {} +ALLOWED_EXTRA_AUTHENTICATIONS: dict[str, dict[str, Callable[..., Any]]] = {} # The id of a template dashboard that should be copied to every new user DASHBOARD_TEMPLATE_ID = None @@ -1224,14 +1211,14 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument # Owners, filters for created_by, etc. # The users can also be excluded by overriding the get_exclude_users_from_lists method # in security manager -EXCLUDE_USERS_FROM_LISTS: Optional[List[str]] = None +EXCLUDE_USERS_FROM_LISTS: list[str] | None = None # For database connections, this dictionary will remove engines from the available # list/dropdown if you do not want these dbs to show as available. # The available list is generated by driver installed, and some engines have multiple # drivers. # e.g., DBS_AVAILABLE_DENYLIST: Dict[str, Set[str]] = {"databricks": {"pyhive", "pyodbc"}} -DBS_AVAILABLE_DENYLIST: Dict[str, Set[str]] = {} +DBS_AVAILABLE_DENYLIST: dict[str, set[str]] = {} # This auth provider is used by background (offline) tasks that need to access # protected resources. Can be overridden by end users in order to support @@ -1261,7 +1248,7 @@ ALERT_REPORTS_WORKING_TIME_OUT_KILL = True # ExecutorType.OWNER, # ExecutorType.SELENIUM, # ] -ALERT_REPORTS_EXECUTE_AS: List[ExecutorType] = [ExecutorType.OWNER] +ALERT_REPORTS_EXECUTE_AS: list[ExecutorType] = [ExecutorType.OWNER] # if ALERT_REPORTS_WORKING_TIME_OUT_KILL is True, set a celery hard timeout # Equal to working timeout + ALERT_REPORTS_WORKING_TIME_OUT_LAG ALERT_REPORTS_WORKING_TIME_OUT_LAG = int(timedelta(seconds=10).total_seconds()) @@ -1286,7 +1273,7 @@ EMAIL_REPORTS_SUBJECT_PREFIX = "[Report] " EMAIL_REPORTS_CTA = "Explore in Superset" # Slack API token for the superset reports, either string or callable -SLACK_API_TOKEN: Optional[Union[Callable[[], str], str]] = None +SLACK_API_TOKEN: Callable[[], str] | str | None = None SLACK_PROXY = None # The webdriver to use for generating reports. Use one of the following @@ -1310,7 +1297,7 @@ WEBDRIVER_WINDOW = { WEBDRIVER_AUTH_FUNC = None # Any config options to be passed as-is to the webdriver -WEBDRIVER_CONFIGURATION: Dict[Any, Any] = {"service_log_path": "/dev/null"} +WEBDRIVER_CONFIGURATION: dict[Any, Any] = {"service_log_path": "/dev/null"} # Additional args to be passed as arguments to the config object # Note: If using Chrome, you'll want to add the "--marionette" arg. @@ -1353,7 +1340,7 @@ SQL_VALIDATORS_BY_ENGINE = { # displayed prominently in the "Add Database" dialog. You should # use the "engine_name" attribute of the corresponding DB engine spec # in `superset/db_engine_specs/`. -PREFERRED_DATABASES: List[str] = [ +PREFERRED_DATABASES: list[str] = [ "PostgreSQL", "Presto", "MySQL", @@ -1386,7 +1373,7 @@ TALISMAN_CONFIG = { # SESSION_COOKIE_HTTPONLY = True # Prevent cookie from being read by frontend JS? SESSION_COOKIE_SECURE = False # Prevent cookie from being transmitted over non-tls? -SESSION_COOKIE_SAMESITE: Optional[Literal["None", "Lax", "Strict"]] = "Lax" +SESSION_COOKIE_SAMESITE: Literal["None", "Lax", "Strict"] | None = "Lax" # Accepts None, "basic" and "strong", more details on: https://flask-login.readthedocs.io/en/latest/#session-protection SESSION_PROTECTION = "strong" @@ -1418,7 +1405,7 @@ DATASET_IMPORT_ALLOWED_DATA_URLS = [r".*"] # Path used to store SSL certificates that are generated when using custom certs. # Defaults to temporary directory. # Example: SSL_CERT_PATH = "/certs" -SSL_CERT_PATH: Optional[str] = None +SSL_CERT_PATH: str | None = None # SQLA table mutator, every time we fetch the metadata for a certain table # (superset.connectors.sqla.models.SqlaTable), we call this hook @@ -1443,9 +1430,9 @@ GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT = 1000 GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT_FIREHOSE = 1000000 GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME = "async-token" GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SECURE = False -GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SAMESITE: Optional[ +GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SAMESITE: None | ( Literal["None", "Lax", "Strict"] -] = None +) = None GLOBAL_ASYNC_QUERIES_JWT_COOKIE_DOMAIN = None GLOBAL_ASYNC_QUERIES_JWT_SECRET = "test-secret-change-me" GLOBAL_ASYNC_QUERIES_TRANSPORT = "polling" @@ -1461,7 +1448,7 @@ GUEST_TOKEN_JWT_ALGO = "HS256" GUEST_TOKEN_HEADER_NAME = "X-GuestToken" GUEST_TOKEN_JWT_EXP_SECONDS = 300 # 5 minutes # Guest token audience for the embedded superset, either string or callable -GUEST_TOKEN_JWT_AUDIENCE: Optional[Union[Callable[[], str], str]] = None +GUEST_TOKEN_JWT_AUDIENCE: Callable[[], str] | str | None = None # A SQL dataset health check. Note if enabled it is strongly advised that the callable # be memoized to aid with performance, i.e., @@ -1492,7 +1479,7 @@ GUEST_TOKEN_JWT_AUDIENCE: Optional[Union[Callable[[], str], str]] = None # cache_manager.cache.delete_memoized(func) # cache_manager.cache.set(name, code, timeout=0) # -DATASET_HEALTH_CHECK: Optional[Callable[["SqlaTable"], str]] = None +DATASET_HEALTH_CHECK: Callable[[SqlaTable], str] | None = None # Do not show user info or profile in the menu MENU_HIDE_USER_INFO = False @@ -1502,7 +1489,7 @@ MENU_HIDE_USER_INFO = False ENABLE_BROAD_ACTIVITY_ACCESS = True # the advanced data type key should correspond to that set in the column metadata -ADVANCED_DATA_TYPES: Dict[str, AdvancedDataType] = { +ADVANCED_DATA_TYPES: dict[str, AdvancedDataType] = { "internet_address": internet_address, "port": internet_port, } @@ -1514,9 +1501,9 @@ ADVANCED_DATA_TYPES: Dict[str, AdvancedDataType] = { # "Xyz", # [{"col": 'created_by', "opr": 'rel_o_m', "value": 10}], # ) -WELCOME_PAGE_LAST_TAB: Union[ - Literal["examples", "all"], Tuple[str, List[Dict[str, Any]]] -] = "all" +WELCOME_PAGE_LAST_TAB: ( + Literal["examples", "all"] | tuple[str, list[dict[str, Any]]] +) = "all" # Configuration for environment tag shown on the navbar. Setting 'text' to '' will hide the tag. # 'color' can either be a hex color code, or a dot-indexed theme color (e.g. error.base) diff --git a/superset/connectors/base/models.py b/superset/connectors/base/models.py index 2cb0d54c51..d43d078639 100644 --- a/superset/connectors/base/models.py +++ b/superset/connectors/base/models.py @@ -16,21 +16,12 @@ # under the License. from __future__ import annotations +import builtins import json +from collections.abc import Hashable from datetime import datetime from enum import Enum -from typing import ( - Any, - Dict, - Hashable, - List, - Optional, - Set, - Tuple, - Type, - TYPE_CHECKING, - Union, -) +from typing import Any, TYPE_CHECKING from flask_appbuilder.security.sqla.models import User from flask_babel import gettext as __ @@ -89,23 +80,23 @@ class BaseDatasource( # --------------------------------------------------------------- # class attributes to define when deriving BaseDatasource # --------------------------------------------------------------- - __tablename__: Optional[str] = None # {connector_name}_datasource - baselink: Optional[str] = None # url portion pointing to ModelView endpoint + __tablename__: str | None = None # {connector_name}_datasource + baselink: str | None = None # url portion pointing to ModelView endpoint @property - def column_class(self) -> Type["BaseColumn"]: + def column_class(self) -> type[BaseColumn]: # link to derivative of BaseColumn raise NotImplementedError() @property - def metric_class(self) -> Type["BaseMetric"]: + def metric_class(self) -> type[BaseMetric]: # link to derivative of BaseMetric raise NotImplementedError() - owner_class: Optional[User] = None + owner_class: User | None = None # Used to do code highlighting when displaying the query in the UI - query_language: Optional[str] = None + query_language: str | None = None # Only some datasources support Row Level Security is_rls_supported: bool = False @@ -131,9 +122,9 @@ class BaseDatasource( is_managed_externally = Column(Boolean, nullable=False, default=False) external_url = Column(Text, nullable=True) - sql: Optional[str] = None - owners: List[User] - update_from_object_fields: List[str] + sql: str | None = None + owners: list[User] + update_from_object_fields: list[str] extra_import_fields = ["is_managed_externally", "external_url"] @@ -142,7 +133,7 @@ class BaseDatasource( return DatasourceKind.VIRTUAL if self.sql else DatasourceKind.PHYSICAL @property - def owners_data(self) -> List[Dict[str, Any]]: + def owners_data(self) -> list[dict[str, Any]]: return [ { "first_name": o.first_name, @@ -167,8 +158,8 @@ class BaseDatasource( ), ) - columns: List["BaseColumn"] = [] - metrics: List["BaseMetric"] = [] + columns: list[BaseColumn] = [] + metrics: list[BaseMetric] = [] @property def type(self) -> str: @@ -180,11 +171,11 @@ class BaseDatasource( return f"{self.id}__{self.type}" @property - def column_names(self) -> List[str]: + def column_names(self) -> list[str]: return sorted([c.column_name for c in self.columns], key=lambda x: x or "") @property - def columns_types(self) -> Dict[str, str]: + def columns_types(self) -> dict[str, str]: return {c.column_name: c.type for c in self.columns} @property @@ -196,26 +187,26 @@ class BaseDatasource( raise NotImplementedError() @property - def connection(self) -> Optional[str]: + def connection(self) -> str | None: """String representing the context of the Datasource""" return None @property - def schema(self) -> Optional[str]: + def schema(self) -> str | None: """String representing the schema of the Datasource (if it applies)""" return None @property - def filterable_column_names(self) -> List[str]: + def filterable_column_names(self) -> list[str]: return sorted([c.column_name for c in self.columns if c.filterable]) @property - def dttm_cols(self) -> List[str]: + def dttm_cols(self) -> list[str]: return [] @property def url(self) -> str: - return "/{}/edit/{}".format(self.baselink, self.id) + return f"/{self.baselink}/edit/{self.id}" @property def explore_url(self) -> str: @@ -224,10 +215,10 @@ class BaseDatasource( return f"/explore/?datasource_type={self.type}&datasource_id={self.id}" @property - def column_formats(self) -> Dict[str, Optional[str]]: + def column_formats(self) -> dict[str, str | None]: return {m.metric_name: m.d3format for m in self.metrics if m.d3format} - def add_missing_metrics(self, metrics: List["BaseMetric"]) -> None: + def add_missing_metrics(self, metrics: list[BaseMetric]) -> None: existing_metrics = {m.metric_name for m in self.metrics} for metric in metrics: if metric.metric_name not in existing_metrics: @@ -235,7 +226,7 @@ class BaseDatasource( self.metrics.append(metric) @property - def short_data(self) -> Dict[str, Any]: + def short_data(self) -> dict[str, Any]: """Data representation of the datasource sent to the frontend""" return { "edit_url": self.url, @@ -249,11 +240,11 @@ class BaseDatasource( } @property - def select_star(self) -> Optional[str]: + def select_star(self) -> str | None: pass @property - def order_by_choices(self) -> List[Tuple[str, str]]: + def order_by_choices(self) -> list[tuple[str, str]]: choices = [] # self.column_names return sorted column_names for column_name in self.column_names: @@ -267,7 +258,7 @@ class BaseDatasource( return choices @property - def verbose_map(self) -> Dict[str, str]: + def verbose_map(self) -> dict[str, str]: verb_map = {"__timestamp": "Time"} verb_map.update( {o.metric_name: o.verbose_name or o.metric_name for o in self.metrics} @@ -278,7 +269,7 @@ class BaseDatasource( return verb_map @property - def data(self) -> Dict[str, Any]: + def data(self) -> dict[str, Any]: """Data representation of the datasource sent to the frontend""" return { # simple fields @@ -313,8 +304,8 @@ class BaseDatasource( } def data_for_slices( # pylint: disable=too-many-locals - self, slices: List[Slice] - ) -> Dict[str, Any]: + self, slices: list[Slice] + ) -> dict[str, Any]: """ The representation of the datasource containing only the required data to render the provided slices. @@ -381,8 +372,8 @@ class BaseDatasource( if metric["metric_name"] in metric_names ] - filtered_columns: List[Column] = [] - column_types: Set[GenericDataType] = set() + filtered_columns: list[Column] = [] + column_types: set[GenericDataType] = set() for column in data["columns"]: generic_type = column.get("type_generic") if generic_type is not None: @@ -413,18 +404,18 @@ class BaseDatasource( @staticmethod def filter_values_handler( # pylint: disable=too-many-arguments - values: Optional[FilterValues], + values: FilterValues | None, operator: str, target_generic_type: GenericDataType, - target_native_type: Optional[str] = None, + target_native_type: str | None = None, is_list_target: bool = False, - db_engine_spec: Optional[Type[BaseEngineSpec]] = None, - db_extra: Optional[Dict[str, Any]] = None, - ) -> Optional[FilterValues]: + db_engine_spec: builtins.type[BaseEngineSpec] | None = None, + db_extra: dict[str, Any] | None = None, + ) -> FilterValues | None: if values is None: return None - def handle_single_value(value: Optional[FilterValue]) -> Optional[FilterValue]: + def handle_single_value(value: FilterValue | None) -> FilterValue | None: if operator == utils.FilterOperator.TEMPORAL_RANGE: return value if ( @@ -464,7 +455,7 @@ class BaseDatasource( values = values[0] if values else None return values - def external_metadata(self) -> List[Dict[str, str]]: + def external_metadata(self) -> list[dict[str, str]]: """Returns column information from the external system""" raise NotImplementedError() @@ -483,7 +474,7 @@ class BaseDatasource( """ raise NotImplementedError() - def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]: + def values_for_column(self, column_name: str, limit: int = 10000) -> list[Any]: """Given a column, returns an iterable of distinct values This is used to populate the dropdown showing a list of @@ -494,7 +485,7 @@ class BaseDatasource( def default_query(qry: Query) -> Query: return qry - def get_column(self, column_name: Optional[str]) -> Optional["BaseColumn"]: + def get_column(self, column_name: str | None) -> BaseColumn | None: if not column_name: return None for col in self.columns: @@ -504,11 +495,11 @@ class BaseDatasource( @staticmethod def get_fk_many_from_list( - object_list: List[Any], - fkmany: List[Column], - fkmany_class: Type[Union["BaseColumn", "BaseMetric"]], + object_list: list[Any], + fkmany: list[Column], + fkmany_class: builtins.type[BaseColumn | BaseMetric], key_attr: str, - ) -> List[Column]: + ) -> list[Column]: """Update ORM one-to-many list from object list Used for syncing metrics and columns using the same code""" @@ -541,7 +532,7 @@ class BaseDatasource( fkmany += new_fks return fkmany - def update_from_object(self, obj: Dict[str, Any]) -> None: + def update_from_object(self, obj: dict[str, Any]) -> None: """Update datasource from a data structure The UI's table editor crafts a complex data structure that @@ -578,7 +569,7 @@ class BaseDatasource( def get_extra_cache_keys( # pylint: disable=no-self-use self, query_obj: QueryObjectDict # pylint: disable=unused-argument - ) -> List[Hashable]: + ) -> list[Hashable]: """If a datasource needs to provide additional keys for calculation of cache keys, those can be provided via this method @@ -607,14 +598,14 @@ class BaseDatasource( @classmethod def get_datasource_by_name( cls, session: Session, datasource_name: str, schema: str, database_name: str - ) -> Optional["BaseDatasource"]: + ) -> BaseDatasource | None: raise NotImplementedError() class BaseColumn(AuditMixinNullable, ImportExportMixin): """Interface for column""" - __tablename__: Optional[str] = None # {connector_name}_column + __tablename__: str | None = None # {connector_name}_column id = Column(Integer, primary_key=True) column_name = Column(String(255), nullable=False) @@ -628,7 +619,7 @@ class BaseColumn(AuditMixinNullable, ImportExportMixin): is_dttm = None # [optional] Set this to support import/export functionality - export_fields: List[Any] = [] + export_fields: list[Any] = [] def __repr__(self) -> str: return str(self.column_name) @@ -666,7 +657,7 @@ class BaseColumn(AuditMixinNullable, ImportExportMixin): return self.type and any(map(lambda t: t in self.type.upper(), self.bool_types)) @property - def type_generic(self) -> Optional[utils.GenericDataType]: + def type_generic(self) -> utils.GenericDataType | None: if self.is_string: return utils.GenericDataType.STRING if self.is_boolean: @@ -686,7 +677,7 @@ class BaseColumn(AuditMixinNullable, ImportExportMixin): raise NotImplementedError() @property - def data(self) -> Dict[str, Any]: + def data(self) -> dict[str, Any]: attrs = ( "id", "column_name", @@ -705,7 +696,7 @@ class BaseColumn(AuditMixinNullable, ImportExportMixin): class BaseMetric(AuditMixinNullable, ImportExportMixin): """Interface for Metrics""" - __tablename__: Optional[str] = None # {connector_name}_metric + __tablename__: str | None = None # {connector_name}_metric id = Column(Integer, primary_key=True) metric_name = Column(String(255), nullable=False) @@ -730,7 +721,7 @@ class BaseMetric(AuditMixinNullable, ImportExportMixin): """ @property - def perm(self) -> Optional[str]: + def perm(self) -> str | None: raise NotImplementedError() @property @@ -738,7 +729,7 @@ class BaseMetric(AuditMixinNullable, ImportExportMixin): raise NotImplementedError() @property - def data(self) -> Dict[str, Any]: + def data(self) -> dict[str, Any]: attrs = ( "id", "metric_name", diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 8833d6f6cb..41a9c89757 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -22,21 +22,10 @@ import json import logging import re from collections import defaultdict +from collections.abc import Hashable from dataclasses import dataclass, field from datetime import datetime, timedelta -from typing import ( - Any, - Callable, - cast, - Dict, - Hashable, - List, - Optional, - Set, - Tuple, - Type, - Union, -) +from typing import Any, Callable, cast import dateutil.parser import numpy as np @@ -136,9 +125,9 @@ ADDITIVE_METRIC_TYPES_LOWER = {op.lower() for op in ADDITIVE_METRIC_TYPES} @dataclass class MetadataResult: - added: List[str] = field(default_factory=list) - removed: List[str] = field(default_factory=list) - modified: List[str] = field(default_factory=list) + added: list[str] = field(default_factory=list) + removed: list[str] = field(default_factory=list) + modified: list[str] = field(default_factory=list) class AnnotationDatasource(BaseDatasource): @@ -190,7 +179,7 @@ class AnnotationDatasource(BaseDatasource): def get_query_str(self, query_obj: QueryObjectDict) -> str: raise NotImplementedError() - def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]: + def values_for_column(self, column_name: str, limit: int = 10000) -> list[Any]: raise NotImplementedError() @@ -201,7 +190,7 @@ class TableColumn(Model, BaseColumn, CertificationMixin): __tablename__ = "table_columns" __table_args__ = (UniqueConstraint("table_id", "column_name"),) table_id = Column(Integer, ForeignKey("tables.id")) - table: Mapped["SqlaTable"] = relationship( + table: Mapped[SqlaTable] = relationship( "SqlaTable", back_populates="columns", ) @@ -263,15 +252,15 @@ class TableColumn(Model, BaseColumn, CertificationMixin): return self.type_generic == GenericDataType.TEMPORAL @property - def db_engine_spec(self) -> Type[BaseEngineSpec]: + def db_engine_spec(self) -> type[BaseEngineSpec]: return self.table.db_engine_spec @property - def db_extra(self) -> Dict[str, Any]: + def db_extra(self) -> dict[str, Any]: return self.table.database.get_extra() @property - def type_generic(self) -> Optional[utils.GenericDataType]: + def type_generic(self) -> utils.GenericDataType | None: if self.is_dttm: return GenericDataType.TEMPORAL @@ -310,8 +299,8 @@ class TableColumn(Model, BaseColumn, CertificationMixin): def get_sqla_col( self, - label: Optional[str] = None, - template_processor: Optional[BaseTemplateProcessor] = None, + label: str | None = None, + template_processor: BaseTemplateProcessor | None = None, ) -> Column: label = label or self.column_name db_engine_spec = self.db_engine_spec @@ -332,10 +321,10 @@ class TableColumn(Model, BaseColumn, CertificationMixin): def get_timestamp_expression( self, - time_grain: Optional[str], - label: Optional[str] = None, - template_processor: Optional[BaseTemplateProcessor] = None, - ) -> Union[TimestampExpression, Label]: + time_grain: str | None, + label: str | None = None, + template_processor: BaseTemplateProcessor | None = None, + ) -> TimestampExpression | Label: """ Return a SQLAlchemy Core element representation of self to be used in a query. @@ -365,7 +354,7 @@ class TableColumn(Model, BaseColumn, CertificationMixin): return self.table.make_sqla_column_compatible(time_expr, label) @property - def data(self) -> Dict[str, Any]: + def data(self) -> dict[str, Any]: attrs = ( "id", "column_name", @@ -399,7 +388,7 @@ class SqlMetric(Model, BaseMetric, CertificationMixin): __tablename__ = "sql_metrics" __table_args__ = (UniqueConstraint("table_id", "metric_name"),) table_id = Column(Integer, ForeignKey("tables.id")) - table: Mapped["SqlaTable"] = relationship( + table: Mapped[SqlaTable] = relationship( "SqlaTable", back_populates="metrics", ) @@ -425,8 +414,8 @@ class SqlMetric(Model, BaseMetric, CertificationMixin): def get_sqla_col( self, - label: Optional[str] = None, - template_processor: Optional[BaseTemplateProcessor] = None, + label: str | None = None, + template_processor: BaseTemplateProcessor | None = None, ) -> Column: label = label or self.metric_name expression = self.expression @@ -437,7 +426,7 @@ class SqlMetric(Model, BaseMetric, CertificationMixin): return self.table.make_sqla_column_compatible(sqla_col, label) @property - def perm(self) -> Optional[str]: + def perm(self) -> str | None: return ( ("{parent_name}.[{obj.metric_name}](id:{obj.id})").format( obj=self, parent_name=self.table.full_name @@ -446,11 +435,11 @@ class SqlMetric(Model, BaseMetric, CertificationMixin): else None ) - def get_perm(self) -> Optional[str]: + def get_perm(self) -> str | None: return self.perm @property - def data(self) -> Dict[str, Any]: + def data(self) -> dict[str, Any]: attrs = ( "is_certified", "certified_by", @@ -473,11 +462,11 @@ sqlatable_user = Table( def _process_sql_expression( - expression: Optional[str], + expression: str | None, database_id: int, schema: str, - template_processor: Optional[BaseTemplateProcessor] = None, -) -> Optional[str]: + template_processor: BaseTemplateProcessor | None = None, +) -> str | None: if template_processor and expression: expression = template_processor.process_template(expression) if expression: @@ -501,12 +490,12 @@ class SqlaTable( type = "table" query_language = "sql" is_rls_supported = True - columns: Mapped[List[TableColumn]] = relationship( + columns: Mapped[list[TableColumn]] = relationship( TableColumn, back_populates="table", cascade="all, delete-orphan", ) - metrics: Mapped[List[SqlMetric]] = relationship( + metrics: Mapped[list[SqlMetric]] = relationship( SqlMetric, back_populates="table", cascade="all, delete-orphan", @@ -577,11 +566,11 @@ class SqlaTable( return self.name @property - def db_extra(self) -> Dict[str, Any]: + def db_extra(self) -> dict[str, Any]: return self.database.get_extra() @staticmethod - def _apply_cte(sql: str, cte: Optional[str]) -> str: + def _apply_cte(sql: str, cte: str | None) -> str: """ Append a CTE before the SELECT statement if defined @@ -594,7 +583,7 @@ class SqlaTable( return sql @property - def db_engine_spec(self) -> Type[BaseEngineSpec]: + def db_engine_spec(self) -> __builtins__.type[BaseEngineSpec]: return self.database.db_engine_spec @property @@ -637,9 +626,9 @@ class SqlaTable( cls, session: Session, datasource_name: str, - schema: Optional[str], + schema: str | None, database_name: str, - ) -> Optional[SqlaTable]: + ) -> SqlaTable | None: schema = schema or None query = ( session.query(cls) @@ -660,7 +649,7 @@ class SqlaTable( anchor = f'{name}' return Markup(anchor) - def get_schema_perm(self) -> Optional[str]: + def get_schema_perm(self) -> str | None: """Returns schema permission if present, database one otherwise.""" return security_manager.get_schema_perm(self.database, self.schema) @@ -685,18 +674,18 @@ class SqlaTable( ) @property - def dttm_cols(self) -> List[str]: + def dttm_cols(self) -> list[str]: l = [c.column_name for c in self.columns if c.is_dttm] if self.main_dttm_col and self.main_dttm_col not in l: l.append(self.main_dttm_col) return l @property - def num_cols(self) -> List[str]: + def num_cols(self) -> list[str]: return [c.column_name for c in self.columns if c.is_numeric] @property - def any_dttm_col(self) -> Optional[str]: + def any_dttm_col(self) -> str | None: cols = self.dttm_cols return cols[0] if cols else None @@ -713,7 +702,7 @@ class SqlaTable( def sql_url(self) -> str: return self.database.sql_url + "?table_name=" + str(self.table_name) - def external_metadata(self) -> List[Dict[str, str]]: + def external_metadata(self) -> list[dict[str, str]]: # todo(yongjie): create a physical table column type in a separate PR if self.sql: return get_virtual_table_metadata(dataset=self) # type: ignore @@ -724,14 +713,14 @@ class SqlaTable( ) @property - def time_column_grains(self) -> Dict[str, Any]: + def time_column_grains(self) -> dict[str, Any]: return { "time_columns": self.dttm_cols, "time_grains": [grain.name for grain in self.database.grains()], } @property - def select_star(self) -> Optional[str]: + def select_star(self) -> str | None: # show_cols and latest_partition set to false to avoid # the expensive cost of inspecting the DB return self.database.select_star( @@ -739,20 +728,20 @@ class SqlaTable( ) @property - def health_check_message(self) -> Optional[str]: + def health_check_message(self) -> str | None: check = config["DATASET_HEALTH_CHECK"] return check(self) if check else None @property - def granularity_sqla(self) -> List[Tuple[Any, Any]]: + def granularity_sqla(self) -> list[tuple[Any, Any]]: return utils.choicify(self.dttm_cols) @property - def time_grain_sqla(self) -> List[Tuple[Any, Any]]: + def time_grain_sqla(self) -> list[tuple[Any, Any]]: return [(g.duration, g.name) for g in self.database.grains() or []] @property - def data(self) -> Dict[str, Any]: + def data(self) -> dict[str, Any]: data_ = super().data if self.type == "table": data_["granularity_sqla"] = self.granularity_sqla @@ -767,7 +756,7 @@ class SqlaTable( return data_ @property - def extra_dict(self) -> Dict[str, Any]: + def extra_dict(self) -> dict[str, Any]: try: return json.loads(self.extra) except (TypeError, json.JSONDecodeError): @@ -775,7 +764,7 @@ class SqlaTable( def get_fetch_values_predicate( self, - template_processor: Optional[BaseTemplateProcessor] = None, + template_processor: BaseTemplateProcessor | None = None, ) -> TextClause: fetch_values_predicate = self.fetch_values_predicate if template_processor: @@ -792,7 +781,7 @@ class SqlaTable( ) ) from ex - def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]: + def values_for_column(self, column_name: str, limit: int = 10000) -> list[Any]: """Runs query against sqla to retrieve some sample values for the given column. """ @@ -869,8 +858,8 @@ class SqlaTable( return tbl def get_from_clause( - self, template_processor: Optional[BaseTemplateProcessor] = None - ) -> Tuple[Union[TableClause, Alias], Optional[str]]: + self, template_processor: BaseTemplateProcessor | None = None + ) -> tuple[TableClause | Alias, str | None]: """ Return where to select the columns and metrics from. Either a physical table or a virtual table with it's own subquery. If the FROM is referencing a @@ -899,7 +888,7 @@ class SqlaTable( return from_clause, cte def get_rendered_sql( - self, template_processor: Optional[BaseTemplateProcessor] = None + self, template_processor: BaseTemplateProcessor | None = None ) -> str: """ Render sql with template engine (Jinja). @@ -928,8 +917,8 @@ class SqlaTable( def adhoc_metric_to_sqla( self, metric: AdhocMetric, - columns_by_name: Dict[str, TableColumn], - template_processor: Optional[BaseTemplateProcessor] = None, + columns_by_name: dict[str, TableColumn], + template_processor: BaseTemplateProcessor | None = None, ) -> ColumnElement: """ Turn an adhoc metric into a sqlalchemy column. @@ -946,7 +935,7 @@ class SqlaTable( if expression_type == utils.AdhocMetricExpressionType.SIMPLE: metric_column = metric.get("column") or {} column_name = cast(str, metric_column.get("column_name")) - table_column: Optional[TableColumn] = columns_by_name.get(column_name) + table_column: TableColumn | None = columns_by_name.get(column_name) if table_column: sqla_column = table_column.get_sqla_col( template_processor=template_processor @@ -971,7 +960,7 @@ class SqlaTable( self, col: AdhocColumn, force_type_check: bool = False, - template_processor: Optional[BaseTemplateProcessor] = None, + template_processor: BaseTemplateProcessor | None = None, ) -> ColumnElement: """ Turn an adhoc column into a sqlalchemy column. @@ -1021,7 +1010,7 @@ class SqlaTable( return self.make_sqla_column_compatible(sqla_column, label) def make_sqla_column_compatible( - self, sqla_col: ColumnElement, label: Optional[str] = None + self, sqla_col: ColumnElement, label: str | None = None ) -> ColumnElement: """Takes a sqlalchemy column object and adds label info if supported by engine. :param sqla_col: sqlalchemy column instance @@ -1038,7 +1027,7 @@ class SqlaTable( return sqla_col def make_orderby_compatible( - self, select_exprs: List[ColumnElement], orderby_exprs: List[ColumnElement] + self, select_exprs: list[ColumnElement], orderby_exprs: list[ColumnElement] ) -> None: """ If needed, make sure aliases for selected columns are not used in @@ -1069,7 +1058,7 @@ class SqlaTable( def get_sqla_row_level_filters( self, template_processor: BaseTemplateProcessor, - ) -> List[TextClause]: + ) -> list[TextClause]: """ Return the appropriate row level security filters for this table and the current user. A custom username can be passed when the user is not present in the @@ -1078,8 +1067,8 @@ class SqlaTable( :param template_processor: The template processor to apply to the filters. :returns: A list of SQL clauses to be ANDed together. """ - all_filters: List[TextClause] = [] - filter_groups: Dict[Union[int, str], List[TextClause]] = defaultdict(list) + all_filters: list[TextClause] = [] + filter_groups: dict[int | str, list[TextClause]] = defaultdict(list) try: for filter_ in security_manager.get_rls_filters(self): clause = self.text( @@ -1114,9 +1103,9 @@ class SqlaTable( def _get_series_orderby( self, series_limit_metric: Metric, - metrics_by_name: Dict[str, SqlMetric], - columns_by_name: Dict[str, TableColumn], - template_processor: Optional[BaseTemplateProcessor] = None, + metrics_by_name: dict[str, SqlMetric], + columns_by_name: dict[str, TableColumn], + template_processor: BaseTemplateProcessor | None = None, ) -> Column: if utils.is_adhoc_metric(series_limit_metric): assert isinstance(series_limit_metric, dict) @@ -1138,8 +1127,8 @@ class SqlaTable( self, row: pd.Series, dimension: str, - columns_by_name: Dict[str, TableColumn], - ) -> Union[str, int, float, bool, Text]: + columns_by_name: dict[str, TableColumn], + ) -> str | int | float | bool | Text: """ Convert a prequery result type to its equivalent Python type. @@ -1159,7 +1148,7 @@ class SqlaTable( value = value.item() column_ = columns_by_name[dimension] - db_extra: Dict[str, Any] = self.database.get_extra() + db_extra: dict[str, Any] = self.database.get_extra() if column_.type and column_.is_temporal and isinstance(value, str): sql = self.db_engine_spec.convert_dttm( @@ -1174,9 +1163,9 @@ class SqlaTable( def _get_top_groups( self, df: pd.DataFrame, - dimensions: List[str], - groupby_exprs: Dict[str, Any], - columns_by_name: Dict[str, TableColumn], + dimensions: list[str], + groupby_exprs: dict[str, Any], + columns_by_name: dict[str, TableColumn], ) -> ColumnElement: groups = [] for _unused, row in df.iterrows(): @@ -1201,7 +1190,7 @@ class SqlaTable( errors = None error_message = None - def assign_column_label(df: pd.DataFrame) -> Optional[pd.DataFrame]: + def assign_column_label(df: pd.DataFrame) -> pd.DataFrame | None: """ Some engines change the case or generate bespoke column names, either by default or due to lack of support for aliasing. This function ensures that @@ -1283,7 +1272,7 @@ class SqlaTable( else self.columns ) - old_columns_by_name: Dict[str, TableColumn] = { + old_columns_by_name: dict[str, TableColumn] = { col.column_name: col for col in old_columns } results = MetadataResult( @@ -1341,8 +1330,8 @@ class SqlaTable( session: Session, database: Database, datasource_name: str, - schema: Optional[str] = None, - ) -> List[SqlaTable]: + schema: str | None = None, + ) -> list[SqlaTable]: query = ( session.query(cls) .filter_by(database_id=database.id) @@ -1357,9 +1346,9 @@ class SqlaTable( cls, session: Session, database: Database, - permissions: Set[str], - schema_perms: Set[str], - ) -> List[SqlaTable]: + permissions: set[str], + schema_perms: set[str], + ) -> list[SqlaTable]: # TODO(hughhhh): add unit test return ( session.query(cls) @@ -1389,7 +1378,7 @@ class SqlaTable( ) @classmethod - def get_all_datasources(cls, session: Session) -> List[SqlaTable]: + def get_all_datasources(cls, session: Session) -> list[SqlaTable]: qry = session.query(cls) qry = cls.default_query(qry) return qry.all() @@ -1409,7 +1398,7 @@ class SqlaTable( :param query_obj: query object to analyze :return: True if there are call(s) to an `ExtraCache` method, False otherwise """ - templatable_statements: List[str] = [] + templatable_statements: list[str] = [] if self.sql: templatable_statements.append(self.sql) if self.fetch_values_predicate: @@ -1428,7 +1417,7 @@ class SqlaTable( return True return False - def get_extra_cache_keys(self, query_obj: QueryObjectDict) -> List[Hashable]: + def get_extra_cache_keys(self, query_obj: QueryObjectDict) -> list[Hashable]: """ The cache key of a SqlaTable needs to consider any keys added by the parent class and any keys added via `ExtraCache`. @@ -1489,7 +1478,7 @@ class SqlaTable( @staticmethod def update_column( # pylint: disable=unused-argument - mapper: Mapper, connection: Connection, target: Union[SqlMetric, TableColumn] + mapper: Mapper, connection: Connection, target: SqlMetric | TableColumn ) -> None: """ :param mapper: Unused. diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py index 698311dab6..d41c0555d3 100644 --- a/superset/connectors/sqla/utils.py +++ b/superset/connectors/sqla/utils.py @@ -17,19 +17,9 @@ from __future__ import annotations import logging +from collections.abc import Iterable, Iterator from functools import lru_cache -from typing import ( - Any, - Callable, - Dict, - Iterable, - Iterator, - List, - Optional, - Type, - TYPE_CHECKING, - TypeVar, -) +from typing import Any, Callable, TYPE_CHECKING, TypeVar from uuid import UUID from flask_babel import lazy_gettext as _ @@ -58,8 +48,8 @@ if TYPE_CHECKING: def get_physical_table_metadata( database: Database, table_name: str, - schema_name: Optional[str] = None, -) -> List[Dict[str, Any]]: + schema_name: str | None = None, +) -> list[dict[str, Any]]: """Use SQLAlchemy inspector to get table metadata""" db_engine_spec = database.db_engine_spec db_dialect = database.get_dialect() @@ -103,7 +93,7 @@ def get_physical_table_metadata( return cols -def get_virtual_table_metadata(dataset: SqlaTable) -> List[ResultSetColumnType]: +def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]: """Use SQLparser to get virtual dataset metadata""" if not dataset.sql: raise SupersetGenericDBErrorException( @@ -150,7 +140,7 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> List[ResultSetColumnType]: def get_columns_description( database: Database, query: str, -) -> List[ResultSetColumnType]: +) -> list[ResultSetColumnType]: db_engine_spec = database.db_engine_spec try: with database.get_raw_connection() as conn: @@ -171,7 +161,7 @@ def get_dialect_name(drivername: str) -> str: @lru_cache(maxsize=LRU_CACHE_MAX_SIZE) -def get_identifier_quoter(drivername: str) -> Dict[str, Callable[[str], str]]: +def get_identifier_quoter(drivername: str) -> dict[str, Callable[[str], str]]: return SqlaURL.create(drivername).get_dialect()().identifier_preparer.quote @@ -181,9 +171,9 @@ logger = logging.getLogger(__name__) def find_cached_objects_in_session( session: Session, - cls: Type[DeclarativeModel], - ids: Optional[Iterable[int]] = None, - uuids: Optional[Iterable[UUID]] = None, + cls: type[DeclarativeModel], + ids: Iterable[int] | None = None, + uuids: Iterable[UUID] | None = None, ) -> Iterator[DeclarativeModel]: """Find known ORM instances in cached SQLA session states. diff --git a/superset/connectors/sqla/views.py b/superset/connectors/sqla/views.py index 0989a545fd..9116b9636e 100644 --- a/superset/connectors/sqla/views.py +++ b/superset/connectors/sqla/views.py @@ -447,7 +447,7 @@ class TableModelView( # pylint: disable=too-many-ancestors resp = super().edit(pk) if isinstance(resp, str): return resp - return redirect("/explore/?datasource_type=table&datasource_id={}".format(pk)) + return redirect(f"/explore/?datasource_type=table&datasource_id={pk}") @expose("/list/") @has_access diff --git a/superset/css_templates/commands/bulk_delete.py b/superset/css_templates/commands/bulk_delete.py index 93564208c4..57612d9048 100644 --- a/superset/css_templates/commands/bulk_delete.py +++ b/superset/css_templates/commands/bulk_delete.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import List, Optional +from typing import Optional from superset.commands.base import BaseCommand from superset.css_templates.commands.exceptions import ( @@ -30,9 +30,9 @@ logger = logging.getLogger(__name__) class BulkDeleteCssTemplateCommand(BaseCommand): - def __init__(self, model_ids: List[int]): + def __init__(self, model_ids: list[int]): self._model_ids = model_ids - self._models: Optional[List[CssTemplate]] = None + self._models: Optional[list[CssTemplate]] = None def run(self) -> None: self.validate() diff --git a/superset/css_templates/dao.py b/superset/css_templates/dao.py index 1862fb7aaf..bc1a796269 100644 --- a/superset/css_templates/dao.py +++ b/superset/css_templates/dao.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import List, Optional +from typing import Optional from sqlalchemy.exc import SQLAlchemyError @@ -31,7 +31,7 @@ class CssTemplateDAO(BaseDAO): model_cls = CssTemplate @staticmethod - def bulk_delete(models: Optional[List[CssTemplate]], commit: bool = True) -> None: + def bulk_delete(models: Optional[list[CssTemplate]], commit: bool = True) -> None: item_ids = [model.id for model in models] if models else [] try: db.session.query(CssTemplate).filter(CssTemplate.id.in_(item_ids)).delete( diff --git a/superset/dao/base.py b/superset/dao/base.py index d3675a0e17..539dbab2d5 100644 --- a/superset/dao/base.py +++ b/superset/dao/base.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=isinstance-second-argument-not-valid-type -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Optional, Union from flask_appbuilder.models.filters import BaseFilter from flask_appbuilder.models.sqla import Model @@ -37,7 +37,7 @@ class BaseDAO: Base DAO, implement base CRUD sqlalchemy operations """ - model_cls: Optional[Type[Model]] = None + model_cls: Optional[type[Model]] = None """ Child classes need to state the Model class so they don't need to implement basic create, update and delete methods @@ -75,10 +75,10 @@ class BaseDAO: @classmethod def find_by_ids( cls, - model_ids: Union[List[str], List[int]], + model_ids: Union[list[str], list[int]], session: Session = None, skip_base_filter: bool = False, - ) -> List[Model]: + ) -> list[Model]: """ Find a List of models by a list of ids, if defined applies `base_filter` """ @@ -95,7 +95,7 @@ class BaseDAO: return query.all() @classmethod - def find_all(cls) -> List[Model]: + def find_all(cls) -> list[Model]: """ Get all that fit the `base_filter` """ @@ -121,7 +121,7 @@ class BaseDAO: return query.filter_by(**filter_by).one_or_none() @classmethod - def create(cls, properties: Dict[str, Any], commit: bool = True) -> Model: + def create(cls, properties: dict[str, Any], commit: bool = True) -> Model: """ Generic for creating models :raises: DAOCreateFailedError @@ -163,7 +163,7 @@ class BaseDAO: @classmethod def update( - cls, model: Model, properties: Dict[str, Any], commit: bool = True + cls, model: Model, properties: dict[str, Any], commit: bool = True ) -> Model: """ Generic update a model @@ -196,7 +196,7 @@ class BaseDAO: return model @classmethod - def bulk_delete(cls, models: List[Model], commit: bool = True) -> None: + def bulk_delete(cls, models: list[Model], commit: bool = True) -> None: try: for model in models: cls.delete(model, False) diff --git a/superset/dashboards/commands/bulk_delete.py b/superset/dashboards/commands/bulk_delete.py index 13541cd946..385f1fbc6d 100644 --- a/superset/dashboards/commands/bulk_delete.py +++ b/superset/dashboards/commands/bulk_delete.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import List, Optional +from typing import Optional from flask_babel import lazy_gettext as _ @@ -37,9 +37,9 @@ logger = logging.getLogger(__name__) class BulkDeleteDashboardCommand(BaseCommand): - def __init__(self, model_ids: List[int]): + def __init__(self, model_ids: list[int]): self._model_ids = model_ids - self._models: Optional[List[Dashboard]] = None + self._models: Optional[list[Dashboard]] = None def run(self) -> None: self.validate() diff --git a/superset/dashboards/commands/create.py b/superset/dashboards/commands/create.py index 0ad8ddee7c..58acc379ba 100644 --- a/superset/dashboards/commands/create.py +++ b/superset/dashboards/commands/create.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask_appbuilder.models.sqla import Model from marshmallow import ValidationError @@ -34,7 +34,7 @@ logger = logging.getLogger(__name__) class CreateDashboardCommand(CreateMixin, BaseCommand): - def __init__(self, data: Dict[str, Any]): + def __init__(self, data: dict[str, Any]): self._properties = data.copy() def run(self) -> Model: @@ -48,9 +48,9 @@ class CreateDashboardCommand(CreateMixin, BaseCommand): return dashboard def validate(self) -> None: - exceptions: List[ValidationError] = [] - owner_ids: Optional[List[int]] = self._properties.get("owners") - role_ids: Optional[List[int]] = self._properties.get("roles") + exceptions: list[ValidationError] = [] + owner_ids: Optional[list[int]] = self._properties.get("owners") + role_ids: Optional[list[int]] = self._properties.get("roles") slug: str = self._properties.get("slug", "") # Validate slug uniqueness diff --git a/superset/dashboards/commands/export.py b/superset/dashboards/commands/export.py index 886b84ffa6..2e70e29bb0 100644 --- a/superset/dashboards/commands/export.py +++ b/superset/dashboards/commands/export.py @@ -20,7 +20,8 @@ import json import logging import random import string -from typing import Any, Dict, Iterator, Optional, Set, Tuple +from typing import Any, Optional +from collections.abc import Iterator import yaml @@ -52,7 +53,7 @@ def suffix(length: int = 8) -> str: ) -def get_default_position(title: str) -> Dict[str, Any]: +def get_default_position(title: str) -> dict[str, Any]: return { "DASHBOARD_VERSION_KEY": "v2", "ROOT_ID": {"children": ["GRID_ID"], "id": "ROOT_ID", "type": "ROOT"}, @@ -66,7 +67,7 @@ def get_default_position(title: str) -> Dict[str, Any]: } -def append_charts(position: Dict[str, Any], charts: Set[Slice]) -> Dict[str, Any]: +def append_charts(position: dict[str, Any], charts: set[Slice]) -> dict[str, Any]: chart_hashes = [f"CHART-{suffix()}" for _ in charts] # if we have ROOT_ID/GRID_ID, append orphan charts to a new row inside the grid @@ -109,7 +110,7 @@ class ExportDashboardsCommand(ExportModelsCommand): @staticmethod def _export( model: Dashboard, export_related: bool = True - ) -> Iterator[Tuple[str, str]]: + ) -> Iterator[tuple[str, str]]: file_name = get_filename(model.dashboard_title, model.id) file_path = f"dashboards/{file_name}.yaml" diff --git a/superset/dashboards/commands/importers/dispatcher.py b/superset/dashboards/commands/importers/dispatcher.py index dd0121f3e3..d5323b4fe4 100644 --- a/superset/dashboards/commands/importers/dispatcher.py +++ b/superset/dashboards/commands/importers/dispatcher.py @@ -16,7 +16,7 @@ # under the License. import logging -from typing import Any, Dict +from typing import Any from marshmallow.exceptions import ValidationError @@ -43,7 +43,7 @@ class ImportDashboardsCommand(BaseCommand): until it finds one that matches. """ - def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any): + def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any): self.contents = contents self.args = args self.kwargs = kwargs diff --git a/superset/dashboards/commands/importers/v0.py b/superset/dashboards/commands/importers/v0.py index e49c931896..012dbbc5c9 100644 --- a/superset/dashboards/commands/importers/v0.py +++ b/superset/dashboards/commands/importers/v0.py @@ -19,7 +19,7 @@ import logging import time from copy import copy from datetime import datetime -from typing import Any, Dict, Optional +from typing import Any, Optional from flask_babel import lazy_gettext as _ from sqlalchemy.orm import make_transient, Session @@ -83,7 +83,7 @@ def import_chart( def import_dashboard( # pylint: disable=too-many-locals,too-many-statements dashboard_to_import: Dashboard, - dataset_id_mapping: Optional[Dict[int, int]] = None, + dataset_id_mapping: Optional[dict[int, int]] = None, import_time: Optional[int] = None, ) -> int: """Imports the dashboard from the object to the database. @@ -97,7 +97,7 @@ def import_dashboard( """ def alter_positions( - dashboard: Dashboard, old_to_new_slc_id_dict: Dict[int, int] + dashboard: Dashboard, old_to_new_slc_id_dict: dict[int, int] ) -> None: """Updates slice_ids in the position json. @@ -166,7 +166,7 @@ def import_dashboard( dashboard_to_import.slug = None old_json_metadata = json.loads(dashboard_to_import.json_metadata or "{}") - old_to_new_slc_id_dict: Dict[int, int] = {} + old_to_new_slc_id_dict: dict[int, int] = {} new_timed_refresh_immune_slices = [] new_expanded_slices = {} new_filter_scopes = {} @@ -268,7 +268,7 @@ def import_dashboard( return dashboard_to_import.id # type: ignore -def decode_dashboards(o: Dict[str, Any]) -> Any: +def decode_dashboards(o: dict[str, Any]) -> Any: """ Function to be passed into json.loads obj_hook parameter Recreates the dashboard object from a json representation. @@ -302,7 +302,7 @@ def import_dashboards( data = json.loads(content, object_hook=decode_dashboards) if not data: raise DashboardImportException(_("No data in file")) - dataset_id_mapping: Dict[int, int] = {} + dataset_id_mapping: dict[int, int] = {} for table in data["datasources"]: new_dataset_id = import_dataset(table, database_id, import_time=import_time) params = json.loads(table.params) @@ -324,7 +324,7 @@ class ImportDashboardsCommand(BaseCommand): # pylint: disable=unused-argument def __init__( - self, contents: Dict[str, str], database_id: Optional[int] = None, **kwargs: Any + self, contents: dict[str, str], database_id: Optional[int] = None, **kwargs: Any ): self.contents = contents self.database_id = database_id diff --git a/superset/dashboards/commands/importers/v1/__init__.py b/superset/dashboards/commands/importers/v1/__init__.py index 5d83a580bd..597adba6d9 100644 --- a/superset/dashboards/commands/importers/v1/__init__.py +++ b/superset/dashboards/commands/importers/v1/__init__.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Set, Tuple +from typing import Any from marshmallow import Schema from sqlalchemy.orm import Session @@ -47,7 +47,7 @@ class ImportDashboardsCommand(ImportModelsCommand): dao = DashboardDAO model_name = "dashboard" prefix = "dashboards/" - schemas: Dict[str, Schema] = { + schemas: dict[str, Schema] = { "charts/": ImportV1ChartSchema(), "dashboards/": ImportV1DashboardSchema(), "datasets/": ImportV1DatasetSchema(), @@ -59,11 +59,11 @@ class ImportDashboardsCommand(ImportModelsCommand): # pylint: disable=too-many-branches, too-many-locals @staticmethod def _import( - session: Session, configs: Dict[str, Any], overwrite: bool = False + session: Session, configs: dict[str, Any], overwrite: bool = False ) -> None: # discover charts and datasets associated with dashboards - chart_uuids: Set[str] = set() - dataset_uuids: Set[str] = set() + chart_uuids: set[str] = set() + dataset_uuids: set[str] = set() for file_name, config in configs.items(): if file_name.startswith("dashboards/"): chart_uuids.update(find_chart_uuids(config["position"])) @@ -77,20 +77,20 @@ class ImportDashboardsCommand(ImportModelsCommand): dataset_uuids.add(config["dataset_uuid"]) # discover databases associated with datasets - database_uuids: Set[str] = set() + database_uuids: set[str] = set() for file_name, config in configs.items(): if file_name.startswith("datasets/") and config["uuid"] in dataset_uuids: database_uuids.add(config["database_uuid"]) # import related databases - database_ids: Dict[str, int] = {} + database_ids: dict[str, int] = {} for file_name, config in configs.items(): if file_name.startswith("databases/") and config["uuid"] in database_uuids: database = import_database(session, config, overwrite=False) database_ids[str(database.uuid)] = database.id # import datasets with the correct parent ref - dataset_info: Dict[str, Dict[str, Any]] = {} + dataset_info: dict[str, dict[str, Any]] = {} for file_name, config in configs.items(): if ( file_name.startswith("datasets/") @@ -105,7 +105,7 @@ class ImportDashboardsCommand(ImportModelsCommand): } # import charts with the correct parent ref - chart_ids: Dict[str, int] = {} + chart_ids: dict[str, int] = {} for file_name, config in configs.items(): if ( file_name.startswith("charts/") @@ -129,7 +129,7 @@ class ImportDashboardsCommand(ImportModelsCommand): ).fetchall() # import dashboards - dashboard_chart_ids: List[Tuple[int, int]] = [] + dashboard_chart_ids: list[tuple[int, int]] = [] for file_name, config in configs.items(): if file_name.startswith("dashboards/"): config = update_id_refs(config, chart_ids, dataset_info) diff --git a/superset/dashboards/commands/importers/v1/utils.py b/superset/dashboards/commands/importers/v1/utils.py index 9f0ffc36a1..1deb44949a 100644 --- a/superset/dashboards/commands/importers/v1/utils.py +++ b/superset/dashboards/commands/importers/v1/utils.py @@ -17,7 +17,7 @@ import json import logging -from typing import Any, Dict, Set +from typing import Any from flask import g from sqlalchemy.orm import Session @@ -32,12 +32,12 @@ logger = logging.getLogger(__name__) JSON_KEYS = {"position": "position_json", "metadata": "json_metadata"} -def find_chart_uuids(position: Dict[str, Any]) -> Set[str]: +def find_chart_uuids(position: dict[str, Any]) -> set[str]: return set(build_uuid_to_id_map(position)) -def find_native_filter_datasets(metadata: Dict[str, Any]) -> Set[str]: - uuids: Set[str] = set() +def find_native_filter_datasets(metadata: dict[str, Any]) -> set[str]: + uuids: set[str] = set() for native_filter in metadata.get("native_filter_configuration", []): targets = native_filter.get("targets", []) for target in targets: @@ -47,7 +47,7 @@ def find_native_filter_datasets(metadata: Dict[str, Any]) -> Set[str]: return uuids -def build_uuid_to_id_map(position: Dict[str, Any]) -> Dict[str, int]: +def build_uuid_to_id_map(position: dict[str, Any]) -> dict[str, int]: return { child["meta"]["uuid"]: child["meta"]["chartId"] for child in position.values() @@ -60,10 +60,10 @@ def build_uuid_to_id_map(position: Dict[str, Any]) -> Dict[str, int]: def update_id_refs( # pylint: disable=too-many-locals - config: Dict[str, Any], - chart_ids: Dict[str, int], - dataset_info: Dict[str, Dict[str, Any]], -) -> Dict[str, Any]: + config: dict[str, Any], + chart_ids: dict[str, int], + dataset_info: dict[str, dict[str, Any]], +) -> dict[str, Any]: """Update dashboard metadata to use new IDs""" fixed = config.copy() @@ -147,7 +147,7 @@ def update_id_refs( # pylint: disable=too-many-locals def import_dashboard( session: Session, - config: Dict[str, Any], + config: dict[str, Any], overwrite: bool = False, ignore_permissions: bool = False, ) -> Dashboard: diff --git a/superset/dashboards/commands/update.py b/superset/dashboards/commands/update.py index 11833a64be..fefa65e3f6 100644 --- a/superset/dashboards/commands/update.py +++ b/superset/dashboards/commands/update.py @@ -16,7 +16,7 @@ # under the License. import json import logging -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask_appbuilder.models.sqla import Model from marshmallow import ValidationError @@ -41,7 +41,7 @@ logger = logging.getLogger(__name__) class UpdateDashboardCommand(UpdateMixin, BaseCommand): - def __init__(self, model_id: int, data: Dict[str, Any]): + def __init__(self, model_id: int, data: dict[str, Any]): self._model_id = model_id self._properties = data.copy() self._model: Optional[Dashboard] = None @@ -64,9 +64,9 @@ class UpdateDashboardCommand(UpdateMixin, BaseCommand): return dashboard def validate(self) -> None: - exceptions: List[ValidationError] = [] - owners_ids: Optional[List[int]] = self._properties.get("owners") - roles_ids: Optional[List[int]] = self._properties.get("roles") + exceptions: list[ValidationError] = [] + owners_ids: Optional[list[int]] = self._properties.get("owners") + roles_ids: Optional[list[int]] = self._properties.get("roles") slug: Optional[str] = self._properties.get("slug") # Validate/populate model exists diff --git a/superset/dashboards/dao.py b/superset/dashboards/dao.py index 5355d602be..d88fb431b7 100644 --- a/superset/dashboards/dao.py +++ b/superset/dashboards/dao.py @@ -17,7 +17,7 @@ import json import logging from datetime import datetime -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union from flask import g from flask_appbuilder.models.sqla.interface import SQLAInterface @@ -68,12 +68,12 @@ class DashboardDAO(BaseDAO): return dashboard @staticmethod - def get_datasets_for_dashboard(id_or_slug: str) -> List[Any]: + def get_datasets_for_dashboard(id_or_slug: str) -> list[Any]: dashboard = DashboardDAO.get_by_id_or_slug(id_or_slug) return dashboard.datasets_trimmed_for_slices() @staticmethod - def get_charts_for_dashboard(id_or_slug: str) -> List[Slice]: + def get_charts_for_dashboard(id_or_slug: str) -> list[Slice]: return DashboardDAO.get_by_id_or_slug(id_or_slug).slices @staticmethod @@ -173,7 +173,7 @@ class DashboardDAO(BaseDAO): return model @staticmethod - def bulk_delete(models: Optional[List[Dashboard]], commit: bool = True) -> None: + def bulk_delete(models: Optional[list[Dashboard]], commit: bool = True) -> None: item_ids = [model.id for model in models] if models else [] # bulk delete, first delete related data if models: @@ -196,8 +196,8 @@ class DashboardDAO(BaseDAO): @staticmethod def set_dash_metadata( # pylint: disable=too-many-locals dashboard: Dashboard, - data: Dict[Any, Any], - old_to_new_slice_ids: Optional[Dict[int, int]] = None, + data: dict[Any, Any], + old_to_new_slice_ids: Optional[dict[int, int]] = None, commit: bool = False, ) -> Dashboard: new_filter_scopes = {} @@ -235,7 +235,7 @@ class DashboardDAO(BaseDAO): if "filter_scopes" in data: # replace filter_id and immune ids from old slice id to new slice id: # and remove slice ids that are not in dash anymore - slc_id_dict: Dict[int, int] = {} + slc_id_dict: dict[int, int] = {} if old_to_new_slice_ids: slc_id_dict = { old: new @@ -288,7 +288,7 @@ class DashboardDAO(BaseDAO): return dashboard @staticmethod - def favorited_ids(dashboards: List[Dashboard]) -> List[FavStar]: + def favorited_ids(dashboards: list[Dashboard]) -> list[FavStar]: ids = [dash.id for dash in dashboards] return [ star.obj_id @@ -303,7 +303,7 @@ class DashboardDAO(BaseDAO): @classmethod def copy_dashboard( - cls, original_dash: Dashboard, data: Dict[str, Any] + cls, original_dash: Dashboard, data: dict[str, Any] ) -> Dashboard: dash = Dashboard() dash.owners = [g.user] if g.user else [] @@ -311,7 +311,7 @@ class DashboardDAO(BaseDAO): dash.css = data.get("css") metadata = json.loads(data["json_metadata"]) - old_to_new_slice_ids: Dict[int, int] = {} + old_to_new_slice_ids: dict[int, int] = {} if data.get("duplicate_slices"): # Duplicating slices as well, mapping old ids to new ones for slc in original_dash.slices: diff --git a/superset/dashboards/filter_sets/commands/create.py b/superset/dashboards/filter_sets/commands/create.py index de1d70daf7..63c4534786 100644 --- a/superset/dashboards/filter_sets/commands/create.py +++ b/superset/dashboards/filter_sets/commands/create.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict +from typing import Any from flask_appbuilder.models.sqla import Model @@ -40,7 +40,7 @@ logger = logging.getLogger(__name__) class CreateFilterSetCommand(BaseFilterSetCommand): # pylint: disable=C0103 - def __init__(self, dashboard_id: int, data: Dict[str, Any]): + def __init__(self, dashboard_id: int, data: dict[str, Any]): super().__init__(dashboard_id) self._properties = data.copy() diff --git a/superset/dashboards/filter_sets/commands/update.py b/superset/dashboards/filter_sets/commands/update.py index 07d59f93ae..722672d668 100644 --- a/superset/dashboards/filter_sets/commands/update.py +++ b/superset/dashboards/filter_sets/commands/update.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict +from typing import Any from flask_appbuilder.models.sqla import Model @@ -31,7 +31,7 @@ logger = logging.getLogger(__name__) class UpdateFilterSetCommand(BaseFilterSetCommand): - def __init__(self, dashboard_id: int, filter_set_id: int, data: Dict[str, Any]): + def __init__(self, dashboard_id: int, filter_set_id: int, data: dict[str, Any]): super().__init__(dashboard_id) self._filter_set_id = filter_set_id self._properties = data.copy() diff --git a/superset/dashboards/filter_sets/dao.py b/superset/dashboards/filter_sets/dao.py index 949aa6d3fd..5f2b0ba418 100644 --- a/superset/dashboards/filter_sets/dao.py +++ b/superset/dashboards/filter_sets/dao.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict +from typing import Any from flask_appbuilder.models.sqla import Model from sqlalchemy.exc import SQLAlchemyError @@ -40,7 +40,7 @@ class FilterSetDAO(BaseDAO): model_cls = FilterSet @classmethod - def create(cls, properties: Dict[str, Any], commit: bool = True) -> Model: + def create(cls, properties: dict[str, Any], commit: bool = True) -> Model: if cls.model_cls is None: raise DAOConfigError() model = FilterSet() diff --git a/superset/dashboards/filter_sets/schemas.py b/superset/dashboards/filter_sets/schemas.py index c1a13b424e..2309eea99f 100644 --- a/superset/dashboards/filter_sets/schemas.py +++ b/superset/dashboards/filter_sets/schemas.py @@ -14,7 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, cast, Dict, Mapping +from collections.abc import Mapping +from typing import Any, cast from marshmallow import fields, post_load, Schema, ValidationError from marshmallow.validate import Length, OneOf @@ -64,11 +65,11 @@ class FilterSetPostSchema(FilterSetSchema): @post_load def validate( self, data: Mapping[Any, Any], *, many: Any, partial: Any - ) -> Dict[str, Any]: + ) -> dict[str, Any]: self._validate_json_meta_data(data[JSON_METADATA_FIELD]) if data[OWNER_TYPE_FIELD] == USER_OWNER_TYPE and OWNER_ID_FIELD not in data: raise ValidationError("owner_id is mandatory when owner_type is User") - return cast(Dict[str, Any], data) + return cast(dict[str, Any], data) class FilterSetPutSchema(FilterSetSchema): @@ -84,14 +85,14 @@ class FilterSetPutSchema(FilterSetSchema): @post_load def validate( # pylint: disable=unused-argument self, data: Mapping[Any, Any], *, many: Any, partial: Any - ) -> Dict[str, Any]: + ) -> dict[str, Any]: if JSON_METADATA_FIELD in data: self._validate_json_meta_data(data[JSON_METADATA_FIELD]) - return cast(Dict[str, Any], data) + return cast(dict[str, Any], data) -def validate_pair(first_field: str, second_field: str, data: Dict[str, Any]) -> None: +def validate_pair(first_field: str, second_field: str, data: dict[str, Any]) -> None: if first_field in data and second_field not in data: raise ValidationError( - "{} must be included alongside {}".format(first_field, second_field) + f"{first_field} must be included alongside {second_field}" ) diff --git a/superset/dashboards/filter_state/api.py b/superset/dashboards/filter_state/api.py index 7a771d6b54..a1b855ca9e 100644 --- a/superset/dashboards/filter_state/api.py +++ b/superset/dashboards/filter_state/api.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Type from flask import Response from flask_appbuilder.api import expose, protect, safe @@ -35,16 +34,16 @@ class DashboardFilterStateRestApi(TemporaryCacheRestApi): resource_name = "dashboard" openapi_spec_tag = "Dashboard Filter State" - def get_create_command(self) -> Type[CreateFilterStateCommand]: + def get_create_command(self) -> type[CreateFilterStateCommand]: return CreateFilterStateCommand - def get_update_command(self) -> Type[UpdateFilterStateCommand]: + def get_update_command(self) -> type[UpdateFilterStateCommand]: return UpdateFilterStateCommand - def get_get_command(self) -> Type[GetFilterStateCommand]: + def get_get_command(self) -> type[GetFilterStateCommand]: return GetFilterStateCommand - def get_delete_command(self) -> Type[DeleteFilterStateCommand]: + def get_delete_command(self) -> type[DeleteFilterStateCommand]: return DeleteFilterStateCommand @expose("//filter_state", methods=("POST",)) diff --git a/superset/dashboards/permalink/types.py b/superset/dashboards/permalink/types.py index 91c5a9620c..4961d2a17b 100644 --- a/superset/dashboards/permalink/types.py +++ b/superset/dashboards/permalink/types.py @@ -14,14 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Optional, Tuple, TypedDict +from typing import Any, Optional, TypedDict class DashboardPermalinkState(TypedDict): - dataMask: Optional[Dict[str, Any]] - activeTabs: Optional[List[str]] + dataMask: Optional[dict[str, Any]] + activeTabs: Optional[list[str]] anchor: Optional[str] - urlParams: Optional[List[Tuple[str, str]]] + urlParams: Optional[list[tuple[str, str]]] class DashboardPermalinkValue(TypedDict): diff --git a/superset/dashboards/schemas.py b/superset/dashboards/schemas.py index ab93e4130f..846ed39e82 100644 --- a/superset/dashboards/schemas.py +++ b/superset/dashboards/schemas.py @@ -16,7 +16,7 @@ # under the License. import json import re -from typing import Any, Dict, Union +from typing import Any, Union from marshmallow import fields, post_load, pre_load, Schema from marshmallow.validate import Length, ValidationError @@ -144,9 +144,9 @@ class DashboardJSONMetadataSchema(Schema): @pre_load def remove_show_native_filters( # pylint: disable=unused-argument, no-self-use self, - data: Dict[str, Any], + data: dict[str, Any], **kwargs: Any, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Remove ``show_native_filters`` from the JSON metadata. @@ -254,7 +254,7 @@ class DashboardDatasetSchema(Schema): class BaseDashboardSchema(Schema): # pylint: disable=no-self-use,unused-argument @post_load - def post_load(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: + def post_load(self, data: dict[str, Any], **kwargs: Any) -> dict[str, Any]: if data.get("slug"): data["slug"] = data["slug"].strip() data["slug"] = data["slug"].replace(" ", "-") diff --git a/superset/databases/api.py b/superset/databases/api.py index 77f9596182..c214065a27 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -19,7 +19,7 @@ import json import logging from datetime import datetime from io import BytesIO -from typing import Any, cast, Dict, List, Optional +from typing import Any, cast, Optional from zipfile import is_zipfile, ZipFile from flask import request, Response, send_file @@ -1328,13 +1328,13 @@ class DatabaseRestApi(BaseSupersetModelRestApi): 500: $ref: '#/components/responses/500' """ - preferred_databases: List[str] = app.config.get("PREFERRED_DATABASES", []) + preferred_databases: list[str] = app.config.get("PREFERRED_DATABASES", []) available_databases = [] for engine_spec, drivers in get_available_engine_specs().items(): if not drivers: continue - payload: Dict[str, Any] = { + payload: dict[str, Any] = { "name": engine_spec.engine_name, "engine": engine_spec.engine, "available_drivers": sorted(drivers), diff --git a/superset/databases/commands/create.py b/superset/databases/commands/create.py index 16d27835b3..e3fd667130 100644 --- a/superset/databases/commands/create.py +++ b/superset/databases/commands/create.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask import current_app from flask_appbuilder.models.sqla import Model @@ -47,7 +47,7 @@ stats_logger = current_app.config["STATS_LOGGER"] class CreateDatabaseCommand(BaseCommand): - def __init__(self, data: Dict[str, Any]): + def __init__(self, data: dict[str, Any]): self._properties = data.copy() def run(self) -> Model: @@ -128,7 +128,7 @@ class CreateDatabaseCommand(BaseCommand): return database def validate(self) -> None: - exceptions: List[ValidationError] = [] + exceptions: list[ValidationError] = [] sqlalchemy_uri: Optional[str] = self._properties.get("sqlalchemy_uri") database_name: Optional[str] = self._properties.get("database_name") if not sqlalchemy_uri: diff --git a/superset/databases/commands/export.py b/superset/databases/commands/export.py index e1f8fc2a25..889cb86c8f 100644 --- a/superset/databases/commands/export.py +++ b/superset/databases/commands/export.py @@ -18,7 +18,8 @@ import json import logging -from typing import Any, Dict, Iterator, Tuple +from typing import Any +from collections.abc import Iterator import yaml @@ -33,7 +34,7 @@ from superset.utils.ssh_tunnel import mask_password_info logger = logging.getLogger(__name__) -def parse_extra(extra_payload: str) -> Dict[str, Any]: +def parse_extra(extra_payload: str) -> dict[str, Any]: try: extra = json.loads(extra_payload) except json.decoder.JSONDecodeError: @@ -57,7 +58,7 @@ class ExportDatabasesCommand(ExportModelsCommand): @staticmethod def _export( model: Database, export_related: bool = True - ) -> Iterator[Tuple[str, str]]: + ) -> Iterator[tuple[str, str]]: db_file_name = get_filename(model.database_name, model.id, skip_id=True) file_path = f"databases/{db_file_name}.yaml" diff --git a/superset/databases/commands/importers/dispatcher.py b/superset/databases/commands/importers/dispatcher.py index 88d38bf13b..70031b09e4 100644 --- a/superset/databases/commands/importers/dispatcher.py +++ b/superset/databases/commands/importers/dispatcher.py @@ -16,7 +16,7 @@ # under the License. import logging -from typing import Any, Dict +from typing import Any from marshmallow.exceptions import ValidationError @@ -38,7 +38,7 @@ class ImportDatabasesCommand(BaseCommand): until it finds one that matches. """ - def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any): + def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any): self.contents = contents self.args = args self.kwargs = kwargs diff --git a/superset/databases/commands/importers/v1/__init__.py b/superset/databases/commands/importers/v1/__init__.py index 239bd0977f..ba119beaaa 100644 --- a/superset/databases/commands/importers/v1/__init__.py +++ b/superset/databases/commands/importers/v1/__init__.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict +from typing import Any from marshmallow import Schema from sqlalchemy.orm import Session @@ -36,7 +36,7 @@ class ImportDatabasesCommand(ImportModelsCommand): dao = DatabaseDAO model_name = "database" prefix = "databases/" - schemas: Dict[str, Schema] = { + schemas: dict[str, Schema] = { "databases/": ImportV1DatabaseSchema(), "datasets/": ImportV1DatasetSchema(), } @@ -44,10 +44,10 @@ class ImportDatabasesCommand(ImportModelsCommand): @staticmethod def _import( - session: Session, configs: Dict[str, Any], overwrite: bool = False + session: Session, configs: dict[str, Any], overwrite: bool = False ) -> None: # first import databases - database_ids: Dict[str, int] = {} + database_ids: dict[str, int] = {} for file_name, config in configs.items(): if file_name.startswith("databases/"): database = import_database(session, config, overwrite=overwrite) diff --git a/superset/databases/commands/importers/v1/utils.py b/superset/databases/commands/importers/v1/utils.py index c0c0ee60d9..8881f78a9c 100644 --- a/superset/databases/commands/importers/v1/utils.py +++ b/superset/databases/commands/importers/v1/utils.py @@ -16,7 +16,7 @@ # under the License. import json -from typing import Any, Dict +from typing import Any from sqlalchemy.orm import Session @@ -28,7 +28,7 @@ from superset.models.core import Database def import_database( session: Session, - config: Dict[str, Any], + config: dict[str, Any], overwrite: bool = False, ignore_permissions: bool = False, ) -> Database: diff --git a/superset/databases/commands/tables.py b/superset/databases/commands/tables.py index 48e9227dea..b7dbb4d461 100644 --- a/superset/databases/commands/tables.py +++ b/superset/databases/commands/tables.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, cast, Dict +from typing import Any, cast from superset.commands.base import BaseCommand from superset.connectors.sqla.models import SqlaTable @@ -40,7 +40,7 @@ class TablesDatabaseCommand(BaseCommand): self._schema_name = schema_name self._force = force - def run(self) -> Dict[str, Any]: + def run(self) -> dict[str, Any]: self.validate() try: tables = security_manager.get_datasources_accessible_by_user( diff --git a/superset/databases/commands/test_connection.py b/superset/databases/commands/test_connection.py index 9809641d5c..2680c5e8c1 100644 --- a/superset/databases/commands/test_connection.py +++ b/superset/databases/commands/test_connection.py @@ -17,7 +17,7 @@ import logging import sqlite3 from contextlib import closing -from typing import Any, Dict, Optional +from typing import Any, Optional from flask import current_app as app from flask_babel import gettext as _ @@ -64,7 +64,7 @@ def get_log_connection_action( class TestConnectionDatabaseCommand(BaseCommand): - def __init__(self, data: Dict[str, Any]): + def __init__(self, data: dict[str, Any]): self._properties = data.copy() self._model: Optional[Database] = None diff --git a/superset/databases/commands/update.py b/superset/databases/commands/update.py index 746f7a8152..f12706fa1d 100644 --- a/superset/databases/commands/update.py +++ b/superset/databases/commands/update.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask_appbuilder.models.sqla import Model from marshmallow import ValidationError @@ -47,7 +47,7 @@ logger = logging.getLogger(__name__) class UpdateDatabaseCommand(BaseCommand): - def __init__(self, model_id: int, data: Dict[str, Any]): + def __init__(self, model_id: int, data: dict[str, Any]): self._properties = data.copy() self._model_id = model_id self._model: Optional[Database] = None @@ -78,7 +78,7 @@ class UpdateDatabaseCommand(BaseCommand): raise DatabaseConnectionFailedError() from ex # Update database schema permissions - new_schemas: List[str] = [] + new_schemas: list[str] = [] for schema in schemas: old_view_menu_name = security_manager.get_schema_perm( @@ -164,7 +164,7 @@ class UpdateDatabaseCommand(BaseCommand): chart.schema_perm = new_view_menu_name def validate(self) -> None: - exceptions: List[ValidationError] = [] + exceptions: list[ValidationError] = [] # Validate/populate model exists self._model = DatabaseDAO.find_by_id(self._model_id) if not self._model: diff --git a/superset/databases/commands/validate.py b/superset/databases/commands/validate.py index 2a624e32c7..d97ad33af9 100644 --- a/superset/databases/commands/validate.py +++ b/superset/databases/commands/validate.py @@ -16,7 +16,7 @@ # under the License. import json from contextlib import closing -from typing import Any, Dict, Optional +from typing import Any, Optional from flask_babel import gettext as __ @@ -38,7 +38,7 @@ BYPASS_VALIDATION_ENGINES = {"bigquery"} class ValidateDatabaseParametersCommand(BaseCommand): - def __init__(self, properties: Dict[str, Any]): + def __init__(self, properties: dict[str, Any]): self._properties = properties.copy() self._model: Optional[Database] = None diff --git a/superset/databases/commands/validate_sql.py b/superset/databases/commands/validate_sql.py index 346d684a0d..40d88af745 100644 --- a/superset/databases/commands/validate_sql.py +++ b/superset/databases/commands/validate_sql.py @@ -16,7 +16,7 @@ # under the License. import logging import re -from typing import Any, Dict, List, Optional, Type +from typing import Any, Optional from flask import current_app from flask_babel import gettext as __ @@ -41,13 +41,13 @@ logger = logging.getLogger(__name__) class ValidateSQLCommand(BaseCommand): - def __init__(self, model_id: int, data: Dict[str, Any]): + def __init__(self, model_id: int, data: dict[str, Any]): self._properties = data.copy() self._model_id = model_id self._model: Optional[Database] = None - self._validator: Optional[Type[BaseSQLValidator]] = None + self._validator: Optional[type[BaseSQLValidator]] = None - def run(self) -> List[Dict[str, Any]]: + def run(self) -> list[dict[str, Any]]: """ Validates a SQL statement @@ -97,9 +97,7 @@ class ValidateSQLCommand(BaseCommand): if not validators_by_engine or spec.engine not in validators_by_engine: raise NoValidatorConfigFoundError( SupersetError( - message=__( - "no SQL validator is configured for {}".format(spec.engine) - ), + message=__(f"no SQL validator is configured for {spec.engine}"), error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, level=ErrorLevel.ERROR, ), diff --git a/superset/databases/dao.py b/superset/databases/dao.py index c82f0db574..9ce3b5e73e 100644 --- a/superset/databases/dao.py +++ b/superset/databases/dao.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, Optional +from typing import Any, Optional from superset.dao.base import BaseDAO from superset.databases.filters import DatabaseFilter @@ -38,7 +38,7 @@ class DatabaseDAO(BaseDAO): def update( cls, model: Database, - properties: Dict[str, Any], + properties: dict[str, Any], commit: bool = True, ) -> Database: """ @@ -93,7 +93,7 @@ class DatabaseDAO(BaseDAO): ) @classmethod - def get_related_objects(cls, database_id: int) -> Dict[str, Any]: + def get_related_objects(cls, database_id: int) -> dict[str, Any]: database: Any = cls.find_by_id(database_id) datasets = database.tables dataset_ids = [dataset.id for dataset in datasets] diff --git a/superset/databases/filters.py b/superset/databases/filters.py index 86564e8f15..2ca77b77d1 100644 --- a/superset/databases/filters.py +++ b/superset/databases/filters.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Set +from typing import Any from flask import g from flask_babel import lazy_gettext as _ @@ -30,7 +30,7 @@ from superset.views.base import BaseFilter def can_access_databases( view_menu_name: str, -) -> Set[str]: +) -> set[str]: return { security_manager.unpack_database_and_schema(vm).database for vm in security_manager.user_view_menu_names(view_menu_name) diff --git a/superset/databases/schemas.py b/superset/databases/schemas.py index 00e8c3ca53..01a00e8b80 100644 --- a/superset/databases/schemas.py +++ b/superset/databases/schemas.py @@ -19,7 +19,7 @@ import inspect import json -from typing import Any, Dict, List +from typing import Any from flask import current_app from flask_babel import lazy_gettext as _ @@ -263,8 +263,8 @@ class DatabaseParametersSchemaMixin: # pylint: disable=too-few-public-methods @pre_load def build_sqlalchemy_uri( - self, data: Dict[str, Any], **kwargs: Any - ) -> Dict[str, Any]: + self, data: dict[str, Any], **kwargs: Any + ) -> dict[str, Any]: """ Build SQLAlchemy URI from separate parameters. @@ -325,9 +325,9 @@ class DatabaseParametersSchemaMixin: # pylint: disable=too-few-public-methods def rename_encrypted_extra( self: Schema, - data: Dict[str, Any], + data: dict[str, Any], **kwargs: Any, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Rename ``encrypted_extra`` to ``masked_encrypted_extra``. @@ -707,8 +707,8 @@ class DatabaseFunctionNamesResponse(Schema): class ImportV1DatabaseExtraSchema(Schema): @pre_load def fix_schemas_allowed_for_csv_upload( # pylint: disable=invalid-name - self, data: Dict[str, Any], **kwargs: Any - ) -> Dict[str, Any]: + self, data: dict[str, Any], **kwargs: Any + ) -> dict[str, Any]: """ Fixes for ``schemas_allowed_for_csv_upload``. """ @@ -744,8 +744,8 @@ class ImportV1DatabaseExtraSchema(Schema): class ImportV1DatabaseSchema(Schema): @pre_load def fix_allow_csv_upload( - self, data: Dict[str, Any], **kwargs: Any - ) -> Dict[str, Any]: + self, data: dict[str, Any], **kwargs: Any + ) -> dict[str, Any]: """ Fix for ``allow_csv_upload`` . """ @@ -775,7 +775,7 @@ class ImportV1DatabaseSchema(Schema): ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True) @validates_schema - def validate_password(self, data: Dict[str, Any], **kwargs: Any) -> None: + def validate_password(self, data: dict[str, Any], **kwargs: Any) -> None: """If sqlalchemy_uri has a masked password, password is required""" uuid = data["uuid"] existing = db.session.query(Database).filter_by(uuid=uuid).first() @@ -789,7 +789,7 @@ class ImportV1DatabaseSchema(Schema): @validates_schema def validate_ssh_tunnel_credentials( - self, data: Dict[str, Any], **kwargs: Any + self, data: dict[str, Any], **kwargs: Any ) -> None: """If ssh_tunnel has a masked credentials, credentials are required""" uuid = data["uuid"] @@ -829,7 +829,7 @@ class ImportV1DatabaseSchema(Schema): # or there're times where it's masked. # If both are masked, we need to return a list of errors # so the UI ask for both fields at the same time if needed - exception_messages: List[str] = [] + exception_messages: list[str] = [] if private_key is None or private_key == PASSWORD_MASK: # If we get here we need to ask for the private key exception_messages.append( @@ -864,7 +864,7 @@ class EncryptedDict(EncryptedField, fields.Dict): pass -def encrypted_field_properties(self, field: Any, **_) -> Dict[str, Any]: # type: ignore +def encrypted_field_properties(self, field: Any, **_) -> dict[str, Any]: # type: ignore ret = {} if isinstance(field, EncryptedField): if self.openapi_version.major > 2: diff --git a/superset/databases/ssh_tunnel/commands/create.py b/superset/databases/ssh_tunnel/commands/create.py index 45e5af5f44..9c41b83392 100644 --- a/superset/databases/ssh_tunnel/commands/create.py +++ b/superset/databases/ssh_tunnel/commands/create.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask_appbuilder.models.sqla import Model from marshmallow import ValidationError @@ -34,7 +34,7 @@ logger = logging.getLogger(__name__) class CreateSSHTunnelCommand(BaseCommand): - def __init__(self, database_id: int, data: Dict[str, Any]): + def __init__(self, database_id: int, data: dict[str, Any]): self._properties = data.copy() self._properties["database_id"] = database_id @@ -61,7 +61,7 @@ class CreateSSHTunnelCommand(BaseCommand): def validate(self) -> None: # TODO(hughhh): check to make sure the server port is not localhost # using the config.SSH_TUNNEL_MANAGER - exceptions: List[ValidationError] = [] + exceptions: list[ValidationError] = [] database_id: Optional[int] = self._properties.get("database_id") server_address: Optional[str] = self._properties.get("server_address") server_port: Optional[int] = self._properties.get("server_port") diff --git a/superset/databases/ssh_tunnel/commands/update.py b/superset/databases/ssh_tunnel/commands/update.py index 42925d1caa..37fd4a94b9 100644 --- a/superset/databases/ssh_tunnel/commands/update.py +++ b/superset/databases/ssh_tunnel/commands/update.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, Optional +from typing import Any, Optional from flask_appbuilder.models.sqla import Model @@ -34,7 +34,7 @@ logger = logging.getLogger(__name__) class UpdateSSHTunnelCommand(BaseCommand): - def __init__(self, model_id: int, data: Dict[str, Any]): + def __init__(self, model_id: int, data: dict[str, Any]): self._properties = data.copy() self._model_id = model_id self._model: Optional[SSHTunnel] = None diff --git a/superset/databases/ssh_tunnel/dao.py b/superset/databases/ssh_tunnel/dao.py index 89562fc05d..731f9183b3 100644 --- a/superset/databases/ssh_tunnel/dao.py +++ b/superset/databases/ssh_tunnel/dao.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict +from typing import Any from superset.dao.base import BaseDAO from superset.databases.ssh_tunnel.models import SSHTunnel @@ -31,7 +31,7 @@ class SSHTunnelDAO(BaseDAO): def update( cls, model: SSHTunnel, - properties: Dict[str, Any], + properties: dict[str, Any], commit: bool = True, ) -> SSHTunnel: """ diff --git a/superset/databases/ssh_tunnel/models.py b/superset/databases/ssh_tunnel/models.py index 3384679cb7..d9462a63db 100644 --- a/superset/databases/ssh_tunnel/models.py +++ b/superset/databases/ssh_tunnel/models.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict +from typing import Any import sqlalchemy as sa from flask import current_app @@ -82,7 +82,7 @@ class SSHTunnel(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): ] @property - def data(self) -> Dict[str, Any]: + def data(self) -> dict[str, Any]: output = { "id": self.id, "server_address": self.server_address, diff --git a/superset/databases/utils.py b/superset/databases/utils.py index 9229bb8cba..74943f4747 100644 --- a/superset/databases/utils.py +++ b/superset/databases/utils.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union from sqlalchemy.engine.url import make_url, URL @@ -25,7 +25,7 @@ def get_foreign_keys_metadata( database: Any, table_name: str, schema_name: Optional[str], -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: foreign_keys = database.get_foreign_keys(table_name, schema_name) for fk in foreign_keys: fk["column_names"] = fk.pop("constrained_columns") @@ -35,14 +35,14 @@ def get_foreign_keys_metadata( def get_indexes_metadata( database: Any, table_name: str, schema_name: Optional[str] -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: indexes = database.get_indexes(table_name, schema_name) for idx in indexes: idx["type"] = "index" return indexes -def get_col_type(col: Dict[Any, Any]) -> str: +def get_col_type(col: dict[Any, Any]) -> str: try: dtype = f"{col['type']}" except Exception: # pylint: disable=broad-except @@ -53,7 +53,7 @@ def get_col_type(col: Dict[Any, Any]) -> str: def get_table_metadata( database: Any, table_name: str, schema_name: Optional[str] -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Get table metadata information, including type, pk, fks. This function raises SQLAlchemyError when a schema is not found. @@ -73,7 +73,7 @@ def get_table_metadata( foreign_keys = get_foreign_keys_metadata(database, table_name, schema_name) indexes = get_indexes_metadata(database, table_name, schema_name) keys += foreign_keys + indexes - payload_columns: List[Dict[str, Any]] = [] + payload_columns: list[dict[str, Any]] = [] table_comment = database.get_table_comment(table_name, schema_name) for col in columns: dtype = get_col_type(col) diff --git a/superset/dataframe.py b/superset/dataframe.py index 8abeedc095..8083993294 100644 --- a/superset/dataframe.py +++ b/superset/dataframe.py @@ -17,7 +17,7 @@ """ Superset utilities for pandas.DataFrame. """ import logging -from typing import Any, Dict, List +from typing import Any import pandas as pd @@ -37,7 +37,7 @@ def _convert_big_integers(val: Any) -> Any: return str(val) if isinstance(val, int) and abs(val) > JS_MAX_INTEGER else val -def df_to_records(dframe: pd.DataFrame) -> List[Dict[str, Any]]: +def df_to_records(dframe: pd.DataFrame) -> list[dict[str, Any]]: """ Convert a DataFrame to a set of records. diff --git a/superset/datasets/commands/bulk_delete.py b/superset/datasets/commands/bulk_delete.py index 643ac784ec..fd13351809 100644 --- a/superset/datasets/commands/bulk_delete.py +++ b/superset/datasets/commands/bulk_delete.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import List, Optional +from typing import Optional from superset import security_manager from superset.commands.base import BaseCommand @@ -34,9 +34,9 @@ logger = logging.getLogger(__name__) class BulkDeleteDatasetCommand(BaseCommand): - def __init__(self, model_ids: List[int]): + def __init__(self, model_ids: list[int]): self._model_ids = model_ids - self._models: Optional[List[SqlaTable]] = None + self._models: Optional[list[SqlaTable]] = None def run(self) -> None: self.validate() diff --git a/superset/datasets/commands/create.py b/superset/datasets/commands/create.py index 04f54339d0..1c864ad196 100644 --- a/superset/datasets/commands/create.py +++ b/superset/datasets/commands/create.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask_appbuilder.models.sqla import Model from marshmallow import ValidationError @@ -37,7 +37,7 @@ logger = logging.getLogger(__name__) class CreateDatasetCommand(CreateMixin, BaseCommand): - def __init__(self, data: Dict[str, Any]): + def __init__(self, data: dict[str, Any]): self._properties = data.copy() def run(self) -> Model: @@ -55,12 +55,12 @@ class CreateDatasetCommand(CreateMixin, BaseCommand): return dataset def validate(self) -> None: - exceptions: List[ValidationError] = [] + exceptions: list[ValidationError] = [] database_id = self._properties["database"] table_name = self._properties["table_name"] schema = self._properties.get("schema", None) sql = self._properties.get("sql", None) - owner_ids: Optional[List[int]] = self._properties.get("owners") + owner_ids: Optional[list[int]] = self._properties.get("owners") # Validate uniqueness if not DatasetDAO.validate_uniqueness(database_id, schema, table_name): diff --git a/superset/datasets/commands/duplicate.py b/superset/datasets/commands/duplicate.py index 5fc642cbe3..5a4a84fdf9 100644 --- a/superset/datasets/commands/duplicate.py +++ b/superset/datasets/commands/duplicate.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, List +from typing import Any from flask_appbuilder.models.sqla import Model from flask_babel import gettext as __ @@ -43,7 +43,7 @@ logger = logging.getLogger(__name__) class DuplicateDatasetCommand(CreateMixin, BaseCommand): - def __init__(self, data: Dict[str, Any]) -> None: + def __init__(self, data: dict[str, Any]) -> None: self._base_model: SqlaTable = SqlaTable() self._properties = data.copy() @@ -105,7 +105,7 @@ class DuplicateDatasetCommand(CreateMixin, BaseCommand): return table def validate(self) -> None: - exceptions: List[ValidationError] = [] + exceptions: list[ValidationError] = [] base_model_id = self._properties["base_model_id"] duplicate_name = self._properties["table_name"] diff --git a/superset/datasets/commands/export.py b/superset/datasets/commands/export.py index c6fe43c89d..8c02a23f29 100644 --- a/superset/datasets/commands/export.py +++ b/superset/datasets/commands/export.py @@ -18,7 +18,7 @@ import json import logging -from typing import Iterator, Tuple +from collections.abc import Iterator import yaml @@ -43,7 +43,7 @@ class ExportDatasetsCommand(ExportModelsCommand): @staticmethod def _export( model: SqlaTable, export_related: bool = True - ) -> Iterator[Tuple[str, str]]: + ) -> Iterator[tuple[str, str]]: db_file_name = get_filename( model.database.database_name, model.database.id, skip_id=True ) diff --git a/superset/datasets/commands/importers/dispatcher.py b/superset/datasets/commands/importers/dispatcher.py index 74f1129d23..6be8635da2 100644 --- a/superset/datasets/commands/importers/dispatcher.py +++ b/superset/datasets/commands/importers/dispatcher.py @@ -16,7 +16,7 @@ # under the License. import logging -from typing import Any, Dict +from typing import Any from marshmallow.exceptions import ValidationError @@ -43,7 +43,7 @@ class ImportDatasetsCommand(BaseCommand): until it finds one that matches. """ - def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any): + def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any): self.contents = contents self.args = args self.kwargs = kwargs diff --git a/superset/datasets/commands/importers/v0.py b/superset/datasets/commands/importers/v0.py index f706ecf38b..c530be3c14 100644 --- a/superset/datasets/commands/importers/v0.py +++ b/superset/datasets/commands/importers/v0.py @@ -16,7 +16,7 @@ # under the License. import json import logging -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Optional import yaml from flask_appbuilder import Model @@ -213,7 +213,7 @@ def import_simple_obj( def import_from_dict( - session: Session, data: Dict[str, Any], sync: Optional[List[str]] = None + session: Session, data: dict[str, Any], sync: Optional[list[str]] = None ) -> None: """Imports databases from dictionary""" if not sync: @@ -238,12 +238,12 @@ class ImportDatasetsCommand(BaseCommand): # pylint: disable=unused-argument def __init__( self, - contents: Dict[str, str], + contents: dict[str, str], *args: Any, **kwargs: Any, ): self.contents = contents - self._configs: Dict[str, Any] = {} + self._configs: dict[str, Any] = {} self.sync = [] if kwargs.get("sync_columns"): diff --git a/superset/datasets/commands/importers/v1/__init__.py b/superset/datasets/commands/importers/v1/__init__.py index e73213319d..e753138ab8 100644 --- a/superset/datasets/commands/importers/v1/__init__.py +++ b/superset/datasets/commands/importers/v1/__init__.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, Set +from typing import Any from marshmallow import Schema from sqlalchemy.orm import Session @@ -36,7 +36,7 @@ class ImportDatasetsCommand(ImportModelsCommand): dao = DatasetDAO model_name = "dataset" prefix = "datasets/" - schemas: Dict[str, Schema] = { + schemas: dict[str, Schema] = { "databases/": ImportV1DatabaseSchema(), "datasets/": ImportV1DatasetSchema(), } @@ -44,16 +44,16 @@ class ImportDatasetsCommand(ImportModelsCommand): @staticmethod def _import( - session: Session, configs: Dict[str, Any], overwrite: bool = False + session: Session, configs: dict[str, Any], overwrite: bool = False ) -> None: # discover databases associated with datasets - database_uuids: Set[str] = set() + database_uuids: set[str] = set() for file_name, config in configs.items(): if file_name.startswith("datasets/"): database_uuids.add(config["database_uuid"]) # import related databases - database_ids: Dict[str, int] = {} + database_ids: dict[str, int] = {} for file_name, config in configs.items(): if file_name.startswith("databases/") and config["uuid"] in database_uuids: database = import_database(session, config, overwrite=False) diff --git a/superset/datasets/commands/importers/v1/utils.py b/superset/datasets/commands/importers/v1/utils.py index 52f46829b5..ae47fc411a 100644 --- a/superset/datasets/commands/importers/v1/utils.py +++ b/superset/datasets/commands/importers/v1/utils.py @@ -18,7 +18,7 @@ import gzip import json import logging import re -from typing import Any, Dict +from typing import Any from urllib import request import pandas as pd @@ -69,7 +69,7 @@ def get_sqla_type(native_type: str) -> VisitableType: raise Exception(f"Unknown type: {native_type}") -def get_dtype(df: pd.DataFrame, dataset: SqlaTable) -> Dict[str, VisitableType]: +def get_dtype(df: pd.DataFrame, dataset: SqlaTable) -> dict[str, VisitableType]: return { column.column_name: get_sqla_type(column.type) for column in dataset.columns @@ -101,7 +101,7 @@ def validate_data_uri(data_uri: str) -> None: def import_dataset( session: Session, - config: Dict[str, Any], + config: dict[str, Any], overwrite: bool = False, force_data: bool = False, ignore_permissions: bool = False, diff --git a/superset/datasets/commands/update.py b/superset/datasets/commands/update.py index cc9f480a41..be9625709f 100644 --- a/superset/datasets/commands/update.py +++ b/superset/datasets/commands/update.py @@ -16,7 +16,7 @@ # under the License. import logging from collections import Counter -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask import current_app from flask_appbuilder.models.sqla import Model @@ -52,7 +52,7 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand): def __init__( self, model_id: int, - data: Dict[str, Any], + data: dict[str, Any], override_columns: Optional[bool] = False, ): self._model_id = model_id @@ -76,8 +76,8 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand): raise DatasetUpdateFailedError() def validate(self) -> None: - exceptions: List[ValidationError] = [] - owner_ids: Optional[List[int]] = self._properties.get("owners") + exceptions: list[ValidationError] = [] + owner_ids: Optional[list[int]] = self._properties.get("owners") # Validate/populate model exists self._model = DatasetDAO.find_by_id(self._model_id) if not self._model: @@ -125,14 +125,14 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand): raise DatasetInvalidError(exceptions=exceptions) def _validate_columns( - self, columns: List[Dict[str, Any]], exceptions: List[ValidationError] + self, columns: list[dict[str, Any]], exceptions: list[ValidationError] ) -> None: # Validate duplicates on data if self._get_duplicates(columns, "column_name"): exceptions.append(DatasetColumnsDuplicateValidationError()) else: # validate invalid id's - columns_ids: List[int] = [ + columns_ids: list[int] = [ column["id"] for column in columns if "id" in column ] if not DatasetDAO.validate_columns_exist(self._model_id, columns_ids): @@ -140,7 +140,7 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand): # validate new column names uniqueness if not self.override_columns: - columns_names: List[str] = [ + columns_names: list[str] = [ column["column_name"] for column in columns if "id" not in column ] if not DatasetDAO.validate_columns_uniqueness( @@ -149,26 +149,26 @@ class UpdateDatasetCommand(UpdateMixin, BaseCommand): exceptions.append(DatasetColumnsExistsValidationError()) def _validate_metrics( - self, metrics: List[Dict[str, Any]], exceptions: List[ValidationError] + self, metrics: list[dict[str, Any]], exceptions: list[ValidationError] ) -> None: if self._get_duplicates(metrics, "metric_name"): exceptions.append(DatasetMetricsDuplicateValidationError()) else: # validate invalid id's - metrics_ids: List[int] = [ + metrics_ids: list[int] = [ metric["id"] for metric in metrics if "id" in metric ] if not DatasetDAO.validate_metrics_exist(self._model_id, metrics_ids): exceptions.append(DatasetMetricsNotFoundValidationError()) # validate new metric names uniqueness - metric_names: List[str] = [ + metric_names: list[str] = [ metric["metric_name"] for metric in metrics if "id" not in metric ] if not DatasetDAO.validate_metrics_uniqueness(self._model_id, metric_names): exceptions.append(DatasetMetricsExistsValidationError()) @staticmethod - def _get_duplicates(data: List[Dict[str, Any]], key: str) -> List[str]: + def _get_duplicates(data: list[dict[str, Any]], key: str) -> list[str]: duplicates = [ name for name, count in Counter([item[key] for item in data]).items() diff --git a/superset/datasets/dao.py b/superset/datasets/dao.py index b158fce1fe..f4d46be109 100644 --- a/superset/datasets/dao.py +++ b/superset/datasets/dao.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, List, Optional +from typing import Any, Optional from sqlalchemy.exc import SQLAlchemyError @@ -44,7 +44,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods return None @staticmethod - def get_related_objects(database_id: int) -> Dict[str, Any]: + def get_related_objects(database_id: int) -> dict[str, Any]: charts = ( db.session.query(Slice) .filter( @@ -108,7 +108,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods return not db.session.query(dataset_query.exists()).scalar() @staticmethod - def validate_columns_exist(dataset_id: int, columns_ids: List[int]) -> bool: + def validate_columns_exist(dataset_id: int, columns_ids: list[int]) -> bool: dataset_query = ( db.session.query(TableColumn.id).filter( TableColumn.table_id == dataset_id, TableColumn.id.in_(columns_ids) @@ -117,7 +117,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods return len(columns_ids) == len(dataset_query) @staticmethod - def validate_columns_uniqueness(dataset_id: int, columns_names: List[str]) -> bool: + def validate_columns_uniqueness(dataset_id: int, columns_names: list[str]) -> bool: dataset_query = ( db.session.query(TableColumn.id).filter( TableColumn.table_id == dataset_id, @@ -127,7 +127,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods return len(dataset_query) == 0 @staticmethod - def validate_metrics_exist(dataset_id: int, metrics_ids: List[int]) -> bool: + def validate_metrics_exist(dataset_id: int, metrics_ids: list[int]) -> bool: dataset_query = ( db.session.query(SqlMetric.id).filter( SqlMetric.table_id == dataset_id, SqlMetric.id.in_(metrics_ids) @@ -136,7 +136,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods return len(metrics_ids) == len(dataset_query) @staticmethod - def validate_metrics_uniqueness(dataset_id: int, metrics_names: List[str]) -> bool: + def validate_metrics_uniqueness(dataset_id: int, metrics_names: list[str]) -> bool: dataset_query = ( db.session.query(SqlMetric.id).filter( SqlMetric.table_id == dataset_id, @@ -149,7 +149,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods def update( cls, model: SqlaTable, - properties: Dict[str, Any], + properties: dict[str, Any], commit: bool = True, ) -> Optional[SqlaTable]: """ @@ -173,7 +173,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods def update_columns( cls, model: SqlaTable, - property_columns: List[Dict[str, Any]], + property_columns: list[dict[str, Any]], commit: bool = True, override_columns: bool = False, ) -> None: @@ -239,7 +239,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods def update_metrics( cls, model: SqlaTable, - property_metrics: List[Dict[str, Any]], + property_metrics: list[dict[str, Any]], commit: bool = True, ) -> None: """ @@ -304,14 +304,14 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods def update_column( cls, model: TableColumn, - properties: Dict[str, Any], + properties: dict[str, Any], commit: bool = True, ) -> TableColumn: return DatasetColumnDAO.update(model, properties, commit=commit) @classmethod def create_column( - cls, properties: Dict[str, Any], commit: bool = True + cls, properties: dict[str, Any], commit: bool = True ) -> TableColumn: """ Creates a Dataset model on the metadata DB @@ -346,7 +346,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods def update_metric( cls, model: SqlMetric, - properties: Dict[str, Any], + properties: dict[str, Any], commit: bool = True, ) -> SqlMetric: return DatasetMetricDAO.update(model, properties, commit=commit) @@ -354,7 +354,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods @classmethod def create_metric( cls, - properties: Dict[str, Any], + properties: dict[str, Any], commit: bool = True, ) -> SqlMetric: """ @@ -363,7 +363,7 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods return DatasetMetricDAO.create(properties, commit=commit) @staticmethod - def bulk_delete(models: Optional[List[SqlaTable]], commit: bool = True) -> None: + def bulk_delete(models: Optional[list[SqlaTable]], commit: bool = True) -> None: item_ids = [model.id for model in models] if models else [] # bulk delete, first delete related data if models: diff --git a/superset/datasets/models.py b/superset/datasets/models.py index b433709f2c..50aeea7b51 100644 --- a/superset/datasets/models.py +++ b/superset/datasets/models.py @@ -24,7 +24,6 @@ dataset, new models for columns, metrics, and tables were also introduced. These models are not fully implemented, and shouldn't be used yet. """ -from typing import List import sqlalchemy as sa from flask_appbuilder import Model @@ -87,7 +86,7 @@ class Dataset(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): # The relationship between datasets and columns is 1:n, but we use a # many-to-many association table to avoid adding two mutually exclusive # columns(dataset_id and table_id) to Column - columns: List[Column] = relationship( + columns: list[Column] = relationship( "Column", secondary=dataset_column_association_table, cascade="all, delete-orphan", @@ -97,7 +96,7 @@ class Dataset(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): owners = relationship( security_manager.user_model, secondary=dataset_user_association_table ) - tables: List[Table] = relationship( + tables: list[Table] = relationship( "Table", secondary=dataset_table_association_table, backref="datasets" ) diff --git a/superset/datasets/schemas.py b/superset/datasets/schemas.py index f248fc70ff..9a2af98066 100644 --- a/superset/datasets/schemas.py +++ b/superset/datasets/schemas.py @@ -16,7 +16,7 @@ # under the License. import json import re -from typing import Any, Dict +from typing import Any from flask_babel import lazy_gettext as _ from marshmallow import fields, pre_load, Schema, ValidationError @@ -150,7 +150,7 @@ class DatasetRelatedObjectsResponse(Schema): class ImportV1ColumnSchema(Schema): # pylint: disable=no-self-use, unused-argument @pre_load - def fix_extra(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: + def fix_extra(self, data: dict[str, Any], **kwargs: Any) -> dict[str, Any]: """ Fix for extra initially being exported as a string. """ @@ -176,7 +176,7 @@ class ImportV1ColumnSchema(Schema): class ImportV1MetricSchema(Schema): # pylint: disable=no-self-use, unused-argument @pre_load - def fix_extra(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: + def fix_extra(self, data: dict[str, Any], **kwargs: Any) -> dict[str, Any]: """ Fix for extra initially being exported as a string. """ @@ -198,7 +198,7 @@ class ImportV1MetricSchema(Schema): class ImportV1DatasetSchema(Schema): # pylint: disable=no-self-use, unused-argument @pre_load - def fix_extra(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: + def fix_extra(self, data: dict[str, Any], **kwargs: Any) -> dict[str, Any]: """ Fix for extra initially being exported as a string. """ diff --git a/superset/datasource/dao.py b/superset/datasource/dao.py index 158a32c7fd..4682f070e2 100644 --- a/superset/datasource/dao.py +++ b/superset/datasource/dao.py @@ -16,7 +16,7 @@ # under the License. import logging -from typing import Dict, Type, Union +from typing import Union from sqlalchemy.orm import Session @@ -34,7 +34,7 @@ Datasource = Union[Dataset, SqlaTable, Table, Query, SavedQuery] class DatasourceDAO(BaseDAO): - sources: Dict[Union[DatasourceType, str], Type[Datasource]] = { + sources: dict[Union[DatasourceType, str], type[Datasource]] = { DatasourceType.TABLE: SqlaTable, DatasourceType.QUERY: Query, DatasourceType.SAVEDQUERY: SavedQuery, diff --git a/superset/db_engine_specs/__init__.py b/superset/db_engine_specs/__init__.py index f19dffd4a3..20cdfcc51f 100644 --- a/superset/db_engine_specs/__init__.py +++ b/superset/db_engine_specs/__init__.py @@ -33,7 +33,7 @@ import pkgutil from collections import defaultdict from importlib import import_module from pathlib import Path -from typing import Any, Dict, List, Optional, Set, Type +from typing import Any, Optional import sqlalchemy.databases import sqlalchemy.dialects @@ -58,11 +58,11 @@ def is_engine_spec(obj: Any) -> bool: ) -def load_engine_specs() -> List[Type[BaseEngineSpec]]: +def load_engine_specs() -> list[type[BaseEngineSpec]]: """ Load all engine specs, native and 3rd party. """ - engine_specs: List[Type[BaseEngineSpec]] = [] + engine_specs: list[type[BaseEngineSpec]] = [] # load standard engines db_engine_spec_dir = str(Path(__file__).parent) @@ -85,7 +85,7 @@ def load_engine_specs() -> List[Type[BaseEngineSpec]]: return engine_specs -def get_engine_spec(backend: str, driver: Optional[str] = None) -> Type[BaseEngineSpec]: +def get_engine_spec(backend: str, driver: Optional[str] = None) -> type[BaseEngineSpec]: """ Return the DB engine spec associated with a given SQLAlchemy URL. @@ -120,11 +120,11 @@ backend_replacements = { } -def get_available_engine_specs() -> Dict[Type[BaseEngineSpec], Set[str]]: +def get_available_engine_specs() -> dict[type[BaseEngineSpec], set[str]]: """ Return available engine specs and installed drivers for them. """ - drivers: Dict[str, Set[str]] = defaultdict(set) + drivers: dict[str, set[str]] = defaultdict(set) # native SQLAlchemy dialects for attr in sqlalchemy.databases.__all__: diff --git a/superset/db_engine_specs/athena.py b/superset/db_engine_specs/athena.py index 047952402d..ad6bed113d 100644 --- a/superset/db_engine_specs/athena.py +++ b/superset/db_engine_specs/athena.py @@ -16,7 +16,8 @@ # under the License. import re from datetime import datetime -from typing import Any, Dict, Optional, Pattern, Tuple +from re import Pattern +from typing import Any, Optional from flask_babel import gettext as __ from sqlalchemy import types @@ -51,7 +52,7 @@ class AthenaEngineSpec(BaseEngineSpec): date_add('day', 1, CAST({col} AS TIMESTAMP))))", } - custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = { + custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { SYNTAX_ERROR_REGEX: ( __( "Please check your query for syntax errors at or " @@ -64,7 +65,7 @@ class AthenaEngineSpec(BaseEngineSpec): @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index a7ff862272..ef922a5e63 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -22,22 +22,8 @@ import json import logging import re from datetime import datetime -from typing import ( - Any, - Callable, - ContextManager, - Dict, - List, - Match, - NamedTuple, - Optional, - Pattern, - Set, - Tuple, - Type, - TYPE_CHECKING, - Union, -) +from re import Match, Pattern +from typing import Any, Callable, ContextManager, NamedTuple, TYPE_CHECKING, Union import pandas as pd import sqlparse @@ -77,7 +63,7 @@ if TYPE_CHECKING: from superset.models.core import Database from superset.models.sql_lab import Query -ColumnTypeMapping = Tuple[ +ColumnTypeMapping = tuple[ Pattern[str], Union[TypeEngine, Callable[[Match[str]], TypeEngine]], GenericDataType, @@ -90,10 +76,10 @@ class TimeGrain(NamedTuple): name: str # TODO: redundant field, remove label: str function: str - duration: Optional[str] + duration: str | None -builtin_time_grains: Dict[Optional[str], str] = { +builtin_time_grains: dict[str | None, str] = { "PT1S": __("Second"), "PT5S": __("5 second"), "PT30S": __("30 second"), @@ -160,12 +146,12 @@ class MetricType(TypedDict, total=False): metric_name: str expression: str - verbose_name: Optional[str] - metric_type: Optional[str] - description: Optional[str] - d3format: Optional[str] - warning_text: Optional[str] - extra: Optional[str] + verbose_name: str | None + metric_type: str | None + description: str | None + d3format: str | None + warning_text: str | None + extra: str | None class BaseEngineSpec: # pylint: disable=too-many-public-methods @@ -182,19 +168,19 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods having to add the same aggregation in SELECT. """ - engine_name: Optional[str] = None # for user messages, overridden in child classes + engine_name: str | None = None # for user messages, overridden in child classes # These attributes map the DB engine spec to one or more SQLAlchemy dialects/drivers; # see the ``supports_url`` and ``supports_backend`` methods below. engine = "base" # str as defined in sqlalchemy.engine.engine - engine_aliases: Set[str] = set() - drivers: Dict[str, str] = {} - default_driver: Optional[str] = None + engine_aliases: set[str] = set() + drivers: dict[str, str] = {} + default_driver: str | None = None disable_ssh_tunneling = False - _date_trunc_functions: Dict[str, str] = {} - _time_grain_expressions: Dict[Optional[str], str] = {} - _default_column_type_mappings: Tuple[ColumnTypeMapping, ...] = ( + _date_trunc_functions: dict[str, str] = {} + _time_grain_expressions: dict[str | None, str] = {} + _default_column_type_mappings: tuple[ColumnTypeMapping, ...] = ( ( re.compile(r"^string", re.IGNORECASE), types.String(), @@ -312,7 +298,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods ), ) # engine-specific type mappings to check prior to the defaults - column_type_mappings: Tuple[ColumnTypeMapping, ...] = () + column_type_mappings: tuple[ColumnTypeMapping, ...] = () # Does database support join-free timeslot grouping time_groupby_inline = False @@ -351,23 +337,23 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods allow_limit_clause = True # This set will give keywords for select statements # to consider for the engines with TOP SQL parsing - select_keywords: Set[str] = {"SELECT"} + select_keywords: set[str] = {"SELECT"} # This set will give the keywords for data limit statements # to consider for the engines with TOP SQL parsing - top_keywords: Set[str] = {"TOP"} + top_keywords: set[str] = {"TOP"} # A set of disallowed connection query parameters by driver name - disallow_uri_query_params: Dict[str, Set[str]] = {} + disallow_uri_query_params: dict[str, set[str]] = {} # A Dict of query parameters that will always be used on every connection # by driver name - enforce_uri_query_params: Dict[str, Dict[str, Any]] = {} + enforce_uri_query_params: dict[str, dict[str, Any]] = {} force_column_alias_quotes = False arraysize = 0 max_column_name_length = 0 try_remove_schema_from_table_name = True # pylint: disable=invalid-name run_multiple_statements_as_one = False - custom_errors: Dict[ - Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]] + custom_errors: dict[ + Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]] ] = {} # Whether the engine supports file uploads @@ -422,7 +408,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return cls.supports_backend(backend, driver) @classmethod - def supports_backend(cls, backend: str, driver: Optional[str] = None) -> bool: + def supports_backend(cls, backend: str, driver: str | None = None) -> bool: """ Returns true if the DB engine spec supports a given SQLAlchemy backend/driver. """ @@ -439,7 +425,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return driver in cls.drivers @classmethod - def get_default_schema(cls, database: Database) -> Optional[str]: + def get_default_schema(cls, database: Database) -> str | None: """ Return the default schema in a given database. """ @@ -450,8 +436,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods def get_schema_from_engine_params( # pylint: disable=unused-argument cls, sqlalchemy_uri: URL, - connect_args: Dict[str, Any], - ) -> Optional[str]: + connect_args: dict[str, Any], + ) -> str | None: """ Return the schema configured in a SQLALchemy URI and connection argments, if any. """ @@ -462,7 +448,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods cls, database: Database, query: Query, - ) -> Optional[str]: + ) -> str | None: """ Return the default schema for a given query. @@ -501,7 +487,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return cls.get_default_schema(database) @classmethod - def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]: + def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]: """ Each engine can implement and converge its own specific exceptions into Superset DBAPI exceptions @@ -541,7 +527,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods @classmethod def get_allow_cost_estimate( # pylint: disable=unused-argument cls, - extra: Dict[str, Any], + extra: dict[str, Any], ) -> bool: return False @@ -561,8 +547,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods def get_engine( cls, database: Database, - schema: Optional[str] = None, - source: Optional[utils.QuerySource] = None, + schema: str | None = None, + source: utils.QuerySource | None = None, ) -> ContextManager[Engine]: """ Return an engine context manager. @@ -578,8 +564,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods def get_timestamp_expr( cls, col: ColumnClause, - pdf: Optional[str], - time_grain: Optional[str], + pdf: str | None, + time_grain: str | None, ) -> TimestampExpression: """ Construct a TimestampExpression to be used in a SQLAlchemy query. @@ -616,7 +602,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return TimestampExpression(time_expr, col, type_=col.type) @classmethod - def get_time_grains(cls) -> Tuple[TimeGrain, ...]: + def get_time_grains(cls) -> tuple[TimeGrain, ...]: """ Generate a tuple of supported time grains. @@ -634,8 +620,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods @classmethod def _sort_time_grains( - cls, val: Tuple[Optional[str], str], index: int - ) -> Union[float, int, str]: + cls, val: tuple[str | None, str], index: int + ) -> float | int | str: """ Return an ordered time-based value of a portion of a time grain for sorting @@ -695,7 +681,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return plist.get(index, 0) @classmethod - def get_time_grain_expressions(cls) -> Dict[Optional[str], str]: + def get_time_grain_expressions(cls) -> dict[str | None, str]: """ Return a dict of all supported time grains including any potential added grains but excluding any potentially disabled grains in the config file. @@ -706,7 +692,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods time_grain_expressions = cls._time_grain_expressions.copy() grain_addon_expressions = current_app.config["TIME_GRAIN_ADDON_EXPRESSIONS"] time_grain_expressions.update(grain_addon_expressions.get(cls.engine, {})) - denylist: List[str] = current_app.config["TIME_GRAIN_DENYLIST"] + denylist: list[str] = current_app.config["TIME_GRAIN_DENYLIST"] for key in denylist: time_grain_expressions.pop(key, None) @@ -723,9 +709,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods ) @classmethod - def fetch_data( - cls, cursor: Any, limit: Optional[int] = None - ) -> List[Tuple[Any, ...]]: + def fetch_data(cls, cursor: Any, limit: int | None = None) -> list[tuple[Any, ...]]: """ :param cursor: Cursor instance @@ -743,9 +727,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods @classmethod def expand_data( - cls, columns: List[ResultSetColumnType], data: List[Dict[Any, Any]] - ) -> Tuple[ - List[ResultSetColumnType], List[Dict[Any, Any]], List[ResultSetColumnType] + cls, columns: list[ResultSetColumnType], data: list[dict[Any, Any]] + ) -> tuple[ + list[ResultSetColumnType], list[dict[Any, Any]], list[ResultSetColumnType] ]: """ Some engines support expanding nested fields. See implementation in Presto @@ -759,7 +743,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return columns, data, [] @classmethod - def alter_new_orm_column(cls, orm_col: "TableColumn") -> None: + def alter_new_orm_column(cls, orm_col: TableColumn) -> None: """Allow altering default column attributes when first detected/added For instance special column like `__time` for Druid can be @@ -789,7 +773,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return cls.epoch_to_dttm().replace("{col}", "({col}/1000)") @classmethod - def get_datatype(cls, type_code: Any) -> Optional[str]: + def get_datatype(cls, type_code: Any) -> str | None: """ Change column type code from cursor description to string representation. @@ -802,7 +786,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods @classmethod @deprecated(deprecated_in="3.0") - def normalize_indexes(cls, indexes: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + def normalize_indexes(cls, indexes: list[dict[str, Any]]) -> list[dict[str, Any]]: """ Normalizes indexes for more consistency across db engines @@ -818,8 +802,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods cls, database: Database, table_name: str, - schema_name: Optional[str], - ) -> Dict[str, Any]: + schema_name: str | None, + ) -> dict[str, Any]: """ Returns engine-specific table metadata @@ -872,7 +856,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods sql_remainder = None sql = sql.strip(" \t\n;") sql_statement = sqlparse.format(sql, strip_comments=True) - query_limit: Optional[int] = sql_parse.extract_top_from_query( + query_limit: int | None = sql_parse.extract_top_from_query( sql_statement, cls.top_keywords ) if not limit: @@ -928,7 +912,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return True @classmethod - def get_limit_from_sql(cls, sql: str) -> Optional[int]: + def get_limit_from_sql(cls, sql: str) -> int | None: """ Extract limit from SQL query @@ -951,7 +935,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return parsed_query.set_or_update_query_limit(limit) @classmethod - def get_cte_query(cls, sql: str) -> Optional[str]: + def get_cte_query(cls, sql: str) -> str | None: """ Convert the input CTE based SQL to the SQL for virtual table conversion @@ -981,7 +965,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods database: Database, table: Table, df: pd.DataFrame, - to_sql_kwargs: Dict[str, Any], + to_sql_kwargs: dict[str, Any], ) -> None: """ Upload data from a Pandas DataFrame to a database. @@ -1012,8 +996,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods @classmethod def convert_dttm( # pylint: disable=unused-argument - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None - ) -> Optional[str]: + cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None + ) -> str | None: """ Convert a Python `datetime` object to a SQL expression. @@ -1044,8 +1028,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods @classmethod def extract_errors( - cls, ex: Exception, context: Optional[Dict[str, Any]] = None - ) -> List[SupersetError]: + cls, ex: Exception, context: dict[str, Any] | None = None + ) -> list[SupersetError]: raw_message = cls._extract_error_message(ex) context = context or {} @@ -1076,10 +1060,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods def adjust_engine_params( # pylint: disable=unused-argument cls, uri: URL, - connect_args: Dict[str, Any], - catalog: Optional[str] = None, - schema: Optional[str] = None, - ) -> Tuple[URL, Dict[str, Any]]: + connect_args: dict[str, Any], + catalog: str | None = None, + schema: str | None = None, + ) -> tuple[URL, dict[str, Any]]: """ Return a new URL and ``connect_args`` for a specific catalog/schema. @@ -1116,7 +1100,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods cls, database: Database, inspector: Inspector, - ) -> List[str]: + ) -> list[str]: """ Get all catalogs from database. @@ -1126,7 +1110,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return [] @classmethod - def get_schema_names(cls, inspector: Inspector) -> List[str]: + def get_schema_names(cls, inspector: Inspector) -> list[str]: """ Get all schemas from database @@ -1140,8 +1124,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods cls, database: Database, inspector: Inspector, - schema: Optional[str], - ) -> Set[str]: + schema: str | None, + ) -> set[str]: """ Get all the real table names within the specified schema. @@ -1168,8 +1152,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods cls, database: Database, inspector: Inspector, - schema: Optional[str], - ) -> Set[str]: + schema: str | None, + ) -> set[str]: """ Get all the view names within the specified schema. @@ -1197,8 +1181,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods database: Database, # pylint: disable=unused-argument inspector: Inspector, table_name: str, - schema: Optional[str], - ) -> List[Dict[str, Any]]: + schema: str | None, + ) -> list[dict[str, Any]]: """ Get the indexes associated with the specified schema/table. @@ -1213,8 +1197,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods @classmethod def get_table_comment( - cls, inspector: Inspector, table_name: str, schema: Optional[str] - ) -> Optional[str]: + cls, inspector: Inspector, table_name: str, schema: str | None + ) -> str | None: """ Get comment of table from a given schema and table @@ -1237,8 +1221,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods @classmethod def get_columns( - cls, inspector: Inspector, table_name: str, schema: Optional[str] - ) -> List[Dict[str, Any]]: + cls, inspector: Inspector, table_name: str, schema: str | None + ) -> list[dict[str, Any]]: """ Get all columns from a given schema and table @@ -1255,8 +1239,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods database: Database, inspector: Inspector, table_name: str, - schema: Optional[str], - ) -> List[MetricType]: + schema: str | None, + ) -> list[MetricType]: """ Get all metrics from a given schema and table. """ @@ -1273,11 +1257,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods def where_latest_partition( # pylint: disable=too-many-arguments,unused-argument cls, table_name: str, - schema: Optional[str], + schema: str | None, database: Database, query: Select, - columns: Optional[List[Dict[str, Any]]] = None, - ) -> Optional[Select]: + columns: list[dict[str, Any]] | None = None, + ) -> Select | None: """ Add a where clause to a query to reference only the most recent partition @@ -1293,7 +1277,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return None @classmethod - def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[Any]: + def _get_fields(cls, cols: list[dict[str, Any]]) -> list[Any]: return [column(c["name"]) for c in cols] @classmethod @@ -1302,12 +1286,12 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods database: Database, table_name: str, engine: Engine, - schema: Optional[str] = None, + schema: str | None = None, limit: int = 100, show_cols: bool = False, indent: bool = True, latest_partition: bool = True, - cols: Optional[List[Dict[str, Any]]] = None, + cols: list[dict[str, Any]] | None = None, ) -> str: """ Generate a "SELECT * from [schema.]table_name" query with appropriate limit. @@ -1326,7 +1310,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods :return: SQL query """ # pylint: disable=redefined-outer-name - fields: Union[str, List[Any]] = "*" + fields: str | list[Any] = "*" cols = cols or [] if (show_cols or latest_partition) and not cols: cols = database.get_columns(table_name, schema) @@ -1355,7 +1339,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return sql @classmethod - def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]: + def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]: """ Generate a SQL query that estimates the cost of a given statement. @@ -1367,8 +1351,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods @classmethod def query_cost_formatter( - cls, raw_cost: List[Dict[str, Any]] - ) -> List[Dict[str, str]]: + cls, raw_cost: list[dict[str, Any]] + ) -> list[dict[str, str]]: """ Format cost estimate. @@ -1405,8 +1389,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods database: Database, schema: str, sql: str, - source: Optional[utils.QuerySource] = None, - ) -> List[Dict[str, Any]]: + source: utils.QuerySource | None = None, + ) -> list[dict[str, Any]]: """ Estimate the cost of a multiple statement SQL query. @@ -1433,7 +1417,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods @classmethod def get_url_for_impersonation( - cls, url: URL, impersonate_user: bool, username: Optional[str] + cls, url: URL, impersonate_user: bool, username: str | None ) -> URL: """ Return a modified URL with the username set. @@ -1450,9 +1434,9 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods @classmethod def update_impersonation_config( cls, - connect_args: Dict[str, Any], + connect_args: dict[str, Any], uri: str, - username: Optional[str], + username: str | None, ) -> None: """ Update a configuration dictionary @@ -1490,7 +1474,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods raise cls.get_dbapi_mapped_exception(ex) from ex @classmethod - def make_label_compatible(cls, label: str) -> Union[str, quoted_name]: + def make_label_compatible(cls, label: str) -> str | quoted_name: """ Conditionally mutate and/or quote a sqlalchemy expression label. If force_column_alias_quotes is set to True, return the label as a @@ -1515,8 +1499,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods @classmethod def get_column_types( cls, - column_type: Optional[str], - ) -> Optional[Tuple[TypeEngine, GenericDataType]]: + column_type: str | None, + ) -> tuple[TypeEngine, GenericDataType] | None: """ Return a sqlalchemy native column type and generic data type that corresponds to the column type defined in the data source (return None to use default type @@ -1598,7 +1582,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods def get_function_names( # pylint: disable=unused-argument cls, database: Database, - ) -> List[str]: + ) -> list[str]: """ Get a list of function names that are able to be called on the database. Used for SQL Lab autocomplete. @@ -1609,7 +1593,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return [] @staticmethod - def pyodbc_rows_to_tuples(data: List[Any]) -> List[Tuple[Any, ...]]: + def pyodbc_rows_to_tuples(data: list[Any]) -> list[tuple[Any, ...]]: """ Convert pyodbc.Row objects from `fetch_data` to tuples. @@ -1634,7 +1618,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return None @staticmethod - def get_extra_params(database: Database) -> Dict[str, Any]: + def get_extra_params(database: Database) -> dict[str, Any]: """ Some databases require adding elements to connection parameters, like passing certificates to `extra`. This can be done here. @@ -1642,7 +1626,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods :param database: database instance from which to extract extras :raises CertificateException: If certificate is not valid/unparseable """ - extra: Dict[str, Any] = {} + extra: dict[str, Any] = {} if database.extra: try: extra = json.loads(database.extra) @@ -1653,7 +1637,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods @staticmethod def update_params_from_encrypted_extra( # pylint: disable=invalid-name - database: Database, params: Dict[str, Any] + database: Database, params: dict[str, Any] ) -> None: """ Some databases require some sensitive information which do not conform to @@ -1691,10 +1675,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods @classmethod def get_column_spec( # pylint: disable=unused-argument cls, - native_type: Optional[str], - db_extra: Optional[Dict[str, Any]] = None, + native_type: str | None, + db_extra: dict[str, Any] | None = None, source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE, - ) -> Optional[ColumnSpec]: + ) -> ColumnSpec | None: """ Get generic type related specs regarding a native column type. @@ -1714,10 +1698,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods @classmethod def get_sqla_column_type( cls, - native_type: Optional[str], - db_extra: Optional[Dict[str, Any]] = None, + native_type: str | None, + db_extra: dict[str, Any] | None = None, source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE, - ) -> Optional[TypeEngine]: + ) -> TypeEngine | None: """ Converts native database type to sqlalchemy column type. @@ -1761,7 +1745,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods cls, cursor: Any, query: Query, - ) -> Optional[str]: + ) -> str | None: """ Select identifiers from the database engine that uniquely identifies the queries to cancel. The identifier is typically a session id, process id @@ -1794,11 +1778,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return False @classmethod - def parse_sql(cls, sql: str) -> List[str]: + def parse_sql(cls, sql: str) -> list[str]: return [str(s).strip(" ;") for s in sqlparse.parse(sql)] @classmethod - def get_impersonation_key(cls, user: Optional[User]) -> Any: + def get_impersonation_key(cls, user: User | None) -> Any: """ Construct an impersonation key, by default it's the given username. @@ -1809,7 +1793,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return user.username if user else None @classmethod - def mask_encrypted_extra(cls, encrypted_extra: Optional[str]) -> Optional[str]: + def mask_encrypted_extra(cls, encrypted_extra: str | None) -> str | None: """ Mask ``encrypted_extra``. @@ -1822,9 +1806,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods # pylint: disable=unused-argument @classmethod - def unmask_encrypted_extra( - cls, old: Optional[str], new: Optional[str] - ) -> Optional[str]: + def unmask_encrypted_extra(cls, old: str | None, new: str | None) -> str | None: """ Remove masks from ``encrypted_extra``. @@ -1835,7 +1817,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods return new @classmethod - def get_public_information(cls) -> Dict[str, Any]: + def get_public_information(cls) -> dict[str, Any]: """ Construct a Dict with properties we want to expose. @@ -1891,12 +1873,12 @@ class BasicParametersSchema(Schema): class BasicParametersType(TypedDict, total=False): - username: Optional[str] - password: Optional[str] + username: str | None + password: str | None host: str port: int database: str - query: Dict[str, Any] + query: dict[str, Any] encryption: bool @@ -1929,13 +1911,13 @@ class BasicParametersMixin: # query parameter to enable encryption in the database connection # for Postgres this would be `{"sslmode": "verify-ca"}`, eg. - encryption_parameters: Dict[str, str] = {} + encryption_parameters: dict[str, str] = {} @classmethod def build_sqlalchemy_uri( # pylint: disable=unused-argument cls, parameters: BasicParametersType, - encrypted_extra: Optional[Dict[str, str]] = None, + encrypted_extra: dict[str, str] | None = None, ) -> str: # make a copy so that we don't update the original query = parameters.get("query", {}).copy() @@ -1958,7 +1940,7 @@ class BasicParametersMixin: @classmethod def get_parameters_from_uri( # pylint: disable=unused-argument - cls, uri: str, encrypted_extra: Optional[Dict[str, Any]] = None + cls, uri: str, encrypted_extra: dict[str, Any] | None = None ) -> BasicParametersType: url = make_url_safe(uri) query = { @@ -1982,14 +1964,14 @@ class BasicParametersMixin: @classmethod def validate_parameters( cls, properties: BasicPropertiesType - ) -> List[SupersetError]: + ) -> list[SupersetError]: """ Validates any number of parameters, for progressive validation. If only the hostname is present it will check if the name is resolvable. As more parameters are present in the request, more validation is done. """ - errors: List[SupersetError] = [] + errors: list[SupersetError] = [] required = {"host", "port", "username", "database"} parameters = properties.get("parameters", {}) diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py index 1f5068ad04..3b62f4bbb8 100644 --- a/superset/db_engine_specs/bigquery.py +++ b/superset/db_engine_specs/bigquery.py @@ -18,7 +18,8 @@ import json import re import urllib from datetime import datetime -from typing import Any, Dict, List, Optional, Pattern, Tuple, Type, TYPE_CHECKING +from re import Pattern +from typing import Any, Optional, TYPE_CHECKING import pandas as pd from apispec import APISpec @@ -99,8 +100,8 @@ class BigQueryParametersSchema(Schema): class BigQueryParametersType(TypedDict): - credentials_info: Dict[str, Any] - query: Dict[str, Any] + credentials_info: dict[str, Any] + query: dict[str, Any] class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-methods @@ -173,7 +174,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met "P1Y": "{func}({col}, YEAR)", } - custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = { + custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { CONNECTION_DATABASE_PERMISSIONS_REGEX: ( __( "Unable to connect. Verify that the following roles are set " @@ -219,7 +220,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) if isinstance(sqla_type, types.Date): @@ -235,7 +236,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met @classmethod def fetch_data( cls, cursor: Any, limit: Optional[int] = None - ) -> List[Tuple[Any, ...]]: + ) -> list[tuple[Any, ...]]: data = super().fetch_data(cursor, limit) # Support type BigQuery Row, introduced here PR #4071 # google.cloud.bigquery.table.Row @@ -280,7 +281,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met @classmethod @deprecated(deprecated_in="3.0") - def normalize_indexes(cls, indexes: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + def normalize_indexes(cls, indexes: list[dict[str, Any]]) -> list[dict[str, Any]]: """ Normalizes indexes for more consistency across db engines @@ -305,7 +306,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met inspector: Inspector, table_name: str, schema: Optional[str], - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """ Get the indexes associated with the specified schema/table. @@ -321,7 +322,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met @classmethod def extra_table_metadata( cls, database: "Database", table_name: str, schema_name: Optional[str] - ) -> Dict[str, Any]: + ) -> dict[str, Any]: indexes = database.get_indexes(table_name, schema_name) if not indexes: return {} @@ -354,7 +355,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met database: "Database", table: Table, df: pd.DataFrame, - to_sql_kwargs: Dict[str, Any], + to_sql_kwargs: dict[str, Any], ) -> None: """ Upload data from a Pandas DataFrame to a database. @@ -421,7 +422,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met schema: str, sql: str, source: Optional[utils.QuerySource] = None, - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """ Estimate the cost of a multiple statement SQL query. @@ -448,7 +449,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met cls, database: "Database", inspector: Inspector, - ) -> List[str]: + ) -> list[str]: """ Get all catalogs. @@ -462,11 +463,11 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met return sorted(project.project_id for project in projects) @classmethod - def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool: + def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool: return True @classmethod - def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]: + def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]: with cls.get_engine(cursor) as engine: client = cls._get_client(engine) job_config = bigquery.QueryJobConfig(dry_run=True) @@ -503,15 +504,15 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met @classmethod def query_cost_formatter( - cls, raw_cost: List[Dict[str, Any]] - ) -> List[Dict[str, str]]: + cls, raw_cost: list[dict[str, Any]] + ) -> list[dict[str, str]]: return [{k: str(v) for k, v in row.items()} for row in raw_cost] @classmethod def build_sqlalchemy_uri( cls, parameters: BigQueryParametersType, - encrypted_extra: Optional[Dict[str, Any]] = None, + encrypted_extra: Optional[dict[str, Any]] = None, ) -> str: query = parameters.get("query", {}) query_params = urllib.parse.urlencode(query) @@ -533,7 +534,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met def get_parameters_from_uri( cls, uri: str, - encrypted_extra: Optional[Dict[str, Any]] = None, + encrypted_extra: Optional[dict[str, Any]] = None, ) -> Any: value = make_url_safe(uri) @@ -592,7 +593,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met return json.dumps(new_config) @classmethod - def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]: + def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]: # pylint: disable=import-outside-toplevel from google.auth.exceptions import DefaultCredentialsError @@ -602,7 +603,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met def validate_parameters( cls, properties: BasicPropertiesType, # pylint: disable=unused-argument - ) -> List[SupersetError]: + ) -> list[SupersetError]: return [] @classmethod @@ -636,7 +637,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met show_cols: bool = False, indent: bool = True, latest_partition: bool = True, - cols: Optional[List[Dict[str, Any]]] = None, + cols: Optional[list[dict[str, Any]]] = None, ) -> str: """ Remove array structures from `SELECT *`. @@ -699,7 +700,7 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met ) @classmethod - def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[Any]: + def _get_fields(cls, cols: list[dict[str, Any]]) -> list[Any]: """ Label columns using their fully qualified name. diff --git a/superset/db_engine_specs/clickhouse.py b/superset/db_engine_specs/clickhouse.py index a62087bc6a..af38c15e0b 100644 --- a/superset/db_engine_specs/clickhouse.py +++ b/superset/db_engine_specs/clickhouse.py @@ -19,7 +19,7 @@ from __future__ import annotations import logging import re from datetime import datetime -from typing import Any, cast, Dict, List, Optional, Type, TYPE_CHECKING +from typing import Any, cast, TYPE_CHECKING from flask import current_app from flask_babel import gettext as __ @@ -124,8 +124,8 @@ class ClickHouseBaseEngineSpec(BaseEngineSpec): @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None - ) -> Optional[str]: + cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None + ) -> str | None: sqla_type = cls.get_sqla_column_type(target_type) if isinstance(sqla_type, types.Date): @@ -145,7 +145,7 @@ class ClickHouseEngineSpec(ClickHouseBaseEngineSpec): supports_file_upload = False @classmethod - def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]: + def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]: return {NewConnectionError: SupersetDBAPIDatabaseError} @classmethod @@ -159,7 +159,7 @@ class ClickHouseEngineSpec(ClickHouseBaseEngineSpec): @classmethod @cache_manager.cache.memoize() - def get_function_names(cls, database: Database) -> List[str]: + def get_function_names(cls, database: Database) -> list[str]: """ Get a list of function names that are able to be called on the database. Used for SQL Lab autocomplete. @@ -256,7 +256,7 @@ class ClickHouseConnectEngineSpec(ClickHouseEngineSpec, BasicParametersMixin): engine_name = "ClickHouse Connect (Superset)" default_driver = "connect" - _function_names: List[str] = [] + _function_names: list[str] = [] sqlalchemy_uri_placeholder = ( "clickhousedb://user:password@host[:port][/dbname][?secure=value&=value...]" @@ -265,7 +265,7 @@ class ClickHouseConnectEngineSpec(ClickHouseEngineSpec, BasicParametersMixin): encryption_parameters = {"secure": "true"} @classmethod - def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]: + def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]: return {} @classmethod @@ -278,7 +278,7 @@ class ClickHouseConnectEngineSpec(ClickHouseEngineSpec, BasicParametersMixin): return new_exception(str(exception)) @classmethod - def get_function_names(cls, database: Database) -> List[str]: + def get_function_names(cls, database: Database) -> list[str]: # pylint: disable=import-outside-toplevel,import-error from clickhouse_connect.driver.exceptions import ClickHouseError @@ -304,7 +304,7 @@ class ClickHouseConnectEngineSpec(ClickHouseEngineSpec, BasicParametersMixin): def build_sqlalchemy_uri( cls, parameters: BasicParametersType, - encrypted_extra: Optional[Dict[str, str]] = None, + encrypted_extra: dict[str, str] | None = None, ) -> str: url_params = parameters.copy() if url_params.get("encryption"): @@ -318,7 +318,7 @@ class ClickHouseConnectEngineSpec(ClickHouseEngineSpec, BasicParametersMixin): @classmethod def get_parameters_from_uri( - cls, uri: str, encrypted_extra: Optional[Dict[str, Any]] = None + cls, uri: str, encrypted_extra: dict[str, Any] | None = None ) -> BasicParametersType: url = make_url_safe(uri) query = url.query @@ -340,7 +340,7 @@ class ClickHouseConnectEngineSpec(ClickHouseEngineSpec, BasicParametersMixin): @classmethod def validate_parameters( cls, properties: BasicPropertiesType - ) -> List[SupersetError]: + ) -> list[SupersetError]: # pylint: disable=import-outside-toplevel,import-error from clickhouse_connect.driver import default_port diff --git a/superset/db_engine_specs/crate.py b/superset/db_engine_specs/crate.py index 6eafae829e..d8d91c6796 100644 --- a/superset/db_engine_specs/crate.py +++ b/superset/db_engine_specs/crate.py @@ -17,7 +17,7 @@ from __future__ import annotations from datetime import datetime -from typing import Any, Dict, Optional, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from sqlalchemy import types @@ -53,8 +53,8 @@ class CrateEngineSpec(BaseEngineSpec): @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None - ) -> Optional[str]: + cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None + ) -> str | None: sqla_type = cls.get_sqla_column_type(target_type) if isinstance(sqla_type, types.TIMESTAMP): diff --git a/superset/db_engine_specs/databricks.py b/superset/db_engine_specs/databricks.py index 5f12f3174d..5df24be65d 100644 --- a/superset/db_engine_specs/databricks.py +++ b/superset/db_engine_specs/databricks.py @@ -17,7 +17,7 @@ import json from datetime import datetime -from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin @@ -135,7 +135,7 @@ class DatabricksODBCEngineSpec(BaseEngineSpec): @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: return HiveEngineSpec.convert_dttm(target_type, dttm, db_extra=db_extra) @@ -160,14 +160,14 @@ class DatabricksNativeEngineSpec(DatabricksODBCEngineSpec, BasicParametersMixin) encryption_parameters = {"ssl": "1"} @staticmethod - def get_extra_params(database: "Database") -> Dict[str, Any]: + def get_extra_params(database: "Database") -> dict[str, Any]: """ Add a user agent to be used in the requests. Trim whitespace from connect_args to avoid databricks driver errors """ - extra: Dict[str, Any] = BaseEngineSpec.get_extra_params(database) - engine_params: Dict[str, Any] = extra.setdefault("engine_params", {}) - connect_args: Dict[str, Any] = engine_params.setdefault("connect_args", {}) + extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database) + engine_params: dict[str, Any] = extra.setdefault("engine_params", {}) + connect_args: dict[str, Any] = engine_params.setdefault("connect_args", {}) connect_args.setdefault("http_headers", [("User-Agent", USER_AGENT)]) connect_args.setdefault("_user_agent_entry", USER_AGENT) @@ -184,7 +184,7 @@ class DatabricksNativeEngineSpec(DatabricksODBCEngineSpec, BasicParametersMixin) database: "Database", inspector: Inspector, schema: Optional[str], - ) -> Set[str]: + ) -> set[str]: return super().get_table_names( database, inspector, schema ) - cls.get_view_names(database, inspector, schema) @@ -213,8 +213,8 @@ class DatabricksNativeEngineSpec(DatabricksODBCEngineSpec, BasicParametersMixin) @classmethod def extract_errors( - cls, ex: Exception, context: Optional[Dict[str, Any]] = None - ) -> List[SupersetError]: + cls, ex: Exception, context: Optional[dict[str, Any]] = None + ) -> list[SupersetError]: raw_message = cls._extract_error_message(ex) context = context or {} @@ -271,8 +271,8 @@ class DatabricksNativeEngineSpec(DatabricksODBCEngineSpec, BasicParametersMixin) def validate_parameters( # type: ignore cls, properties: DatabricksPropertiesType, - ) -> List[SupersetError]: - errors: List[SupersetError] = [] + ) -> list[SupersetError]: + errors: list[SupersetError] = [] required = {"access_token", "host", "port", "database", "extra"} extra = json.loads(properties.get("extra", "{}")) engine_params = extra.get("engine_params", {}) diff --git a/superset/db_engine_specs/dremio.py b/superset/db_engine_specs/dremio.py index 7fae3014d6..7b4c0458cd 100644 --- a/superset/db_engine_specs/dremio.py +++ b/superset/db_engine_specs/dremio.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from datetime import datetime -from typing import Any, Dict, Optional +from typing import Any, Optional from sqlalchemy import types @@ -46,7 +46,7 @@ class DremioEngineSpec(BaseEngineSpec): @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) diff --git a/superset/db_engine_specs/drill.py b/superset/db_engine_specs/drill.py index 16ac89212a..946544863d 100644 --- a/superset/db_engine_specs/drill.py +++ b/superset/db_engine_specs/drill.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from datetime import datetime -from typing import Any, Dict, Optional, Tuple +from typing import Any, Optional from urllib import parse from sqlalchemy import types @@ -59,7 +59,7 @@ class DrillEngineSpec(BaseEngineSpec): @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) @@ -74,10 +74,10 @@ class DrillEngineSpec(BaseEngineSpec): def adjust_engine_params( cls, uri: URL, - connect_args: Dict[str, Any], + connect_args: dict[str, Any], catalog: Optional[str] = None, schema: Optional[str] = None, - ) -> Tuple[URL, Dict[str, Any]]: + ) -> tuple[URL, dict[str, Any]]: if schema: uri = uri.set(database=parse.quote(schema.replace(".", "/"), safe="")) @@ -87,7 +87,7 @@ class DrillEngineSpec(BaseEngineSpec): def get_schema_from_engine_params( cls, sqlalchemy_uri: URL, - connect_args: Dict[str, Any], + connect_args: dict[str, Any], ) -> Optional[str]: """ Return the configured schema. diff --git a/superset/db_engine_specs/druid.py b/superset/db_engine_specs/druid.py index 83829ec22a..43ce310a40 100644 --- a/superset/db_engine_specs/druid.py +++ b/superset/db_engine_specs/druid.py @@ -20,7 +20,7 @@ from __future__ import annotations import json import logging from datetime import datetime -from typing import Any, Dict, List, Optional, Type, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from sqlalchemy import types from sqlalchemy.engine.reflection import Inspector @@ -79,7 +79,7 @@ class DruidEngineSpec(BaseEngineSpec): orm_col.is_dttm = True @staticmethod - def get_extra_params(database: Database) -> Dict[str, Any]: + def get_extra_params(database: Database) -> dict[str, Any]: """ For Druid, the path to a SSL certificate is placed in `connect_args`. @@ -104,8 +104,8 @@ class DruidEngineSpec(BaseEngineSpec): @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None - ) -> Optional[str]: + cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None + ) -> str | None: sqla_type = cls.get_sqla_column_type(target_type) if isinstance(sqla_type, types.Date): @@ -130,15 +130,15 @@ class DruidEngineSpec(BaseEngineSpec): @classmethod def get_columns( - cls, inspector: Inspector, table_name: str, schema: Optional[str] - ) -> List[Dict[str, Any]]: + cls, inspector: Inspector, table_name: str, schema: str | None + ) -> list[dict[str, Any]]: """ Update the Druid type map. """ return super().get_columns(inspector, table_name, schema) @classmethod - def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]: + def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]: # pylint: disable=import-outside-toplevel from requests import exceptions as requests_exceptions diff --git a/superset/db_engine_specs/duckdb.py b/superset/db_engine_specs/duckdb.py index 1248287b84..3bbf9ecc38 100644 --- a/superset/db_engine_specs/duckdb.py +++ b/superset/db_engine_specs/duckdb.py @@ -18,7 +18,8 @@ from __future__ import annotations import re from datetime import datetime -from typing import Any, Dict, Optional, Pattern, Set, Tuple, TYPE_CHECKING +from re import Pattern +from typing import Any, TYPE_CHECKING from flask_babel import gettext as __ from sqlalchemy import types @@ -51,7 +52,7 @@ class DuckDBEngineSpec(BaseEngineSpec): "P1Y": "DATE_TRUNC('year', {col})", } - custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = { + custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { COLUMN_DOES_NOT_EXIST_REGEX: ( __('We can\'t seem to resolve the column "%(column_name)s"'), SupersetErrorType.COLUMN_DOES_NOT_EXIST_ERROR, @@ -65,8 +66,8 @@ class DuckDBEngineSpec(BaseEngineSpec): @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None - ) -> Optional[str]: + cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None + ) -> str | None: sqla_type = cls.get_sqla_column_type(target_type) if isinstance(sqla_type, (types.String, types.DateTime)): @@ -75,6 +76,6 @@ class DuckDBEngineSpec(BaseEngineSpec): @classmethod def get_table_names( - cls, database: Database, inspector: Inspector, schema: Optional[str] - ) -> Set[str]: + cls, database: Database, inspector: Inspector, schema: str | None + ) -> set[str]: return set(inspector.get_table_names(schema)) diff --git a/superset/db_engine_specs/dynamodb.py b/superset/db_engine_specs/dynamodb.py index c398a9c1df..5f7a9e2b71 100644 --- a/superset/db_engine_specs/dynamodb.py +++ b/superset/db_engine_specs/dynamodb.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from datetime import datetime -from typing import Any, Dict, Optional +from typing import Any, Optional from sqlalchemy import types @@ -55,7 +55,7 @@ class DynamoDBEngineSpec(BaseEngineSpec): @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) diff --git a/superset/db_engine_specs/elasticsearch.py b/superset/db_engine_specs/elasticsearch.py index 934aa0bb03..d717c52bf5 100644 --- a/superset/db_engine_specs/elasticsearch.py +++ b/superset/db_engine_specs/elasticsearch.py @@ -16,7 +16,7 @@ # under the License. import logging from datetime import datetime -from typing import Any, Dict, Optional, Type +from typing import Any, Optional from packaging.version import Version from sqlalchemy import types @@ -50,10 +50,10 @@ class ElasticSearchEngineSpec(BaseEngineSpec): # pylint: disable=abstract-metho "P1Y": "HISTOGRAM({col}, INTERVAL 1 YEAR)", } - type_code_map: Dict[int, str] = {} # loaded from get_datatype only if needed + type_code_map: dict[int, str] = {} # loaded from get_datatype only if needed @classmethod - def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]: + def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]: # pylint: disable=import-error,import-outside-toplevel import es.exceptions as es_exceptions @@ -65,7 +65,7 @@ class ElasticSearchEngineSpec(BaseEngineSpec): # pylint: disable=abstract-metho @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: db_extra = db_extra or {} @@ -117,7 +117,7 @@ class OpenDistroEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) diff --git a/superset/db_engine_specs/exasol.py b/superset/db_engine_specs/exasol.py index c06fbd826d..6da56e2fee 100644 --- a/superset/db_engine_specs/exasol.py +++ b/superset/db_engine_specs/exasol.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, List, Optional, Tuple +from typing import Any, Optional from superset.db_engine_specs.base import BaseEngineSpec @@ -42,7 +42,7 @@ class ExasolEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method @classmethod def fetch_data( cls, cursor: Any, limit: Optional[int] = None - ) -> List[Tuple[Any, ...]]: + ) -> list[tuple[Any, ...]]: data = super().fetch_data(cursor, limit) # Lists of `pyodbc.Row` need to be unpacked further return cls.pyodbc_rows_to_tuples(data) diff --git a/superset/db_engine_specs/firebird.py b/superset/db_engine_specs/firebird.py index 306a642dc3..4448074157 100644 --- a/superset/db_engine_specs/firebird.py +++ b/superset/db_engine_specs/firebird.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from datetime import datetime -from typing import Any, Dict, Optional +from typing import Any, Optional from sqlalchemy import types @@ -72,7 +72,7 @@ class FirebirdEngineSpec(BaseEngineSpec): @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) diff --git a/superset/db_engine_specs/firebolt.py b/superset/db_engine_specs/firebolt.py index 65cd714352..ace3d6b3b2 100644 --- a/superset/db_engine_specs/firebolt.py +++ b/superset/db_engine_specs/firebolt.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from datetime import datetime -from typing import Any, Dict, Optional +from typing import Any, Optional from sqlalchemy import types @@ -43,7 +43,7 @@ class FireboltEngineSpec(BaseEngineSpec): @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) diff --git a/superset/db_engine_specs/gsheets.py b/superset/db_engine_specs/gsheets.py index 73a66c464f..abf5bac48f 100644 --- a/superset/db_engine_specs/gsheets.py +++ b/superset/db_engine_specs/gsheets.py @@ -16,7 +16,8 @@ # under the License. import json import re -from typing import Any, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING +from re import Pattern +from typing import Any, Optional, TYPE_CHECKING from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin @@ -56,12 +57,12 @@ class GSheetsParametersSchema(Schema): class GSheetsParametersType(TypedDict): service_account_info: str - catalog: Optional[Dict[str, str]] + catalog: Optional[dict[str, str]] class GSheetsPropertiesType(TypedDict): parameters: GSheetsParametersType - catalog: Dict[str, str] + catalog: dict[str, str] class GSheetsEngineSpec(SqliteEngineSpec): @@ -77,7 +78,7 @@ class GSheetsEngineSpec(SqliteEngineSpec): default_driver = "apsw" sqlalchemy_uri_placeholder = "gsheets://" - custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = { + custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { SYNTAX_ERROR_REGEX: ( __( 'Please check your query for syntax errors near "%(server_error)s". ' @@ -110,7 +111,7 @@ class GSheetsEngineSpec(SqliteEngineSpec): database: "Database", table_name: str, schema_name: Optional[str], - ) -> Dict[str, Any]: + ) -> dict[str, Any]: with database.get_raw_connection(schema=schema_name) as conn: cursor = conn.cursor() cursor.execute(f'SELECT GET_METADATA("{table_name}")') @@ -127,7 +128,7 @@ class GSheetsEngineSpec(SqliteEngineSpec): cls, _: GSheetsParametersType, encrypted_extra: Optional[ # pylint: disable=unused-argument - Dict[str, Any] + dict[str, Any] ] = None, ) -> str: return "gsheets://" @@ -136,7 +137,7 @@ class GSheetsEngineSpec(SqliteEngineSpec): def get_parameters_from_uri( cls, uri: str, # pylint: disable=unused-argument - encrypted_extra: Optional[Dict[str, Any]] = None, + encrypted_extra: Optional[dict[str, Any]] = None, ) -> Any: # Building parameters from encrypted_extra and uri if encrypted_extra: @@ -214,8 +215,8 @@ class GSheetsEngineSpec(SqliteEngineSpec): def validate_parameters( cls, properties: GSheetsPropertiesType, - ) -> List[SupersetError]: - errors: List[SupersetError] = [] + ) -> list[SupersetError]: + errors: list[SupersetError] = [] # backwards compatible just incase people are send data # via parameters for validation diff --git a/superset/db_engine_specs/hana.py b/superset/db_engine_specs/hana.py index e579550b2e..108838f9d2 100644 --- a/superset/db_engine_specs/hana.py +++ b/superset/db_engine_specs/hana.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from datetime import datetime -from typing import Any, Dict, Optional +from typing import Any, Optional from sqlalchemy import types @@ -45,7 +45,7 @@ class HanaEngineSpec(PostgresBaseEngineSpec): @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index 6d8986c1c7..7601ebb2cd 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -22,7 +22,7 @@ import re import tempfile import time from datetime import datetime -from typing import Any, Dict, List, Optional, Set, Tuple, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from urllib import parse import numpy as np @@ -150,9 +150,7 @@ class HiveEngineSpec(PrestoEngineSpec): hive.Cursor.fetch_logs = fetch_logs @classmethod - def fetch_data( - cls, cursor: Any, limit: Optional[int] = None - ) -> List[Tuple[Any, ...]]: + def fetch_data(cls, cursor: Any, limit: int | None = None) -> list[tuple[Any, ...]]: # pylint: disable=import-outside-toplevel import pyhive from TCLIService import ttypes @@ -168,10 +166,10 @@ class HiveEngineSpec(PrestoEngineSpec): @classmethod def df_to_sql( cls, - database: "Database", + database: Database, table: Table, df: pd.DataFrame, - to_sql_kwargs: Dict[str, Any], + to_sql_kwargs: dict[str, Any], ) -> None: """ Upload data from a Pandas DataFrame to a database. @@ -248,8 +246,8 @@ class HiveEngineSpec(PrestoEngineSpec): @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None - ) -> Optional[str]: + cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None + ) -> str | None: sqla_type = cls.get_sqla_column_type(target_type) if isinstance(sqla_type, types.Date): @@ -263,10 +261,10 @@ class HiveEngineSpec(PrestoEngineSpec): def adjust_engine_params( cls, uri: URL, - connect_args: Dict[str, Any], - catalog: Optional[str] = None, - schema: Optional[str] = None, - ) -> Tuple[URL, Dict[str, Any]]: + connect_args: dict[str, Any], + catalog: str | None = None, + schema: str | None = None, + ) -> tuple[URL, dict[str, Any]]: if schema: uri = uri.set(database=parse.quote(schema, safe="")) @@ -276,8 +274,8 @@ class HiveEngineSpec(PrestoEngineSpec): def get_schema_from_engine_params( cls, sqlalchemy_uri: URL, - connect_args: Dict[str, Any], - ) -> Optional[str]: + connect_args: dict[str, Any], + ) -> str | None: """ Return the configured schema. """ @@ -292,10 +290,10 @@ class HiveEngineSpec(PrestoEngineSpec): return msg @classmethod - def progress(cls, log_lines: List[str]) -> int: + def progress(cls, log_lines: list[str]) -> int: total_jobs = 1 # assuming there's at least 1 job current_job = 1 - stages: Dict[int, float] = {} + stages: dict[int, float] = {} for line in log_lines: match = cls.jobs_stats_r.match(line) if match: @@ -323,7 +321,7 @@ class HiveEngineSpec(PrestoEngineSpec): return int(progress) @classmethod - def get_tracking_url_from_logs(cls, log_lines: List[str]) -> Optional[str]: + def get_tracking_url_from_logs(cls, log_lines: list[str]) -> str | None: lkp = "Tracking URL = " for line in log_lines: if lkp in line: @@ -407,19 +405,19 @@ class HiveEngineSpec(PrestoEngineSpec): @classmethod def get_columns( - cls, inspector: Inspector, table_name: str, schema: Optional[str] - ) -> List[Dict[str, Any]]: + cls, inspector: Inspector, table_name: str, schema: str | None + ) -> list[dict[str, Any]]: return inspector.get_columns(table_name, schema) @classmethod def where_latest_partition( # pylint: disable=too-many-arguments cls, table_name: str, - schema: Optional[str], - database: "Database", + schema: str | None, + database: Database, query: Select, - columns: Optional[List[Dict[str, Any]]] = None, - ) -> Optional[Select]: + columns: list[dict[str, Any]] | None = None, + ) -> Select | None: try: col_names, values = cls.latest_partition( table_name, schema, database, show_first=True @@ -437,18 +435,18 @@ class HiveEngineSpec(PrestoEngineSpec): return None @classmethod - def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[ColumnClause]: + def _get_fields(cls, cols: list[dict[str, Any]]) -> list[ColumnClause]: return BaseEngineSpec._get_fields(cols) # pylint: disable=protected-access @classmethod def latest_sub_partition( # type: ignore - cls, table_name: str, schema: Optional[str], database: "Database", **kwargs: Any + cls, table_name: str, schema: str | None, database: Database, **kwargs: Any ) -> str: # TODO(bogdan): implement` pass @classmethod - def _latest_partition_from_df(cls, df: pd.DataFrame) -> Optional[List[str]]: + def _latest_partition_from_df(cls, df: pd.DataFrame) -> list[str] | None: """Hive partitions look like ds={partition name}/ds={partition name}""" if not df.empty: return [ @@ -461,12 +459,12 @@ class HiveEngineSpec(PrestoEngineSpec): def _partition_query( # pylint: disable=too-many-arguments cls, table_name: str, - schema: Optional[str], - indexes: List[Dict[str, Any]], - database: "Database", + schema: str | None, + indexes: list[dict[str, Any]], + database: Database, limit: int = 0, - order_by: Optional[List[Tuple[str, bool]]] = None, - filters: Optional[Dict[Any, Any]] = None, + order_by: list[tuple[str, bool]] | None = None, + filters: dict[Any, Any] | None = None, ) -> str: full_table_name = f"{schema}.{table_name}" if schema else table_name return f"SHOW PARTITIONS {full_table_name}" @@ -474,15 +472,15 @@ class HiveEngineSpec(PrestoEngineSpec): @classmethod def select_star( # pylint: disable=too-many-arguments cls, - database: "Database", + database: Database, table_name: str, engine: Engine, - schema: Optional[str] = None, + schema: str | None = None, limit: int = 100, show_cols: bool = False, indent: bool = True, latest_partition: bool = True, - cols: Optional[List[Dict[str, Any]]] = None, + cols: list[dict[str, Any]] | None = None, ) -> str: return super( # pylint: disable=bad-super-call PrestoEngineSpec, cls @@ -500,7 +498,7 @@ class HiveEngineSpec(PrestoEngineSpec): @classmethod def get_url_for_impersonation( - cls, url: URL, impersonate_user: bool, username: Optional[str] + cls, url: URL, impersonate_user: bool, username: str | None ) -> URL: """ Return a modified URL with the username set. @@ -516,9 +514,9 @@ class HiveEngineSpec(PrestoEngineSpec): @classmethod def update_impersonation_config( cls, - connect_args: Dict[str, Any], + connect_args: dict[str, Any], uri: str, - username: Optional[str], + username: str | None, ) -> None: """ Update a configuration dictionary @@ -549,7 +547,7 @@ class HiveEngineSpec(PrestoEngineSpec): @classmethod @cache_manager.cache.memoize() - def get_function_names(cls, database: "Database") -> List[str]: + def get_function_names(cls, database: Database) -> list[str]: """ Get a list of function names that are able to be called on the database. Used for SQL Lab autocomplete. @@ -600,10 +598,10 @@ class HiveEngineSpec(PrestoEngineSpec): @classmethod def get_view_names( cls, - database: "Database", + database: Database, inspector: Inspector, - schema: Optional[str], - ) -> Set[str]: + schema: str | None, + ) -> set[str]: """ Get all the view names within the specified schema. @@ -635,9 +633,9 @@ class HiveEngineSpec(PrestoEngineSpec): # TODO: contribute back to pyhive. def fetch_logs( # pylint: disable=protected-access - self: "Cursor", + self: Cursor, _max_rows: int = 1024, - orientation: Optional["TFetchOrientation"] = None, + orientation: TFetchOrientation | None = None, ) -> str: """Mocked. Retrieve the logs produced by the execution of the query. Can be called multiple times to fetch the logs produced after diff --git a/superset/db_engine_specs/impala.py b/superset/db_engine_specs/impala.py index e59c2b74fb..cd1c9e4732 100644 --- a/superset/db_engine_specs/impala.py +++ b/superset/db_engine_specs/impala.py @@ -18,7 +18,7 @@ import logging import re import time from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask import current_app from sqlalchemy import types @@ -57,7 +57,7 @@ class ImpalaEngineSpec(BaseEngineSpec): @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) @@ -68,7 +68,7 @@ class ImpalaEngineSpec(BaseEngineSpec): return None @classmethod - def get_schema_names(cls, inspector: Inspector) -> List[str]: + def get_schema_names(cls, inspector: Inspector) -> list[str]: schemas = [ row[0] for row in inspector.engine.execute("SHOW SCHEMAS") diff --git a/superset/db_engine_specs/kusto.py b/superset/db_engine_specs/kusto.py index 9fddb23d26..17147d5cc0 100644 --- a/superset/db_engine_specs/kusto.py +++ b/superset/db_engine_specs/kusto.py @@ -16,7 +16,7 @@ # under the License. import re from datetime import datetime -from typing import Any, Dict, List, Optional, Type +from typing import Any, Optional from sqlalchemy import types from sqlalchemy.dialects.mssql.base import SMALLDATETIME @@ -61,7 +61,7 @@ class KustoSqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method " DATEDIFF(week, 0, DATEADD(day, -1, {col})), 0)", } - type_code_map: Dict[int, str] = {} # loaded from get_datatype only if needed + type_code_map: dict[int, str] = {} # loaded from get_datatype only if needed column_type_mappings = ( ( @@ -72,7 +72,7 @@ class KustoSqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method ) @classmethod - def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]: + def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]: # pylint: disable=import-outside-toplevel,import-error import sqlalchemy_kusto.errors as kusto_exceptions @@ -84,7 +84,7 @@ class KustoSqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) @@ -128,10 +128,10 @@ class KustoKqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method "P1Y": "datetime_diff('year',CreateDate, datetime(0001-01-01 00:00:00))+1", } - type_code_map: Dict[int, str] = {} # loaded from get_datatype only if needed + type_code_map: dict[int, str] = {} # loaded from get_datatype only if needed @classmethod - def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]: + def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]: # pylint: disable=import-outside-toplevel,import-error import sqlalchemy_kusto.errors as kusto_exceptions @@ -143,7 +143,7 @@ class KustoKqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) @@ -168,7 +168,7 @@ class KustoKqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method return not parsed_query.sql.startswith(".") @classmethod - def parse_sql(cls, sql: str) -> List[str]: + def parse_sql(cls, sql: str) -> list[str]: """ Kusto supports a single query statement, but it could include sub queries and variables declared via let keyword. diff --git a/superset/db_engine_specs/kylin.py b/superset/db_engine_specs/kylin.py index e340daea51..f522602a48 100644 --- a/superset/db_engine_specs/kylin.py +++ b/superset/db_engine_specs/kylin.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from datetime import datetime -from typing import Any, Dict, Optional +from typing import Any, Optional from sqlalchemy import types @@ -42,7 +42,7 @@ class KylinEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) diff --git a/superset/db_engine_specs/mssql.py b/superset/db_engine_specs/mssql.py index 8b38ec7421..3e0879b904 100644 --- a/superset/db_engine_specs/mssql.py +++ b/superset/db_engine_specs/mssql.py @@ -17,7 +17,8 @@ import logging import re from datetime import datetime -from typing import Any, Dict, List, Optional, Pattern, Tuple +from re import Pattern +from typing import Any, Optional from flask_babel import gettext as __ from sqlalchemy import types @@ -80,7 +81,7 @@ class MssqlEngineSpec(BaseEngineSpec): ), ) - custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = { + custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { CONNECTION_ACCESS_DENIED_REGEX: ( __( 'Either the username "%(username)s", password, ' @@ -115,7 +116,7 @@ class MssqlEngineSpec(BaseEngineSpec): @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) @@ -132,7 +133,7 @@ class MssqlEngineSpec(BaseEngineSpec): @classmethod def fetch_data( cls, cursor: Any, limit: Optional[int] = None - ) -> List[Tuple[Any, ...]]: + ) -> list[tuple[Any, ...]]: data = super().fetch_data(cursor, limit) # Lists of `pyodbc.Row` need to be unpacked further return cls.pyodbc_rows_to_tuples(data) diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py index 6258f6b21a..9f853d577c 100644 --- a/superset/db_engine_specs/mysql.py +++ b/superset/db_engine_specs/mysql.py @@ -16,7 +16,8 @@ # under the License. import re from datetime import datetime -from typing import Any, Dict, Optional, Pattern, Tuple +from re import Pattern +from typing import Any, Optional from urllib import parse from flask_babel import gettext as __ @@ -143,9 +144,9 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin): "INTERVAL 1 DAY)) - 1 DAY))", } - type_code_map: Dict[int, str] = {} # loaded from get_datatype only if needed + type_code_map: dict[int, str] = {} # loaded from get_datatype only if needed - custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = { + custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { CONNECTION_ACCESS_DENIED_REGEX: ( __('Either the username "%(username)s" or the password is incorrect.'), SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR, @@ -186,7 +187,7 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin): @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) @@ -201,10 +202,10 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin): def adjust_engine_params( cls, uri: URL, - connect_args: Dict[str, Any], + connect_args: dict[str, Any], catalog: Optional[str] = None, schema: Optional[str] = None, - ) -> Tuple[URL, Dict[str, Any]]: + ) -> tuple[URL, dict[str, Any]]: uri, new_connect_args = super().adjust_engine_params( uri, connect_args, @@ -221,7 +222,7 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin): def get_schema_from_engine_params( cls, sqlalchemy_uri: URL, - connect_args: Dict[str, Any], + connect_args: dict[str, Any], ) -> Optional[str]: """ Return the configured schema. diff --git a/superset/db_engine_specs/ocient.py b/superset/db_engine_specs/ocient.py index 4b8a59117e..59fa52a656 100644 --- a/superset/db_engine_specs/ocient.py +++ b/superset/db_engine_specs/ocient.py @@ -17,7 +17,8 @@ import re import threading -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Pattern, Set, Tuple +from re import Pattern +from typing import Any, Callable, List, NamedTuple, Optional from flask_babel import gettext as __ from sqlalchemy.engine.reflection import Inspector @@ -178,15 +179,13 @@ def _polygon_to_geo_json( # Sanitization function for column values SanitizeFunc = Callable[[Any], Any] + # Represents a pair of a column index and the sanitization function # to apply to its values. -PlacedSanitizeFunc = NamedTuple( - "PlacedSanitizeFunc", - [ - ("column_index", int), - ("sanitize_func", SanitizeFunc), - ], -) +class PlacedSanitizeFunc(NamedTuple): + column_index: int + sanitize_func: SanitizeFunc + # This map contains functions used to sanitize values for column types # that cannot be processed natively by Superset. @@ -199,7 +198,7 @@ PlacedSanitizeFunc = NamedTuple( try: from pyocient import TypeCodes - _sanitized_ocient_type_codes: Dict[int, SanitizeFunc] = { + _sanitized_ocient_type_codes: dict[int, SanitizeFunc] = { TypeCodes.BINARY: _to_hex, TypeCodes.ST_POINT: _point_to_geo_json, TypeCodes.IP: str, @@ -211,7 +210,7 @@ except ImportError as e: _sanitized_ocient_type_codes = {} -def _find_columns_to_sanitize(cursor: Any) -> List[PlacedSanitizeFunc]: +def _find_columns_to_sanitize(cursor: Any) -> list[PlacedSanitizeFunc]: """ Cleans the column value for consumption by Superset. @@ -238,10 +237,10 @@ class OcientEngineSpec(BaseEngineSpec): # Store mapping of superset Query id -> Ocient ID # These are inserted into the cache when executing the query # They are then removed, either upon cancellation or query completion - query_id_mapping: Dict[str, str] = dict() + query_id_mapping: dict[str, str] = dict() query_id_mapping_lock = threading.Lock() - custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = { + custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { CONNECTION_INVALID_USERNAME_REGEX: ( __('The username "%(username)s" does not exist.'), SupersetErrorType.CONNECTION_INVALID_USERNAME_ERROR, @@ -309,15 +308,15 @@ class OcientEngineSpec(BaseEngineSpec): @classmethod def get_table_names( cls, database: Database, inspector: Inspector, schema: Optional[str] - ) -> Set[str]: + ) -> set[str]: return inspector.get_table_names(schema) @classmethod def fetch_data( cls, cursor: Any, limit: Optional[int] = None - ) -> List[Tuple[Any, ...]]: + ) -> list[tuple[Any, ...]]: try: - rows: List[Tuple[Any, ...]] = super().fetch_data(cursor, limit) + rows: list[tuple[Any, ...]] = super().fetch_data(cursor, limit) except Exception as exception: with OcientEngineSpec.query_id_mapping_lock: del OcientEngineSpec.query_id_mapping[ @@ -329,7 +328,7 @@ class OcientEngineSpec(BaseEngineSpec): if len(rows) > 0 and type(rows[0]).__name__ == "Row": # Peek at the schema to determine which column values, if any, # require sanitization. - columns_to_sanitize: List[PlacedSanitizeFunc] = _find_columns_to_sanitize( + columns_to_sanitize: list[PlacedSanitizeFunc] = _find_columns_to_sanitize( cursor ) @@ -341,7 +340,7 @@ class OcientEngineSpec(BaseEngineSpec): # Use the identity function if the column type doesn't need to be # sanitized. - sanitization_functions: List[SanitizeFunc] = [ + sanitization_functions: list[SanitizeFunc] = [ identity for _ in range(len(cursor.description)) ] for info in columns_to_sanitize: diff --git a/superset/db_engine_specs/oracle.py b/superset/db_engine_specs/oracle.py index 4a219919bb..1199b74406 100644 --- a/superset/db_engine_specs/oracle.py +++ b/superset/db_engine_specs/oracle.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from datetime import datetime -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Optional from sqlalchemy import types @@ -43,7 +43,7 @@ class OracleEngineSpec(BaseEngineSpec): @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) @@ -68,7 +68,7 @@ class OracleEngineSpec(BaseEngineSpec): @classmethod def fetch_data( cls, cursor: Any, limit: Optional[int] = None - ) -> List[Tuple[Any, ...]]: + ) -> list[tuple[Any, ...]]: """ :param cursor: Cursor instance :param limit: Maximum number of rows to be returned by the cursor diff --git a/superset/db_engine_specs/pinot.py b/superset/db_engine_specs/pinot.py index cebdd693a4..bfec8b2947 100644 --- a/superset/db_engine_specs/pinot.py +++ b/superset/db_engine_specs/pinot.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Dict, Optional +from typing import Optional from sqlalchemy.sql.expression import ColumnClause @@ -30,7 +30,7 @@ class PinotEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method allows_alias_in_orderby = False # Pinot does its own conversion below - _time_grain_expressions: Dict[Optional[str], str] = { + _time_grain_expressions: dict[Optional[str], str] = { "PT1S": "1:SECONDS", "PT1M": "1:MINUTES", "PT5M": "5:MINUTES", @@ -45,7 +45,7 @@ class PinotEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method "P1Y": "year", } - _python_to_java_time_patterns: Dict[str, str] = { + _python_to_java_time_patterns: dict[str, str] = { "%Y": "yyyy", "%m": "MM", "%d": "dd", @@ -54,7 +54,7 @@ class PinotEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method "%S": "ss", } - _use_date_trunc_function: Dict[str, bool] = { + _use_date_trunc_function: dict[str, bool] = { "PT1S": False, "PT1M": False, "PT5M": False, diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index e809187af6..2088782f83 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -18,7 +18,8 @@ import json import logging import re from datetime import datetime -from typing import Any, Dict, List, Optional, Pattern, Set, Tuple, TYPE_CHECKING +from re import Pattern +from typing import Any, Optional, TYPE_CHECKING from flask_babel import gettext as __ from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, ENUM, JSON @@ -73,7 +74,7 @@ COLUMN_DOES_NOT_EXIST_REGEX = re.compile( SYNTAX_ERROR_REGEX = re.compile('syntax error at or near "(?P.*?)"') -def parse_options(connect_args: Dict[str, Any]) -> Dict[str, str]: +def parse_options(connect_args: dict[str, Any]) -> dict[str, str]: """ Parse ``options`` from ``connect_args`` into a dictionary. """ @@ -109,7 +110,7 @@ class PostgresBaseEngineSpec(BaseEngineSpec): "P1Y": "DATE_TRUNC('year', {col})", } - custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = { + custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { CONNECTION_INVALID_USERNAME_REGEX: ( __('The username "%(username)s" does not exist.'), SupersetErrorType.CONNECTION_INVALID_USERNAME_ERROR, @@ -169,7 +170,7 @@ class PostgresBaseEngineSpec(BaseEngineSpec): @classmethod def fetch_data( cls, cursor: Any, limit: Optional[int] = None - ) -> List[Tuple[Any, ...]]: + ) -> list[tuple[Any, ...]]: if not cursor.description: return [] return super().fetch_data(cursor, limit) @@ -221,7 +222,7 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin): def get_schema_from_engine_params( cls, sqlalchemy_uri: URL, - connect_args: Dict[str, Any], + connect_args: dict[str, Any], ) -> Optional[str]: """ Return the configured schema. @@ -253,10 +254,10 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin): def adjust_engine_params( cls, uri: URL, - connect_args: Dict[str, Any], + connect_args: dict[str, Any], catalog: Optional[str] = None, schema: Optional[str] = None, - ) -> Tuple[URL, Dict[str, Any]]: + ) -> tuple[URL, dict[str, Any]]: if not schema: return uri, connect_args @@ -269,11 +270,11 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin): return uri, connect_args @classmethod - def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool: + def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool: return True @classmethod - def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]: + def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]: sql = f"EXPLAIN {statement}" cursor.execute(sql) @@ -289,8 +290,8 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin): @classmethod def query_cost_formatter( - cls, raw_cost: List[Dict[str, Any]] - ) -> List[Dict[str, str]]: + cls, raw_cost: list[dict[str, Any]] + ) -> list[dict[str, str]]: return [{k: str(v) for k, v in row.items()} for row in raw_cost] @classmethod @@ -298,7 +299,7 @@ class PostgresEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin): cls, database: "Database", inspector: Inspector, - ) -> List[str]: + ) -> list[str]: """ Return all catalogs. @@ -317,7 +318,7 @@ WHERE datistemplate = false; @classmethod def get_table_names( cls, database: "Database", inspector: PGInspector, schema: Optional[str] - ) -> Set[str]: + ) -> set[str]: """Need to consider foreign tables for PostgreSQL""" return set(inspector.get_table_names(schema)) | set( inspector.get_foreign_table_names(schema) @@ -325,7 +326,7 @@ WHERE datistemplate = false; @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) @@ -337,7 +338,7 @@ WHERE datistemplate = false; return None @staticmethod - def get_extra_params(database: "Database") -> Dict[str, Any]: + def get_extra_params(database: "Database") -> dict[str, Any]: """ For Postgres, the path to a SSL certificate is placed in `connect_args`. diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 82b05e53e3..d5a2ab7605 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -23,19 +23,9 @@ import time from abc import ABCMeta from collections import defaultdict, deque from datetime import datetime +from re import Pattern from textwrap import dedent -from typing import ( - Any, - cast, - Dict, - List, - Optional, - Pattern, - Set, - Tuple, - TYPE_CHECKING, - Union, -) +from typing import Any, cast, Optional, TYPE_CHECKING from urllib import parse import pandas as pd @@ -78,7 +68,7 @@ if TYPE_CHECKING: # need try/catch because pyhive may not be installed try: - from pyhive.presto import Cursor # pylint: disable=unused-import + from pyhive.presto import Cursor except ImportError: pass @@ -107,7 +97,7 @@ CONNECTION_UNKNOWN_DATABASE_ERROR = re.compile( logger = logging.getLogger(__name__) -def get_children(column: ResultSetColumnType) -> List[ResultSetColumnType]: +def get_children(column: ResultSetColumnType) -> list[ResultSetColumnType]: """ Get the children of a complex Presto type (row or array). @@ -276,8 +266,8 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta): @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None - ) -> Optional[str]: + cls, target_type: str, dttm: datetime, db_extra: dict[str, Any] | None = None + ) -> str | None: """ Convert a Python `datetime` object to a SQL expression. :param target_type: The target type of expression @@ -304,10 +294,10 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta): def adjust_engine_params( cls, uri: URL, - connect_args: Dict[str, Any], - catalog: Optional[str] = None, - schema: Optional[str] = None, - ) -> Tuple[URL, Dict[str, Any]]: + connect_args: dict[str, Any], + catalog: str | None = None, + schema: str | None = None, + ) -> tuple[URL, dict[str, Any]]: database = uri.database if schema and database: schema = parse.quote(schema, safe="") @@ -323,8 +313,8 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta): def get_schema_from_engine_params( cls, sqlalchemy_uri: URL, - connect_args: Dict[str, Any], - ) -> Optional[str]: + connect_args: dict[str, Any], + ) -> str | None: """ Return the configured schema. @@ -341,7 +331,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta): return parse.unquote(database.split("/")[1]) @classmethod - def estimate_statement_cost(cls, statement: str, cursor: Any) -> Dict[str, Any]: + def estimate_statement_cost(cls, statement: str, cursor: Any) -> dict[str, Any]: """ Run a SQL query that estimates the cost of a given statement. :param statement: A single SQL statement @@ -369,8 +359,8 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta): @classmethod def query_cost_formatter( - cls, raw_cost: List[Dict[str, Any]] - ) -> List[Dict[str, str]]: + cls, raw_cost: list[dict[str, Any]] + ) -> list[dict[str, str]]: """ Format cost estimate. :param raw_cost: JSON estimate from Trino @@ -401,7 +391,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta): ("networkCost", "Network cost", ""), ] for row in raw_cost: - estimate: Dict[str, float] = row.get("estimate", {}) + estimate: dict[str, float] = row.get("estimate", {}) statement_cost = {} for key, label, suffix in columns: if key in estimate: @@ -412,7 +402,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta): @classmethod @cache_manager.data_cache.memoize() - def get_function_names(cls, database: Database) -> List[str]: + def get_function_names(cls, database: Database) -> list[str]: """ Get a list of function names that are able to be called on the database. Used for SQL Lab autocomplete. @@ -426,12 +416,12 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta): def _partition_query( # pylint: disable=too-many-arguments,too-many-locals,unused-argument cls, table_name: str, - schema: Optional[str], - indexes: List[Dict[str, Any]], + schema: str | None, + indexes: list[dict[str, Any]], database: Database, limit: int = 0, - order_by: Optional[List[Tuple[str, bool]]] = None, - filters: Optional[Dict[Any, Any]] = None, + order_by: list[tuple[str, bool]] | None = None, + filters: dict[Any, Any] | None = None, ) -> str: """ Return a partition query. @@ -449,7 +439,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta): order :param filters: dict of field name and filter value combinations """ - limit_clause = "LIMIT {}".format(limit) if limit else "" + limit_clause = f"LIMIT {limit}" if limit else "" order_by_clause = "" if order_by: l = [] @@ -492,11 +482,11 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta): def where_latest_partition( # pylint: disable=too-many-arguments cls, table_name: str, - schema: Optional[str], + schema: str | None, database: Database, query: Select, - columns: Optional[List[Dict[str, Any]]] = None, - ) -> Optional[Select]: + columns: list[dict[str, Any]] | None = None, + ) -> Select | None: try: col_names, values = cls.latest_partition( table_name, schema, database, show_first=True @@ -525,7 +515,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta): return query @classmethod - def _latest_partition_from_df(cls, df: pd.DataFrame) -> Optional[List[str]]: + def _latest_partition_from_df(cls, df: pd.DataFrame) -> list[str] | None: if not df.empty: return df.to_records(index=False)[0].item() return None @@ -535,10 +525,10 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta): def latest_partition( cls, table_name: str, - schema: Optional[str], + schema: str | None, database: Database, show_first: bool = False, - ) -> Tuple[List[str], Optional[List[str]]]: + ) -> tuple[list[str], list[str] | None]: """Returns col name and the latest (max) partition value for a table :param table_name: the name of the table @@ -589,7 +579,7 @@ class PrestoBaseEngineSpec(BaseEngineSpec, metaclass=ABCMeta): @classmethod def latest_sub_partition( - cls, table_name: str, schema: Optional[str], database: Database, **kwargs: Any + cls, table_name: str, schema: str | None, database: Database, **kwargs: Any ) -> Any: """Returns the latest (max) partition value for a table @@ -652,7 +642,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): engine_name = "Presto" allows_alias_to_source_column = False - custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = { + custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { COLUMN_DOES_NOT_EXIST_REGEX: ( __( 'We can\'t seem to resolve the column "%(column_name)s" at ' @@ -708,16 +698,16 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): } @classmethod - def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool: + def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool: version = extra.get("version") return version is not None and Version(version) >= Version("0.319") @classmethod def update_impersonation_config( cls, - connect_args: Dict[str, Any], + connect_args: dict[str, Any], uri: str, - username: Optional[str], + username: str | None, ) -> None: """ Update a configuration dictionary @@ -741,8 +731,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): cls, database: Database, inspector: Inspector, - schema: Optional[str], - ) -> Set[str]: + schema: str | None, + ) -> set[str]: """ Get all the real table names within the specified schema. @@ -769,8 +759,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): cls, database: Database, inspector: Inspector, - schema: Optional[str], - ) -> Set[str]: + schema: str | None, + ) -> set[str]: """ Get all the view names within the specified schema. @@ -817,7 +807,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): cls, database: Database, inspector: Inspector, - ) -> List[str]: + ) -> list[str]: """ Get all catalogs. """ @@ -826,7 +816,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): @classmethod def _create_column_info( cls, name: str, data_type: types.TypeEngine - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Create column info object :param name: column name @@ -836,7 +826,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): return {"name": name, "type": f"{data_type}"} @classmethod - def _get_full_name(cls, names: List[Tuple[str, str]]) -> str: + def _get_full_name(cls, names: list[tuple[str, str]]) -> str: """ Get the full column name :param names: list of all individual column names @@ -860,7 +850,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): ) @classmethod - def _split_data_type(cls, data_type: str, delimiter: str) -> List[str]: + def _split_data_type(cls, data_type: str, delimiter: str) -> list[str]: """ Split data type based on given delimiter. Do not split the string if the delimiter is enclosed in quotes @@ -869,16 +859,14 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): comma, whitespace) :return: list of strings after breaking it by the delimiter """ - return re.split( - r"{}(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)".format(delimiter), data_type - ) + return re.split(rf"{delimiter}(?=(?:[^\"]*\"[^\"]*\")*[^\"]*$)", data_type) @classmethod def _parse_structural_column( # pylint: disable=too-many-locals cls, parent_column_name: str, parent_data_type: str, - result: List[Dict[str, Any]], + result: list[dict[str, Any]], ) -> None: """ Parse a row or array column @@ -893,7 +881,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): # split on open parenthesis ( to get the structural # data type and its component types data_types = cls._split_data_type(full_data_type, r"\(") - stack: List[Tuple[str, str]] = [] + stack: list[tuple[str, str]] = [] for data_type in data_types: # split on closed parenthesis ) to track which component # types belong to what structural data type @@ -962,8 +950,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): @classmethod def _show_columns( - cls, inspector: Inspector, table_name: str, schema: Optional[str] - ) -> List[ResultRow]: + cls, inspector: Inspector, table_name: str, schema: str | None + ) -> list[ResultRow]: """ Show presto column names :param inspector: object that performs database schema inspection @@ -974,13 +962,13 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): quote = inspector.engine.dialect.identifier_preparer.quote_identifier full_table = quote(table_name) if schema: - full_table = "{}.{}".format(quote(schema), full_table) + full_table = f"{quote(schema)}.{full_table}" return inspector.bind.execute(f"SHOW COLUMNS FROM {full_table}").fetchall() @classmethod def get_columns( - cls, inspector: Inspector, table_name: str, schema: Optional[str] - ) -> List[Dict[str, Any]]: + cls, inspector: Inspector, table_name: str, schema: str | None + ) -> list[dict[str, Any]]: """ Get columns from a Presto data source. This includes handling row and array data types @@ -991,7 +979,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): (i.e. column name and data type) """ columns = cls._show_columns(inspector, table_name, schema) - result: List[Dict[str, Any]] = [] + result: list[dict[str, Any]] = [] for column in columns: # parse column if it is a row or array if is_feature_enabled("PRESTO_EXPAND_DATA") and ( @@ -1031,7 +1019,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): return column_name.startswith('"') and column_name.endswith('"') @classmethod - def _get_fields(cls, cols: List[Dict[str, Any]]) -> List[ColumnClause]: + def _get_fields(cls, cols: list[dict[str, Any]]) -> list[ColumnClause]: """ Format column clauses where names are in quotes and labels are specified :param cols: columns @@ -1053,7 +1041,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): # quote each column name if it is not already quoted for index, col_name in enumerate(col_names): if not cls._is_column_name_quoted(col_name): - col_names[index] = '"{}"'.format(col_name) + col_names[index] = f'"{col_name}"' quoted_col_name = ".".join( col_name if cls._is_column_name_quoted(col_name) else f'"{col_name}"' for col_name in col_names @@ -1069,12 +1057,12 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): database: Database, table_name: str, engine: Engine, - schema: Optional[str] = None, + schema: str | None = None, limit: int = 100, show_cols: bool = False, indent: bool = True, latest_partition: bool = True, - cols: Optional[List[Dict[str, Any]]] = None, + cols: list[dict[str, Any]] | None = None, ) -> str: """ Include selecting properties of row objects. We cannot easily break arrays into @@ -1102,9 +1090,9 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): @classmethod def expand_data( # pylint: disable=too-many-locals - cls, columns: List[ResultSetColumnType], data: List[Dict[Any, Any]] - ) -> Tuple[ - List[ResultSetColumnType], List[Dict[Any, Any]], List[ResultSetColumnType] + cls, columns: list[ResultSetColumnType], data: list[dict[Any, Any]] + ) -> tuple[ + list[ResultSetColumnType], list[dict[Any, Any]], list[ResultSetColumnType] ]: """ We do not immediately display rows and arrays clearly in the data grid. This @@ -1133,7 +1121,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): # process each column, unnesting ARRAY types and # expanding ROW types into new columns to_process = deque((column, 0) for column in columns) - all_columns: List[ResultSetColumnType] = [] + all_columns: list[ResultSetColumnType] = [] expanded_columns = [] current_array_level = None while to_process: @@ -1147,11 +1135,11 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): # added by the first. every time we change a level in the nested arrays # we reinitialize this. if level != current_array_level: - unnested_rows: Dict[int, int] = defaultdict(int) + unnested_rows: dict[int, int] = defaultdict(int) current_array_level = level name = column["name"] - values: Optional[Union[str, List[Any]]] + values: str | list[Any] | None if column["type"] and column["type"].startswith("ARRAY("): # keep processing array children; we append to the right so that @@ -1198,7 +1186,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): for row in data: values = row.get(name) or [] if isinstance(values, str): - values = cast(Optional[List[Any]], destringify(values)) + values = cast(Optional[list[Any]], destringify(values)) row[name] = values for value, col in zip(values or [], expanded): row[col["name"]] = value @@ -1211,8 +1199,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): @classmethod def extra_table_metadata( - cls, database: Database, table_name: str, schema_name: Optional[str] - ) -> Dict[str, Any]: + cls, database: Database, table_name: str, schema_name: str | None + ) -> dict[str, Any]: metadata = {} if indexes := database.get_indexes(table_name, schema_name): @@ -1243,8 +1231,8 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): @classmethod def get_create_view( - cls, database: Database, schema: Optional[str], table: str - ) -> Optional[str]: + cls, database: Database, schema: str | None, table: str + ) -> str | None: """ Return a CREATE VIEW statement, or `None` if not a view. @@ -1267,7 +1255,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): return rows[0][0] @classmethod - def get_tracking_url(cls, cursor: "Cursor") -> Optional[str]: + def get_tracking_url(cls, cursor: Cursor) -> str | None: try: if cursor.last_query_id: # pylint: disable=protected-access, line-too-long @@ -1277,7 +1265,7 @@ class PrestoEngineSpec(PrestoBaseEngineSpec): return None @classmethod - def handle_cursor(cls, cursor: "Cursor", query: Query, session: Session) -> None: + def handle_cursor(cls, cursor: Cursor, query: Query, session: Session) -> None: """Updates progress information""" if tracking_url := cls.get_tracking_url(cursor): query.tracking_url = tracking_url diff --git a/superset/db_engine_specs/redshift.py b/superset/db_engine_specs/redshift.py index 27b749e418..2e746a6349 100644 --- a/superset/db_engine_specs/redshift.py +++ b/superset/db_engine_specs/redshift.py @@ -16,7 +16,8 @@ # under the License. import logging import re -from typing import Any, Dict, Optional, Pattern, Tuple +from re import Pattern +from typing import Any, Optional import pandas as pd from flask_babel import gettext as __ @@ -66,7 +67,7 @@ class RedshiftEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin): encryption_parameters = {"sslmode": "verify-ca"} - custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = { + custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { CONNECTION_ACCESS_DENIED_REGEX: ( __('Either the username "%(username)s" or the password is incorrect.'), SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR, @@ -106,7 +107,7 @@ class RedshiftEngineSpec(PostgresBaseEngineSpec, BasicParametersMixin): database: Database, table: Table, df: pd.DataFrame, - to_sql_kwargs: Dict[str, Any], + to_sql_kwargs: dict[str, Any], ) -> None: """ Upload data from a Pandas DataFrame to a database. diff --git a/superset/db_engine_specs/rockset.py b/superset/db_engine_specs/rockset.py index cc215054be..71adca0b10 100644 --- a/superset/db_engine_specs/rockset.py +++ b/superset/db_engine_specs/rockset.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from datetime import datetime -from typing import Any, Dict, Optional, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING from sqlalchemy import types @@ -51,7 +51,7 @@ class RocksetEngineSpec(BaseEngineSpec): @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) diff --git a/superset/db_engine_specs/snowflake.py b/superset/db_engine_specs/snowflake.py index 69ccf55931..32ade649b0 100644 --- a/superset/db_engine_specs/snowflake.py +++ b/superset/db_engine_specs/snowflake.py @@ -18,7 +18,8 @@ import json import logging import re from datetime import datetime -from typing import Any, Dict, List, Optional, Pattern, Tuple, TYPE_CHECKING +from re import Pattern +from typing import Any, Optional, TYPE_CHECKING from urllib import parse from apispec import APISpec @@ -107,7 +108,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec): "P1Y": "DATE_TRUNC('YEAR', {col})", } - custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = { + custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { OBJECT_DOES_NOT_EXIST_REGEX: ( __("%(object)s does not exist in this database."), SupersetErrorType.OBJECT_DOES_NOT_EXIST_ERROR, @@ -124,13 +125,13 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec): } @staticmethod - def get_extra_params(database: "Database") -> Dict[str, Any]: + def get_extra_params(database: "Database") -> dict[str, Any]: """ Add a user agent to be used in the requests. """ - extra: Dict[str, Any] = BaseEngineSpec.get_extra_params(database) - engine_params: Dict[str, Any] = extra.setdefault("engine_params", {}) - connect_args: Dict[str, Any] = engine_params.setdefault("connect_args", {}) + extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database) + engine_params: dict[str, Any] = extra.setdefault("engine_params", {}) + connect_args: dict[str, Any] = engine_params.setdefault("connect_args", {}) connect_args.setdefault("application", USER_AGENT) @@ -140,10 +141,10 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec): def adjust_engine_params( cls, uri: URL, - connect_args: Dict[str, Any], + connect_args: dict[str, Any], catalog: Optional[str] = None, schema: Optional[str] = None, - ) -> Tuple[URL, Dict[str, Any]]: + ) -> tuple[URL, dict[str, Any]]: database = uri.database if "/" in database: database = database.split("/")[0] @@ -157,7 +158,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec): def get_schema_from_engine_params( cls, sqlalchemy_uri: URL, - connect_args: Dict[str, Any], + connect_args: dict[str, Any], ) -> Optional[str]: """ Return the configured schema. @@ -174,7 +175,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec): cls, database: "Database", inspector: Inspector, - ) -> List[str]: + ) -> list[str]: """ Return all catalogs. @@ -197,7 +198,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec): @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) @@ -261,7 +262,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec): cls, parameters: SnowflakeParametersType, encrypted_extra: Optional[ # pylint: disable=unused-argument - Dict[str, Any] + dict[str, Any] ] = None, ) -> str: return str( @@ -283,7 +284,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec): cls, uri: str, encrypted_extra: Optional[ # pylint: disable=unused-argument - Dict[str, str] + dict[str, str] ] = None, ) -> Any: url = make_url_safe(uri) @@ -300,8 +301,8 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec): @classmethod def validate_parameters( cls, properties: BasicPropertiesType - ) -> List[SupersetError]: - errors: List[SupersetError] = [] + ) -> list[SupersetError]: + errors: list[SupersetError] = [] required = { "warehouse", "username", @@ -346,7 +347,7 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec): @staticmethod def update_params_from_encrypted_extra( database: "Database", - params: Dict[str, Any], + params: dict[str, Any], ) -> None: if not database.encrypted_extra: return diff --git a/superset/db_engine_specs/sqlite.py b/superset/db_engine_specs/sqlite.py index a414143296..767d0a20ad 100644 --- a/superset/db_engine_specs/sqlite.py +++ b/superset/db_engine_specs/sqlite.py @@ -16,7 +16,8 @@ # under the License. import re from datetime import datetime -from typing import Any, Dict, Optional, Pattern, Set, Tuple, TYPE_CHECKING +from re import Pattern +from typing import Any, Optional, TYPE_CHECKING from flask_babel import gettext as __ from sqlalchemy import types @@ -60,7 +61,7 @@ class SqliteEngineSpec(BaseEngineSpec): ), } - custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = { + custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { COLUMN_DOES_NOT_EXIST_REGEX: ( __('We can\'t seem to resolve the column "%(column_name)s"'), SupersetErrorType.COLUMN_DOES_NOT_EXIST_ERROR, @@ -74,7 +75,7 @@ class SqliteEngineSpec(BaseEngineSpec): @classmethod def convert_dttm( - cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: sqla_type = cls.get_sqla_column_type(target_type) if isinstance(sqla_type, (types.String, types.DateTime)): @@ -84,6 +85,6 @@ class SqliteEngineSpec(BaseEngineSpec): @classmethod def get_table_names( cls, database: "Database", inspector: Inspector, schema: Optional[str] - ) -> Set[str]: + ) -> set[str]: """Need to disregard the schema for Sqlite""" return set(inspector.get_table_names()) diff --git a/superset/db_engine_specs/starrocks.py b/superset/db_engine_specs/starrocks.py index f687fdbdb3..63269439af 100644 --- a/superset/db_engine_specs/starrocks.py +++ b/superset/db_engine_specs/starrocks.py @@ -17,7 +17,8 @@ import logging import re -from typing import Any, Dict, List, Optional, Pattern, Tuple, Type +from re import Pattern +from typing import Any, Optional from urllib import parse from flask_babel import gettext as __ @@ -40,11 +41,11 @@ CONNECTION_UNKNOWN_DATABASE_REGEX = re.compile("Unknown database '(?P. logger = logging.getLogger(__name__) -class TINYINT(Integer): # pylint: disable=no-init +class TINYINT(Integer): __visit_name__ = "TINYINT" -class DOUBLE(Numeric): # pylint: disable=no-init +class DOUBLE(Numeric): __visit_name__ = "DOUBLE" @@ -52,7 +53,7 @@ class ARRAY(TypeEngine): # pylint: disable=no-init __visit_name__ = "ARRAY" @property - def python_type(self) -> Optional[Type[List[Any]]]: + def python_type(self) -> Optional[type[list[Any]]]: return list @@ -60,7 +61,7 @@ class MAP(TypeEngine): # pylint: disable=no-init __visit_name__ = "MAP" @property - def python_type(self) -> Optional[Type[Dict[Any, Any]]]: + def python_type(self) -> Optional[type[dict[Any, Any]]]: return dict @@ -68,7 +69,7 @@ class STRUCT(TypeEngine): # pylint: disable=no-init __visit_name__ = "STRUCT" @property - def python_type(self) -> Optional[Type[Any]]: + def python_type(self) -> Optional[type[Any]]: return None @@ -117,7 +118,7 @@ class StarRocksEngineSpec(MySQLEngineSpec): (re.compile(r"^struct.*", re.IGNORECASE), STRUCT(), GenericDataType.STRING), ) - custom_errors: Dict[Pattern[str], Tuple[str, SupersetErrorType, Dict[str, Any]]] = { + custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { CONNECTION_ACCESS_DENIED_REGEX: ( __('Either the username "%(username)s" or the password is incorrect.'), SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR, @@ -134,10 +135,10 @@ class StarRocksEngineSpec(MySQLEngineSpec): def adjust_engine_params( cls, uri: URL, - connect_args: Dict[str, Any], + connect_args: dict[str, Any], catalog: Optional[str] = None, schema: Optional[str] = None, - ) -> Tuple[URL, Dict[str, Any]]: + ) -> tuple[URL, dict[str, Any]]: database = uri.database if schema and database: schema = parse.quote(schema, safe="") @@ -152,9 +153,9 @@ class StarRocksEngineSpec(MySQLEngineSpec): @classmethod def get_columns( cls, inspector: Inspector, table_name: str, schema: Optional[str] - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: columns = cls._show_columns(inspector, table_name, schema) - result: List[Dict[str, Any]] = [] + result: list[dict[str, Any]] = [] for column in columns: column_spec = cls.get_column_spec(column.Type) column_type = column_spec.sqla_type if column_spec else None @@ -174,7 +175,7 @@ class StarRocksEngineSpec(MySQLEngineSpec): @classmethod def _show_columns( cls, inspector: Inspector, table_name: str, schema: Optional[str] - ) -> List[ResultRow]: + ) -> list[ResultRow]: """ Show starrocks column names :param inspector: object that performs database schema inspection @@ -185,13 +186,13 @@ class StarRocksEngineSpec(MySQLEngineSpec): quote = inspector.engine.dialect.identifier_preparer.quote_identifier full_table = quote(table_name) if schema: - full_table = "{}.{}".format(quote(schema), full_table) + full_table = f"{quote(schema)}.{full_table}" return inspector.bind.execute(f"SHOW COLUMNS FROM {full_table}").fetchall() @classmethod def _create_column_info( cls, name: str, data_type: types.TypeEngine - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """ Create column info object :param name: column name @@ -204,7 +205,7 @@ class StarRocksEngineSpec(MySQLEngineSpec): def get_schema_from_engine_params( cls, sqlalchemy_uri: URL, - connect_args: Dict[str, Any], + connect_args: dict[str, Any], ) -> Optional[str]: """ Return the configured schema. diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index 0fa4d05cbc..f05bd67ec3 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -17,7 +17,7 @@ from __future__ import annotations import logging -from typing import Any, Dict, Optional, Type, TYPE_CHECKING +from typing import Any, TYPE_CHECKING import simplejson as json from flask import current_app @@ -53,8 +53,8 @@ class TrinoEngineSpec(PrestoBaseEngineSpec): cls, database: Database, table_name: str, - schema_name: Optional[str], - ) -> Dict[str, Any]: + schema_name: str | None, + ) -> dict[str, Any]: metadata = {} if indexes := database.get_indexes(table_name, schema_name): @@ -68,12 +68,12 @@ class TrinoEngineSpec(PrestoBaseEngineSpec): metadata["partitions"] = { "cols": sorted( list( - set( + { column_name for index in indexes if index.get("name") == "partition" for column_name in index.get("column_names", []) - ) + } ) ), "latest": dict(zip(col_names, latest_parts)), @@ -95,9 +95,9 @@ class TrinoEngineSpec(PrestoBaseEngineSpec): @classmethod def update_impersonation_config( cls, - connect_args: Dict[str, Any], + connect_args: dict[str, Any], uri: str, - username: Optional[str], + username: str | None, ) -> None: """ Update a configuration dictionary @@ -118,7 +118,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec): @classmethod def get_url_for_impersonation( - cls, url: URL, impersonate_user: bool, username: Optional[str] + cls, url: URL, impersonate_user: bool, username: str | None ) -> URL: """ Return a modified URL with the username set. @@ -131,11 +131,11 @@ class TrinoEngineSpec(PrestoBaseEngineSpec): return url @classmethod - def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool: + def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool: return True @classmethod - def get_tracking_url(cls, cursor: Cursor) -> Optional[str]: + def get_tracking_url(cls, cursor: Cursor) -> str | None: try: return cursor.info_uri except AttributeError: @@ -199,7 +199,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec): return True @staticmethod - def get_extra_params(database: Database) -> Dict[str, Any]: + def get_extra_params(database: Database) -> dict[str, Any]: """ Some databases require adding elements to connection parameters, like passing certificates to `extra`. This can be done here. @@ -207,9 +207,9 @@ class TrinoEngineSpec(PrestoBaseEngineSpec): :param database: database instance from which to extract extras :raises CertificateException: If certificate is not valid/unparseable """ - extra: Dict[str, Any] = BaseEngineSpec.get_extra_params(database) - engine_params: Dict[str, Any] = extra.setdefault("engine_params", {}) - connect_args: Dict[str, Any] = engine_params.setdefault("connect_args", {}) + extra: dict[str, Any] = BaseEngineSpec.get_extra_params(database) + engine_params: dict[str, Any] = extra.setdefault("engine_params", {}) + connect_args: dict[str, Any] = engine_params.setdefault("connect_args", {}) connect_args.setdefault("source", USER_AGENT) @@ -222,7 +222,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec): @staticmethod def update_params_from_encrypted_extra( database: Database, - params: Dict[str, Any], + params: dict[str, Any], ) -> None: if not database.encrypted_extra: return @@ -262,7 +262,7 @@ class TrinoEngineSpec(PrestoBaseEngineSpec): raise ex @classmethod - def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]: + def get_dbapi_exception_mapping(cls) -> dict[type[Exception], type[Exception]]: # pylint: disable=import-outside-toplevel from requests import exceptions as requests_exceptions diff --git a/superset/embedded/dao.py b/superset/embedded/dao.py index 957a7242a7..27ca338502 100644 --- a/superset/embedded/dao.py +++ b/superset/embedded/dao.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, List +from typing import Any from superset.dao.base import BaseDAO from superset.extensions import db @@ -31,7 +31,7 @@ class EmbeddedDAO(BaseDAO): id_column_name = "uuid" @staticmethod - def upsert(dashboard: Dashboard, allowed_domains: List[str]) -> EmbeddedDashboard: + def upsert(dashboard: Dashboard, allowed_domains: list[str]) -> EmbeddedDashboard: """ Sets up a dashboard to be embeddable. Upsert is used to preserve the embedded_dashboard uuid across updates. @@ -45,7 +45,7 @@ class EmbeddedDAO(BaseDAO): return embedded @classmethod - def create(cls, properties: Dict[str, Any], commit: bool = True) -> Any: + def create(cls, properties: dict[str, Any], commit: bool = True) -> Any: """ Use EmbeddedDAO.upsert() instead. At least, until we are ok with more than one embedded instance per dashboard. diff --git a/superset/errors.py b/superset/errors.py index 5261848687..6f68f2466c 100644 --- a/superset/errors.py +++ b/superset/errors.py @@ -16,7 +16,7 @@ # under the License. from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, Optional +from typing import Any, Optional from flask_babel import lazy_gettext as _ @@ -204,7 +204,7 @@ class SupersetError: message: str error_type: SupersetErrorType level: ErrorLevel - extra: Optional[Dict[str, Any]] = None + extra: Optional[dict[str, Any]] = None def __post_init__(self) -> None: """ @@ -227,7 +227,7 @@ class SupersetError: } ) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: rv = {"message": self.message, "error_type": self.error_type} if self.extra: rv["extra"] = self.extra # type: ignore diff --git a/superset/examples/bart_lines.py b/superset/examples/bart_lines.py index 5d167b02d0..e18f6e4632 100644 --- a/superset/examples/bart_lines.py +++ b/superset/examples/bart_lines.py @@ -55,7 +55,7 @@ def load_bart_lines(only_metadata: bool = False, force: bool = False) -> None: index=False, ) - print("Creating table {} reference".format(tbl_name)) + print(f"Creating table {tbl_name} reference") table = get_table_connector_registry() tbl = db.session.query(table).filter_by(table_name=tbl_name).first() if not tbl: diff --git a/superset/examples/big_data.py b/superset/examples/big_data.py index 8c0f2e267c..ed738d2c96 100644 --- a/superset/examples/big_data.py +++ b/superset/examples/big_data.py @@ -16,7 +16,6 @@ # under the License. import random import string -from typing import List import sqlalchemy.sql.sqltypes @@ -36,7 +35,7 @@ COLUMN_TYPES = [ def load_big_data() -> None: print("Creating table `wide_table` with 100 columns") - columns: List[ColumnInfo] = [] + columns: list[ColumnInfo] = [] for i in range(100): column: ColumnInfo = { "name": f"col{i}", diff --git a/superset/examples/birth_names.py b/superset/examples/birth_names.py index 8da041550e..45a3b39eb3 100644 --- a/superset/examples/birth_names.py +++ b/superset/examples/birth_names.py @@ -16,7 +16,7 @@ # under the License. import json import textwrap -from typing import Dict, List, Tuple, Union +from typing import Union import pandas as pd from sqlalchemy import DateTime, inspect, String @@ -42,7 +42,7 @@ from .helpers import ( def gen_filter( subject: str, comparator: str, operator: str = "==" -) -> Dict[str, Union[bool, str]]: +) -> dict[str, Union[bool, str]]: return { "clause": "WHERE", "comparator": comparator, @@ -152,7 +152,7 @@ def _add_table_metrics(datasource: SqlaTable) -> None: datasource.metrics = metrics -def create_slices(tbl: SqlaTable) -> Tuple[List[Slice], List[Slice]]: +def create_slices(tbl: SqlaTable) -> tuple[list[Slice], list[Slice]]: metrics = [ { "expressionType": "SIMPLE", @@ -529,7 +529,7 @@ def create_slices(tbl: SqlaTable) -> Tuple[List[Slice], List[Slice]]: return slices, misc_slices -def create_dashboard(slices: List[Slice]) -> Dashboard: +def create_dashboard(slices: list[Slice]) -> Dashboard: print("Creating a dashboard") dash = db.session.query(Dashboard).filter_by(slug="births").first() if not dash: diff --git a/superset/examples/countries.py b/superset/examples/countries.py index 8f1d5466ae..2ea12baae7 100644 --- a/superset/examples/countries.py +++ b/superset/examples/countries.py @@ -16,9 +16,9 @@ # under the License. """This module contains data related to countries and is used for geo mapping""" # pylint: disable=too-many-lines -from typing import Any, Dict, List, Optional +from typing import Any, Optional -countries: List[Dict[str, Any]] = [ +countries: list[dict[str, Any]] = [ { "name": "Angola", "area": 1246700, @@ -2491,7 +2491,7 @@ countries: List[Dict[str, Any]] = [ }, ] -all_lookups: Dict[str, Dict[str, Dict[str, Any]]] = {} +all_lookups: dict[str, dict[str, dict[str, Any]]] = {} lookups = ["cioc", "cca2", "cca3", "name"] for lookup in lookups: all_lookups[lookup] = {} @@ -2499,7 +2499,7 @@ for lookup in lookups: all_lookups[lookup][country[lookup].lower()] = country -def get(field: str, symbol: str) -> Optional[Dict[str, Any]]: +def get(field: str, symbol: str) -> Optional[dict[str, Any]]: """ Get country data based on a standard code and a symbol """ diff --git a/superset/examples/helpers.py b/superset/examples/helpers.py index e26e05e497..9f893f1ccc 100644 --- a/superset/examples/helpers.py +++ b/superset/examples/helpers.py @@ -17,7 +17,7 @@ """Loads datasets, dashboards and slices in a new superset instance""" import json import os -from typing import Any, Dict, List, Set +from typing import Any from superset import app, db from superset.connectors.sqla.models import SqlaTable @@ -25,7 +25,7 @@ from superset.models.slice import Slice BASE_URL = "https://github.com/apache-superset/examples-data/blob/master/" -misc_dash_slices: Set[str] = set() # slices assembled in a 'Misc Chart' dashboard +misc_dash_slices: set[str] = set() # slices assembled in a 'Misc Chart' dashboard def get_table_connector_registry() -> Any: @@ -36,7 +36,7 @@ def get_examples_folder() -> str: return os.path.join(app.config["BASE_DIR"], "examples") -def update_slice_ids(pos: Dict[Any, Any]) -> List[Slice]: +def update_slice_ids(pos: dict[Any, Any]) -> list[Slice]: """Update slice ids in position_json and return the slices found.""" slice_components = [ component @@ -44,7 +44,7 @@ def update_slice_ids(pos: Dict[Any, Any]) -> List[Slice]: if isinstance(component, dict) and component.get("type") == "CHART" ] slices = {} - for name in set(component["meta"]["sliceName"] for component in slice_components): + for name in {component["meta"]["sliceName"] for component in slice_components}: slc = db.session.query(Slice).filter_by(slice_name=name).first() if slc: slices[name] = slc @@ -64,7 +64,7 @@ def merge_slice(slc: Slice) -> None: db.session.commit() -def get_slice_json(defaults: Dict[Any, Any], **kwargs: Any) -> str: +def get_slice_json(defaults: dict[Any, Any], **kwargs: Any) -> str: defaults_copy = defaults.copy() defaults_copy.update(kwargs) return json.dumps(defaults_copy, indent=4, sort_keys=True) diff --git a/superset/examples/multiformat_time_series.py b/superset/examples/multiformat_time_series.py index de9630ef58..6bad2a7ac2 100644 --- a/superset/examples/multiformat_time_series.py +++ b/superset/examples/multiformat_time_series.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Dict, Optional, Tuple +from typing import Optional import pandas as pd from sqlalchemy import BigInteger, Date, DateTime, inspect, String @@ -85,7 +85,7 @@ def load_multiformat_time_series( # pylint: disable=too-many-locals obj.main_dttm_col = "ds" obj.database = database obj.filter_select_enabled = True - dttm_and_expr_dict: Dict[str, Tuple[Optional[str], None]] = { + dttm_and_expr_dict: dict[str, tuple[Optional[str], None]] = { "ds": (None, None), "ds2": (None, None), "epoch_s": ("epoch_s", None), diff --git a/superset/examples/paris.py b/superset/examples/paris.py index a54a3706b1..1180c428fe 100644 --- a/superset/examples/paris.py +++ b/superset/examples/paris.py @@ -52,7 +52,7 @@ def load_paris_iris_geojson(only_metadata: bool = False, force: bool = False) -> index=False, ) - print("Creating table {} reference".format(tbl_name)) + print(f"Creating table {tbl_name} reference") table = get_table_connector_registry() tbl = db.session.query(table).filter_by(table_name=tbl_name).first() if not tbl: diff --git a/superset/examples/sf_population_polygons.py b/superset/examples/sf_population_polygons.py index 6011b82b09..76c039afb8 100644 --- a/superset/examples/sf_population_polygons.py +++ b/superset/examples/sf_population_polygons.py @@ -54,7 +54,7 @@ def load_sf_population_polygons( index=False, ) - print("Creating table {} reference".format(tbl_name)) + print(f"Creating table {tbl_name} reference") table = get_table_connector_registry() tbl = db.session.query(table).filter_by(table_name=tbl_name).first() if not tbl: diff --git a/superset/examples/supported_charts_dashboard.py b/superset/examples/supported_charts_dashboard.py index e6d7557deb..bce50c854e 100644 --- a/superset/examples/supported_charts_dashboard.py +++ b/superset/examples/supported_charts_dashboard.py @@ -19,7 +19,6 @@ import json import textwrap -from typing import List from sqlalchemy import inspect @@ -40,7 +39,7 @@ from .helpers import ( DASH_SLUG = "supported_charts_dash" -def create_slices(tbl: SqlaTable) -> List[Slice]: +def create_slices(tbl: SqlaTable) -> list[Slice]: slice_kwargs = { "datasource_id": tbl.id, "datasource_type": DatasourceType.TABLE, diff --git a/superset/examples/utils.py b/superset/examples/utils.py index 8c2cfea23c..52d58e3e4a 100644 --- a/superset/examples/utils.py +++ b/superset/examples/utils.py @@ -17,7 +17,7 @@ import logging import re from pathlib import Path -from typing import Any, Dict +from typing import Any import yaml from pkg_resources import resource_isdir, resource_listdir, resource_stream @@ -42,13 +42,13 @@ def load_examples_from_configs( command.run() -def load_contents(load_test_data: bool = False) -> Dict[str, Any]: +def load_contents(load_test_data: bool = False) -> dict[str, Any]: """Traverse configs directory and load contents""" root = Path("examples/configs") resource_names = resource_listdir("superset", str(root)) queue = [root / resource_name for resource_name in resource_names] - contents: Dict[Path, str] = {} + contents: dict[Path, str] = {} while queue: path_name = queue.pop() test_re = re.compile(r"\.test\.|metadata\.yaml$") @@ -74,7 +74,7 @@ def load_configs_from_directory( """ Load all the examples from a given directory. """ - contents: Dict[str, str] = {} + contents: dict[str, str] = {} queue = [root] while queue: path_name = queue.pop() diff --git a/superset/examples/world_bank.py b/superset/examples/world_bank.py index 2972188e02..9f9f6bb700 100644 --- a/superset/examples/world_bank.py +++ b/superset/examples/world_bank.py @@ -17,7 +17,6 @@ """Loads datasets, dashboards and slices in a new superset instance""" import json import os -from typing import List import pandas as pd from sqlalchemy import DateTime, inspect, String @@ -139,7 +138,7 @@ def load_world_bank_health_n_pop( # pylint: disable=too-many-locals, too-many-s db.session.commit() -def create_slices(tbl: BaseDatasource) -> List[Slice]: +def create_slices(tbl: BaseDatasource) -> list[Slice]: metric = "sum__SP_POP_TOTL" metrics = ["sum__SP_POP_TOTL"] secondary_metric = { diff --git a/superset/exceptions.py b/superset/exceptions.py index 32b06203cd..018d1b6dfb 100644 --- a/superset/exceptions.py +++ b/superset/exceptions.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from collections import defaultdict -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask_babel import gettext as _ from marshmallow import ValidationError @@ -47,7 +47,7 @@ class SupersetException(Exception): def error_type(self) -> Optional[SupersetErrorType]: return self._error_type - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: rv = {} if hasattr(self, "message"): rv["message"] = self.message @@ -67,7 +67,7 @@ class SupersetErrorException(SupersetException): if status is not None: self.status = status - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return self.error.to_dict() @@ -94,7 +94,7 @@ class SupersetErrorFromParamsException(SupersetErrorException): error_type: SupersetErrorType, message: str, level: ErrorLevel, - extra: Optional[Dict[str, Any]] = None, + extra: Optional[dict[str, Any]] = None, ) -> None: super().__init__( SupersetError( @@ -107,7 +107,7 @@ class SupersetErrorsException(SupersetException): """Exceptions with multiple SupersetErrorType associated with them""" def __init__( - self, errors: List[SupersetError], status: Optional[int] = None + self, errors: list[SupersetError], status: Optional[int] = None ) -> None: super().__init__(str(errors)) self.errors = errors @@ -119,7 +119,7 @@ class SupersetSyntaxErrorException(SupersetErrorsException): status = 422 error_type = SupersetErrorType.SYNTAX_ERROR - def __init__(self, errors: List[SupersetError]) -> None: + def __init__(self, errors: list[SupersetError]) -> None: super().__init__(errors) @@ -134,7 +134,7 @@ class SupersetGenericDBErrorException(SupersetErrorFromParamsException): self, message: str, level: ErrorLevel = ErrorLevel.ERROR, - extra: Optional[Dict[str, Any]] = None, + extra: Optional[dict[str, Any]] = None, ) -> None: super().__init__( SupersetErrorType.GENERIC_DB_ENGINE_ERROR, @@ -152,7 +152,7 @@ class SupersetTemplateParamsErrorException(SupersetErrorFromParamsException): message: str, error: SupersetErrorType, level: ErrorLevel = ErrorLevel.ERROR, - extra: Optional[Dict[str, Any]] = None, + extra: Optional[dict[str, Any]] = None, ) -> None: super().__init__( error, @@ -166,7 +166,7 @@ class SupersetSecurityException(SupersetErrorException): status = 403 def __init__( - self, error: SupersetError, payload: Optional[Dict[str, Any]] = None + self, error: SupersetError, payload: Optional[dict[str, Any]] = None ) -> None: super().__init__(error) self.payload = payload diff --git a/superset/explore/commands/get.py b/superset/explore/commands/get.py index fb690a9d75..490d198360 100644 --- a/superset/explore/commands/get.py +++ b/superset/explore/commands/get.py @@ -16,7 +16,7 @@ # under the License. import logging from abc import ABC -from typing import Any, cast, Dict, Optional +from typing import Any, cast, Optional import simplejson as json from flask import current_app, request @@ -60,7 +60,7 @@ class GetExploreCommand(BaseCommand, ABC): self._slice_id = params.slice_id # pylint: disable=too-many-locals,too-many-branches,too-many-statements - def run(self) -> Optional[Dict[str, Any]]: + def run(self) -> Optional[dict[str, Any]]: initial_form_data = {} if self._permalink_key is not None: @@ -147,7 +147,7 @@ class GetExploreCommand(BaseCommand, ABC): utils.merge_request_params(form_data, request.args) # TODO: this is a dummy placeholder - should be refactored to being just `None` - datasource_data: Dict[str, Any] = { + datasource_data: dict[str, Any] = { "type": self._datasource_type, "name": datasource_name, "columns": [], diff --git a/superset/explore/permalink/commands/create.py b/superset/explore/permalink/commands/create.py index 90e64f6df7..97a8bcbf09 100644 --- a/superset/explore/permalink/commands/create.py +++ b/superset/explore/permalink/commands/create.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, Optional +from typing import Any, Optional from sqlalchemy.exc import SQLAlchemyError @@ -31,7 +31,7 @@ logger = logging.getLogger(__name__) class CreateExplorePermalinkCommand(BaseExplorePermalinkCommand): - def __init__(self, state: Dict[str, Any]): + def __init__(self, state: dict[str, Any]): self.chart_id: Optional[int] = state["formData"].get("slice_id") self.datasource: str = state["formData"]["datasource"] self.state = state diff --git a/superset/explore/permalink/types.py b/superset/explore/permalink/types.py index 393f0ed8d5..7eb4a7cb6b 100644 --- a/superset/explore/permalink/types.py +++ b/superset/explore/permalink/types.py @@ -14,12 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Optional, Tuple, TypedDict +from typing import Any, Optional, TypedDict class ExplorePermalinkState(TypedDict, total=False): - formData: Dict[str, Any] - urlParams: Optional[List[Tuple[str, str]]] + formData: dict[str, Any] + urlParams: Optional[list[tuple[str, str]]] class ExplorePermalinkValue(TypedDict): diff --git a/superset/extensions/__init__.py b/superset/extensions/__init__.py index f633385972..c2e84f700f 100644 --- a/superset/extensions/__init__.py +++ b/superset/extensions/__init__.py @@ -16,7 +16,7 @@ # under the License. import json import os -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Optional import celery from cachelib.base import BaseCache @@ -58,7 +58,7 @@ class ResultsBackendManager: class UIManifestProcessor: def __init__(self, app_dir: str) -> None: self.app: Optional[Flask] = None - self.manifest: Dict[str, Dict[str, List[str]]] = {} + self.manifest: dict[str, dict[str, list[str]]] = {} self.manifest_file = f"{app_dir}/static/assets/manifest.json" def init_app(self, app: Flask) -> None: @@ -70,10 +70,10 @@ class UIManifestProcessor: def register_processor(self, app: Flask) -> None: app.template_context_processors[None].append(self.get_manifest) - def get_manifest(self) -> Dict[str, Callable[[str], List[str]]]: + def get_manifest(self) -> dict[str, Callable[[str], list[str]]]: loaded_chunks = set() - def get_files(bundle: str, asset_type: str = "js") -> List[str]: + def get_files(bundle: str, asset_type: str = "js") -> list[str]: files = self.get_manifest_files(bundle, asset_type) filtered_files = [f for f in files if f not in loaded_chunks] for f in filtered_files: @@ -88,7 +88,7 @@ class UIManifestProcessor: def parse_manifest_json(self) -> None: try: - with open(self.manifest_file, "r") as f: + with open(self.manifest_file) as f: # the manifest includes non-entry files we only need entries in # templates full_manifest = json.load(f) @@ -96,7 +96,7 @@ class UIManifestProcessor: except Exception: # pylint: disable=broad-except pass - def get_manifest_files(self, bundle: str, asset_type: str) -> List[str]: + def get_manifest_files(self, bundle: str, asset_type: str) -> list[str]: if self.app and self.app.debug: self.parse_manifest_json() return self.manifest.get(bundle, {}).get(asset_type, []) @@ -117,7 +117,7 @@ cache_manager = CacheManager() celery_app = celery.Celery() csrf = CSRFProtect() db = SQLA() -_event_logger: Dict[str, Any] = {} +_event_logger: dict[str, Any] = {} encrypted_field_factory = EncryptedFieldFactory() event_logger = LocalProxy(lambda: _event_logger.get("event_logger")) feature_flag_manager = FeatureFlagManager() diff --git a/superset/extensions/metastore_cache.py b/superset/extensions/metastore_cache.py index f69276c908..6e928c0d5f 100644 --- a/superset/extensions/metastore_cache.py +++ b/superset/extensions/metastore_cache.py @@ -16,7 +16,7 @@ # under the License. from datetime import datetime, timedelta -from typing import Any, Dict, List, Optional +from typing import Any, Optional from uuid import UUID, uuid3 from flask import Flask @@ -37,7 +37,7 @@ class SupersetMetastoreCache(BaseCache): @classmethod def factory( - cls, app: Flask, config: Dict[str, Any], args: List[Any], kwargs: Dict[str, Any] + cls, app: Flask, config: dict[str, Any], args: list[Any], kwargs: dict[str, Any] ) -> BaseCache: seed = config.get("CACHE_KEY_PREFIX", "") kwargs["namespace"] = get_uuid_namespace(seed) diff --git a/superset/forms.py b/superset/forms.py index c9b29dfcd0..f1e220ba95 100644 --- a/superset/forms.py +++ b/superset/forms.py @@ -16,7 +16,7 @@ # under the License. """Contains the logic to create cohesive forms on the explore view""" import json -from typing import Any, List, Optional +from typing import Any, Optional from flask_appbuilder.fieldwidgets import BS3TextFieldWidget from wtforms import Field @@ -24,12 +24,12 @@ from wtforms import Field class JsonListField(Field): widget = BS3TextFieldWidget() - data: List[str] = [] + data: list[str] = [] def _value(self) -> str: return json.dumps(self.data) - def process_formdata(self, valuelist: List[str]) -> None: + def process_formdata(self, valuelist: list[str]) -> None: if valuelist and valuelist[0]: self.data = json.loads(valuelist[0]) else: @@ -38,7 +38,7 @@ class JsonListField(Field): class CommaSeparatedListField(Field): widget = BS3TextFieldWidget() - data: List[str] = [] + data: list[str] = [] def _value(self) -> str: if self.data: @@ -46,14 +46,14 @@ class CommaSeparatedListField(Field): return "" - def process_formdata(self, valuelist: List[str]) -> None: + def process_formdata(self, valuelist: list[str]) -> None: if valuelist: self.data = [x.strip() for x in valuelist[0].split(",")] else: self.data = [] -def filter_not_empty_values(values: Optional[List[Any]]) -> Optional[List[Any]]: +def filter_not_empty_values(values: Optional[list[Any]]) -> Optional[list[Any]]: """Returns a list of non empty values or None""" if not values: return None diff --git a/superset/initialization/__init__.py b/superset/initialization/__init__.py index c489cc323c..bbe25f498b 100644 --- a/superset/initialization/__init__.py +++ b/superset/initialization/__init__.py @@ -19,7 +19,7 @@ from __future__ import annotations import logging import os import sys -from typing import Any, Callable, Dict, TYPE_CHECKING +from typing import Any, Callable, TYPE_CHECKING import wtforms_json from deprecation import deprecated @@ -68,7 +68,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods self.superset_app = app self.config = app.config - self.manifest: Dict[Any, Any] = {} + self.manifest: dict[Any, Any] = {} @deprecated(details="use self.superset_app instead of self.flask_app") # type: ignore @property @@ -597,7 +597,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods self.app = app def __call__( - self, environ: Dict[str, Any], start_response: Callable[..., Any] + self, environ: dict[str, Any], start_response: Callable[..., Any] ) -> Any: # Setting wsgi.input_terminated tells werkzeug.wsgi to ignore # content-length and read the stream till the end. diff --git a/superset/jinja_context.py b/superset/jinja_context.py index 360c4fc1f4..f096b65cd1 100644 --- a/superset/jinja_context.py +++ b/superset/jinja_context.py @@ -18,17 +18,7 @@ import json import re from functools import lru_cache, partial -from typing import ( - Any, - Callable, - cast, - Dict, - List, - Optional, - Tuple, - TYPE_CHECKING, - Union, -) +from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union from flask import current_app, g, has_request_context, request from flask_babel import gettext as _ @@ -71,14 +61,14 @@ COLLECTION_TYPES = ("list", "dict", "tuple", "set") @lru_cache(maxsize=LRU_CACHE_MAX_SIZE) -def context_addons() -> Dict[str, Any]: +def context_addons() -> dict[str, Any]: return current_app.config.get("JINJA_CONTEXT_ADDONS", {}) class Filter(TypedDict): op: str # pylint: disable=C0103 col: str - val: Union[None, Any, List[Any]] + val: Union[None, Any, list[Any]] class ExtraCache: @@ -100,9 +90,9 @@ class ExtraCache: def __init__( self, - extra_cache_keys: Optional[List[Any]] = None, - applied_filters: Optional[List[str]] = None, - removed_filters: Optional[List[str]] = None, + extra_cache_keys: Optional[list[Any]] = None, + applied_filters: Optional[list[str]] = None, + removed_filters: Optional[list[str]] = None, dialect: Optional[Dialect] = None, ): self.extra_cache_keys = extra_cache_keys @@ -206,7 +196,7 @@ class ExtraCache: def filter_values( self, column: str, default: Optional[str] = None, remove_filter: bool = False - ) -> List[Any]: + ) -> list[Any]: """Gets a values for a particular filter as a list This is useful if: @@ -230,7 +220,7 @@ class ExtraCache: only apply to the inner query :return: returns a list of filter values """ - return_val: List[Any] = [] + return_val: list[Any] = [] filters = self.get_filters(column, remove_filter) for flt in filters: val = flt.get("val") @@ -245,7 +235,7 @@ class ExtraCache: return return_val - def get_filters(self, column: str, remove_filter: bool = False) -> List[Filter]: + def get_filters(self, column: str, remove_filter: bool = False) -> list[Filter]: """Get the filters applied to the given column. In addition to returning values like the filter_values function the get_filters function returns the operator specified in the explorer UI. @@ -316,10 +306,10 @@ class ExtraCache: convert_legacy_filters_into_adhoc(form_data) merge_extra_filters(form_data) - filters: List[Filter] = [] + filters: list[Filter] = [] for flt in form_data.get("adhoc_filters", []): - val: Union[Any, List[Any]] = flt.get("comparator") + val: Union[Any, list[Any]] = flt.get("comparator") op: str = flt["operator"].upper() if flt.get("operator") else None # fltOpName: str = flt.get("filterOptionName") if ( @@ -370,7 +360,7 @@ def safe_proxy(func: Callable[..., Any], *args: Any, **kwargs: Any) -> Any: return return_value -def validate_context_types(context: Dict[str, Any]) -> Dict[str, Any]: +def validate_context_types(context: dict[str, Any]) -> dict[str, Any]: for key in context: arg_type = type(context[key]).__name__ if arg_type not in ALLOWED_TYPES and key not in context_addons(): @@ -395,8 +385,8 @@ def validate_context_types(context: Dict[str, Any]) -> Dict[str, Any]: def validate_template_context( - engine: Optional[str], context: Dict[str, Any] -) -> Dict[str, Any]: + engine: Optional[str], context: dict[str, Any] +) -> dict[str, Any]: if engine and engine in context: # validate engine context separately to allow for engine-specific methods engine_context = validate_context_types(context.pop(engine)) @@ -407,7 +397,7 @@ def validate_template_context( return validate_context_types(context) -def where_in(values: List[Any], mark: str = "'") -> str: +def where_in(values: list[Any], mark: str = "'") -> str: """ Given a list of values, build a parenthesis list suitable for an IN expression. @@ -439,9 +429,9 @@ class BaseTemplateProcessor: database: "Database", query: Optional["Query"] = None, table: Optional["SqlaTable"] = None, - extra_cache_keys: Optional[List[Any]] = None, - removed_filters: Optional[List[str]] = None, - applied_filters: Optional[List[str]] = None, + extra_cache_keys: Optional[list[Any]] = None, + removed_filters: Optional[list[str]] = None, + applied_filters: Optional[list[str]] = None, **kwargs: Any, ) -> None: self._database = database @@ -454,7 +444,7 @@ class BaseTemplateProcessor: self._extra_cache_keys = extra_cache_keys self._applied_filters = applied_filters self._removed_filters = removed_filters - self._context: Dict[str, Any] = {} + self._context: dict[str, Any] = {} self._env = SandboxedEnvironment(undefined=DebugUndefined) self.set_context(**kwargs) @@ -530,7 +520,7 @@ class PrestoTemplateProcessor(JinjaTemplateProcessor): @staticmethod def _schema_table( table_name: str, schema: Optional[str] - ) -> Tuple[str, Optional[str]]: + ) -> tuple[str, Optional[str]]: if "." in table_name: schema, table_name = table_name.split(".") return table_name, schema @@ -547,7 +537,7 @@ class PrestoTemplateProcessor(JinjaTemplateProcessor): latest_partitions = self.latest_partitions(table_name) return latest_partitions[0] if latest_partitions else None - def latest_partitions(self, table_name: str) -> Optional[List[str]]: + def latest_partitions(self, table_name: str) -> Optional[list[str]]: """ Gets the array of all latest partitions @@ -603,7 +593,7 @@ DEFAULT_PROCESSORS = { @lru_cache(maxsize=LRU_CACHE_MAX_SIZE) -def get_template_processors() -> Dict[str, Any]: +def get_template_processors() -> dict[str, Any]: processors = current_app.config.get("CUSTOM_TEMPLATE_PROCESSORS", {}) for engine, processor in DEFAULT_PROCESSORS.items(): # do not overwrite engine-specific CUSTOM_TEMPLATE_PROCESSORS @@ -631,7 +621,7 @@ def get_template_processor( def dataset_macro( dataset_id: int, include_metrics: bool = False, - columns: Optional[List[str]] = None, + columns: Optional[list[str]] = None, ) -> str: """ Given a dataset ID, return the SQL that represents it. diff --git a/superset/key_value/types.py b/superset/key_value/types.py index fb9c31899f..b2a47336c3 100644 --- a/superset/key_value/types.py +++ b/superset/key_value/types.py @@ -21,7 +21,7 @@ import pickle from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum -from typing import Any, Optional, TypedDict +from typing import Any, TypedDict from uuid import UUID from marshmallow import Schema, ValidationError @@ -34,14 +34,14 @@ from superset.key_value.exceptions import ( @dataclass class Key: - id: Optional[int] - uuid: Optional[UUID] + id: int | None + uuid: UUID | None class KeyValueFilter(TypedDict, total=False): resource: str - id: Optional[int] - uuid: Optional[UUID] + id: int | None + uuid: UUID | None class KeyValueResource(str, Enum): diff --git a/superset/key_value/utils.py b/superset/key_value/utils.py index 2468618a81..6b487c278c 100644 --- a/superset/key_value/utils.py +++ b/superset/key_value/utils.py @@ -18,7 +18,7 @@ from __future__ import annotations from hashlib import md5 from secrets import token_urlsafe -from typing import Any, Union +from typing import Any from uuid import UUID, uuid3 import hashids @@ -35,7 +35,7 @@ def random_key() -> str: return token_urlsafe(48) -def get_filter(resource: KeyValueResource, key: Union[int, UUID]) -> KeyValueFilter: +def get_filter(resource: KeyValueResource, key: int | UUID) -> KeyValueFilter: try: filter_: KeyValueFilter = {"resource": resource.value} if isinstance(key, UUID): diff --git a/superset/legacy.py b/superset/legacy.py index 168b9c0b60..03c1eff7dd 100644 --- a/superset/legacy.py +++ b/superset/legacy.py @@ -15,10 +15,10 @@ # specific language governing permissions and limitations # under the License. """Code related with dealing with legacy / change management""" -from typing import Any, Dict +from typing import Any -def update_time_range(form_data: Dict[str, Any]) -> None: +def update_time_range(form_data: dict[str, Any]) -> None: """Move since and until to time_range.""" if "since" in form_data or "until" in form_data: form_data["time_range"] = "{} : {}".format( diff --git a/superset/migrations/env.py b/superset/migrations/env.py index e3779bb65b..130fb367fb 100755 --- a/superset/migrations/env.py +++ b/superset/migrations/env.py @@ -17,7 +17,6 @@ import logging import urllib.parse from logging.config import fileConfig -from typing import List from alembic import context from alembic.operations.ops import MigrationScript @@ -85,7 +84,7 @@ def run_migrations_online() -> None: # when there are no changes to the schema # reference: https://alembic.sqlalchemy.org/en/latest/cookbook.html def process_revision_directives( # pylint: disable=redefined-outer-name, unused-argument - context: MigrationContext, revision: str, directives: List[MigrationScript] + context: MigrationContext, revision: str, directives: list[MigrationScript] ) -> None: if getattr(config.cmd_opts, "autogenerate", False): script = directives[0] diff --git a/superset/migrations/shared/migrate_viz/base.py b/superset/migrations/shared/migrate_viz/base.py index 5ea23551ea..d3b2efa7a0 100644 --- a/superset/migrations/shared/migrate_viz/base.py +++ b/superset/migrations/shared/migrate_viz/base.py @@ -18,7 +18,7 @@ from __future__ import annotations import copy import json -from typing import Any, Dict, Set +from typing import Any from alembic import op from sqlalchemy import and_, Column, Integer, String, Text @@ -44,8 +44,8 @@ FORM_DATA_BAK_FIELD_NAME = "form_data_bak" class MigrateViz: - remove_keys: Set[str] = set() - rename_keys: Dict[str, str] = {} + remove_keys: set[str] = set() + rename_keys: dict[str, str] = {} source_viz_type: str target_viz_type: str has_x_axis_control: bool = False @@ -85,7 +85,7 @@ class MigrateViz: def _post_action(self) -> None: """Some actions after migrate""" - def _migrate_temporal_filter(self, rv_data: Dict[str, Any]) -> None: + def _migrate_temporal_filter(self, rv_data: dict[str, Any]) -> None: """Adds a temporal filter.""" granularity_sqla = rv_data.pop("granularity_sqla", None) time_range = rv_data.pop("time_range", None) or conf.get("DEFAULT_TIME_FILTER") diff --git a/superset/migrations/shared/security_converge.py b/superset/migrations/shared/security_converge.py index 19caa3932b..9b1730a2a1 100644 --- a/superset/migrations/shared/security_converge.py +++ b/superset/migrations/shared/security_converge.py @@ -16,7 +16,6 @@ # under the License. import logging from dataclasses import dataclass -from typing import Dict, List, Tuple from sqlalchemy import ( Column, @@ -41,7 +40,7 @@ class Pvm: permission: str -PvmMigrationMapType = Dict[Pvm, Tuple[Pvm, ...]] +PvmMigrationMapType = dict[Pvm, tuple[Pvm, ...]] # Partial freeze of the current metadata db schema @@ -162,8 +161,8 @@ def _find_pvm(session: Session, view_name: str, permission_name: str) -> Permiss def add_pvms( - session: Session, pvm_data: Dict[str, Tuple[str, ...]], commit: bool = False -) -> List[PermissionView]: + session: Session, pvm_data: dict[str, tuple[str, ...]], commit: bool = False +) -> list[PermissionView]: """ Checks if exists and adds new Permissions, Views and PermissionView's """ @@ -181,7 +180,7 @@ def add_pvms( def _delete_old_permissions( - session: Session, pvm_map: Dict[PermissionView, List[PermissionView]] + session: Session, pvm_map: dict[PermissionView, list[PermissionView]] ) -> None: """ Delete old permissions: @@ -222,7 +221,7 @@ def migrate_roles( Migrates all existing roles that have the permissions to be migrated """ # Collect a map of PermissionView objects for migration - pvm_map: Dict[PermissionView, List[PermissionView]] = {} + pvm_map: dict[PermissionView, list[PermissionView]] = {} for old_pvm_key, new_pvms_ in pvm_key_map.items(): old_pvm = _find_pvm(session, old_pvm_key.view, old_pvm_key.permission) if old_pvm: @@ -252,8 +251,8 @@ def migrate_roles( session.commit() -def get_reversed_new_pvms(pvm_map: PvmMigrationMapType) -> Dict[str, Tuple[str, ...]]: - reversed_pvms: Dict[str, Tuple[str, ...]] = {} +def get_reversed_new_pvms(pvm_map: PvmMigrationMapType) -> dict[str, tuple[str, ...]]: + reversed_pvms: dict[str, tuple[str, ...]] = {} for old_pvm, new_pvms in pvm_map.items(): if old_pvm.view not in reversed_pvms: reversed_pvms[old_pvm.view] = (old_pvm.permission,) diff --git a/superset/migrations/shared/utils.py b/superset/migrations/shared/utils.py index e05b1d357f..32e7dc1a39 100644 --- a/superset/migrations/shared/utils.py +++ b/superset/migrations/shared/utils.py @@ -18,7 +18,8 @@ import json import logging import os import time -from typing import Any, Callable, Dict, Iterator, Optional, Union +from collections.abc import Iterator +from typing import Any, Callable, Optional, Union from uuid import uuid4 from alembic import op @@ -127,7 +128,7 @@ def paginated_update( print_page_progress(processed, total) -def try_load_json(data: Optional[str]) -> Dict[str, Any]: +def try_load_json(data: Optional[str]) -> dict[str, Any]: try: return data and json.loads(data) or {} except json.decoder.JSONDecodeError: diff --git a/superset/migrations/versions/2017-01-24_12-31_db0c65b146bd_update_slice_model_json.py b/superset/migrations/versions/2017-01-24_12-31_db0c65b146bd_update_slice_model_json.py index 56d5f887b3..1f3dbab636 100644 --- a/superset/migrations/versions/2017-01-24_12-31_db0c65b146bd_update_slice_model_json.py +++ b/superset/migrations/versions/2017-01-24_12-31_db0c65b146bd_update_slice_model_json.py @@ -59,7 +59,7 @@ def upgrade(): slc.params = json.dumps(d, indent=2, sort_keys=True) session.merge(slc) session.commit() - print("Upgraded ({}/{}): {}".format(i, slice_len, slc.slice_name)) + print(f"Upgraded ({i}/{slice_len}): {slc.slice_name}") except Exception as ex: print(slc.slice_name + " error: " + str(ex)) diff --git a/superset/migrations/versions/2017-02-08_14-16_a99f2f7c195a_rewriting_url_from_shortner_with_new_.py b/superset/migrations/versions/2017-02-08_14-16_a99f2f7c195a_rewriting_url_from_shortner_with_new_.py index 04a39a31f5..8e97ada3cd 100644 --- a/superset/migrations/versions/2017-02-08_14-16_a99f2f7c195a_rewriting_url_from_shortner_with_new_.py +++ b/superset/migrations/versions/2017-02-08_14-16_a99f2f7c195a_rewriting_url_from_shortner_with_new_.py @@ -82,7 +82,7 @@ def upgrade(): url.url = newurl session.merge(url) session.commit() - print("Updating url ({}/{})".format(i, urls_len)) + print(f"Updating url ({i}/{urls_len})") session.close() diff --git a/superset/migrations/versions/2017-10-03_14-37_4736ec66ce19_.py b/superset/migrations/versions/2017-10-03_14-37_4736ec66ce19_.py index 26cfb93b99..f6d5610d97 100644 --- a/superset/migrations/versions/2017-10-03_14-37_4736ec66ce19_.py +++ b/superset/migrations/versions/2017-10-03_14-37_4736ec66ce19_.py @@ -69,7 +69,7 @@ def upgrade(): batch_op.add_column(sa.Column("datasource_id", sa.Integer)) batch_op.create_foreign_key( - "fk_{}_datasource_id_datasources".format(foreign), + f"fk_{foreign}_datasource_id_datasources", "datasources", ["datasource_id"], ["id"], @@ -102,7 +102,7 @@ def upgrade(): for name in names: batch_op.drop_constraint( - name or "fk_{}_datasource_name_datasources".format(foreign), + name or f"fk_{foreign}_datasource_name_datasources", type_="foreignkey", ) @@ -148,7 +148,7 @@ def downgrade(): batch_op.add_column(sa.Column("datasource_name", sa.String(255))) batch_op.create_foreign_key( - "fk_{}_datasource_name_datasources".format(foreign), + f"fk_{foreign}_datasource_name_datasources", "datasources", ["datasource_name"], ["datasource_name"], @@ -174,7 +174,7 @@ def downgrade(): with op.batch_alter_table(foreign, naming_convention=conv) as batch_op: # Drop the datasource_id column and associated constraint. batch_op.drop_constraint( - "fk_{}_datasource_id_datasources".format(foreign), type_="foreignkey" + f"fk_{foreign}_datasource_id_datasources", type_="foreignkey" ) batch_op.drop_column("datasource_id") @@ -201,7 +201,7 @@ def downgrade(): # Re-create the foreign key associated with the cluster_name column. batch_op.create_foreign_key( - "fk_{}_datasource_id_datasources".format(foreign), + f"fk_{foreign}_datasource_id_datasources", "clusters", ["cluster_name"], ["cluster_name"], diff --git a/superset/migrations/versions/2017-12-17_11-06_21e88bc06c02_annotation_migration.py b/superset/migrations/versions/2017-12-17_11-06_21e88bc06c02_annotation_migration.py index 5593af0eb6..4b1b807a6f 100644 --- a/superset/migrations/versions/2017-12-17_11-06_21e88bc06c02_annotation_migration.py +++ b/superset/migrations/versions/2017-12-17_11-06_21e88bc06c02_annotation_migration.py @@ -59,7 +59,7 @@ def upgrade(): { "annotationType": "INTERVAL", "style": "solid", - "name": "Layer {}".format(layer), + "name": f"Layer {layer}", "show": True, "overrides": {"since": None, "until": None}, "value": layer, diff --git a/superset/migrations/versions/2018-02-13_08-07_e866bd2d4976_smaller_grid.py b/superset/migrations/versions/2018-02-13_08-07_e866bd2d4976_smaller_grid.py index 286be8a5fc..bf6276d702 100644 --- a/superset/migrations/versions/2018-02-13_08-07_e866bd2d4976_smaller_grid.py +++ b/superset/migrations/versions/2018-02-13_08-07_e866bd2d4976_smaller_grid.py @@ -51,7 +51,7 @@ def upgrade(): dashboards = session.query(Dashboard).all() for i, dashboard in enumerate(dashboards): - print("Upgrading ({}/{}): {}".format(i, len(dashboards), dashboard.id)) + print(f"Upgrading ({i}/{len(dashboards)}): {dashboard.id}") positions = json.loads(dashboard.position_json or "{}") for pos in positions: if pos.get("v", 0) == 0: @@ -74,7 +74,7 @@ def downgrade(): dashboards = session.query(Dashboard).all() for i, dashboard in enumerate(dashboards): - print("Downgrading ({}/{}): {}".format(i, len(dashboards), dashboard.id)) + print(f"Downgrading ({i}/{len(dashboards)}): {dashboard.id}") positions = json.loads(dashboard.position_json or "{}") for pos in positions: if pos.get("v", 0) == 1: diff --git a/superset/migrations/versions/2018-03-20_19-47_f231d82b9b26_.py b/superset/migrations/versions/2018-03-20_19-47_f231d82b9b26_.py index dbe3f0ace4..c73399fb92 100644 --- a/superset/migrations/versions/2018-03-20_19-47_f231d82b9b26_.py +++ b/superset/migrations/versions/2018-03-20_19-47_f231d82b9b26_.py @@ -49,7 +49,7 @@ def upgrade(): for table, column in names.items(): with op.batch_alter_table(table, naming_convention=conv) as batch_op: batch_op.create_unique_constraint( - "uq_{}_{}".format(table, column), [column, "datasource_id"] + f"uq_{table}_{column}", [column, "datasource_id"] ) @@ -71,6 +71,6 @@ def downgrade(): with op.batch_alter_table(table, naming_convention=conv) as batch_op: batch_op.drop_constraint( generic_find_uq_constraint_name(table, {column, "datasource_id"}, insp) - or "uq_{}_{}".format(table, column), + or f"uq_{table}_{column}", type_="unique", ) diff --git a/superset/migrations/versions/2018-04-10_11-19_bf706ae5eb46_cal_heatmap_metric_to_metrics.py b/superset/migrations/versions/2018-04-10_11-19_bf706ae5eb46_cal_heatmap_metric_to_metrics.py index 3e2b81c17a..49b19b9c69 100644 --- a/superset/migrations/versions/2018-04-10_11-19_bf706ae5eb46_cal_heatmap_metric_to_metrics.py +++ b/superset/migrations/versions/2018-04-10_11-19_bf706ae5eb46_cal_heatmap_metric_to_metrics.py @@ -61,7 +61,7 @@ def upgrade(): slc.params = json.dumps(params, indent=2, sort_keys=True) session.merge(slc) session.commit() - print("Upgraded ({}/{}): {}".format(i, slice_len, slc.slice_name)) + print(f"Upgraded ({i}/{slice_len}): {slc.slice_name}") except Exception as ex: print(slc.slice_name + " error: " + str(ex)) diff --git a/superset/migrations/versions/2018-06-13_14-54_bddc498dd179_adhoc_filters.py b/superset/migrations/versions/2018-06-13_14-54_bddc498dd179_adhoc_filters.py index ec03328271..6292e2860a 100644 --- a/superset/migrations/versions/2018-06-13_14-54_bddc498dd179_adhoc_filters.py +++ b/superset/migrations/versions/2018-06-13_14-54_bddc498dd179_adhoc_filters.py @@ -28,8 +28,6 @@ down_revision = "80a67c5192fa" import json -import uuid -from collections import defaultdict from alembic import op from sqlalchemy import Column, Integer, Text diff --git a/superset/migrations/versions/2018-07-05_15-19_3dda56f1c4c6_migrate_num_period_compare_and_period_.py b/superset/migrations/versions/2018-07-05_15-19_3dda56f1c4c6_migrate_num_period_compare_and_period_.py index 2e491e9303..a2dd50bf9c 100644 --- a/superset/migrations/versions/2018-07-05_15-19_3dda56f1c4c6_migrate_num_period_compare_and_period_.py +++ b/superset/migrations/versions/2018-07-05_15-19_3dda56f1c4c6_migrate_num_period_compare_and_period_.py @@ -77,24 +77,24 @@ def isodate_duration_to_string(obj): if obj.tdelta: if not obj.months and not obj.years: return format_seconds(obj.tdelta.total_seconds()) - raise Exception("Unable to convert: {0}".format(obj)) + raise Exception(f"Unable to convert: {obj}") if obj.months % 12 != 0: months = obj.months + 12 * obj.years - return "{0} months".format(months) + return f"{months} months" - return "{0} years".format(obj.years + obj.months // 12) + return f"{obj.years + obj.months // 12} years" def timedelta_to_string(obj): if obj.microseconds: - raise Exception("Unable to convert: {0}".format(obj)) + raise Exception(f"Unable to convert: {obj}") elif obj.seconds: return format_seconds(obj.total_seconds()) elif obj.days % 7 == 0: - return "{0} weeks".format(obj.days // 7) + return f"{obj.days // 7} weeks" else: - return "{0} days".format(obj.days) + return f"{obj.days} days" def format_seconds(value): @@ -106,7 +106,7 @@ def format_seconds(value): else: period = "second" - return "{0} {1}{2}".format(value, period, "s" if value > 1 else "") + return "{} {}{}".format(value, period, "s" if value > 1 else "") def compute_time_compare(granularity, periods): @@ -120,11 +120,11 @@ def compute_time_compare(granularity, periods): obj = isodate.parse_duration(granularity) * periods except isodate.isoerror.ISO8601Error: # if parse_human_timedelta can parse it, return it directly - delta = "{0} {1}{2}".format(periods, granularity, "s" if periods > 1 else "") + delta = "{} {}{}".format(periods, granularity, "s" if periods > 1 else "") obj = parse_human_timedelta(delta) if obj: return delta - raise Exception("Unable to parse: {0}".format(granularity)) + raise Exception(f"Unable to parse: {granularity}") if isinstance(obj, isodate.duration.Duration): return isodate_duration_to_string(obj) diff --git a/superset/migrations/versions/2018-07-22_11-59_bebcf3fed1fe_convert_dashboard_v1_positions.py b/superset/migrations/versions/2018-07-22_11-59_bebcf3fed1fe_convert_dashboard_v1_positions.py index 13c4e61718..620e2c5008 100644 --- a/superset/migrations/versions/2018-07-22_11-59_bebcf3fed1fe_convert_dashboard_v1_positions.py +++ b/superset/migrations/versions/2018-07-22_11-59_bebcf3fed1fe_convert_dashboard_v1_positions.py @@ -173,7 +173,7 @@ def get_header_component(title): def get_row_container(): return { "type": ROW_TYPE, - "id": "DASHBOARD_ROW_TYPE-{}".format(generate_id()), + "id": f"DASHBOARD_ROW_TYPE-{generate_id()}", "children": [], "meta": {"background": BACKGROUND_TRANSPARENT}, } @@ -182,7 +182,7 @@ def get_row_container(): def get_col_container(): return { "type": COLUMN_TYPE, - "id": "DASHBOARD_COLUMN_TYPE-{}".format(generate_id()), + "id": f"DASHBOARD_COLUMN_TYPE-{generate_id()}", "children": [], "meta": {"background": BACKGROUND_TRANSPARENT}, } @@ -203,18 +203,18 @@ def get_chart_holder(position): if len(code): markdown_content = code elif slice_name.strip(): - markdown_content = "##### {}".format(slice_name) + markdown_content = f"##### {slice_name}" return { "type": MARKDOWN_TYPE, - "id": "DASHBOARD_MARKDOWN_TYPE-{}".format(generate_id()), + "id": f"DASHBOARD_MARKDOWN_TYPE-{generate_id()}", "children": [], "meta": {"width": width, "height": height, "code": markdown_content}, } return { "type": CHART_TYPE, - "id": "DASHBOARD_CHART_TYPE-{}".format(generate_id()), + "id": f"DASHBOARD_CHART_TYPE-{generate_id()}", "children": [], "meta": {"width": width, "height": height, "chartId": int(slice_id)}, } @@ -584,10 +584,10 @@ def upgrade(): dashboards = session.query(Dashboard).all() for i, dashboard in enumerate(dashboards): - print("scanning dashboard ({}/{}) >>>>".format(i + 1, len(dashboards))) + print(f"scanning dashboard ({i + 1}/{len(dashboards)}) >>>>") position_json = json.loads(dashboard.position_json or "[]") if not is_v2_dash(position_json): - print("Converting dashboard... dash_id: {}".format(dashboard.id)) + print(f"Converting dashboard... dash_id: {dashboard.id}") position_dict = {} positions = [] slices = dashboard.slices @@ -650,7 +650,7 @@ def upgrade(): session.merge(dashboard) session.commit() else: - print("Skip converted dash_id: {}".format(dashboard.id)) + print(f"Skip converted dash_id: {dashboard.id}") session.close() diff --git a/superset/migrations/versions/2018-11-12_13-31_4ce8df208545_migrate_time_range_for_default_filters.py b/superset/migrations/versions/2018-11-12_13-31_4ce8df208545_migrate_time_range_for_default_filters.py index bfb7a66161..3c6979f961 100644 --- a/superset/migrations/versions/2018-11-12_13-31_4ce8df208545_migrate_time_range_for_default_filters.py +++ b/superset/migrations/versions/2018-11-12_13-31_4ce8df208545_migrate_time_range_for_default_filters.py @@ -51,7 +51,7 @@ def upgrade(): dashboards = session.query(Dashboard).all() for i, dashboard in enumerate(dashboards): - print("scanning dashboard ({}/{}) >>>>".format(i + 1, len(dashboards))) + print(f"scanning dashboard ({i + 1}/{len(dashboards)}) >>>>") if dashboard.json_metadata: json_metadata = json.loads(dashboard.json_metadata) has_update = False @@ -74,7 +74,7 @@ def upgrade(): # if user already defined __time_range, # just abandon __from and __to if "__time_range" not in val: - val["__time_range"] = "{} : {}".format(__from, __to) + val["__time_range"] = f"{__from} : {__to}" json_metadata["default_filters"] = json.dumps(filters) has_update = True except Exception: diff --git a/superset/migrations/versions/2019-11-06_15-23_78ee127d0d1d_reconvert_legacy_filters_into_adhoc.py b/superset/migrations/versions/2019-11-06_15-23_78ee127d0d1d_reconvert_legacy_filters_into_adhoc.py index 1d0690c5e0..073bfdc474 100644 --- a/superset/migrations/versions/2019-11-06_15-23_78ee127d0d1d_reconvert_legacy_filters_into_adhoc.py +++ b/superset/migrations/versions/2019-11-06_15-23_78ee127d0d1d_reconvert_legacy_filters_into_adhoc.py @@ -29,8 +29,6 @@ down_revision = "c2acd2cf3df2" import copy import json import logging -import uuid -from collections import defaultdict from alembic import op from sqlalchemy import Column, Integer, Text diff --git a/superset/migrations/versions/2020-03-25_10-49_b5998378c225_add_certificate_to_dbs.py b/superset/migrations/versions/2020-03-25_10-49_b5998378c225_add_certificate_to_dbs.py index 404ea96e44..3b7c3951cd 100644 --- a/superset/migrations/versions/2020-03-25_10-49_b5998378c225_add_certificate_to_dbs.py +++ b/superset/migrations/versions/2020-03-25_10-49_b5998378c225_add_certificate_to_dbs.py @@ -26,14 +26,13 @@ Create Date: 2020-03-25 10:49:10.883065 revision = "b5998378c225" down_revision = "72428d1ea401" -from typing import Dict import sqlalchemy as sa from alembic import op def upgrade(): - kwargs: Dict[str, str] = {} + kwargs: dict[str, str] = {} bind = op.get_bind() op.add_column( "dbs", diff --git a/superset/migrations/versions/2020-08-12_00-24_978245563a02_migrate_iframe_to_dash_markdown.py b/superset/migrations/versions/2020-08-12_00-24_978245563a02_migrate_iframe_to_dash_markdown.py index 6b63c468ec..4202de4560 100644 --- a/superset/migrations/versions/2020-08-12_00-24_978245563a02_migrate_iframe_to_dash_markdown.py +++ b/superset/migrations/versions/2020-08-12_00-24_978245563a02_migrate_iframe_to_dash_markdown.py @@ -21,7 +21,6 @@ Revises: f2672aa8350a Create Date: 2020-08-12 00:24:39.617899 """ -import collections import json import logging import uuid @@ -77,7 +76,7 @@ class Dashboard(Base): def create_new_markdown_component(chart_position, url): return { "type": "MARKDOWN", - "id": "MARKDOWN-{}".format(uuid.uuid4().hex[:8]), + "id": f"MARKDOWN-{uuid.uuid4().hex[:8]}", "children": [], "parents": chart_position["parents"], "meta": { diff --git a/superset/migrations/versions/2020-09-24_12-04_3fbbc6e8d654_fix_data_access_permissions_for_virtual_.py b/superset/migrations/versions/2020-09-24_12-04_3fbbc6e8d654_fix_data_access_permissions_for_virtual_.py index a6db4c2cb6..45f091c38e 100644 --- a/superset/migrations/versions/2020-09-24_12-04_3fbbc6e8d654_fix_data_access_permissions_for_virtual_.py +++ b/superset/migrations/versions/2020-09-24_12-04_3fbbc6e8d654_fix_data_access_permissions_for_virtual_.py @@ -167,7 +167,7 @@ def upgrade(): orphaned_faulty_view_menus = [] for faulty_view_menu in faulty_view_menus: # Get the dataset id from the view_menu name - match_ds_id = re.match("\[None\]\.\[.*\]\(id:(\d+)\)", faulty_view_menu.name) + match_ds_id = re.match(r"\[None\]\.\[.*\]\(id:(\d+)\)", faulty_view_menu.name) if match_ds_id: dataset_id = int(match_ds_id.group(1)) dataset = session.query(SqlaTable).get(dataset_id) diff --git a/superset/migrations/versions/2021-04-12_12-38_fc3a3a8ff221_migrate_filter_sets_to_new_format.py b/superset/migrations/versions/2021-04-12_12-38_fc3a3a8ff221_migrate_filter_sets_to_new_format.py index 79b032894f..64396b6abe 100644 --- a/superset/migrations/versions/2021-04-12_12-38_fc3a3a8ff221_migrate_filter_sets_to_new_format.py +++ b/superset/migrations/versions/2021-04-12_12-38_fc3a3a8ff221_migrate_filter_sets_to_new_format.py @@ -27,7 +27,8 @@ revision = "fc3a3a8ff221" down_revision = "085f06488938" import json -from typing import Any, Dict, Iterable +from collections.abc import Iterable +from typing import Any from alembic import op from sqlalchemy import Column, Integer, Text @@ -77,7 +78,7 @@ EXTRA_FORM_DATA_OVERRIDE_KEYS = ( ) -def upgrade_select_filters(native_filters: Iterable[Dict[str, Any]]) -> None: +def upgrade_select_filters(native_filters: Iterable[dict[str, Any]]) -> None: """ Add `defaultToFirstItem` to `controlValues` of `select_filter` components """ @@ -89,7 +90,7 @@ def upgrade_select_filters(native_filters: Iterable[Dict[str, Any]]) -> None: control_values["defaultToFirstItem"] = value -def upgrade_filter_set(filter_set: Dict[str, Any]) -> int: +def upgrade_filter_set(filter_set: dict[str, Any]) -> int: changed_filters = 0 upgrade_select_filters(filter_set.get("nativeFilters", {}).values()) data_mask = filter_set.get("dataMask", {}) @@ -124,7 +125,7 @@ def upgrade_filter_set(filter_set: Dict[str, Any]) -> int: return changed_filters -def downgrade_filter_set(filter_set: Dict[str, Any]) -> int: +def downgrade_filter_set(filter_set: dict[str, Any]) -> int: changed_filters = 0 old_data_mask = filter_set.pop("dataMask", {}) native_filters = {} diff --git a/superset/migrations/versions/2021-04-29_15-32_f1410ed7ec95_migrate_native_filters_to_new_schema.py b/superset/migrations/versions/2021-04-29_15-32_f1410ed7ec95_migrate_native_filters_to_new_schema.py index ec8f8e1cc0..42368ce896 100644 --- a/superset/migrations/versions/2021-04-29_15-32_f1410ed7ec95_migrate_native_filters_to_new_schema.py +++ b/superset/migrations/versions/2021-04-29_15-32_f1410ed7ec95_migrate_native_filters_to_new_schema.py @@ -27,7 +27,8 @@ revision = "f1410ed7ec95" down_revision = "d416d0d715cc" import json -from typing import Any, Dict, Iterable, Tuple +from collections.abc import Iterable +from typing import Any from alembic import op from sqlalchemy import Column, Integer, Text @@ -46,7 +47,7 @@ class Dashboard(Base): json_metadata = Column(Text) -def upgrade_filters(native_filters: Iterable[Dict[str, Any]]) -> int: +def upgrade_filters(native_filters: Iterable[dict[str, Any]]) -> int: """ Move `defaultValue` into `defaultDataMask.filterState` """ @@ -61,7 +62,7 @@ def upgrade_filters(native_filters: Iterable[Dict[str, Any]]) -> int: return changed_filters -def downgrade_filters(native_filters: Iterable[Dict[str, Any]]) -> int: +def downgrade_filters(native_filters: Iterable[dict[str, Any]]) -> int: """ Move `defaultDataMask.filterState` into `defaultValue` """ @@ -76,7 +77,7 @@ def downgrade_filters(native_filters: Iterable[Dict[str, Any]]) -> int: return changed_filters -def upgrade_dashboard(dashboard: Dict[str, Any]) -> Tuple[int, int]: +def upgrade_dashboard(dashboard: dict[str, Any]) -> tuple[int, int]: changed_filters, changed_filter_sets = 0, 0 # upgrade native select filter metadata # upgrade native select filter metadata @@ -119,7 +120,7 @@ def upgrade(): print(f"Upgraded {changed_filters} filters and {changed_filter_sets} filter sets.") -def downgrade_dashboard(dashboard: Dict[str, Any]) -> Tuple[int, int]: +def downgrade_dashboard(dashboard: dict[str, Any]) -> tuple[int, int]: changed_filters, changed_filter_sets = 0, 0 # upgrade native select filter metadata if native_filters := dashboard.get("native_filter_configuration"): diff --git a/superset/migrations/versions/2021-08-03_15-36_143b6f2815da_migrate_pivot_table_v2_heatmaps_to_new_.py b/superset/migrations/versions/2021-08-03_15-36_143b6f2815da_migrate_pivot_table_v2_heatmaps_to_new_.py index 888925a888..8be11d3cf6 100644 --- a/superset/migrations/versions/2021-08-03_15-36_143b6f2815da_migrate_pivot_table_v2_heatmaps_to_new_.py +++ b/superset/migrations/versions/2021-08-03_15-36_143b6f2815da_migrate_pivot_table_v2_heatmaps_to_new_.py @@ -27,7 +27,6 @@ revision = "143b6f2815da" down_revision = "e323605f370a" import json -from typing import Any, Dict, List, Tuple from alembic import op from sqlalchemy import and_, Column, Integer, String, Text diff --git a/superset/migrations/versions/2021-09-27_11-31_60dc453f4e2e_migrate_timeseries_limit_metric_to_.py b/superset/migrations/versions/2021-09-27_11-31_60dc453f4e2e_migrate_timeseries_limit_metric_to_.py index e44bdae782..ab852c324b 100644 --- a/superset/migrations/versions/2021-09-27_11-31_60dc453f4e2e_migrate_timeseries_limit_metric_to_.py +++ b/superset/migrations/versions/2021-09-27_11-31_60dc453f4e2e_migrate_timeseries_limit_metric_to_.py @@ -27,7 +27,6 @@ revision = "60dc453f4e2e" down_revision = "3ebe0993c770" import json -import re from alembic import op from sqlalchemy import and_, Column, Integer, String, Text diff --git a/superset/migrations/versions/2021-10-12_11-15_32646df09c64_update_time_grain_sqla.py b/superset/migrations/versions/2021-10-12_11-15_32646df09c64_update_time_grain_sqla.py index db1b87e546..b85e9397e6 100644 --- a/superset/migrations/versions/2021-10-12_11-15_32646df09c64_update_time_grain_sqla.py +++ b/superset/migrations/versions/2021-10-12_11-15_32646df09c64_update_time_grain_sqla.py @@ -27,7 +27,6 @@ revision = "32646df09c64" down_revision = "60dc453f4e2e" import json -from typing import Dict from alembic import op from sqlalchemy import Column, Integer, Text @@ -45,7 +44,7 @@ class Slice(Base): params = Column(Text) -def migrate(mapping: Dict[str, str]) -> None: +def migrate(mapping: dict[str, str]) -> None: bind = op.get_bind() session = db.Session(bind=bind) diff --git a/superset/migrations/versions/2022-04-01_14-38_a9422eeaae74_new_dataset_models_take_2.py b/superset/migrations/versions/2022-04-01_14-38_a9422eeaae74_new_dataset_models_take_2.py index 286a0731fc..b51f6c78ac 100644 --- a/superset/migrations/versions/2022-04-01_14-38_a9422eeaae74_new_dataset_models_take_2.py +++ b/superset/migrations/versions/2022-04-01_14-38_a9422eeaae74_new_dataset_models_take_2.py @@ -29,7 +29,7 @@ down_revision = "ad07e4fdbaba" import json import os from datetime import datetime -from typing import List, Optional, Set, Type, Union +from typing import Optional, Union from uuid import uuid4 import sqlalchemy as sa @@ -86,7 +86,7 @@ class AuxiliaryColumnsMixin(UUIDMixin): def insert_from_select( - target: Union[str, sa.Table, Type[Base]], source: sa.sql.expression.Select + target: Union[str, sa.Table, type[Base]], source: sa.sql.expression.Select ) -> None: """ Execute INSERT FROM SELECT to copy data from a SELECT query to the target table. @@ -274,8 +274,8 @@ def find_tables( session: Session, database_id: int, default_schema: Optional[str], - tables: Set[Table], -) -> List[int]: + tables: set[Table], +) -> list[int]: """ Look for NewTable's of from a specific database """ diff --git a/superset/models/annotations.py b/superset/models/annotations.py index 3185460bf5..54de94e7f6 100644 --- a/superset/models/annotations.py +++ b/superset/models/annotations.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """a collection of Annotation-related models""" -from typing import Any, Dict +from typing import Any from flask_appbuilder import Model from sqlalchemy import Column, DateTime, ForeignKey, Index, Integer, String, Text @@ -54,7 +54,7 @@ class Annotation(Model, AuditMixinNullable): __table_args__ = (Index("ti_dag_state", layer_id, start_dttm, end_dttm),) @property - def data(self) -> Dict[str, Any]: + def data(self) -> dict[str, Any]: return { "layer_id": self.layer_id, "start_dttm": self.start_dttm, diff --git a/superset/models/core.py b/superset/models/core.py index ee50f06345..3c2b12d378 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=line-too-long """A collection of ORM sqlalchemy models for Superset""" +import builtins import enum import json import logging @@ -25,7 +26,7 @@ from contextlib import closing, contextmanager, nullcontext from copy import deepcopy from datetime import datetime from functools import lru_cache -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TYPE_CHECKING +from typing import Any, Callable, Optional, TYPE_CHECKING import numpy import pandas as pd @@ -194,7 +195,7 @@ class Database( return self.db_engine_spec.allows_subqueries @property - def function_names(self) -> List[str]: + def function_names(self) -> list[str]: try: return self.db_engine_spec.get_function_names(self) except Exception as ex: # pylint: disable=broad-except @@ -234,7 +235,7 @@ class Database( return True @property - def data(self) -> Dict[str, Any]: + def data(self) -> dict[str, Any]: return { "id": self.id, "name": self.database_name, @@ -271,7 +272,7 @@ class Database( return self.db_engine_spec.mask_encrypted_extra(self.encrypted_extra) @property - def parameters(self) -> Dict[str, Any]: + def parameters(self) -> dict[str, Any]: # Database parameters are a dictionary of values that are used to make up # the sqlalchemy_uri # When returning the parameters we should use the masked SQLAlchemy URI and the @@ -296,7 +297,7 @@ class Database( return parameters @property - def parameters_schema(self) -> Dict[str, Any]: + def parameters_schema(self) -> dict[str, Any]: try: parameters_schema = self.db_engine_spec.parameters_json_schema() # type: ignore except Exception: # pylint: disable=broad-except @@ -304,7 +305,7 @@ class Database( return parameters_schema @property - def metadata_cache_timeout(self) -> Dict[str, Any]: + def metadata_cache_timeout(self) -> dict[str, Any]: return self.get_extra().get("metadata_cache_timeout", {}) @property @@ -324,15 +325,15 @@ class Database( return self.metadata_cache_timeout.get("table_cache_timeout") @property - def default_schemas(self) -> List[str]: + def default_schemas(self) -> list[str]: return self.get_extra().get("default_schemas", []) @property - def connect_args(self) -> Dict[str, Any]: + def connect_args(self) -> dict[str, Any]: return self.get_extra().get("engine_params", {}).get("connect_args", {}) @property - def engine_information(self) -> Dict[str, Any]: + def engine_information(self) -> dict[str, Any]: try: engine_information = self.db_engine_spec.get_public_information() except Exception: # pylint: disable=broad-except @@ -540,7 +541,7 @@ class Database( """Add quotes to potential identifiter expressions if needed""" return self.get_dialect().identifier_preparer.quote - def get_reserved_words(self) -> Set[str]: + def get_reserved_words(self) -> set[str]: return self.get_dialect().preparer.reserved_words def get_df( # pylint: disable=too-many-locals @@ -629,7 +630,7 @@ class Database( show_cols: bool = False, indent: bool = True, latest_partition: bool = False, - cols: Optional[List[Dict[str, Any]]] = None, + cols: Optional[list[dict[str, Any]]] = None, ) -> str: """Generates a ``select *`` statement in the proper dialect""" eng = self._get_sqla_engine(schema=schema, source=utils.QuerySource.SQL_LAB) @@ -670,7 +671,7 @@ class Database( cache: bool = False, cache_timeout: Optional[int] = None, force: bool = False, - ) -> Set[Tuple[str, str]]: + ) -> set[tuple[str, str]]: """Parameters need to be passed as keyword arguments. For unused parameters, they are referenced in @@ -706,7 +707,7 @@ class Database( cache: bool = False, cache_timeout: Optional[int] = None, force: bool = False, - ) -> Set[Tuple[str, str]]: + ) -> set[tuple[str, str]]: """Parameters need to be passed as keyword arguments. For unused parameters, they are referenced in @@ -750,7 +751,7 @@ class Database( cache_timeout: Optional[int] = None, force: bool = False, ssh_tunnel: Optional["SSHTunnel"] = None, - ) -> List[str]: + ) -> list[str]: """Parameters need to be passed as keyword arguments. For unused parameters, they are referenced in @@ -768,13 +769,15 @@ class Database( raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex @property - def db_engine_spec(self) -> Type[db_engine_specs.BaseEngineSpec]: + def db_engine_spec(self) -> builtins.type[db_engine_specs.BaseEngineSpec]: url = make_url_safe(self.sqlalchemy_uri_decrypted) return self.get_db_engine_spec(url) @classmethod @lru_cache(maxsize=LRU_CACHE_MAX_SIZE) - def get_db_engine_spec(cls, url: URL) -> Type[db_engine_specs.BaseEngineSpec]: + def get_db_engine_spec( + cls, url: URL + ) -> builtins.type[db_engine_specs.BaseEngineSpec]: backend = url.get_backend_name() try: driver = url.get_driver_name() @@ -784,7 +787,7 @@ class Database( return db_engine_specs.get_engine_spec(backend, driver) - def grains(self) -> Tuple[TimeGrain, ...]: + def grains(self) -> tuple[TimeGrain, ...]: """Defines time granularity database-specific expressions. The idea here is to make it easy for users to change the time grain @@ -795,10 +798,10 @@ class Database( """ return self.db_engine_spec.get_time_grains() - def get_extra(self) -> Dict[str, Any]: + def get_extra(self) -> dict[str, Any]: return self.db_engine_spec.get_extra_params(self) - def get_encrypted_extra(self) -> Dict[str, Any]: + def get_encrypted_extra(self) -> dict[str, Any]: encrypted_extra = {} if self.encrypted_extra: try: @@ -809,7 +812,7 @@ class Database( return encrypted_extra # pylint: disable=invalid-name - def update_params_from_encrypted_extra(self, params: Dict[str, Any]) -> None: + def update_params_from_encrypted_extra(self, params: dict[str, Any]) -> None: self.db_engine_spec.update_params_from_encrypted_extra(self, params) def get_table(self, table_name: str, schema: Optional[str] = None) -> Table: @@ -832,7 +835,7 @@ class Database( def get_columns( self, table_name: str, schema: Optional[str] = None - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: with self.get_inspector_with_context() as inspector: return self.db_engine_spec.get_columns(inspector, table_name, schema) @@ -840,19 +843,19 @@ class Database( self, table_name: str, schema: Optional[str] = None, - ) -> List[MetricType]: + ) -> list[MetricType]: with self.get_inspector_with_context() as inspector: return self.db_engine_spec.get_metrics(self, inspector, table_name, schema) def get_indexes( self, table_name: str, schema: Optional[str] = None - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: with self.get_inspector_with_context() as inspector: return self.db_engine_spec.get_indexes(self, inspector, table_name, schema) def get_pk_constraint( self, table_name: str, schema: Optional[str] = None - ) -> Dict[str, Any]: + ) -> dict[str, Any]: with self.get_inspector_with_context() as inspector: pk_constraint = inspector.get_pk_constraint(table_name, schema) or {} @@ -866,13 +869,13 @@ class Database( def get_foreign_keys( self, table_name: str, schema: Optional[str] = None - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: with self.get_inspector_with_context() as inspector: return inspector.get_foreign_keys(table_name, schema) def get_schema_access_for_file_upload( # pylint: disable=invalid-name self, - ) -> List[str]: + ) -> list[str]: allowed_databases = self.get_extra().get("schemas_allowed_for_file_upload", []) if isinstance(allowed_databases, str): @@ -932,7 +935,7 @@ class Database( view_name: str, schema: Optional[str] = None, ) -> bool: - view_names: List[str] = [] + view_names: list[str] = [] try: view_names = dialect.get_view_names(connection=conn, schema=schema) except Exception: # pylint: disable=broad-except diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py index 9afd74f5e3..f3b9c08794 100644 --- a/superset/models/dashboard.py +++ b/superset/models/dashboard.py @@ -21,7 +21,7 @@ import logging import uuid from collections import defaultdict from functools import partial -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union +from typing import Any, Callable import sqlalchemy as sqla from flask import current_app @@ -68,9 +68,7 @@ config = app.config logger = logging.getLogger(__name__) -def copy_dashboard( - _mapper: Mapper, connection: Connection, target: "Dashboard" -) -> None: +def copy_dashboard(_mapper: Mapper, connection: Connection, target: Dashboard) -> None: dashboard_id = config["DASHBOARD_TEMPLATE_ID"] if dashboard_id is None: return @@ -146,7 +144,7 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin): certification_details = Column(Text) json_metadata = Column(Text) slug = Column(String(255), unique=True) - slices: List[Slice] = relationship( + slices: list[Slice] = relationship( Slice, secondary=dashboard_slices, backref="dashboards" ) owners = relationship(security_manager.user_model, secondary=dashboard_user) @@ -187,14 +185,14 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin): return f"/superset/dashboard/{self.slug or self.id}/" @staticmethod - def get_url(id_: int, slug: Optional[str] = None) -> str: + def get_url(id_: int, slug: str | None = None) -> str: # To be able to generate URL's without instanciating a Dashboard object return f"/superset/dashboard/{slug or id_}/" @property - def datasources(self) -> Set[BaseDatasource]: + def datasources(self) -> set[BaseDatasource]: # Verbose but efficient database enumeration of dashboard datasources. - datasources_by_cls_model: Dict[Type["BaseDatasource"], Set[int]] = defaultdict( + datasources_by_cls_model: dict[type[BaseDatasource], set[int]] = defaultdict( set ) @@ -210,14 +208,14 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin): } @property - def filter_sets(self) -> Dict[int, FilterSet]: + def filter_sets(self) -> dict[int, FilterSet]: return {fs.id: fs for fs in self._filter_sets} @property - def filter_sets_lst(self) -> Dict[int, FilterSet]: + def filter_sets_lst(self) -> dict[int, FilterSet]: if security_manager.is_admin(): return self._filter_sets - filter_sets_by_owner_type: Dict[str, List[Any]] = {"Dashboard": [], "User": []} + filter_sets_by_owner_type: dict[str, list[Any]] = {"Dashboard": [], "User": []} for fs in self._filter_sets: filter_sets_by_owner_type[fs.owner_type].append(fs) user_filter_sets = list( @@ -232,7 +230,7 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin): } @property - def charts(self) -> List[str]: + def charts(self) -> list[str]: return [slc.chart for slc in self.slices] @property @@ -281,7 +279,7 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin): return f"/superset/profile/{self.changed_by.username}" @property - def data(self) -> Dict[str, Any]: + def data(self) -> dict[str, Any]: positions = self.position_json if positions: positions = json.loads(positions) @@ -305,16 +303,16 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin): make_name=lambda fname: f"{fname}-v1.0", unless=lambda: not is_feature_enabled("DASHBOARD_CACHE"), ) - def datasets_trimmed_for_slices(self) -> List[Dict[str, Any]]: + def datasets_trimmed_for_slices(self) -> list[dict[str, Any]]: # Verbose but efficient database enumeration of dashboard datasources. - slices_by_datasource: Dict[ - Tuple[Type["BaseDatasource"], int], Set[Slice] + slices_by_datasource: dict[ + tuple[type[BaseDatasource], int], set[Slice] ] = defaultdict(set) for slc in self.slices: slices_by_datasource[(slc.cls_model, slc.datasource_id)].add(slc) - result: List[Dict[str, Any]] = [] + result: list[dict[str, Any]] = [] for (cls_model, datasource_id), slices in slices_by_datasource.items(): datasource = ( @@ -336,7 +334,7 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin): self.json_metadata = value @property - def position(self) -> Dict[str, Any]: + def position(self) -> dict[str, Any]: if self.position_json: return json.loads(self.position_json) return {} @@ -380,7 +378,7 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin): @classmethod def export_dashboards( # pylint: disable=too-many-locals - cls, dashboard_ids: List[int] + cls, dashboard_ids: list[int] ) -> str: copied_dashboards = [] datasource_ids = set() @@ -413,7 +411,7 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin): slices.append(copied_slc) json_metadata = json.loads(dashboard.json_metadata) - native_filter_configuration: List[Dict[str, Any]] = json_metadata.get( + native_filter_configuration: list[dict[str, Any]] = json_metadata.get( "native_filter_configuration", [] ) for native_filter in native_filter_configuration: @@ -449,12 +447,12 @@ class Dashboard(Model, AuditMixinNullable, ImportExportMixin): ) @classmethod - def get(cls, id_or_slug: Union[str, int]) -> Dashboard: + def get(cls, id_or_slug: str | int) -> Dashboard: qry = db.session.query(Dashboard).filter(id_or_slug_filter(id_or_slug)) return qry.one_or_none() -def is_uuid(value: Union[str, int]) -> bool: +def is_uuid(value: str | int) -> bool: try: uuid.UUID(str(value)) return True @@ -462,7 +460,7 @@ def is_uuid(value: Union[str, int]) -> bool: return False -def is_int(value: Union[str, int]) -> bool: +def is_int(value: str | int) -> bool: try: int(value) return True @@ -470,7 +468,7 @@ def is_int(value: Union[str, int]) -> bool: return False -def id_or_slug_filter(id_or_slug: Union[int, str]) -> BinaryExpression: +def id_or_slug_filter(id_or_slug: int | str) -> BinaryExpression: if is_int(id_or_slug): return Dashboard.id == int(id_or_slug) if is_uuid(id_or_slug): @@ -490,7 +488,7 @@ if is_feature_enabled("DASHBOARD_CACHE"): def clear_dashboard_cache( _mapper: Mapper, _connection: Connection, - obj: Union[Slice, BaseDatasource, Dashboard], + obj: Slice | BaseDatasource | Dashboard, check_modified: bool = True, ) -> None: if check_modified and not object_session(obj).is_modified(obj): diff --git a/superset/models/datasource_access_request.py b/superset/models/datasource_access_request.py index 1f286f96d8..23df4cffae 100644 --- a/superset/models/datasource_access_request.py +++ b/superset/models/datasource_access_request.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Optional, Type, TYPE_CHECKING +from typing import Optional, TYPE_CHECKING from flask import Markup from flask_appbuilder import Model @@ -41,7 +41,7 @@ class DatasourceAccessRequest(Model, AuditMixinNullable): ROLES_DENYLIST = set(config["ROBOT_PERMISSION_ROLES"]) @property - def cls_model(self) -> Type["BaseDatasource"]: + def cls_model(self) -> type["BaseDatasource"]: # pylint: disable=import-outside-toplevel from superset.datasource.dao import DatasourceDAO @@ -77,7 +77,7 @@ class DatasourceAccessRequest(Model, AuditMixinNullable): f"datasource_id={self.datasource_id}&" f"created_by={self.created_by.username}&role_to_grant={role.name}" ) - link = 'Grant {} Role'.format(href, role.name) + link = f'Grant {role.name} Role' action_list = action_list + "
  • " + link + "
  • " return "
      " + action_list + "
    " @@ -90,8 +90,8 @@ class DatasourceAccessRequest(Model, AuditMixinNullable): f"datasource_id={self.datasource_id}&" f"created_by={self.created_by.username}&role_to_extend={role.name}" ) - link = 'Extend {} Role'.format(href, role.name) + link = f'Extend {role.name} Role' if role.name in self.ROLES_DENYLIST: - link = "{} Role".format(role.name) + link = f"{role.name} Role" action_list = action_list + "
  • " + link + "
  • " return "
      " + action_list + "
    " diff --git a/superset/models/embedded_dashboard.py b/superset/models/embedded_dashboard.py index 7718bc886f..32a8e4abce 100644 --- a/superset/models/embedded_dashboard.py +++ b/superset/models/embedded_dashboard.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. import uuid -from typing import List from flask_appbuilder import Model from sqlalchemy import Column, ForeignKey, Integer, Text @@ -49,7 +48,7 @@ class EmbeddedDashboard(Model, AuditMixinNullable): ) @property - def allowed_domains(self) -> List[str]: + def allowed_domains(self) -> list[str]: """ A list of domains which are allowed to embed the dashboard. An empty list means any domain can embed. diff --git a/superset/models/filter_set.py b/superset/models/filter_set.py index 1ace5bca32..ac25b114ff 100644 --- a/superset/models/filter_set.py +++ b/superset/models/filter_set.py @@ -18,7 +18,7 @@ from __future__ import annotations import json import logging -from typing import Any, Dict +from typing import Any from flask import current_app from flask_appbuilder import Model @@ -75,7 +75,7 @@ class FilterSet(Model, AuditMixinNullable): return "" return f"/superset/profile/{self.changed_by.username}" - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "id": self.id, "name": self.name, @@ -105,7 +105,7 @@ class FilterSet(Model, AuditMixinNullable): return qry.all() @property - def params(self) -> Dict[str, Any]: + def params(self) -> dict[str, Any]: if self.json_metadata: return json.loads(self.json_metadata) return {} diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 32c6f5ff6a..42d5a24174 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -16,29 +16,17 @@ # under the License. # pylint: disable=too-many-lines """a collection of model-related helper classes and functions""" +import builtins import dataclasses import json import logging import re import uuid from collections import defaultdict +from collections.abc import Hashable from datetime import datetime, timedelta from json.decoder import JSONDecodeError -from typing import ( - Any, - cast, - Dict, - Hashable, - List, - NamedTuple, - Optional, - Set, - Text, - Tuple, - Type, - TYPE_CHECKING, - Union, -) +from typing import Any, cast, NamedTuple, Optional, TYPE_CHECKING, Union import dateutil.parser import humanize @@ -145,7 +133,7 @@ def validate_adhoc_subquery( return ";\n".join(str(statement) for statement in statements) -def json_to_dict(json_str: str) -> Dict[Any, Any]: +def json_to_dict(json_str: str) -> dict[Any, Any]: if json_str: val = re.sub(",[ \t\r\n]+}", "}", json_str) val = re.sub(",[ \t\r\n]+\\]", "]", val) @@ -179,22 +167,22 @@ class ImportExportMixin: # The name of the attribute # with the SQL Alchemy back reference - export_children: List[str] = [] + export_children: list[str] = [] # List of (str) names of attributes # with the SQL Alchemy forward references - export_fields: List[str] = [] + export_fields: list[str] = [] # The names of the attributes # that are available for import and export - extra_import_fields: List[str] = [] + extra_import_fields: list[str] = [] # Additional fields that should be imported, # even though they were not exported __mapper__: Mapper @classmethod - def _unique_constrains(cls) -> List[Set[str]]: + def _unique_constrains(cls) -> list[set[str]]: """Get all (single column and multi column) unique constraints""" unique = [ {c.name for c in u.columns} @@ -207,7 +195,7 @@ class ImportExportMixin: return unique @classmethod - def parent_foreign_key_mappings(cls) -> Dict[str, str]: + def parent_foreign_key_mappings(cls) -> dict[str, str]: """Get a mapping of foreign name to the local name of foreign keys""" parent_rel = cls.__mapper__.relationships.get(cls.export_parent) if parent_rel: @@ -217,7 +205,7 @@ class ImportExportMixin: @classmethod def export_schema( cls, recursive: bool = True, include_parent_ref: bool = False - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Export schema as a dictionary""" parent_excludes = set() if not include_parent_ref: @@ -227,12 +215,12 @@ class ImportExportMixin: def formatter(column: sa.Column) -> str: return ( - "{0} Default ({1})".format(str(column.type), column.default.arg) + f"{str(column.type)} Default ({column.default.arg})" if column.default else str(column.type) ) - schema: Dict[str, Any] = { + schema: dict[str, Any] = { column.name: formatter(column) for column in cls.__table__.columns # type: ignore if (column.name in cls.export_fields and column.name not in parent_excludes) @@ -252,10 +240,10 @@ class ImportExportMixin: # pylint: disable=too-many-arguments,too-many-branches,too-many-locals cls, session: Session, - dict_rep: Dict[Any, Any], + dict_rep: dict[Any, Any], parent: Optional[Any] = None, recursive: bool = True, - sync: Optional[List[str]] = None, + sync: Optional[list[str]] = None, ) -> Any: """Import obj from a dictionary""" if sync is None: @@ -281,9 +269,7 @@ class ImportExportMixin: if cls.export_parent: for prnt in parent_refs.keys(): if prnt not in dict_rep: - raise RuntimeError( - "{0}: Missing field {1}".format(cls.__name__, prnt) - ) + raise RuntimeError(f"{cls.__name__}: Missing field {prnt}") else: # Set foreign keys to parent obj for k, v in parent_refs.items(): @@ -371,7 +357,7 @@ class ImportExportMixin: include_parent_ref: bool = False, include_defaults: bool = False, export_uuids: bool = False, - ) -> Dict[Any, Any]: + ) -> dict[Any, Any]: """Export obj to dictionary""" export_fields = set(self.export_fields) if export_uuids: @@ -457,18 +443,18 @@ class ImportExportMixin: self.owners = [g.user] @property - def params_dict(self) -> Dict[Any, Any]: + def params_dict(self) -> dict[Any, Any]: return json_to_dict(self.params) @property - def template_params_dict(self) -> Dict[Any, Any]: + def template_params_dict(self) -> dict[Any, Any]: return json_to_dict(self.template_params) # type: ignore def _user_link(user: User) -> Union[Markup, str]: if not user: return "" - url = "/superset/profile/{}/".format(user.username) + url = f"/superset/profile/{user.username}/" return Markup('{}'.format(url, escape(user) or "")) @@ -505,13 +491,13 @@ class AuditMixinNullable(AuditMixin): @property def created_by_name(self) -> str: if self.created_by: - return escape("{}".format(self.created_by)) + return escape(f"{self.created_by}") return "" @property def changed_by_name(self) -> str: if self.changed_by: - return escape("{}".format(self.changed_by)) + return escape(f"{self.changed_by}") return "" @renders("created_by") @@ -565,12 +551,12 @@ class QueryResult: # pylint: disable=too-few-public-methods df: pd.DataFrame, query: str, duration: timedelta, - applied_template_filters: Optional[List[str]] = None, - applied_filter_columns: Optional[List[ColumnTyping]] = None, - rejected_filter_columns: Optional[List[ColumnTyping]] = None, + applied_template_filters: Optional[list[str]] = None, + applied_filter_columns: Optional[list[ColumnTyping]] = None, + rejected_filter_columns: Optional[list[ColumnTyping]] = None, status: str = QueryStatus.SUCCESS, error_message: Optional[str] = None, - errors: Optional[List[Dict[str, Any]]] = None, + errors: Optional[list[dict[str, Any]]] = None, from_dttm: Optional[datetime] = None, to_dttm: Optional[datetime] = None, ) -> None: @@ -593,7 +579,7 @@ class ExtraJSONMixin: extra_json = sa.Column(sa.Text, default="{}") @property - def extra(self) -> Dict[str, Any]: + def extra(self) -> dict[str, Any]: try: return json.loads(self.extra_json or "{}") or {} except (TypeError, JSONDecodeError) as exc: @@ -603,7 +589,7 @@ class ExtraJSONMixin: return {} @extra.setter - def extra(self, extras: Dict[str, Any]) -> None: + def extra(self, extras: dict[str, Any]) -> None: self.extra_json = json.dumps(extras) def set_extra_json_key(self, key: str, value: Any) -> None: @@ -615,7 +601,7 @@ class ExtraJSONMixin: def ensure_extra_json_is_not_none( # pylint: disable=no-self-use self, _: str, - value: Optional[Dict[str, Any]], + value: Optional[dict[str, Any]], ) -> Any: if value is None: return "{}" @@ -627,7 +613,7 @@ class CertificationMixin: extra = sa.Column(sa.Text, default="{}") - def get_extra_dict(self) -> Dict[str, Any]: + def get_extra_dict(self) -> dict[str, Any]: try: return json.loads(self.extra) except (TypeError, json.JSONDecodeError): @@ -652,8 +638,8 @@ class CertificationMixin: def clone_model( target: Model, - ignore: Optional[List[str]] = None, - keep_relations: Optional[List[str]] = None, + ignore: Optional[list[str]] = None, + keep_relations: Optional[list[str]] = None, **kwargs: Any, ) -> Model: """ @@ -676,22 +662,22 @@ def clone_model( # todo(hugh): centralize where this code lives class QueryStringExtended(NamedTuple): - applied_template_filters: Optional[List[str]] - applied_filter_columns: List[ColumnTyping] - rejected_filter_columns: List[ColumnTyping] - labels_expected: List[str] - prequeries: List[str] + applied_template_filters: Optional[list[str]] + applied_filter_columns: list[ColumnTyping] + rejected_filter_columns: list[ColumnTyping] + labels_expected: list[str] + prequeries: list[str] sql: str class SqlaQuery(NamedTuple): - applied_template_filters: List[str] - applied_filter_columns: List[ColumnTyping] - rejected_filter_columns: List[ColumnTyping] + applied_template_filters: list[str] + applied_filter_columns: list[ColumnTyping] + rejected_filter_columns: list[ColumnTyping] cte: Optional[str] - extra_cache_keys: List[Any] - labels_expected: List[str] - prequeries: List[str] + extra_cache_keys: list[Any] + labels_expected: list[str] + prequeries: list[str] sqla_query: Select @@ -719,7 +705,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods raise NotImplementedError() @property - def db_extra(self) -> Optional[Dict[str, Any]]: + def db_extra(self) -> Optional[dict[str, Any]]: raise NotImplementedError() def query(self, query_obj: QueryObjectDict) -> QueryResult: @@ -730,11 +716,11 @@ class ExploreMixin: # pylint: disable=too-many-public-methods raise NotImplementedError() @property - def owners_data(self) -> List[Any]: + def owners_data(self) -> list[Any]: raise NotImplementedError() @property - def metrics(self) -> List[Any]: + def metrics(self) -> list[Any]: return [] @property @@ -750,7 +736,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods raise NotImplementedError() @property - def column_names(self) -> List[str]: + def column_names(self) -> list[str]: raise NotImplementedError() @property @@ -762,15 +748,15 @@ class ExploreMixin: # pylint: disable=too-many-public-methods raise NotImplementedError() @property - def dttm_cols(self) -> List[str]: + def dttm_cols(self) -> list[str]: raise NotImplementedError() @property - def db_engine_spec(self) -> Type["BaseEngineSpec"]: + def db_engine_spec(self) -> builtins.type["BaseEngineSpec"]: raise NotImplementedError() @property - def database(self) -> Type["Database"]: + def database(self) -> builtins.type["Database"]: raise NotImplementedError() @property @@ -782,7 +768,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods raise NotImplementedError() @property - def columns(self) -> List[Any]: + def columns(self) -> list[Any]: raise NotImplementedError() def get_fetch_values_predicate( @@ -790,7 +776,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods ) -> TextClause: raise NotImplementedError() - def get_extra_cache_keys(self, query_obj: Dict[str, Any]) -> List[Hashable]: + def get_extra_cache_keys(self, query_obj: dict[str, Any]) -> list[Hashable]: raise NotImplementedError() def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor: @@ -799,7 +785,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods def get_sqla_row_level_filters( self, template_processor: BaseTemplateProcessor, - ) -> List[TextClause]: + ) -> list[TextClause]: """ Return the appropriate row level security filters for this table and the current user. A custom username can be passed when the user is not present in the @@ -808,8 +794,8 @@ class ExploreMixin: # pylint: disable=too-many-public-methods :param template_processor: The template processor to apply to the filters. :returns: A list of SQL clauses to be ANDed together. """ - all_filters: List[TextClause] = [] - filter_groups: Dict[Union[int, str], List[TextClause]] = defaultdict(list) + all_filters: list[TextClause] = [] + filter_groups: dict[Union[int, str], list[TextClause]] = defaultdict(list) try: for filter_ in security_manager.get_rls_filters(self): clause = self.text( @@ -923,8 +909,8 @@ class ExploreMixin: # pylint: disable=too-many-public-methods self, row: pd.Series, dimension: str, - columns_by_name: Dict[str, "TableColumn"], - ) -> Union[str, int, float, bool, Text]: + columns_by_name: dict[str, "TableColumn"], + ) -> Union[str, int, float, bool, str]: """ Convert a prequery result type to its equivalent Python type. @@ -944,7 +930,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods value = value.item() column_ = columns_by_name[dimension] - db_extra: Dict[str, Any] = self.database.get_extra() # type: ignore + db_extra: dict[str, Any] = self.database.get_extra() # type: ignore if isinstance(column_, dict): if ( @@ -969,7 +955,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods return value def make_orderby_compatible( - self, select_exprs: List[ColumnElement], orderby_exprs: List[ColumnElement] + self, select_exprs: list[ColumnElement], orderby_exprs: list[ColumnElement] ) -> None: """ If needed, make sure aliases for selected columns are not used in @@ -1088,7 +1074,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods def get_from_clause( self, template_processor: Optional[BaseTemplateProcessor] = None - ) -> Tuple[Union[TableClause, Alias], Optional[str]]: + ) -> tuple[Union[TableClause, Alias], Optional[str]]: """ Return where to select the columns and metrics from. Either a physical table or a virtual table with it's own subquery. If the FROM is referencing a @@ -1117,7 +1103,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods def adhoc_metric_to_sqla( self, metric: AdhocMetric, - columns_by_name: Dict[str, "TableColumn"], # pylint: disable=unused-argument + columns_by_name: dict[str, "TableColumn"], # pylint: disable=unused-argument template_processor: Optional[BaseTemplateProcessor] = None, ) -> ColumnElement: """ @@ -1151,7 +1137,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods return self.make_sqla_column_compatible(sqla_metric, label) @property - def template_params_dict(self) -> Dict[Any, Any]: + def template_params_dict(self) -> dict[Any, Any]: return {} @staticmethod @@ -1162,9 +1148,9 @@ class ExploreMixin: # pylint: disable=too-many-public-methods target_native_type: Optional[str] = None, is_list_target: bool = False, db_engine_spec: Optional[ - Type["BaseEngineSpec"] + builtins.type["BaseEngineSpec"] ] = None, # fix(hughhh): Optional[Type[BaseEngineSpec]] - db_extra: Optional[Dict[str, Any]] = None, + db_extra: Optional[dict[str, Any]] = None, ) -> Optional[FilterValues]: if values is None: return None @@ -1217,8 +1203,8 @@ class ExploreMixin: # pylint: disable=too-many-public-methods def _get_series_orderby( self, series_limit_metric: Metric, - metrics_by_name: Dict[str, "SqlMetric"], - columns_by_name: Dict[str, "TableColumn"], + metrics_by_name: dict[str, "SqlMetric"], + columns_by_name: dict[str, "TableColumn"], template_processor: Optional[BaseTemplateProcessor] = None, ) -> Column: if utils.is_adhoc_metric(series_limit_metric): @@ -1248,9 +1234,9 @@ class ExploreMixin: # pylint: disable=too-many-public-methods def _get_top_groups( self, df: pd.DataFrame, - dimensions: List[str], - groupby_exprs: Dict[str, Any], - columns_by_name: Dict[str, "TableColumn"], + dimensions: list[str], + groupby_exprs: dict[str, Any], + columns_by_name: dict[str, "TableColumn"], ) -> ColumnElement: groups = [] for _unused, row in df.iterrows(): @@ -1335,7 +1321,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods ) return and_(*l) - def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]: + def values_for_column(self, column_name: str, limit: int = 10000) -> list[Any]: """Runs query against sqla to retrieve some sample values for the given column. """ @@ -1369,7 +1355,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods def get_timestamp_expression( self, - column: Dict[str, Any], + column: dict[str, Any], time_grain: Optional[str], label: Optional[str] = None, template_processor: Optional[BaseTemplateProcessor] = None, @@ -1417,23 +1403,23 @@ class ExploreMixin: # pylint: disable=too-many-public-methods def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-many-branches,too-many-statements self, apply_fetch_values_predicate: bool = False, - columns: Optional[List[Column]] = None, - extras: Optional[Dict[str, Any]] = None, + columns: Optional[list[Column]] = None, + extras: Optional[dict[str, Any]] = None, filter: Optional[ # pylint: disable=redefined-builtin - List[utils.QueryObjectFilterClause] + list[utils.QueryObjectFilterClause] ] = None, from_dttm: Optional[datetime] = None, granularity: Optional[str] = None, - groupby: Optional[List[Column]] = None, + groupby: Optional[list[Column]] = None, inner_from_dttm: Optional[datetime] = None, inner_to_dttm: Optional[datetime] = None, is_rowcount: bool = False, is_timeseries: bool = True, - metrics: Optional[List[Metric]] = None, - orderby: Optional[List[OrderBy]] = None, + metrics: Optional[list[Metric]] = None, + orderby: Optional[list[OrderBy]] = None, order_desc: bool = True, to_dttm: Optional[datetime] = None, - series_columns: Optional[List[Column]] = None, + series_columns: Optional[list[Column]] = None, series_limit: Optional[int] = None, series_limit_metric: Optional[Metric] = None, row_limit: Optional[int] = None, @@ -1464,23 +1450,23 @@ class ExploreMixin: # pylint: disable=too-many-public-methods } columns = columns or [] groupby = groupby or [] - rejected_adhoc_filters_columns: List[Union[str, ColumnTyping]] = [] - applied_adhoc_filters_columns: List[Union[str, ColumnTyping]] = [] + rejected_adhoc_filters_columns: list[Union[str, ColumnTyping]] = [] + applied_adhoc_filters_columns: list[Union[str, ColumnTyping]] = [] series_column_names = utils.get_column_names(series_columns or []) # deprecated, to be removed in 2.0 if is_timeseries and timeseries_limit: series_limit = timeseries_limit series_limit_metric = series_limit_metric or timeseries_limit_metric template_kwargs.update(self.template_params_dict) - extra_cache_keys: List[Any] = [] + extra_cache_keys: list[Any] = [] template_kwargs["extra_cache_keys"] = extra_cache_keys - removed_filters: List[str] = [] - applied_template_filters: List[str] = [] + removed_filters: list[str] = [] + applied_template_filters: list[str] = [] template_kwargs["removed_filters"] = removed_filters template_kwargs["applied_filters"] = applied_template_filters template_processor = self.get_template_processor(**template_kwargs) db_engine_spec = self.db_engine_spec - prequeries: List[str] = [] + prequeries: list[str] = [] orderby = orderby or [] need_groupby = bool(metrics is not None or groupby) metrics = metrics or [] @@ -1489,11 +1475,11 @@ class ExploreMixin: # pylint: disable=too-many-public-methods if granularity not in self.dttm_cols and granularity is not None: granularity = self.main_dttm_col - columns_by_name: Dict[str, "TableColumn"] = { + columns_by_name: dict[str, "TableColumn"] = { col.column_name: col for col in self.columns } - metrics_by_name: Dict[str, "SqlMetric"] = { + metrics_by_name: dict[str, "SqlMetric"] = { m.metric_name: m for m in self.metrics } @@ -1507,7 +1493,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods if not metrics and not columns and not groupby: raise QueryObjectValidationError(_("Empty query?")) - metrics_exprs: List[ColumnElement] = [] + metrics_exprs: list[ColumnElement] = [] for metric in metrics: if utils.is_adhoc_metric(metric): assert isinstance(metric, dict) @@ -1542,7 +1528,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods metrics_exprs_by_expr = {str(m): m for m in metrics_exprs} # Since orderby may use adhoc metrics, too; we need to process them first - orderby_exprs: List[ColumnElement] = [] + orderby_exprs: list[ColumnElement] = [] for orig_col, ascending in orderby: col: Union[AdhocMetric, ColumnElement] = orig_col if isinstance(col, dict): @@ -1582,7 +1568,7 @@ class ExploreMixin: # pylint: disable=too-many-public-methods _("Unknown column used in orderby: %(col)s", col=orig_col) ) - select_exprs: List[Union[Column, Label]] = [] + select_exprs: list[Union[Column, Label]] = [] groupby_all_columns = {} groupby_series_columns = {} diff --git a/superset/models/slice.py b/superset/models/slice.py index 6835215338..15dddfc7e1 100644 --- a/superset/models/slice.py +++ b/superset/models/slice.py @@ -18,7 +18,7 @@ from __future__ import annotations import json import logging -from typing import Any, Dict, Optional, Type, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from urllib import parse import sqlalchemy as sqla @@ -70,7 +70,7 @@ class Slice( # pylint: disable=too-many-public-methods ): """A slice is essentially a report or a view on data""" - query_context_factory: Optional[QueryContextFactory] = None + query_context_factory: QueryContextFactory | None = None __tablename__ = "slices" id = Column(Integer, primary_key=True) @@ -134,17 +134,17 @@ class Slice( # pylint: disable=too-many-public-methods return self.slice_name or str(self.id) @property - def cls_model(self) -> Type["BaseDatasource"]: + def cls_model(self) -> type[BaseDatasource]: # pylint: disable=import-outside-toplevel from superset.datasource.dao import DatasourceDAO return DatasourceDAO.sources[self.datasource_type] @property - def datasource(self) -> Optional["BaseDatasource"]: + def datasource(self) -> BaseDatasource | None: return self.get_datasource - def clone(self) -> "Slice": + def clone(self) -> Slice: return Slice( slice_name=self.slice_name, datasource_id=self.datasource_id, @@ -158,7 +158,7 @@ class Slice( # pylint: disable=too-many-public-methods # pylint: disable=using-constant-test @datasource.getter # type: ignore - def get_datasource(self) -> Optional["BaseDatasource"]: + def get_datasource(self) -> BaseDatasource | None: return ( db.session.query(self.cls_model) .filter_by(id=self.datasource_id) @@ -166,20 +166,20 @@ class Slice( # pylint: disable=too-many-public-methods ) @renders("datasource_name") - def datasource_link(self) -> Optional[Markup]: + def datasource_link(self) -> Markup | None: # pylint: disable=no-member datasource = self.datasource return datasource.link if datasource else None @renders("datasource_url") - def datasource_url(self) -> Optional[str]: + def datasource_url(self) -> str | None: # pylint: disable=no-member if self.table: return self.table.explore_url datasource = self.datasource return datasource.explore_url if datasource else None - def datasource_name_text(self) -> Optional[str]: + def datasource_name_text(self) -> str | None: # pylint: disable=no-member if self.table: if self.table.schema: @@ -192,7 +192,7 @@ class Slice( # pylint: disable=too-many-public-methods return None @property - def datasource_edit_url(self) -> Optional[str]: + def datasource_edit_url(self) -> str | None: # pylint: disable=no-member datasource = self.datasource return datasource.url if datasource else None @@ -200,7 +200,7 @@ class Slice( # pylint: disable=too-many-public-methods # pylint: enable=using-constant-test @property - def viz(self) -> Optional[BaseViz]: + def viz(self) -> BaseViz | None: form_data = json.loads(self.params) viz_class = viz_types.get(self.viz_type) datasource = self.datasource @@ -213,9 +213,9 @@ class Slice( # pylint: disable=too-many-public-methods return utils.markdown(self.description) @property - def data(self) -> Dict[str, Any]: + def data(self) -> dict[str, Any]: """Data used to render slice in templates""" - data: Dict[str, Any] = {} + data: dict[str, Any] = {} self.token = "" try: viz = self.viz @@ -261,8 +261,8 @@ class Slice( # pylint: disable=too-many-public-methods return json.dumps(self.data) @property - def form_data(self) -> Dict[str, Any]: - form_data: Dict[str, Any] = {} + def form_data(self) -> dict[str, Any]: + form_data: dict[str, Any] = {} try: form_data = json.loads(self.params) except Exception as ex: # pylint: disable=broad-except @@ -272,7 +272,7 @@ class Slice( # pylint: disable=too-many-public-methods { "slice_id": self.id, "viz_type": self.viz_type, - "datasource": "{}__{}".format(self.datasource_id, self.datasource_type), + "datasource": f"{self.datasource_id}__{self.datasource_type}", } ) @@ -281,7 +281,7 @@ class Slice( # pylint: disable=too-many-public-methods update_time_range(form_data) return form_data - def get_query_context(self) -> Optional[QueryContext]: + def get_query_context(self) -> QueryContext | None: if self.query_context: try: return self.get_query_context_factory().create( @@ -295,13 +295,13 @@ class Slice( # pylint: disable=too-many-public-methods def get_explore_url( self, base_url: str = "/explore", - overrides: Optional[Dict[str, Any]] = None, + overrides: dict[str, Any] | None = None, ) -> str: return self.build_explore_url(self.id, base_url, overrides) @staticmethod def build_explore_url( - id_: int, base_url: str = "/explore", overrides: Optional[Dict[str, Any]] = None + id_: int, base_url: str = "/explore", overrides: dict[str, Any] | None = None ) -> str: overrides = overrides or {} form_data = {"slice_id": id_} diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index b2f0c8c1ed..b9ab153798 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -15,11 +15,13 @@ # specific language governing permissions and limitations # under the License. """A collection of ORM sqlalchemy models for SQL Lab""" +import builtins import inspect import logging import re +from collections.abc import Hashable from datetime import datetime -from typing import Any, Dict, Hashable, List, Optional, Type, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING import simplejson as json import sqlalchemy as sqla @@ -131,7 +133,7 @@ class Query( def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor: return get_template_processor(query=self, database=self.database, **kwargs) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "changedOn": self.changed_on, "changed_on": self.changed_on.isoformat(), @@ -181,11 +183,11 @@ class Query( return self.user.username @property - def sql_tables(self) -> List[Table]: + def sql_tables(self) -> list[Table]: return list(ParsedQuery(self.sql).tables) @property - def columns(self) -> List["TableColumn"]: + def columns(self) -> list["TableColumn"]: from superset.connectors.sqla.models import ( # pylint: disable=import-outside-toplevel TableColumn, ) @@ -204,11 +206,11 @@ class Query( return columns @property - def db_extra(self) -> Optional[Dict[str, Any]]: + def db_extra(self) -> Optional[dict[str, Any]]: return None @property - def data(self) -> Dict[str, Any]: + def data(self) -> dict[str, Any]: order_by_choices = [] for col in self.columns: column_name = str(col.column_name or "") @@ -247,11 +249,13 @@ class Query( security_manager.raise_for_access(query=self) @property - def db_engine_spec(self) -> Type["BaseEngineSpec"]: + def db_engine_spec( + self, + ) -> builtins.type["BaseEngineSpec"]: # pylint: disable=unsubscriptable-object return self.database.db_engine_spec @property - def owners_data(self) -> List[Dict[str, Any]]: + def owners_data(self) -> list[dict[str, Any]]: return [] @property @@ -267,7 +271,7 @@ class Query( return 0 @property - def column_names(self) -> List[Any]: + def column_names(self) -> list[Any]: return [col.column_name for col in self.columns] @property @@ -282,7 +286,7 @@ class Query( return None @property - def dttm_cols(self) -> List[Any]: + def dttm_cols(self) -> list[Any]: return [col.column_name for col in self.columns if col.is_dttm] @property @@ -298,7 +302,7 @@ class Query( return "" @staticmethod - def get_extra_cache_keys(query_obj: Dict[str, Any]) -> List[Hashable]: + def get_extra_cache_keys(query_obj: dict[str, Any]) -> list[Hashable]: return [] @property @@ -322,7 +326,7 @@ class Query( def tracking_url(self, value: str) -> None: self.tracking_url_raw = value - def get_column(self, column_name: Optional[str]) -> Optional[Dict[str, Any]]: + def get_column(self, column_name: Optional[str]) -> Optional[dict[str, Any]]: if not column_name: return None for col in self.columns: @@ -397,7 +401,7 @@ class SavedQuery(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): def __repr__(self) -> str: return str(self.label) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "id": self.id, } @@ -421,10 +425,10 @@ class SavedQuery(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): return self.database.sqlalchemy_uri def url(self) -> str: - return "/superset/sqllab?savedQueryId={0}".format(self.id) + return f"/superset/sqllab?savedQueryId={self.id}" @property - def sql_tables(self) -> List[Table]: + def sql_tables(self) -> list[Table]: return list(ParsedQuery(self.sql).tables) @property @@ -483,7 +487,7 @@ class TabState(Model, AuditMixinNullable, ExtraJSONMixin): ) saved_query = relationship("SavedQuery", foreign_keys=[saved_query_id]) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: return { "id": self.id, "user_id": self.user_id, @@ -520,7 +524,7 @@ class TableSchema(Model, AuditMixinNullable, ExtraJSONMixin): expanded = Column(Boolean, default=False) - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: try: description = json.loads(self.description) except json.JSONDecodeError: diff --git a/superset/models/sql_types/presto_sql_types.py b/superset/models/sql_types/presto_sql_types.py index c496f75039..234581dfb4 100644 --- a/superset/models/sql_types/presto_sql_types.py +++ b/superset/models/sql_types/presto_sql_types.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=abstract-method, no-init -from typing import Any, Dict, List, Optional, Type +from typing import Any, Optional from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.sql.sqltypes import DATE, Integer, TIMESTAMP @@ -33,7 +33,7 @@ class TinyInteger(Integer): """ @property - def python_type(self) -> Type[int]: + def python_type(self) -> type[int]: return int @classmethod @@ -47,7 +47,7 @@ class Interval(TypeEngine): """ @property - def python_type(self) -> Optional[Type[Any]]: + def python_type(self) -> Optional[type[Any]]: return None @classmethod @@ -61,7 +61,7 @@ class Array(TypeEngine): """ @property - def python_type(self) -> Optional[Type[List[Any]]]: + def python_type(self) -> Optional[type[list[Any]]]: return list @classmethod @@ -75,7 +75,7 @@ class Map(TypeEngine): """ @property - def python_type(self) -> Optional[Type[Dict[Any, Any]]]: + def python_type(self) -> Optional[type[dict[Any, Any]]]: return dict @classmethod @@ -89,7 +89,7 @@ class Row(TypeEngine): """ @property - def python_type(self) -> Optional[Type[Any]]: + def python_type(self) -> Optional[type[Any]]: return None @classmethod diff --git a/superset/queries/dao.py b/superset/queries/dao.py index 642a5dd4cb..e9fe15cac5 100644 --- a/superset/queries/dao.py +++ b/superset/queries/dao.py @@ -16,7 +16,7 @@ # under the License. import logging from datetime import datetime -from typing import Any, Dict, List, Union +from typing import Any, Union from superset import sql_lab from superset.common.db_query_status import QueryStatus @@ -56,14 +56,14 @@ class QueryDAO(BaseDAO): db.session.commit() @staticmethod - def save_metadata(query: Query, payload: Dict[str, Any]) -> None: + def save_metadata(query: Query, payload: dict[str, Any]) -> None: # pull relevant data from payload and store in extra_json columns = payload.get("columns", {}) db.session.add(query) query.set_extra_json_key("columns", columns) @staticmethod - def get_queries_changed_after(last_updated_ms: Union[float, int]) -> List[Query]: + def get_queries_changed_after(last_updated_ms: Union[float, int]) -> list[Query]: # UTC date time, same that is stored in the DB. last_updated_dt = datetime.utcfromtimestamp(last_updated_ms / 1000) diff --git a/superset/queries/saved_queries/commands/bulk_delete.py b/superset/queries/saved_queries/commands/bulk_delete.py index c96afd31e5..fb230180c8 100644 --- a/superset/queries/saved_queries/commands/bulk_delete.py +++ b/superset/queries/saved_queries/commands/bulk_delete.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import List, Optional +from typing import Optional from superset.commands.base import BaseCommand from superset.dao.exceptions import DAODeleteFailedError @@ -30,9 +30,9 @@ logger = logging.getLogger(__name__) class BulkDeleteSavedQueryCommand(BaseCommand): - def __init__(self, model_ids: List[int]): + def __init__(self, model_ids: list[int]): self._model_ids = model_ids - self._models: Optional[List[Dashboard]] = None + self._models: Optional[list[Dashboard]] = None def run(self) -> None: self.validate() diff --git a/superset/queries/saved_queries/commands/export.py b/superset/queries/saved_queries/commands/export.py index 8c5357159e..323256306a 100644 --- a/superset/queries/saved_queries/commands/export.py +++ b/superset/queries/saved_queries/commands/export.py @@ -18,7 +18,7 @@ import json import logging -from typing import Iterator, Tuple +from collections.abc import Iterator import yaml from werkzeug.utils import secure_filename @@ -39,7 +39,7 @@ class ExportSavedQueriesCommand(ExportModelsCommand): @staticmethod def _export( model: SavedQuery, export_related: bool = True - ) -> Iterator[Tuple[str, str]]: + ) -> Iterator[tuple[str, str]]: # build filename based on database, optional schema, and label database_slug = secure_filename(model.database.database_name) schema_slug = secure_filename(model.schema) diff --git a/superset/queries/saved_queries/commands/importers/dispatcher.py b/superset/queries/saved_queries/commands/importers/dispatcher.py index 8283202225..c2208f0e2a 100644 --- a/superset/queries/saved_queries/commands/importers/dispatcher.py +++ b/superset/queries/saved_queries/commands/importers/dispatcher.py @@ -16,7 +16,7 @@ # under the License. import logging -from typing import Any, Dict +from typing import Any from marshmallow.exceptions import ValidationError @@ -40,7 +40,7 @@ class ImportSavedQueriesCommand(BaseCommand): until it finds one that matches. """ - def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any): + def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any): self.contents = contents self.args = args self.kwargs = kwargs diff --git a/superset/queries/saved_queries/commands/importers/v1/__init__.py b/superset/queries/saved_queries/commands/importers/v1/__init__.py index 1412dbd356..79ec04f54b 100644 --- a/superset/queries/saved_queries/commands/importers/v1/__init__.py +++ b/superset/queries/saved_queries/commands/importers/v1/__init__.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, Set +from typing import Any from marshmallow import Schema from sqlalchemy.orm import Session @@ -38,7 +38,7 @@ class ImportSavedQueriesCommand(ImportModelsCommand): dao = SavedQueryDAO model_name = "saved_queries" prefix = "queries/" - schemas: Dict[str, Schema] = { + schemas: dict[str, Schema] = { "databases/": ImportV1DatabaseSchema(), "queries/": ImportV1SavedQuerySchema(), } @@ -46,16 +46,16 @@ class ImportSavedQueriesCommand(ImportModelsCommand): @staticmethod def _import( - session: Session, configs: Dict[str, Any], overwrite: bool = False + session: Session, configs: dict[str, Any], overwrite: bool = False ) -> None: # discover databases associated with saved queries - database_uuids: Set[str] = set() + database_uuids: set[str] = set() for file_name, config in configs.items(): if file_name.startswith("queries/"): database_uuids.add(config["database_uuid"]) # import related databases - database_ids: Dict[str, int] = {} + database_ids: dict[str, int] = {} for file_name, config in configs.items(): if file_name.startswith("databases/") and config["uuid"] in database_uuids: database = import_database(session, config, overwrite=False) diff --git a/superset/queries/saved_queries/commands/importers/v1/utils.py b/superset/queries/saved_queries/commands/importers/v1/utils.py index f2d090bf11..813f3c2295 100644 --- a/superset/queries/saved_queries/commands/importers/v1/utils.py +++ b/superset/queries/saved_queries/commands/importers/v1/utils.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict +from typing import Any from sqlalchemy.orm import Session @@ -23,7 +23,7 @@ from superset.models.sql_lab import SavedQuery def import_saved_query( - session: Session, config: Dict[str, Any], overwrite: bool = False + session: Session, config: dict[str, Any], overwrite: bool = False ) -> SavedQuery: existing = session.query(SavedQuery).filter_by(uuid=config["uuid"]).first() if existing: diff --git a/superset/queries/saved_queries/dao.py b/superset/queries/saved_queries/dao.py index c6bcfa035c..daae1de8f5 100644 --- a/superset/queries/saved_queries/dao.py +++ b/superset/queries/saved_queries/dao.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import List, Optional +from typing import Optional from sqlalchemy.exc import SQLAlchemyError @@ -33,7 +33,7 @@ class SavedQueryDAO(BaseDAO): base_filter = SavedQueryFilter @staticmethod - def bulk_delete(models: Optional[List[SavedQuery]], commit: bool = True) -> None: + def bulk_delete(models: Optional[list[SavedQuery]], commit: bool = True) -> None: item_ids = [model.id for model in models] if models else [] try: db.session.query(SavedQuery).filter(SavedQuery.id.in_(item_ids)).delete( diff --git a/superset/queries/schemas.py b/superset/queries/schemas.py index b139784c5b..850664e92f 100644 --- a/superset/queries/schemas.py +++ b/superset/queries/schemas.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import List from marshmallow import fields, Schema @@ -73,7 +72,7 @@ class QuerySchema(Schema): include_relationships = True # pylint: disable=no-self-use - def get_sql_tables(self, obj: Query) -> List[Table]: + def get_sql_tables(self, obj: Query) -> list[Table]: return obj.sql_tables diff --git a/superset/reports/commands/alert.py b/superset/reports/commands/alert.py index c5b4709447..41163dc064 100644 --- a/superset/reports/commands/alert.py +++ b/superset/reports/commands/alert.py @@ -20,7 +20,7 @@ import json import logging from operator import eq, ge, gt, le, lt, ne from timeit import default_timer -from typing import Any, Optional +from typing import Any import numpy as np import pandas as pd @@ -54,7 +54,7 @@ OPERATOR_FUNCTIONS = {">=": ge, ">": gt, "<=": le, "<": lt, "==": eq, "!=": ne} class AlertCommand(BaseCommand): def __init__(self, report_schedule: ReportSchedule): self._report_schedule = report_schedule - self._result: Optional[float] = None + self._result: float | None = None def run(self) -> bool: """ diff --git a/superset/reports/commands/base.py b/superset/reports/commands/base.py index 4fee6a8824..598370576b 100644 --- a/superset/reports/commands/base.py +++ b/superset/reports/commands/base.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, List +from typing import Any from marshmallow import ValidationError @@ -36,7 +36,7 @@ logger = logging.getLogger(__name__) class BaseReportScheduleCommand(BaseCommand): - _properties: Dict[str, Any] + _properties: dict[str, Any] def run(self) -> Any: pass @@ -45,7 +45,7 @@ class BaseReportScheduleCommand(BaseCommand): pass def validate_chart_dashboard( - self, exceptions: List[ValidationError], update: bool = False + self, exceptions: list[ValidationError], update: bool = False ) -> None: """Validate chart or dashboard relation""" chart_id = self._properties.get("chart") diff --git a/superset/reports/commands/bulk_delete.py b/superset/reports/commands/bulk_delete.py index 28a39a2fb6..7d6e1ed791 100644 --- a/superset/reports/commands/bulk_delete.py +++ b/superset/reports/commands/bulk_delete.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import List, Optional +from typing import Optional from superset import security_manager from superset.commands.base import BaseCommand @@ -33,9 +33,9 @@ logger = logging.getLogger(__name__) class BulkDeleteReportScheduleCommand(BaseCommand): - def __init__(self, model_ids: List[int]): + def __init__(self, model_ids: list[int]): self._model_ids = model_ids - self._models: Optional[List[ReportSchedule]] = None + self._models: Optional[list[ReportSchedule]] = None def run(self) -> None: self.validate() diff --git a/superset/reports/commands/create.py b/superset/reports/commands/create.py index 27626170d6..04cf6ef43f 100644 --- a/superset/reports/commands/create.py +++ b/superset/reports/commands/create.py @@ -16,7 +16,7 @@ # under the License. import json import logging -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask_babel import gettext as _ from marshmallow import ValidationError @@ -46,7 +46,7 @@ logger = logging.getLogger(__name__) class CreateReportScheduleCommand(CreateMixin, BaseReportScheduleCommand): - def __init__(self, data: Dict[str, Any]): + def __init__(self, data: dict[str, Any]): self._properties = data.copy() def run(self) -> ReportSchedule: @@ -59,8 +59,8 @@ class CreateReportScheduleCommand(CreateMixin, BaseReportScheduleCommand): return report_schedule def validate(self) -> None: - exceptions: List[ValidationError] = [] - owner_ids: Optional[List[int]] = self._properties.get("owners") + exceptions: list[ValidationError] = [] + owner_ids: Optional[list[int]] = self._properties.get("owners") name = self._properties.get("name", "") report_type = self._properties.get("type") creation_method = self._properties.get("creation_method") @@ -119,7 +119,7 @@ class CreateReportScheduleCommand(CreateMixin, BaseReportScheduleCommand): if exceptions: raise ReportScheduleInvalidError(exceptions=exceptions) - def _validate_report_extra(self, exceptions: List[ValidationError]) -> None: + def _validate_report_extra(self, exceptions: list[ValidationError]) -> None: extra: Optional[ReportScheduleExtra] = self._properties.get("extra") dashboard = self._properties.get("dashboard") diff --git a/superset/reports/commands/exceptions.py b/superset/reports/commands/exceptions.py index 22aff0727d..cba12e0786 100644 --- a/superset/reports/commands/exceptions.py +++ b/superset/reports/commands/exceptions.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import List from flask_babel import lazy_gettext as _ @@ -263,13 +262,13 @@ class ReportScheduleStateNotFoundError(CommandException): class ReportScheduleSystemErrorsException(CommandException, SupersetErrorsException): - errors: List[SupersetError] = [] + errors: list[SupersetError] = [] message = _("Report schedule system error") class ReportScheduleClientErrorsException(CommandException, SupersetErrorsException): status = 400 - errors: List[SupersetError] = [] + errors: list[SupersetError] = [] message = _("Report schedule client error") diff --git a/superset/reports/commands/execute.py b/superset/reports/commands/execute.py index f5f7bf4130..608b2564a2 100644 --- a/superset/reports/commands/execute.py +++ b/superset/reports/commands/execute.py @@ -17,7 +17,7 @@ import json import logging from datetime import datetime, timedelta -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union from uuid import UUID import pandas as pd @@ -80,7 +80,7 @@ logger = logging.getLogger(__name__) class BaseReportState: - current_states: List[ReportState] = [] + current_states: list[ReportState] = [] initial: bool = False def __init__( @@ -195,7 +195,7 @@ class BaseReportState: **kwargs, ) - def _get_screenshots(self) -> List[bytes]: + def _get_screenshots(self) -> list[bytes]: """ Get chart or dashboard screenshots :raises: ReportScheduleScreenshotFailedError @@ -394,14 +394,14 @@ class BaseReportState: def _send( self, notification_content: NotificationContent, - recipients: List[ReportRecipients], + recipients: list[ReportRecipients], ) -> None: """ Sends a notification to all recipients :raises: CommandException """ - notification_errors: List[SupersetError] = [] + notification_errors: list[SupersetError] = [] for recipient in recipients: notification = create_notification(recipient, notification_content) try: diff --git a/superset/reports/commands/update.py b/superset/reports/commands/update.py index 0c4f18f1b8..5ca3ac849a 100644 --- a/superset/reports/commands/update.py +++ b/superset/reports/commands/update.py @@ -16,7 +16,7 @@ # under the License. import json import logging -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask_appbuilder.models.sqla import Model from marshmallow import ValidationError @@ -42,7 +42,7 @@ logger = logging.getLogger(__name__) class UpdateReportScheduleCommand(UpdateMixin, BaseReportScheduleCommand): - def __init__(self, model_id: int, data: Dict[str, Any]): + def __init__(self, model_id: int, data: dict[str, Any]): self._model_id = model_id self._properties = data.copy() self._model: Optional[ReportSchedule] = None @@ -57,8 +57,8 @@ class UpdateReportScheduleCommand(UpdateMixin, BaseReportScheduleCommand): return report_schedule def validate(self) -> None: - exceptions: List[ValidationError] = [] - owner_ids: Optional[List[int]] = self._properties.get("owners") + exceptions: list[ValidationError] = [] + owner_ids: Optional[list[int]] = self._properties.get("owners") report_type = self._properties.get("type", ReportScheduleType.ALERT) name = self._properties.get("name", "") diff --git a/superset/reports/dao.py b/superset/reports/dao.py index be5ee8053c..64777e959a 100644 --- a/superset/reports/dao.py +++ b/superset/reports/dao.py @@ -17,7 +17,7 @@ import json import logging from datetime import datetime -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask_appbuilder import Model from sqlalchemy.exc import SQLAlchemyError @@ -47,7 +47,7 @@ class ReportScheduleDAO(BaseDAO): base_filter = ReportScheduleFilter @staticmethod - def find_by_chart_id(chart_id: int) -> List[ReportSchedule]: + def find_by_chart_id(chart_id: int) -> list[ReportSchedule]: return ( db.session.query(ReportSchedule) .filter(ReportSchedule.chart_id == chart_id) @@ -55,7 +55,7 @@ class ReportScheduleDAO(BaseDAO): ) @staticmethod - def find_by_chart_ids(chart_ids: List[int]) -> List[ReportSchedule]: + def find_by_chart_ids(chart_ids: list[int]) -> list[ReportSchedule]: return ( db.session.query(ReportSchedule) .filter(ReportSchedule.chart_id.in_(chart_ids)) @@ -63,7 +63,7 @@ class ReportScheduleDAO(BaseDAO): ) @staticmethod - def find_by_dashboard_id(dashboard_id: int) -> List[ReportSchedule]: + def find_by_dashboard_id(dashboard_id: int) -> list[ReportSchedule]: return ( db.session.query(ReportSchedule) .filter(ReportSchedule.dashboard_id == dashboard_id) @@ -71,7 +71,7 @@ class ReportScheduleDAO(BaseDAO): ) @staticmethod - def find_by_dashboard_ids(dashboard_ids: List[int]) -> List[ReportSchedule]: + def find_by_dashboard_ids(dashboard_ids: list[int]) -> list[ReportSchedule]: return ( db.session.query(ReportSchedule) .filter(ReportSchedule.dashboard_id.in_(dashboard_ids)) @@ -79,7 +79,7 @@ class ReportScheduleDAO(BaseDAO): ) @staticmethod - def find_by_database_id(database_id: int) -> List[ReportSchedule]: + def find_by_database_id(database_id: int) -> list[ReportSchedule]: return ( db.session.query(ReportSchedule) .filter(ReportSchedule.database_id == database_id) @@ -87,7 +87,7 @@ class ReportScheduleDAO(BaseDAO): ) @staticmethod - def find_by_database_ids(database_ids: List[int]) -> List[ReportSchedule]: + def find_by_database_ids(database_ids: list[int]) -> list[ReportSchedule]: return ( db.session.query(ReportSchedule) .filter(ReportSchedule.database_id.in_(database_ids)) @@ -96,7 +96,7 @@ class ReportScheduleDAO(BaseDAO): @staticmethod def bulk_delete( - models: Optional[List[ReportSchedule]], commit: bool = True + models: Optional[list[ReportSchedule]], commit: bool = True ) -> None: item_ids = [model.id for model in models] if models else [] try: @@ -156,7 +156,7 @@ class ReportScheduleDAO(BaseDAO): return found_id is None or found_id == expect_id @classmethod - def create(cls, properties: Dict[str, Any], commit: bool = True) -> ReportSchedule: + def create(cls, properties: dict[str, Any], commit: bool = True) -> ReportSchedule: """ create a report schedule and nested recipients :raises: DAOCreateFailedError @@ -187,7 +187,7 @@ class ReportScheduleDAO(BaseDAO): @classmethod def update( - cls, model: Model, properties: Dict[str, Any], commit: bool = True + cls, model: Model, properties: dict[str, Any], commit: bool = True ) -> ReportSchedule: """ create a report schedule and nested recipients @@ -219,7 +219,7 @@ class ReportScheduleDAO(BaseDAO): raise DAOCreateFailedError(str(ex)) from ex @staticmethod - def find_active(session: Optional[Session] = None) -> List[ReportSchedule]: + def find_active(session: Optional[Session] = None) -> list[ReportSchedule]: """ Find all active reports. If session is passed it will be used instead of the default `db.session`, this is useful when on a celery worker session context diff --git a/superset/reports/filters.py b/superset/reports/filters.py index 5fb87e0563..a03238b640 100644 --- a/superset/reports/filters.py +++ b/superset/reports/filters.py @@ -52,6 +52,6 @@ class ReportScheduleAllTextFilter(BaseFilter): # pylint: disable=too-few-public or_( ReportSchedule.name.ilike(ilike_value), ReportSchedule.description.ilike(ilike_value), - ReportSchedule.sql.ilike((ilike_value)), + ReportSchedule.sql.ilike(ilike_value), ) ) diff --git a/superset/reports/logs/api.py b/superset/reports/logs/api.py index 8ad8455cc7..f0c272caee 100644 --- a/superset/reports/logs/api.py +++ b/superset/reports/logs/api.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, Optional +from typing import Any, Optional from flask import Response from flask_appbuilder.api import expose, permission_name, protect, rison, safe @@ -83,7 +83,7 @@ class ReportExecutionLogRestApi(BaseSupersetModelRestApi): @staticmethod def _apply_layered_relation_to_rison( # pylint: disable=invalid-name - layer_id: int, rison_parameters: Dict[str, Any] + layer_id: int, rison_parameters: dict[str, Any] ) -> None: if "filters" not in rison_parameters: rison_parameters["filters"] = [] diff --git a/superset/reports/notifications/__init__.py b/superset/reports/notifications/__init__.py index c466f59abd..f2ac40bb46 100644 --- a/superset/reports/notifications/__init__.py +++ b/superset/reports/notifications/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information diff --git a/superset/reports/notifications/base.py b/superset/reports/notifications/base.py index 6eb2405d0f..640b326fc5 100644 --- a/superset/reports/notifications/base.py +++ b/superset/reports/notifications/base.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -16,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from dataclasses import dataclass -from typing import Any, List, Optional, Type +from typing import Any, Optional import pandas as pd @@ -29,7 +28,7 @@ class NotificationContent: name: str header_data: HeaderDataType # this is optional to account for error states csv: Optional[bytes] = None # bytes for csv file - screenshots: Optional[List[bytes]] = None # bytes for a list of screenshots + screenshots: Optional[list[bytes]] = None # bytes for a list of screenshots text: Optional[str] = None description: Optional[str] = "" url: Optional[str] = None # url to chart/dashboard for this screenshot @@ -44,7 +43,7 @@ class BaseNotification: # pylint: disable=too-few-public-methods notification type """ - plugins: List[Type["BaseNotification"]] = [] + plugins: list[type["BaseNotification"]] = [] type: Optional[ReportRecipientType] = None """ Child classes set their notification type ex: `type = "email"` this string will be diff --git a/superset/reports/notifications/email.py b/superset/reports/notifications/email.py index 10a76e7573..1b9e4ade72 100644 --- a/superset/reports/notifications/email.py +++ b/superset/reports/notifications/email.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -20,7 +19,7 @@ import logging import textwrap from dataclasses import dataclass from email.utils import make_msgid, parseaddr -from typing import Any, Dict, Optional +from typing import Any, Optional import nh3 from flask_babel import gettext as __ @@ -69,8 +68,8 @@ ALLOWED_ATTRIBUTES = { class EmailContent: body: str header_data: Optional[HeaderDataType] = None - data: Optional[Dict[str, Any]] = None - images: Optional[Dict[str, bytes]] = None + data: Optional[dict[str, Any]] = None + images: Optional[dict[str, bytes]] = None class EmailNotification(BaseNotification): # pylint: disable=too-few-public-methods diff --git a/superset/reports/notifications/slack.py b/superset/reports/notifications/slack.py index b89a700ef9..4c3f2ee419 100644 --- a/superset/reports/notifications/slack.py +++ b/superset/reports/notifications/slack.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -17,8 +16,9 @@ # under the License. import json import logging +from collections.abc import Sequence from io import IOBase -from typing import Sequence, Union +from typing import Union import backoff from flask_babel import gettext as __ diff --git a/superset/reports/schemas.py b/superset/reports/schemas.py index a45ee4cc38..83dea02f8f 100644 --- a/superset/reports/schemas.py +++ b/superset/reports/schemas.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, Union +from typing import Any, Union from croniter import croniter from flask_babel import gettext as _ @@ -212,7 +212,7 @@ class ReportSchedulePostSchema(Schema): @validates_schema def validate_report_references( # pylint: disable=unused-argument,no-self-use - self, data: Dict[str, Any], **kwargs: Any + self, data: dict[str, Any], **kwargs: Any ) -> None: if data["type"] == ReportScheduleType.REPORT: if "database" in data: diff --git a/superset/result_set.py b/superset/result_set.py index 9aa06bba09..f707b91dce 100644 --- a/superset/result_set.py +++ b/superset/result_set.py @@ -19,7 +19,7 @@ import datetime import json import logging -from typing import Any, Dict, List, Optional, Tuple, Type +from typing import Any, Optional import numpy as np import pandas as pd @@ -33,7 +33,7 @@ from superset.utils import core as utils logger = logging.getLogger(__name__) -def dedup(l: List[str], suffix: str = "__", case_sensitive: bool = True) -> List[str]: +def dedup(l: list[str], suffix: str = "__", case_sensitive: bool = True) -> list[str]: """De-duplicates a list of string by suffixing a counter Always returns the same number of entries as provided, and always returns @@ -46,8 +46,8 @@ def dedup(l: List[str], suffix: str = "__", case_sensitive: bool = True) -> List ) foo,bar,bar__1,bar__2,Bar__3 """ - new_l: List[str] = [] - seen: Dict[str, int] = {} + new_l: list[str] = [] + seen: dict[str, int] = {} for item in l: s_fixed_case = item if case_sensitive else item.lower() if s_fixed_case in seen: @@ -104,14 +104,14 @@ class SupersetResultSet: self, data: DbapiResult, cursor_description: DbapiDescription, - db_engine_spec: Type[BaseEngineSpec], + db_engine_spec: type[BaseEngineSpec], ): self.db_engine_spec = db_engine_spec data = data or [] - column_names: List[str] = [] - pa_data: List[pa.Array] = [] - deduped_cursor_desc: List[Tuple[Any, ...]] = [] - numpy_dtype: List[Tuple[str, ...]] = [] + column_names: list[str] = [] + pa_data: list[pa.Array] = [] + deduped_cursor_desc: list[tuple[Any, ...]] = [] + numpy_dtype: list[tuple[str, ...]] = [] stringified_arr: NDArray[Any] if cursor_description: @@ -181,7 +181,7 @@ class SupersetResultSet: column_names = [] self.table = pa.Table.from_arrays(pa_data, names=column_names) - self._type_dict: Dict[str, Any] = {} + self._type_dict: dict[str, Any] = {} try: # The driver may not be passing a cursor.description self._type_dict = { @@ -245,7 +245,7 @@ class SupersetResultSet: return self.table.num_rows @property - def columns(self) -> List[ResultSetColumnType]: + def columns(self) -> list[ResultSetColumnType]: if not self.table.column_names: return [] diff --git a/superset/row_level_security/commands/bulk_delete.py b/superset/row_level_security/commands/bulk_delete.py index a6d4625a91..a3703346cc 100644 --- a/superset/row_level_security/commands/bulk_delete.py +++ b/superset/row_level_security/commands/bulk_delete.py @@ -16,7 +16,6 @@ # under the License. import logging -from typing import List from superset.commands.base import BaseCommand from superset.dao.exceptions import DAODeleteFailedError @@ -31,9 +30,9 @@ logger = logging.getLogger(__name__) class BulkDeleteRLSRuleCommand(BaseCommand): - def __init__(self, model_ids: List[int]): + def __init__(self, model_ids: list[int]): self._model_ids = model_ids - self._models: List[ReportSchedule] = [] + self._models: list[ReportSchedule] = [] def run(self) -> None: self.validate() diff --git a/superset/row_level_security/commands/create.py b/superset/row_level_security/commands/create.py index 0c348e10c0..5552feeda0 100644 --- a/superset/row_level_security/commands/create.py +++ b/superset/row_level_security/commands/create.py @@ -17,7 +17,7 @@ import logging -from typing import Any, Dict +from typing import Any from superset.commands.base import BaseCommand from superset.commands.exceptions import DatasourceNotFoundValidationError @@ -31,7 +31,7 @@ logger = logging.getLogger(__name__) class CreateRLSRuleCommand(BaseCommand): - def __init__(self, data: Dict[str, Any]): + def __init__(self, data: dict[str, Any]): self._properties = data.copy() self._tables = self._properties.get("tables", []) self._roles = self._properties.get("roles", []) diff --git a/superset/row_level_security/commands/update.py b/superset/row_level_security/commands/update.py index 8c276ee2c4..a206fc3a39 100644 --- a/superset/row_level_security/commands/update.py +++ b/superset/row_level_security/commands/update.py @@ -17,7 +17,7 @@ import logging -from typing import Any, Dict, Optional +from typing import Any, Optional from superset.commands.base import BaseCommand from superset.commands.exceptions import DatasourceNotFoundValidationError @@ -32,7 +32,7 @@ logger = logging.getLogger(__name__) class UpdateRLSRuleCommand(BaseCommand): - def __init__(self, model_id: int, data: Dict[str, Any]): + def __init__(self, model_id: int, data: dict[str, Any]): self._model_id = model_id self._properties = data.copy() self._tables = self._properties.get("tables", []) diff --git a/superset/security/api.py b/superset/security/api.py index 7aac6ae22b..aff536519d 100644 --- a/superset/security/api.py +++ b/superset/security/api.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict +from typing import Any from flask import request, Response from flask_appbuilder import expose @@ -56,8 +56,8 @@ class ResourceSchema(PermissiveSchema): @post_load def convert_enum_to_value( # pylint: disable=no-self-use - self, data: Dict[str, Any], **kwargs: Any # pylint: disable=unused-argument - ) -> Dict[str, Any]: + self, data: dict[str, Any], **kwargs: Any # pylint: disable=unused-argument + ) -> dict[str, Any]: # we don't care about the enum, we want the value inside data["type"] = data["type"].value return data diff --git a/superset/security/guest_token.py b/superset/security/guest_token.py index 44b59c1dbb..a8dc2e3393 100644 --- a/superset/security/guest_token.py +++ b/superset/security/guest_token.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from enum import Enum -from typing import List, Optional, TypedDict, Union +from typing import Optional, TypedDict, Union from flask_appbuilder.security.sqla.models import Role from flask_login import AnonymousUserMixin @@ -36,7 +36,7 @@ class GuestTokenResource(TypedDict): id: Union[str, int] -GuestTokenResources = List[GuestTokenResource] +GuestTokenResources = list[GuestTokenResource] class GuestTokenRlsRule(TypedDict): @@ -49,7 +49,7 @@ class GuestToken(TypedDict): exp: float user: GuestTokenUser resources: GuestTokenResources - rls_rules: List[GuestTokenRlsRule] + rls_rules: list[GuestTokenRlsRule] class GuestUser(AnonymousUserMixin): @@ -76,7 +76,7 @@ class GuestUser(AnonymousUserMixin): """ return False - def __init__(self, token: GuestToken, roles: List[Role]): + def __init__(self, token: GuestToken, roles: list[Role]): user = token["user"] self.guest_token = token self.username = user.get("username", "guest_user") diff --git a/superset/security/manager.py b/superset/security/manager.py index db6e631d91..94a731a3ff 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -20,18 +20,7 @@ import logging import re import time from collections import defaultdict -from typing import ( - Any, - Callable, - cast, - Dict, - List, - NamedTuple, - Optional, - Set, - TYPE_CHECKING, - Union, -) +from typing import Any, Callable, cast, NamedTuple, Optional, TYPE_CHECKING, Union from flask import current_app, Flask, g, Request from flask_appbuilder import Model @@ -479,7 +468,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods ) def get_table_access_error_msg( # pylint: disable=no-self-use - self, tables: Set["Table"] + self, tables: set["Table"] ) -> str: """ Return the error message for the denied SQL tables. @@ -492,7 +481,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods return f"""You need access to the following tables: {", ".join(quoted_tables)}, `all_database_access` or `all_datasource_access` permission""" - def get_table_access_error_object(self, tables: Set["Table"]) -> SupersetError: + def get_table_access_error_object(self, tables: set["Table"]) -> SupersetError: """ Return the error object for the denied SQL tables. @@ -510,7 +499,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods ) def get_table_access_link( # pylint: disable=unused-argument,no-self-use - self, tables: Set["Table"] + self, tables: set["Table"] ) -> Optional[str]: """ Return the access link for the denied SQL tables. @@ -521,7 +510,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods return current_app.config.get("PERMISSION_INSTRUCTIONS_LINK") - def get_user_datasources(self) -> List["BaseDatasource"]: + def get_user_datasources(self) -> list["BaseDatasource"]: """ Collect datasources which the user has explicit permissions to. @@ -542,7 +531,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods # group all datasources by database session = self.get_session all_datasources = SqlaTable.get_all_datasources(session) - datasources_by_database: Dict["Database", Set["SqlaTable"]] = defaultdict(set) + datasources_by_database: dict["Database", set["SqlaTable"]] = defaultdict(set) for datasource in all_datasources: datasources_by_database[datasource.database].add(datasource) @@ -569,7 +558,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods return True - def user_view_menu_names(self, permission_name: str) -> Set[str]: + def user_view_menu_names(self, permission_name: str) -> set[str]: base_query = ( self.get_session.query(self.viewmenu_model.name) .join(self.permissionview_model) @@ -599,7 +588,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods return {s.name for s in view_menu_names} return set() - def get_accessible_databases(self) -> List[int]: + def get_accessible_databases(self) -> list[int]: """ Return the list of databases accessible by the user. @@ -613,8 +602,8 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods ] def get_schemas_accessible_by_user( - self, database: "Database", schemas: List[str], hierarchical: bool = True - ) -> List[str]: + self, database: "Database", schemas: list[str], hierarchical: bool = True + ) -> list[str]: """ Return the list of SQL schemas accessible by the user. @@ -654,9 +643,9 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods def get_datasources_accessible_by_user( # pylint: disable=invalid-name self, database: "Database", - datasource_names: List[DatasourceName], + datasource_names: list[DatasourceName], schema: Optional[str] = None, - ) -> List[DatasourceName]: + ) -> list[DatasourceName]: """ Return the list of SQL tables accessible by the user. @@ -802,7 +791,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods self.get_session.commit() self.clean_perms() - def _get_pvms_from_builtin_role(self, role_name: str) -> List[PermissionView]: + def _get_pvms_from_builtin_role(self, role_name: str) -> list[PermissionView]: """ Gets a list of model PermissionView permissions inferred from a builtin role definition @@ -821,7 +810,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods role_from_permissions.append(pvm) return role_from_permissions - def find_roles_by_id(self, role_ids: List[int]) -> List[Role]: + def find_roles_by_id(self, role_ids: list[int]) -> list[Role]: """ Find a List of models by a list of ids, if defined applies `base_filter` """ @@ -1179,7 +1168,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods connection: Connection, old_database_name: str, target: "Database", - ) -> List[ViewMenu]: + ) -> list[ViewMenu]: """ Helper method that Updates all datasource access permission when a database name changes. @@ -1205,7 +1194,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods .filter(SqlaTable.database_id == target.id) .all() ) - updated_view_menus: List[ViewMenu] = [] + updated_view_menus: list[ViewMenu] = [] for dataset in datasets: old_dataset_vm_name = self.get_dataset_perm( dataset.id, dataset.table_name, old_database_name @@ -1768,7 +1757,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods """ @staticmethod - def get_exclude_users_from_lists() -> List[str]: + def get_exclude_users_from_lists() -> list[str]: """ Override to dynamically identify a list of usernames to exclude from all UI dropdown lists, owners, created_by filters etc... @@ -1896,7 +1885,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods def get_anonymous_user(self) -> User: # pylint: disable=no-self-use return AnonymousUserMixin() - def get_user_roles(self, user: Optional[User] = None) -> List[Role]: + def get_user_roles(self, user: Optional[User] = None) -> list[Role]: if not user: user = g.user if user.is_anonymous: @@ -1906,7 +1895,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods def get_guest_rls_filters( self, dataset: "BaseDatasource" - ) -> List[GuestTokenRlsRule]: + ) -> list[GuestTokenRlsRule]: """ Retrieves the row level security filters for the current user and the dataset, if the user is authenticated with a guest token. @@ -1922,7 +1911,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods ] return [] - def get_rls_filters(self, table: "BaseDatasource") -> List[SqlaQuery]: + def get_rls_filters(self, table: "BaseDatasource") -> list[SqlaQuery]: """ Retrieves the appropriate row level security filters for the current user and the passed table. @@ -1990,7 +1979,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods ) return query.all() - def get_rls_ids(self, table: "BaseDatasource") -> List[int]: + def get_rls_ids(self, table: "BaseDatasource") -> list[int]: """ Retrieves the appropriate row level security filters IDs for the current user and the passed table. @@ -2002,10 +1991,10 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods ids.sort() # Combinations rather than permutations return ids - def get_guest_rls_filters_str(self, table: "BaseDatasource") -> List[str]: + def get_guest_rls_filters_str(self, table: "BaseDatasource") -> list[str]: return [f.get("clause", "") for f in self.get_guest_rls_filters(table)] - def get_rls_cache_key(self, datasource: "BaseDatasource") -> List[str]: + def get_rls_cache_key(self, datasource: "BaseDatasource") -> list[str]: rls_ids = [] if datasource.is_rls_supported: rls_ids = self.get_rls_ids(datasource) @@ -2122,7 +2111,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods self, user: GuestTokenUser, resources: GuestTokenResources, - rls: List[GuestTokenRlsRule], + rls: list[GuestTokenRlsRule], ) -> bytes: secret = current_app.config["GUEST_TOKEN_JWT_SECRET"] algo = current_app.config["GUEST_TOKEN_JWT_ALGO"] @@ -2183,7 +2172,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods roles=[self.find_role(current_app.config["GUEST_ROLE_NAME"])], ) - def parse_jwt_guest_token(self, raw_token: str) -> Dict[str, Any]: + def parse_jwt_guest_token(self, raw_token: str) -> dict[str, Any]: """ Parses a guest token. Raises an error if the jwt fails standard claims checks. :param raw_token: the token gotten from the request diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 9ea881fadf..678da79fa7 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -20,7 +20,7 @@ import uuid from contextlib import closing from datetime import datetime from sys import getsizeof -from typing import Any, cast, Dict, List, Optional, Tuple, Union +from typing import Any, cast, Optional, Union import backoff import msgpack @@ -88,9 +88,9 @@ def handle_query_error( ex: Exception, query: Query, session: Session, - payload: Optional[Dict[str, Any]] = None, + payload: Optional[dict[str, Any]] = None, prefix_message: str = "", -) -> Dict[str, Any]: +) -> dict[str, Any]: """Local method handling error while processing the SQL""" payload = payload or {} msg = f"{prefix_message} {str(ex)}".strip() @@ -122,7 +122,7 @@ def handle_query_error( return payload -def get_query_backoff_handler(details: Dict[Any, Any]) -> None: +def get_query_backoff_handler(details: dict[Any, Any]) -> None: query_id = details["kwargs"]["query_id"] logger.error( "Query with id `%s` could not be retrieved", str(query_id), exc_info=True @@ -168,8 +168,8 @@ def get_sql_results( # pylint: disable=too-many-arguments username: Optional[str] = None, start_time: Optional[float] = None, expand_data: bool = False, - log_params: Optional[Dict[str, Any]] = None, -) -> Optional[Dict[str, Any]]: + log_params: Optional[dict[str, Any]] = None, +) -> Optional[dict[str, Any]]: """Executes the sql query returns the results.""" with session_scope(not ctask.request.called_directly) as session: with override_user(security_manager.find_user(username)): @@ -196,7 +196,7 @@ def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-statem query: Query, session: Session, cursor: Any, - log_params: Optional[Dict[str, Any]], + log_params: Optional[dict[str, Any]], apply_ctas: bool = False, ) -> SupersetResultSet: """Executes a single SQL statement""" @@ -332,7 +332,7 @@ def apply_limit_if_exists( def _serialize_payload( - payload: Dict[Any, Any], use_msgpack: Optional[bool] = False + payload: dict[Any, Any], use_msgpack: Optional[bool] = False ) -> Union[bytes, str]: logger.debug("Serializing to msgpack: %r", use_msgpack) if use_msgpack: @@ -346,10 +346,10 @@ def _serialize_and_expand_data( db_engine_spec: BaseEngineSpec, use_msgpack: Optional[bool] = False, expand_data: bool = False, -) -> Tuple[Union[bytes, str], List[Any], List[Any], List[Any]]: +) -> tuple[Union[bytes, str], list[Any], list[Any], list[Any]]: selected_columns = result_set.columns - all_columns: List[Any] - expanded_columns: List[Any] + all_columns: list[Any] + expanded_columns: list[Any] if use_msgpack: with stats_timing( @@ -383,15 +383,15 @@ def execute_sql_statements( session: Session, start_time: Optional[float], expand_data: bool, - log_params: Optional[Dict[str, Any]], -) -> Optional[Dict[str, Any]]: + log_params: Optional[dict[str, Any]], +) -> Optional[dict[str, Any]]: """Executes the sql query returns the results.""" if store_results and start_time: # only asynchronous queries stats_logger.timing("sqllab.query.time_pending", now_as_float() - start_time) query = get_query(query_id, session) - payload: Dict[str, Any] = dict(query_id=query_id) + payload: dict[str, Any] = dict(query_id=query_id) database = query.database db_engine_spec = database.db_engine_spec db_engine_spec.patch() diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 034daeb7af..974d7eacd4 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -16,9 +16,10 @@ # under the License. import logging import re +from collections.abc import Iterator from dataclasses import dataclass from enum import Enum -from typing import Any, cast, Iterator, List, Optional, Set, Tuple +from typing import Any, cast, Optional from urllib import parse import sqlparse @@ -97,7 +98,7 @@ def _extract_limit_from_query(statement: TokenList) -> Optional[int]: def extract_top_from_query( - statement: TokenList, top_keywords: Set[str] + statement: TokenList, top_keywords: set[str] ) -> Optional[int]: """ Extract top clause value from SQL statement. @@ -122,7 +123,7 @@ def extract_top_from_query( return top -def get_cte_remainder_query(sql: str) -> Tuple[Optional[str], str]: +def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]: """ parse the SQL and return the CTE and rest of the block to the caller @@ -192,8 +193,8 @@ class ParsedQuery: sql_statement = sqlparse.format(sql_statement, strip_comments=True) self.sql: str = sql_statement - self._tables: Set[Table] = set() - self._alias_names: Set[str] = set() + self._tables: set[Table] = set() + self._alias_names: set[str] = set() self._limit: Optional[int] = None logger.debug("Parsing with sqlparse statement: %s", self.sql) @@ -202,7 +203,7 @@ class ParsedQuery: self._limit = _extract_limit_from_query(statement) @property - def tables(self) -> Set[Table]: + def tables(self) -> set[Table]: if not self._tables: for statement in self._parsed: self._extract_from_token(statement) @@ -282,7 +283,7 @@ class ParsedQuery: def strip_comments(self) -> str: return sqlparse.format(self.stripped(), strip_comments=True) - def get_statements(self) -> List[str]: + def get_statements(self) -> list[str]: """Returns a list of SQL statements as strings, stripped""" statements = [] for statement in self._parsed: @@ -737,7 +738,7 @@ RE_JINJA_BLOCK = re.compile(r"\{[%#][^\{\}%#]+[%#]\}") def extract_table_references( sql_text: str, sqla_dialect: str, show_warning: bool = True -) -> Set["Table"]: +) -> set["Table"]: """ Return all the dependencies from a SQL sql_text. """ diff --git a/superset/sql_validators/__init__.py b/superset/sql_validators/__init__.py index c448f696a1..ad048a86a5 100644 --- a/superset/sql_validators/__init__.py +++ b/superset/sql_validators/__init__.py @@ -14,13 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Optional, Type +from typing import Optional from . import base, postgres, presto_db from .base import SQLValidationAnnotation -def get_validator_by_name(name: str) -> Optional[Type[base.BaseSQLValidator]]: +def get_validator_by_name(name: str) -> Optional[type[base.BaseSQLValidator]]: return { "PrestoDBSQLValidator": presto_db.PrestoDBSQLValidator, "PostgreSQLValidator": postgres.PostgreSQLValidator, diff --git a/superset/sql_validators/base.py b/superset/sql_validators/base.py index de29a96e8e..8344fc9264 100644 --- a/superset/sql_validators/base.py +++ b/superset/sql_validators/base.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Optional +from typing import Any, Optional from superset.models.core import Database @@ -34,7 +34,7 @@ class SQLValidationAnnotation: # pylint: disable=too-few-public-methods self.start_column = start_column self.end_column = end_column - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """Return a dictionary representation of this annotation""" return { "line_number": self.line_number, @@ -53,6 +53,6 @@ class BaseSQLValidator: # pylint: disable=too-few-public-methods @classmethod def validate( cls, sql: str, schema: Optional[str], database: Database - ) -> List[SQLValidationAnnotation]: + ) -> list[SQLValidationAnnotation]: """Check that the given SQL querystring is valid for the given engine""" raise NotImplementedError diff --git a/superset/sql_validators/postgres.py b/superset/sql_validators/postgres.py index f62be39f03..60c15ca034 100644 --- a/superset/sql_validators/postgres.py +++ b/superset/sql_validators/postgres.py @@ -16,7 +16,7 @@ # under the License. import re -from typing import List, Optional +from typing import Optional from pgsanity.pgsanity import check_string @@ -32,8 +32,8 @@ class PostgreSQLValidator(BaseSQLValidator): # pylint: disable=too-few-public-m @classmethod def validate( cls, sql: str, schema: Optional[str], database: Database - ) -> List[SQLValidationAnnotation]: - annotations: List[SQLValidationAnnotation] = [] + ) -> list[SQLValidationAnnotation]: + annotations: list[SQLValidationAnnotation] = [] valid, error = check_string(sql, add_semicolon=True) if valid: return annotations diff --git a/superset/sql_validators/presto_db.py b/superset/sql_validators/presto_db.py index c5ecf4c96e..9d3e7641a6 100644 --- a/superset/sql_validators/presto_db.py +++ b/superset/sql_validators/presto_db.py @@ -18,7 +18,7 @@ import logging import time from contextlib import closing -from typing import Any, Dict, List, Optional +from typing import Any, Optional from superset import app, security_manager from superset.models.core import Database @@ -109,7 +109,7 @@ class PrestoDBSQLValidator(BaseSQLValidator): raise PrestoSQLValidationError( "The pyhive presto client returned an unhandled " "database error." ) from db_error - error_args: Dict[str, Any] = db_error.args[0] + error_args: dict[str, Any] = db_error.args[0] # Confirm the two fields we need to be able to present an annotation # are present in the error response -- a message, and a location. @@ -148,7 +148,7 @@ class PrestoDBSQLValidator(BaseSQLValidator): @classmethod def validate( cls, sql: str, schema: Optional[str], database: Database - ) -> List[SQLValidationAnnotation]: + ) -> list[SQLValidationAnnotation]: """ Presto supports query-validation queries by running them with a prepended explain. @@ -167,7 +167,7 @@ class PrestoDBSQLValidator(BaseSQLValidator): ) as engine: # Sharing a single connection and cursor across the # execution of all statements (if many) - annotations: List[SQLValidationAnnotation] = [] + annotations: list[SQLValidationAnnotation] = [] with closing(engine.raw_connection()) as conn: cursor = conn.cursor() for statement in parsed_query.get_statements(): diff --git a/superset/sqllab/api.py b/superset/sqllab/api.py index 3c24bf1c26..35d110d8fc 100644 --- a/superset/sqllab/api.py +++ b/superset/sqllab/api.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, cast, Dict, Optional +from typing import Any, cast, Optional from urllib import parse import simplejson as json @@ -326,7 +326,7 @@ class SqlLabRestApi(BaseSupersetApi): @staticmethod def _create_sql_json_command( - execution_context: SqlJsonExecutionContext, log_params: Optional[Dict[str, Any]] + execution_context: SqlJsonExecutionContext, log_params: Optional[dict[str, Any]] ) -> ExecuteSqlCommand: query_dao = QueryDAO() sql_json_executor = SqlLabRestApi._create_sql_json_executor( diff --git a/superset/sqllab/commands/estimate.py b/superset/sqllab/commands/estimate.py index 2b8c5814b9..bf1d6c4fa5 100644 --- a/superset/sqllab/commands/estimate.py +++ b/superset/sqllab/commands/estimate.py @@ -17,7 +17,7 @@ from __future__ import annotations import logging -from typing import Any, Dict, List +from typing import Any from flask_babel import gettext as __ @@ -40,7 +40,7 @@ logger = logging.getLogger(__name__) class QueryEstimationCommand(BaseCommand): _database_id: int _sql: str - _template_params: Dict[str, Any] + _template_params: dict[str, Any] _schema: str _database: Database @@ -64,7 +64,7 @@ class QueryEstimationCommand(BaseCommand): def run( self, - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: self.validate() sql = self._sql @@ -96,7 +96,7 @@ class QueryEstimationCommand(BaseCommand): ) from ex spec = self._database.db_engine_spec - query_cost_formatters: Dict[str, Any] = app.config[ + query_cost_formatters: dict[str, Any] = app.config[ "QUERY_COST_FORMATTERS_BY_ENGINE" ] query_cost_formatter = query_cost_formatters.get( diff --git a/superset/sqllab/commands/execute.py b/superset/sqllab/commands/execute.py index 97c8514d5d..09b0769ce2 100644 --- a/superset/sqllab/commands/execute.py +++ b/superset/sqllab/commands/execute.py @@ -19,7 +19,7 @@ from __future__ import annotations import copy import logging -from typing import Any, Dict, Optional, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from flask_babel import gettext as __ @@ -51,7 +51,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -CommandResult = Dict[str, Any] +CommandResult = dict[str, Any] class ExecuteSqlCommand(BaseCommand): @@ -63,7 +63,7 @@ class ExecuteSqlCommand(BaseCommand): _sql_json_executor: SqlJsonExecutor _execution_context_convertor: ExecutionContextConvertor _sqllab_ctas_no_limit: bool - _log_params: Optional[Dict[str, Any]] = None + _log_params: dict[str, Any] | None = None def __init__( self, @@ -75,7 +75,7 @@ class ExecuteSqlCommand(BaseCommand): sql_json_executor: SqlJsonExecutor, execution_context_convertor: ExecutionContextConvertor, sqllab_ctas_no_limit_flag: bool, - log_params: Optional[Dict[str, Any]] = None, + log_params: dict[str, Any] | None = None, ) -> None: self._execution_context = execution_context self._query_dao = query_dao @@ -122,7 +122,7 @@ class ExecuteSqlCommand(BaseCommand): except Exception as ex: raise SqlLabException(self._execution_context, exception=ex) from ex - def _try_get_existing_query(self) -> Optional[Query]: + def _try_get_existing_query(self) -> Query | None: return self._query_dao.find_one_or_none( client_id=self._execution_context.client_id, user_id=self._execution_context.user_id, @@ -130,7 +130,7 @@ class ExecuteSqlCommand(BaseCommand): ) @classmethod - def is_query_handled(cls, query: Optional[Query]) -> bool: + def is_query_handled(cls, query: Query | None) -> bool: return query is not None and query.status in [ QueryStatus.RUNNING, QueryStatus.PENDING, @@ -166,7 +166,7 @@ class ExecuteSqlCommand(BaseCommand): return mydb @classmethod - def _validate_query_db(cls, database: Optional[Database]) -> None: + def _validate_query_db(cls, database: Database | None) -> None: if not database: raise SupersetGenericErrorException( __( diff --git a/superset/sqllab/commands/export.py b/superset/sqllab/commands/export.py index e9559be3b9..1b9b0e0344 100644 --- a/superset/sqllab/commands/export.py +++ b/superset/sqllab/commands/export.py @@ -17,7 +17,7 @@ from __future__ import annotations import logging -from typing import Any, cast, List, TypedDict +from typing import Any, cast, TypedDict import pandas as pd from flask_babel import gettext as __ @@ -40,7 +40,7 @@ logger = logging.getLogger(__name__) class SqlExportResult(TypedDict): query: Query count: int - data: List[Any] + data: list[Any] class SqlResultExportCommand(BaseCommand): diff --git a/superset/sqllab/commands/results.py b/superset/sqllab/commands/results.py index d6c415a09f..83c8aa8f6a 100644 --- a/superset/sqllab/commands/results.py +++ b/superset/sqllab/commands/results.py @@ -17,7 +17,7 @@ from __future__ import annotations import logging -from typing import Any, cast, Dict, Optional +from typing import Any, cast from flask_babel import gettext as __ @@ -40,14 +40,14 @@ logger = logging.getLogger(__name__) class SqlExecutionResultsCommand(BaseCommand): _key: str - _rows: Optional[int] + _rows: int | None _blob: Any _query: Query def __init__( self, key: str, - rows: Optional[int] = None, + rows: int | None = None, ) -> None: self._key = key self._rows = rows @@ -100,7 +100,7 @@ class SqlExecutionResultsCommand(BaseCommand): def run( self, - ) -> Dict[str, Any]: + ) -> dict[str, Any]: """Runs arbitrary sql and returns data as json""" self.validate() payload = utils.zlib_decompress( diff --git a/superset/sqllab/exceptions.py b/superset/sqllab/exceptions.py index 70e4fa9752..f06cc8dd2e 100644 --- a/superset/sqllab/exceptions.py +++ b/superset/sqllab/exceptions.py @@ -17,7 +17,7 @@ from __future__ import annotations import os -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING from flask_babel import lazy_gettext as _ @@ -31,15 +31,15 @@ if TYPE_CHECKING: class SqlLabException(SupersetException): sql_json_execution_context: SqlJsonExecutionContext failed_reason_msg: str - suggestion_help_msg: Optional[str] + suggestion_help_msg: str | None def __init__( # pylint: disable=too-many-arguments self, sql_json_execution_context: SqlJsonExecutionContext, - error_type: Optional[SupersetErrorType] = None, - reason_message: Optional[str] = None, - exception: Optional[Exception] = None, - suggestion_help_msg: Optional[str] = None, + error_type: SupersetErrorType | None = None, + reason_message: str | None = None, + exception: Exception | None = None, + suggestion_help_msg: str | None = None, ) -> None: self.sql_json_execution_context = sql_json_execution_context self.failed_reason_msg = self._get_reason(reason_message, exception) @@ -68,21 +68,21 @@ class SqlLabException(SupersetException): if self.failed_reason_msg: msg = msg + self.failed_reason_msg if self.suggestion_help_msg is not None: - msg = "{} {} {}".format(msg, os.linesep, self.suggestion_help_msg) + msg = f"{msg} {os.linesep} {self.suggestion_help_msg}" return msg @classmethod def _get_reason( - cls, reason_message: Optional[str] = None, exception: Optional[Exception] = None + cls, reason_message: str | None = None, exception: Exception | None = None ) -> str: if reason_message is not None: - return ": {}".format(reason_message) + return f": {reason_message}" if exception is not None: if hasattr(exception, "get_message"): - return ": {}".format(exception.get_message()) + return f": {exception.get_message()}" if hasattr(exception, "message"): - return ": {}".format(exception.message) - return ": {}".format(str(exception)) + return f": {exception.message}" + return f": {str(exception)}" return "" @@ -93,7 +93,7 @@ class QueryIsForbiddenToAccessException(SqlLabException): def __init__( self, sql_json_execution_context: SqlJsonExecutionContext, - exception: Optional[Exception] = None, + exception: Exception | None = None, ) -> None: super().__init__( sql_json_execution_context, diff --git a/superset/sqllab/execution_context_convertor.py b/superset/sqllab/execution_context_convertor.py index f49fbd9a31..430db0d52f 100644 --- a/superset/sqllab/execution_context_convertor.py +++ b/superset/sqllab/execution_context_convertor.py @@ -17,7 +17,7 @@ from __future__ import annotations import logging -from typing import Any, Dict, TYPE_CHECKING +from typing import Any, TYPE_CHECKING import simplejson as json @@ -36,7 +36,7 @@ if TYPE_CHECKING: class ExecutionContextConvertor: _max_row_in_display_configuration: int # pylint: disable=invalid-name _exc_status: SqlJsonExecutionStatus - payload: Dict[str, Any] + payload: dict[str, Any] def set_max_row_in_display(self, value: int) -> None: self._max_row_in_display_configuration = value # pylint: disable=invalid-name diff --git a/superset/sqllab/query_render.py b/superset/sqllab/query_render.py index 1369e78db1..db1adf43ba 100644 --- a/superset/sqllab/query_render.py +++ b/superset/sqllab/query_render.py @@ -17,7 +17,7 @@ # pylint: disable=invalid-name, no-self-use, too-few-public-methods, too-many-arguments from __future__ import annotations -from typing import Any, Callable, Dict, Optional, TYPE_CHECKING +from typing import Any, Callable, TYPE_CHECKING from flask_babel import gettext as __, ngettext from jinja2 import TemplateError @@ -124,16 +124,16 @@ class SqlQueryRenderImpl(SqlQueryRender): class SqlQueryRenderException(SqlLabException): - _extra: Optional[Dict[str, Any]] + _extra: dict[str, Any] | None def __init__( self, sql_json_execution_context: SqlJsonExecutionContext, error_type: SupersetErrorType, - reason_message: Optional[str] = None, - exception: Optional[Exception] = None, - suggestion_help_msg: Optional[str] = None, - extra: Optional[Dict[str, Any]] = None, + reason_message: str | None = None, + exception: Exception | None = None, + suggestion_help_msg: str | None = None, + extra: dict[str, Any] | None = None, ) -> None: super().__init__( sql_json_execution_context, @@ -145,10 +145,10 @@ class SqlQueryRenderException(SqlLabException): self._extra = extra @property - def extra(self) -> Optional[Dict[str, Any]]: + def extra(self) -> dict[str, Any] | None: return self._extra - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: rv = super().to_dict() if self._extra: rv["extra"] = self._extra diff --git a/superset/sqllab/sql_json_executer.py b/superset/sqllab/sql_json_executer.py index e4e6b60654..124f477e96 100644 --- a/superset/sqllab/sql_json_executer.py +++ b/superset/sqllab/sql_json_executer.py @@ -20,7 +20,7 @@ from __future__ import annotations import dataclasses import logging from abc import ABC -from typing import Any, Callable, Dict, Optional, TYPE_CHECKING +from typing import Any, Callable, TYPE_CHECKING from flask_babel import gettext as __ @@ -43,7 +43,7 @@ if TYPE_CHECKING: QueryStatus = utils.QueryStatus logger = logging.getLogger(__name__) -SqlResults = Dict[str, Any] +SqlResults = dict[str, Any] GetSqlResultsTask = Callable[..., SqlResults] @@ -53,7 +53,7 @@ class SqlJsonExecutor: self, execution_context: SqlJsonExecutionContext, rendered_query: str, - log_params: Optional[Dict[str, Any]], + log_params: dict[str, Any] | None, ) -> SqlJsonExecutionStatus: raise NotImplementedError() @@ -88,7 +88,7 @@ class SynchronousSqlJsonExecutor(SqlJsonExecutorBase): self, execution_context: SqlJsonExecutionContext, rendered_query: str, - log_params: Optional[Dict[str, Any]], + log_params: dict[str, Any] | None, ) -> SqlJsonExecutionStatus: query_id = execution_context.query.id try: @@ -120,8 +120,8 @@ class SynchronousSqlJsonExecutor(SqlJsonExecutorBase): self, execution_context: SqlJsonExecutionContext, rendered_query: str, - log_params: Optional[Dict[str, Any]], - ) -> Optional[SqlResults]: + log_params: dict[str, Any] | None, + ) -> SqlResults | None: with utils.timeout( seconds=self._timeout_duration_in_seconds, error_message=self._get_timeout_error_msg(), @@ -132,8 +132,8 @@ class SynchronousSqlJsonExecutor(SqlJsonExecutorBase): self, execution_context: SqlJsonExecutionContext, rendered_query: str, - log_params: Optional[Dict[str, Any]], - ) -> Optional[SqlResults]: + log_params: dict[str, Any] | None, + ) -> SqlResults | None: return self._get_sql_results_task( execution_context.query.id, rendered_query, @@ -161,7 +161,7 @@ class ASynchronousSqlJsonExecutor(SqlJsonExecutorBase): self, execution_context: SqlJsonExecutionContext, rendered_query: str, - log_params: Optional[Dict[str, Any]], + log_params: dict[str, Any] | None, ) -> SqlJsonExecutionStatus: query_id = execution_context.query.id logger.info("Query %i: Running query on a Celery worker", query_id) diff --git a/superset/sqllab/sqllab_execution_context.py b/superset/sqllab/sqllab_execution_context.py index 644c978b32..22277804ee 100644 --- a/superset/sqllab/sqllab_execution_context.py +++ b/superset/sqllab/sqllab_execution_context.py @@ -19,7 +19,7 @@ from __future__ import annotations import json import logging from dataclasses import dataclass -from typing import Any, cast, Dict, Optional, TYPE_CHECKING +from typing import Any, cast, TYPE_CHECKING from flask import g from sqlalchemy.orm.exc import DetachedInstanceError @@ -37,7 +37,7 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -SqlResults = Dict[str, Any] +SqlResults = dict[str, Any] @dataclass @@ -45,7 +45,7 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes database_id: int schema: str sql: str - template_params: Dict[str, Any] + template_params: dict[str, Any] async_flag: bool limit: int status: str @@ -53,14 +53,14 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes client_id_or_short_id: str sql_editor_id: str tab_name: str - user_id: Optional[int] + user_id: int | None expand_data: bool - create_table_as_select: Optional[CreateTableAsSelect] - database: Optional[Database] + create_table_as_select: CreateTableAsSelect | None + database: Database | None query: Query - _sql_result: Optional[SqlResults] + _sql_result: SqlResults | None - def __init__(self, query_params: Dict[str, Any]): + def __init__(self, query_params: dict[str, Any]): self.create_table_as_select = None self.database = None self._init_from_query_params(query_params) @@ -70,7 +70,7 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes def set_query(self, query: Query) -> None: self.query = query - def _init_from_query_params(self, query_params: Dict[str, Any]) -> None: + def _init_from_query_params(self, query_params: dict[str, Any]) -> None: self.database_id = cast(int, query_params.get("database_id")) self.schema = cast(str, query_params.get("schema")) self.sql = cast(str, query_params.get("sql")) @@ -90,7 +90,7 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes ) @staticmethod - def _get_template_params(query_params: Dict[str, Any]) -> Dict[str, Any]: + def _get_template_params(query_params: dict[str, Any]) -> dict[str, Any]: try: template_params = json.loads(query_params.get("templateParams") or "{}") except json.JSONDecodeError: @@ -102,7 +102,7 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes return template_params @staticmethod - def _get_limit_param(query_params: Dict[str, Any]) -> int: + def _get_limit_param(query_params: dict[str, Any]) -> int: limit = apply_max_row_limit(query_params.get("queryLimit") or 0) if limit < 0: logger.warning( @@ -125,7 +125,7 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes schema_name = self._get_ctas_target_schema_name(database) self.create_table_as_select.target_schema_name = schema_name # type: ignore - def _get_ctas_target_schema_name(self, database: Database) -> Optional[str]: + def _get_ctas_target_schema_name(self, database: Database) -> str | None: if database.force_ctas_schema: return database.force_ctas_schema return get_cta_schema_name(database, g.user, self.schema, self.sql) @@ -134,10 +134,10 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes # TODO validate db.id is equal to self.database_id pass - def get_execution_result(self) -> Optional[SqlResults]: + def get_execution_result(self) -> SqlResults | None: return self._sql_result - def set_execution_result(self, sql_result: Optional[SqlResults]) -> None: + def set_execution_result(self, sql_result: SqlResults | None) -> None: self._sql_result = sql_result def create_query(self) -> Query: @@ -178,15 +178,15 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes try: if hasattr(self, "query"): if self.query.id: - return "query '{}' - '{}'".format(self.query.id, self.query.sql) + return f"query '{self.query.id}' - '{self.query.sql}'" except DetachedInstanceError: pass - return "query '{}'".format(self.sql) + return f"query '{self.sql}'" class CreateTableAsSelect: # pylint: disable=too-few-public-methods ctas_method: CtasMethod - target_schema_name: Optional[str] + target_schema_name: str | None target_table_name: str def __init__( @@ -197,7 +197,7 @@ class CreateTableAsSelect: # pylint: disable=too-few-public-methods self.target_table_name = target_table_name @staticmethod - def create_from(query_params: Dict[str, Any]) -> CreateTableAsSelect: + def create_from(query_params: dict[str, Any]) -> CreateTableAsSelect: ctas_method = query_params.get("ctas_method", CtasMethod.TABLE) schema = cast(str, query_params.get("schema")) tmp_table_name = cast(str, query_params.get("tmp_table_name")) diff --git a/superset/sqllab/utils.py b/superset/sqllab/utils.py index 3bcd7308a1..abceaaf136 100644 --- a/superset/sqllab/utils.py +++ b/superset/sqllab/utils.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict +from typing import Any import pyarrow as pa @@ -22,8 +22,8 @@ from superset.common.db_query_status import QueryStatus def apply_display_max_row_configuration_if_require( # pylint: disable=invalid-name - sql_results: Dict[str, Any], max_rows_in_result: int -) -> Dict[str, Any]: + sql_results: dict[str, Any], max_rows_in_result: int +) -> dict[str, Any]: """ Given a `sql_results` nested structure, applies a limit to the number of rows diff --git a/superset/stats_logger.py b/superset/stats_logger.py index 4b869042a9..fc223f7529 100644 --- a/superset/stats_logger.py +++ b/superset/stats_logger.py @@ -54,22 +54,20 @@ class DummyStatsLogger(BaseStatsLogger): logger.debug(Fore.CYAN + "[stats_logger] (incr) " + key + Style.RESET_ALL) def decr(self, key: str) -> None: - logger.debug((Fore.CYAN + "[stats_logger] (decr) " + key + Style.RESET_ALL)) + logger.debug(Fore.CYAN + "[stats_logger] (decr) " + key + Style.RESET_ALL) def timing(self, key: str, value: float) -> None: logger.debug( - (Fore.CYAN + f"[stats_logger] (timing) {key} | {value} " + Style.RESET_ALL) + Fore.CYAN + f"[stats_logger] (timing) {key} | {value} " + Style.RESET_ALL ) def gauge(self, key: str, value: float) -> None: logger.debug( - ( - Fore.CYAN - + "[stats_logger] (gauge) " - + f"{key}" - + f"{value}" - + Style.RESET_ALL - ) + Fore.CYAN + + "[stats_logger] (gauge) " + + f"{key}" + + f"{value}" + + Style.RESET_ALL ) diff --git a/superset/superset_typing.py b/superset/superset_typing.py index 8eaea54176..7c21df6a88 100644 --- a/superset/superset_typing.py +++ b/superset/superset_typing.py @@ -14,10 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from collections.abc import Sequence from datetime import datetime -from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union +from typing import Any, Literal, Optional, TYPE_CHECKING, Union -from typing_extensions import Literal, TypedDict +from typing_extensions import TypedDict from werkzeug.wrappers import Response if TYPE_CHECKING: @@ -69,8 +70,8 @@ class ResultSetColumnType(TypedDict): is_dttm: bool -CacheConfig = Dict[str, Any] -DbapiDescriptionRow = Tuple[ +CacheConfig = dict[str, Any] +DbapiDescriptionRow = tuple[ Union[str, bytes], str, Optional[str], @@ -79,27 +80,27 @@ DbapiDescriptionRow = Tuple[ Optional[int], bool, ] -DbapiDescription = Union[List[DbapiDescriptionRow], Tuple[DbapiDescriptionRow, ...]] -DbapiResult = Sequence[Union[List[Any], Tuple[Any, ...]]] +DbapiDescription = Union[list[DbapiDescriptionRow], tuple[DbapiDescriptionRow, ...]] +DbapiResult = Sequence[Union[list[Any], tuple[Any, ...]]] FilterValue = Union[bool, datetime, float, int, str] -FilterValues = Union[FilterValue, List[FilterValue], Tuple[FilterValue]] -FormData = Dict[str, Any] -Granularity = Union[str, Dict[str, Union[str, float]]] +FilterValues = Union[FilterValue, list[FilterValue], tuple[FilterValue]] +FormData = dict[str, Any] +Granularity = Union[str, dict[str, Union[str, float]]] Column = Union[AdhocColumn, str] Metric = Union[AdhocMetric, str] -OrderBy = Tuple[Metric, bool] -QueryObjectDict = Dict[str, Any] -VizData = Optional[Union[List[Any], Dict[Any, Any]]] -VizPayload = Dict[str, Any] +OrderBy = tuple[Metric, bool] +QueryObjectDict = dict[str, Any] +VizData = Optional[Union[list[Any], dict[Any, Any]]] +VizPayload = dict[str, Any] # Flask response. Base = Union[bytes, str] Status = Union[int, str] -Headers = Dict[str, Any] +Headers = dict[str, Any] FlaskResponse = Union[ Response, Base, - Tuple[Base, Status], - Tuple[Base, Status, Headers], - Tuple[Response, Status], + tuple[Base, Status], + tuple[Base, Status, Headers], + tuple[Response, Status], ] diff --git a/superset/tables/models.py b/superset/tables/models.py index 9a0c07fdcf..a24035fb97 100644 --- a/superset/tables/models.py +++ b/superset/tables/models.py @@ -24,7 +24,8 @@ addition to a table, new models for columns, metrics, and datasets were also int These models are not fully implemented, and shouldn't be used yet. """ -from typing import Any, Dict, Iterable, List, Optional, TYPE_CHECKING +from collections.abc import Iterable +from typing import Any, Optional, TYPE_CHECKING import sqlalchemy as sa from flask_appbuilder import Model @@ -87,7 +88,7 @@ class Table(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): # The relationship between datasets and columns is 1:n, but we use a # many-to-many association table to avoid adding two mutually exclusive # columns(dataset_id and table_id) to Column - columns: List[Column] = relationship( + columns: list[Column] = relationship( "Column", secondary=table_column_association_table, cascade="all, delete-orphan", @@ -96,7 +97,7 @@ class Table(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): # is loaded. backref="tables", ) - datasets: List["Dataset"] # will be populated by Dataset.tables backref + datasets: list["Dataset"] # will be populated by Dataset.tables backref # We use ``sa.Text`` for these attributes because (1) in modern databases the # performance is the same as ``VARCHAR``[1] and (2) because some table names can be @@ -130,7 +131,7 @@ class Table(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): existing_columns = {column.name: column for column in self.columns} quote_identifier = self.database.quote_identifier - def update_or_create_column(column_meta: Dict[str, Any]) -> Column: + def update_or_create_column(column_meta: dict[str, Any]) -> Column: column_name: str = column_meta["name"] if column_name in existing_columns: column = existing_columns[column_name] @@ -153,8 +154,8 @@ class Table(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): table_names: Iterable[TableName], default_schema: Optional[str] = None, sync_columns: Optional[bool] = False, - default_props: Optional[Dict[str, Any]] = None, - ) -> List["Table"]: + default_props: Optional[dict[str, Any]] = None, + ) -> list["Table"]: """ Load or create multiple Table instances. """ diff --git a/superset/tags/commands/create.py b/superset/tags/commands/create.py index 1e886e2af6..20327b54f0 100644 --- a/superset/tags/commands/create.py +++ b/superset/tags/commands/create.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import List from superset.commands.base import BaseCommand, CreateMixin from superset.dao.exceptions import DAOCreateFailedError @@ -28,7 +27,7 @@ logger = logging.getLogger(__name__) class CreateCustomTagCommand(CreateMixin, BaseCommand): - def __init__(self, object_type: ObjectTypes, object_id: int, tags: List[str]): + def __init__(self, object_type: ObjectTypes, object_id: int, tags: list[str]): self._object_type = object_type self._object_id = object_id self._tags = tags diff --git a/superset/tags/commands/delete.py b/superset/tags/commands/delete.py index acec016619..08189b5ac5 100644 --- a/superset/tags/commands/delete.py +++ b/superset/tags/commands/delete.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import List from superset.commands.base import BaseCommand from superset.dao.exceptions import DAODeleteFailedError @@ -90,7 +89,7 @@ class DeleteTaggedObjectCommand(DeleteMixin, BaseCommand): class DeleteTagsCommand(DeleteMixin, BaseCommand): - def __init__(self, tags: List[str]): + def __init__(self, tags: list[str]): self._tags = tags def run(self) -> None: diff --git a/superset/tags/dao.py b/superset/tags/dao.py index c676b4ab3c..9ea61f5c90 100644 --- a/superset/tags/dao.py +++ b/superset/tags/dao.py @@ -16,7 +16,7 @@ # under the License. import logging from operator import and_ -from typing import Any, Dict, List, Optional +from typing import Any, Optional from sqlalchemy.exc import SQLAlchemyError @@ -45,7 +45,7 @@ class TagDAO(BaseDAO): @staticmethod def create_custom_tagged_objects( - object_type: ObjectTypes, object_id: int, tag_names: List[str] + object_type: ObjectTypes, object_id: int, tag_names: list[str] ) -> None: tagged_objects = [] for name in tag_names: @@ -95,7 +95,7 @@ class TagDAO(BaseDAO): raise DAODeleteFailedError(exception=ex) from ex @staticmethod - def delete_tags(tag_names: List[str]) -> None: + def delete_tags(tag_names: list[str]) -> None: """ deletes tags from a list of tag names """ @@ -158,8 +158,8 @@ class TagDAO(BaseDAO): @staticmethod def get_tagged_objects_for_tags( - tags: Optional[List[str]] = None, obj_types: Optional[List[str]] = None - ) -> List[Dict[str, Any]]: + tags: Optional[list[str]] = None, obj_types: Optional[list[str]] = None + ) -> list[dict[str, Any]]: """ returns a list of tagged objects filtered by tag names and object types if no filters applied returns all tagged objects @@ -174,7 +174,7 @@ class TagDAO(BaseDAO): # filter types - results: List[Dict[str, Any]] = [] + results: list[dict[str, Any]] = [] # dashboards if (not obj_types) or ("dashboard" in obj_types): diff --git a/superset/tags/models.py b/superset/tags/models.py index 797308c306..bb845303ff 100644 --- a/superset/tags/models.py +++ b/superset/tags/models.py @@ -14,16 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from __future__ import ( - absolute_import, - annotations, - division, - print_function, - unicode_literals, -) +from __future__ import annotations import enum -from typing import List, Optional, TYPE_CHECKING, Union +from typing import TYPE_CHECKING from flask_appbuilder import Model from sqlalchemy import Column, Enum, ForeignKey, Integer, String @@ -122,26 +116,26 @@ def get_object_type(class_name: str) -> ObjectTypes: try: return mapping[class_name.lower()] except KeyError as ex: - raise Exception("No mapping found for {0}".format(class_name)) from ex + raise Exception(f"No mapping found for {class_name}") from ex class ObjectUpdater: - object_type: Optional[str] = None + object_type: str | None = None @classmethod def get_owners_ids( - cls, target: Union[Dashboard, FavStar, Slice, Query, SqlaTable] - ) -> List[int]: + cls, target: Dashboard | FavStar | Slice | Query | SqlaTable + ) -> list[int]: raise NotImplementedError("Subclass should implement `get_owners_ids`") @classmethod def _add_owners( cls, session: Session, - target: Union[Dashboard, FavStar, Slice, Query, SqlaTable], + target: Dashboard | FavStar | Slice | Query | SqlaTable, ) -> None: for owner_id in cls.get_owners_ids(target): - name = "owner:{0}".format(owner_id) + name = f"owner:{owner_id}" tag = get_tag(name, session, TagTypes.owner) tagged_object = TaggedObject( tag_id=tag.id, object_id=target.id, object_type=cls.object_type @@ -153,7 +147,7 @@ class ObjectUpdater: cls, _mapper: Mapper, connection: Connection, - target: Union[Dashboard, FavStar, Slice, Query, SqlaTable], + target: Dashboard | FavStar | Slice | Query | SqlaTable, ) -> None: session = Session(bind=connection) @@ -161,7 +155,7 @@ class ObjectUpdater: cls._add_owners(session, target) # add `type:` tags - tag = get_tag("type:{0}".format(cls.object_type), session, TagTypes.type) + tag = get_tag(f"type:{cls.object_type}", session, TagTypes.type) tagged_object = TaggedObject( tag_id=tag.id, object_id=target.id, object_type=cls.object_type ) @@ -174,7 +168,7 @@ class ObjectUpdater: cls, _mapper: Mapper, connection: Connection, - target: Union[Dashboard, FavStar, Slice, Query, SqlaTable], + target: Dashboard | FavStar | Slice | Query | SqlaTable, ) -> None: session = Session(bind=connection) @@ -203,7 +197,7 @@ class ObjectUpdater: cls, _mapper: Mapper, connection: Connection, - target: Union[Dashboard, FavStar, Slice, Query, SqlaTable], + target: Dashboard | FavStar | Slice | Query | SqlaTable, ) -> None: session = Session(bind=connection) @@ -220,7 +214,7 @@ class ChartUpdater(ObjectUpdater): object_type = "chart" @classmethod - def get_owners_ids(cls, target: Slice) -> List[int]: + def get_owners_ids(cls, target: Slice) -> list[int]: return [owner.id for owner in target.owners] @@ -228,7 +222,7 @@ class DashboardUpdater(ObjectUpdater): object_type = "dashboard" @classmethod - def get_owners_ids(cls, target: Dashboard) -> List[int]: + def get_owners_ids(cls, target: Dashboard) -> list[int]: return [owner.id for owner in target.owners] @@ -236,7 +230,7 @@ class QueryUpdater(ObjectUpdater): object_type = "query" @classmethod - def get_owners_ids(cls, target: Query) -> List[int]: + def get_owners_ids(cls, target: Query) -> list[int]: return [target.user_id] @@ -244,7 +238,7 @@ class DatasetUpdater(ObjectUpdater): object_type = "dataset" @classmethod - def get_owners_ids(cls, target: SqlaTable) -> List[int]: + def get_owners_ids(cls, target: SqlaTable) -> list[int]: return [owner.id for owner in target.owners] @@ -254,7 +248,7 @@ class FavStarUpdater: cls, _mapper: Mapper, connection: Connection, target: FavStar ) -> None: session = Session(bind=connection) - name = "favorited_by:{0}".format(target.user_id) + name = f"favorited_by:{target.user_id}" tag = get_tag(name, session, TagTypes.favorited_by) tagged_object = TaggedObject( tag_id=tag.id, diff --git a/superset/tasks/__init__.py b/superset/tasks/__init__.py index fd9417fe5c..13a83393a9 100644 --- a/superset/tasks/__init__.py +++ b/superset/tasks/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information diff --git a/superset/tasks/async_queries.py b/superset/tasks/async_queries.py index ffd92c2627..cfcb3e31c6 100644 --- a/superset/tasks/async_queries.py +++ b/superset/tasks/async_queries.py @@ -18,7 +18,7 @@ from __future__ import annotations import copy import logging -from typing import Any, cast, Dict, Optional, TYPE_CHECKING +from typing import Any, cast, TYPE_CHECKING from celery.exceptions import SoftTimeLimitExceeded from flask import current_app, g @@ -45,12 +45,12 @@ query_timeout = current_app.config[ ] # TODO: new config key -def set_form_data(form_data: Dict[str, Any]) -> None: +def set_form_data(form_data: dict[str, Any]) -> None: # pylint: disable=assigning-non-slot g.form_data = form_data -def _create_query_context_from_form(form_data: Dict[str, Any]) -> QueryContext: +def _create_query_context_from_form(form_data: dict[str, Any]) -> QueryContext: try: return ChartDataQueryContextSchema().load(form_data) except KeyError as ex: @@ -61,8 +61,8 @@ def _create_query_context_from_form(form_data: Dict[str, Any]) -> QueryContext: @celery_app.task(name="load_chart_data_into_cache", soft_time_limit=query_timeout) def load_chart_data_into_cache( - job_metadata: Dict[str, Any], - form_data: Dict[str, Any], + job_metadata: dict[str, Any], + form_data: dict[str, Any], ) -> None: # pylint: disable=import-outside-toplevel from superset.charts.data.commands.get_data_command import ChartDataCommand @@ -104,9 +104,9 @@ def load_chart_data_into_cache( @celery_app.task(name="load_explore_json_into_cache", soft_time_limit=query_timeout) def load_explore_json_into_cache( # pylint: disable=too-many-locals - job_metadata: Dict[str, Any], - form_data: Dict[str, Any], - response_type: Optional[str] = None, + job_metadata: dict[str, Any], + form_data: dict[str, Any], + response_type: str | None = None, force: bool = False, ) -> None: cache_key_prefix = "ejr-" # ejr: explore_json request diff --git a/superset/tasks/cache.py b/superset/tasks/cache.py index bdbf8add7e..448271269a 100644 --- a/superset/tasks/cache.py +++ b/superset/tasks/cache.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict, List, Optional, Union +from typing import Any, Optional, Union from urllib import request from urllib.error import URLError @@ -72,7 +72,7 @@ class Strategy: # pylint: disable=too-few-public-methods def __init__(self) -> None: pass - def get_urls(self) -> List[str]: + def get_urls(self) -> list[str]: raise NotImplementedError("Subclasses must implement get_urls!") @@ -94,7 +94,7 @@ class DummyStrategy(Strategy): # pylint: disable=too-few-public-methods name = "dummy" - def get_urls(self) -> List[str]: + def get_urls(self) -> list[str]: session = db.create_scoped_session() charts = session.query(Slice).all() @@ -126,7 +126,7 @@ class TopNDashboardsStrategy(Strategy): # pylint: disable=too-few-public-method self.top_n = top_n self.since = parse_human_datetime(since) if since else None - def get_urls(self) -> List[str]: + def get_urls(self) -> list[str]: urls = [] session = db.create_scoped_session() @@ -165,11 +165,11 @@ class DashboardTagsStrategy(Strategy): # pylint: disable=too-few-public-methods name = "dashboard_tags" - def __init__(self, tags: Optional[List[str]] = None) -> None: + def __init__(self, tags: Optional[list[str]] = None) -> None: super().__init__() self.tags = tags or [] - def get_urls(self) -> List[str]: + def get_urls(self) -> list[str]: urls = [] session = db.create_scoped_session() @@ -216,7 +216,7 @@ strategies = [DummyStrategy, TopNDashboardsStrategy, DashboardTagsStrategy] @celery_app.task(name="fetch_url") -def fetch_url(url: str, headers: Dict[str, str]) -> Dict[str, str]: +def fetch_url(url: str, headers: dict[str, str]) -> dict[str, str]: """ Celery job to fetch url """ @@ -242,7 +242,7 @@ def fetch_url(url: str, headers: Dict[str, str]) -> Dict[str, str]: @celery_app.task(name="cache-warmup") def cache_warmup( strategy_name: str, *args: Any, **kwargs: Any -) -> Union[Dict[str, List[str]], str]: +) -> Union[dict[str, list[str]], str]: """ Warm up cache. @@ -272,7 +272,7 @@ def cache_warmup( cookies = MachineAuthProvider.get_auth_cookies(user) headers = {"Cookie": f"session={cookies.get('session', '')}"} - results: Dict[str, List[str]] = {"scheduled": [], "errors": []} + results: dict[str, list[str]] = {"scheduled": [], "errors": []} for url in strategy.get_urls(): try: logger.info("Scheduling %s", url) diff --git a/superset/tasks/cron_util.py b/superset/tasks/cron_util.py index 9c275addf6..19d342ebdc 100644 --- a/superset/tasks/cron_util.py +++ b/superset/tasks/cron_util.py @@ -16,8 +16,8 @@ # under the License. import logging +from collections.abc import Iterator from datetime import datetime, timedelta, timezone as dt_timezone -from typing import Iterator from croniter import croniter from pytz import timezone as pytz_timezone, UnknownTimeZoneError diff --git a/superset/tasks/utils.py b/superset/tasks/utils.py index 9c1dab8220..5012330bbd 100644 --- a/superset/tasks/utils.py +++ b/superset/tasks/utils.py @@ -17,7 +17,7 @@ from __future__ import annotations -from typing import List, Optional, Tuple, TYPE_CHECKING, Union +from typing import TYPE_CHECKING from flask import current_app, g @@ -32,10 +32,10 @@ if TYPE_CHECKING: # pylint: disable=too-many-branches def get_executor( - executor_types: List[ExecutorType], - model: Union[Dashboard, ReportSchedule, Slice], - current_user: Optional[str] = None, -) -> Tuple[ExecutorType, str]: + executor_types: list[ExecutorType], + model: Dashboard | ReportSchedule | Slice, + current_user: str | None = None, +) -> tuple[ExecutorType, str]: """ Extract the user that should be used to execute a scheduled task. Certain executor types extract the user from the underlying object (e.g. CREATOR), the constant @@ -86,7 +86,7 @@ def get_executor( raise ExecutorNotFoundError() -def get_current_user() -> Optional[str]: +def get_current_user() -> str | None: user = g.user if hasattr(g, "user") and g.user else None if user and not user.is_anonymous: return user.username diff --git a/superset/translations/utils.py b/superset/translations/utils.py index 79d01539a1..23eca1dd8c 100644 --- a/superset/translations/utils.py +++ b/superset/translations/utils.py @@ -16,15 +16,15 @@ # under the License. import json import os -from typing import Any, Dict, Optional +from typing import Any, Optional # Global caching for JSON language packs -ALL_LANGUAGE_PACKS: Dict[str, Dict[str, Any]] = {"en": {}} +ALL_LANGUAGE_PACKS: dict[str, dict[str, Any]] = {"en": {}} DIR = os.path.dirname(os.path.abspath(__file__)) -def get_language_pack(locale: str) -> Optional[Dict[str, Any]]: +def get_language_pack(locale: str) -> Optional[dict[str, Any]]: """Get/cache a language pack Returns the language pack from cache if it exists, caches otherwise @@ -34,7 +34,7 @@ def get_language_pack(locale: str) -> Optional[Dict[str, Any]]: """ pack = ALL_LANGUAGE_PACKS.get(locale) if not pack: - filename = DIR + "/{}/LC_MESSAGES/messages.json".format(locale) + filename = DIR + f"/{locale}/LC_MESSAGES/messages.json" try: with open(filename, encoding="utf8") as f: pack = json.load(f) diff --git a/superset/utils/async_query_manager.py b/superset/utils/async_query_manager.py index 71559aaa3d..1913fc1dec 100644 --- a/superset/utils/async_query_manager.py +++ b/superset/utils/async_query_manager.py @@ -17,7 +17,7 @@ import json import logging import uuid -from typing import Any, Dict, List, Literal, Optional, Tuple +from typing import Any, Literal, Optional import jwt import redis @@ -38,7 +38,7 @@ class AsyncQueryJobException(Exception): def build_job_metadata( channel_id: str, job_id: str, user_id: Optional[int], **kwargs: Any -) -> Dict[str, Any]: +) -> dict[str, Any]: return { "channel_id": channel_id, "job_id": job_id, @@ -49,7 +49,7 @@ def build_job_metadata( } -def parse_event(event_data: Tuple[str, Dict[str, Any]]) -> Dict[str, Any]: +def parse_event(event_data: tuple[str, dict[str, Any]]) -> dict[str, Any]: event_id = event_data[0] event_payload = event_data[1]["data"] return {"id": event_id, **json.loads(event_payload)} @@ -149,7 +149,7 @@ class AsyncQueryManager: return response - def parse_jwt_from_request(self, req: Request) -> Dict[str, Any]: + def parse_jwt_from_request(self, req: Request) -> dict[str, Any]: token = req.cookies.get(self._jwt_cookie_name) if not token: raise AsyncQueryTokenException("Token not preset") @@ -160,7 +160,7 @@ class AsyncQueryManager: logger.warning("Parse jwt failed", exc_info=True) raise AsyncQueryTokenException("Failed to parse token") from ex - def init_job(self, channel_id: str, user_id: Optional[int]) -> Dict[str, Any]: + def init_job(self, channel_id: str, user_id: Optional[int]) -> dict[str, Any]: job_id = str(uuid.uuid4()) return build_job_metadata( channel_id, job_id, user_id, status=self.STATUS_PENDING @@ -168,14 +168,14 @@ class AsyncQueryManager: def read_events( self, channel: str, last_id: Optional[str] - ) -> List[Optional[Dict[str, Any]]]: + ) -> list[Optional[dict[str, Any]]]: stream_name = f"{self._stream_prefix}{channel}" start_id = increment_id(last_id) if last_id else "-" results = self._redis.xrange(stream_name, start_id, "+", self.MAX_EVENT_COUNT) return [] if not results else list(map(parse_event, results)) def update_job( - self, job_metadata: Dict[str, Any], status: str, **kwargs: Any + self, job_metadata: dict[str, Any], status: str, **kwargs: Any ) -> None: if "channel_id" not in job_metadata: raise AsyncQueryJobException("No channel ID specified") diff --git a/superset/utils/cache.py b/superset/utils/cache.py index a632b04b37..693f3a73bc 100644 --- a/superset/utils/cache.py +++ b/superset/utils/cache.py @@ -20,7 +20,7 @@ import inspect import logging from datetime import datetime, timedelta from functools import wraps -from typing import Any, Callable, Dict, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, TYPE_CHECKING from flask import current_app as app, request from flask_caching import Cache @@ -41,7 +41,7 @@ stats_logger: BaseStatsLogger = config["STATS_LOGGER"] logger = logging.getLogger(__name__) -def generate_cache_key(values_dict: Dict[str, Any], key_prefix: str = "") -> str: +def generate_cache_key(values_dict: dict[str, Any], key_prefix: str = "") -> str: hash_str = md5_sha_from_dict(values_dict, default=json_int_dttm_ser) return f"{key_prefix}{hash_str}" @@ -49,9 +49,9 @@ def generate_cache_key(values_dict: Dict[str, Any], key_prefix: str = "") -> str def set_and_log_cache( cache_instance: Cache, cache_key: str, - cache_value: Dict[str, Any], - cache_timeout: Optional[int] = None, - datasource_uid: Optional[str] = None, + cache_value: dict[str, Any], + cache_timeout: int | None = None, + datasource_uid: str | None = None, ) -> None: if isinstance(cache_instance.cache, NullCache): return @@ -91,11 +91,11 @@ logger = logging.getLogger(__name__) def view_cache_key(*args: Any, **kwargs: Any) -> str: # pylint: disable=unused-argument args_hash = hash(frozenset(request.args.items())) - return "view/{}/{}".format(request.path, args_hash) + return f"view/{request.path}/{args_hash}" def memoized_func( - key: Optional[str] = None, cache: Cache = cache_manager.cache + key: str | None = None, cache: Cache = cache_manager.cache ) -> Callable[..., Any]: """ Decorator with configurable key and cache backend. @@ -152,10 +152,10 @@ def memoized_func( def etag_cache( cache: Cache = cache_manager.cache, - get_last_modified: Optional[Callable[..., datetime]] = None, - max_age: Optional[Union[int, float]] = None, - raise_for_access: Optional[Callable[..., Any]] = None, - skip: Optional[Callable[..., bool]] = None, + get_last_modified: Callable[..., datetime] | None = None, + max_age: int | float | None = None, + raise_for_access: Callable[..., Any] | None = None, + skip: Callable[..., bool] | None = None, ) -> Callable[..., Any]: """ A decorator for caching views and handling etag conditional requests. diff --git a/superset/utils/celery.py b/superset/utils/celery.py index 474fc98d94..3577179145 100644 --- a/superset/utils/celery.py +++ b/superset/utils/celery.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. import logging +from collections.abc import Iterator from contextlib import contextmanager -from typing import Iterator from sqlalchemy import create_engine from sqlalchemy.exc import SQLAlchemyError diff --git a/superset/utils/core.py b/superset/utils/core.py index c537abf459..24e539b2b6 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -35,6 +35,7 @@ import threading import traceback import uuid import zlib +from collections.abc import Iterable, Iterator, Sequence from contextlib import contextmanager from dataclasses import dataclass from datetime import date, datetime, time, timedelta @@ -47,24 +48,7 @@ from enum import Enum, IntEnum from io import BytesIO from timeit import default_timer from types import TracebackType -from typing import ( - Any, - Callable, - cast, - Dict, - Iterable, - Iterator, - List, - NamedTuple, - Optional, - Sequence, - Set, - Tuple, - Type, - TYPE_CHECKING, - TypeVar, - Union, -) +from typing import Any, Callable, cast, NamedTuple, TYPE_CHECKING, TypeVar from urllib.parse import unquote_plus from zipfile import ZipFile @@ -197,11 +181,11 @@ class LoggerLevel(str, Enum): class HeaderDataType(TypedDict): notification_format: str - owners: List[int] + owners: list[int] notification_type: str - notification_source: Optional[str] - chart_id: Optional[int] - dashboard_id: Optional[int] + notification_source: str | None + chart_id: int | None + dashboard_id: int | None class DatasourceDict(TypedDict): @@ -212,20 +196,20 @@ class DatasourceDict(TypedDict): class AdhocFilterClause(TypedDict, total=False): clause: str expressionType: str - filterOptionName: Optional[str] - comparator: Optional[FilterValues] + filterOptionName: str | None + comparator: FilterValues | None operator: str subject: str - isExtra: Optional[bool] - sqlExpression: Optional[str] + isExtra: bool | None + sqlExpression: str | None class QueryObjectFilterClause(TypedDict, total=False): col: Column op: str # pylint: disable=invalid-name - val: Optional[FilterValues] - grain: Optional[str] - isExtra: Optional[bool] + val: FilterValues | None + grain: str | None + isExtra: bool | None class ExtraFiltersTimeColumnType(str, Enum): @@ -351,9 +335,9 @@ class ReservedUrlParameters(str, Enum): EDIT_MODE = "edit" @staticmethod - def is_standalone_mode() -> Optional[bool]: + def is_standalone_mode() -> bool | None: standalone_param = request.args.get(ReservedUrlParameters.STANDALONE.value) - standalone: Optional[bool] = bool( + standalone: bool | None = bool( standalone_param and standalone_param != "false" and standalone_param != "0" ) return standalone @@ -370,10 +354,10 @@ class ColumnTypeSource(Enum): class ColumnSpec(NamedTuple): - sqla_type: Union[TypeEngine, str] + sqla_type: TypeEngine | str generic_type: GenericDataType is_dttm: bool - python_date_format: Optional[str] = None + python_date_format: str | None = None try: @@ -407,8 +391,8 @@ def flasher(msg: str, severity: str = "message") -> None: def parse_js_uri_path_item( - item: Optional[str], unquote: bool = True, eval_undefined: bool = False -) -> Optional[str]: + item: str | None, unquote: bool = True, eval_undefined: bool = False +) -> str | None: """Parse a uri path item made with js. :param item: a uri path component @@ -421,7 +405,7 @@ def parse_js_uri_path_item( return unquote_plus(item) if unquote and item else item -def cast_to_num(value: Optional[Union[float, int, str]]) -> Optional[Union[float, int]]: +def cast_to_num(value: float | int | str | None) -> float | int | None: """Casts a value to an int/float >>> cast_to_num('1 ') @@ -457,7 +441,7 @@ def cast_to_num(value: Optional[Union[float, int, str]]) -> Optional[Union[float return None -def cast_to_boolean(value: Any) -> Optional[bool]: +def cast_to_boolean(value: Any) -> bool | None: """Casts a value to an int/float >>> cast_to_boolean(1) @@ -487,7 +471,7 @@ def cast_to_boolean(value: Any) -> Optional[bool]: return False -def list_minus(l: List[Any], minus: List[Any]) -> List[Any]: +def list_minus(l: list[Any], minus: list[Any]) -> list[Any]: """Returns l without what is in minus >>> list_minus([1, 2, 3], [2]) @@ -501,12 +485,12 @@ class DashboardEncoder(json.JSONEncoder): super().__init__(*args, **kwargs) self.sort_keys = True - def default(self, o: Any) -> Union[Dict[Any, Any], str]: + def default(self, o: Any) -> dict[Any, Any] | str: if isinstance(o, uuid.UUID): return str(o) try: vals = {k: v for k, v in o.__dict__.items() if k != "_sa_instance_state"} - return {"__{}__".format(o.__class__.__name__): vals} + return {f"__{o.__class__.__name__}__": vals} except Exception: # pylint: disable=broad-except if isinstance(o, datetime): return {"__datetime__": o.replace(microsecond=0).isoformat()} @@ -519,13 +503,13 @@ class JSONEncodedDict(TypeDecorator): # pylint: disable=abstract-method impl = TEXT def process_bind_param( - self, value: Optional[Dict[Any, Any]], dialect: str - ) -> Optional[str]: + self, value: dict[Any, Any] | None, dialect: str + ) -> str | None: return json.dumps(value) if value is not None else None def process_result_value( - self, value: Optional[str], dialect: str - ) -> Optional[Dict[Any, Any]]: + self, value: str | None, dialect: str + ) -> dict[Any, Any] | None: return json.loads(value) if value is not None else None @@ -634,7 +618,7 @@ def json_int_dttm_ser(obj: Any) -> Any: return base_json_conv(obj) -def json_dumps_w_dates(payload: Dict[Any, Any], sort_keys: bool = False) -> str: +def json_dumps_w_dates(payload: dict[Any, Any], sort_keys: bool = False) -> str: """Dumps payload to JSON with Datetime objects properly converted""" return json.dumps(payload, default=json_int_dttm_ser, sort_keys=sort_keys) @@ -662,7 +646,7 @@ def error_msg_from_exception(ex: Exception) -> str: return msg or str(ex) -def markdown(raw: str, markup_wrap: Optional[bool] = False) -> str: +def markdown(raw: str, markup_wrap: bool | None = False) -> str: safe_markdown_tags = { "h1", "h2", @@ -709,15 +693,15 @@ def markdown(raw: str, markup_wrap: Optional[bool] = False) -> str: return safe -def readfile(file_path: str) -> Optional[str]: +def readfile(file_path: str) -> str | None: with open(file_path) as f: content = f.read() return content def generic_find_constraint_name( - table: str, columns: Set[str], referenced: str, database: SQLA -) -> Optional[str]: + table: str, columns: set[str], referenced: str, database: SQLA +) -> str | None: """Utility to find a constraint name in alembic migrations""" tbl = sa.Table( table, database.metadata, autoload=True, autoload_with=database.engine @@ -731,8 +715,8 @@ def generic_find_constraint_name( def generic_find_fk_constraint_name( - table: str, columns: Set[str], referenced: str, insp: Inspector -) -> Optional[str]: + table: str, columns: set[str], referenced: str, insp: Inspector +) -> str | None: """Utility to find a foreign-key constraint name in alembic migrations""" for fk in insp.get_foreign_keys(table): if ( @@ -745,8 +729,8 @@ def generic_find_fk_constraint_name( def generic_find_fk_constraint_names( # pylint: disable=invalid-name - table: str, columns: Set[str], referenced: str, insp: Inspector -) -> Set[str]: + table: str, columns: set[str], referenced: str, insp: Inspector +) -> set[str]: """Utility to find foreign-key constraint names in alembic migrations""" names = set() @@ -761,8 +745,8 @@ def generic_find_fk_constraint_names( # pylint: disable=invalid-name def generic_find_uq_constraint_name( - table: str, columns: Set[str], insp: Inspector -) -> Optional[str]: + table: str, columns: set[str], insp: Inspector +) -> str | None: """Utility to find a unique constraint name in alembic migrations""" for uq in insp.get_unique_constraints(table): @@ -773,14 +757,14 @@ def generic_find_uq_constraint_name( def get_datasource_full_name( - database_name: str, datasource_name: str, schema: Optional[str] = None + database_name: str, datasource_name: str, schema: str | None = None ) -> str: if not schema: - return "[{}].[{}]".format(database_name, datasource_name) - return "[{}].[{}].[{}]".format(database_name, schema, datasource_name) + return f"[{database_name}].[{datasource_name}]" + return f"[{database_name}].[{schema}].[{datasource_name}]" -def validate_json(obj: Union[bytes, bytearray, str]) -> None: +def validate_json(obj: bytes | bytearray | str) -> None: if obj: try: json.loads(obj) @@ -851,7 +835,7 @@ class TimerTimeout: # Windows has no support for SIGALRM, so we use the timer based timeout -timeout: Union[Type[TimerTimeout], Type[SigalrmTimeout]] = ( +timeout: type[TimerTimeout] | type[SigalrmTimeout] = ( TimerTimeout if platform.system() == "Windows" else SigalrmTimeout ) @@ -897,9 +881,9 @@ def notify_user_about_perm_udate( # pylint: disable=too-many-arguments granter: User, user: User, role: Role, - datasource: "BaseDatasource", + datasource: BaseDatasource, tpl_name: str, - config: Dict[str, Any], + config: dict[str, Any], ) -> None: msg = render_template( tpl_name, granter=granter, user=user, role=role, datasource=datasource @@ -923,15 +907,15 @@ def send_email_smtp( # pylint: disable=invalid-name,too-many-arguments,too-many to: str, subject: str, html_content: str, - config: Dict[str, Any], - files: Optional[List[str]] = None, - data: Optional[Dict[str, str]] = None, - images: Optional[Dict[str, bytes]] = None, + config: dict[str, Any], + files: list[str] | None = None, + data: dict[str, str] | None = None, + images: dict[str, bytes] | None = None, dryrun: bool = False, - cc: Optional[str] = None, - bcc: Optional[str] = None, + cc: str | None = None, + bcc: str | None = None, mime_subtype: str = "mixed", - header_data: Optional[HeaderDataType] = None, + header_data: HeaderDataType | None = None, ) -> None: """ Send an email with html content, eg: @@ -1000,9 +984,9 @@ def send_email_smtp( # pylint: disable=invalid-name,too-many-arguments,too-many def send_mime_email( e_from: str, - e_to: List[str], + e_to: list[str], mime_msg: MIMEMultipart, - config: Dict[str, Any], + config: dict[str, Any], dryrun: bool = False, ) -> None: smtp_host = config["SMTP_HOST"] @@ -1035,8 +1019,8 @@ def send_mime_email( smtp.quit() -def get_email_address_list(address_string: str) -> List[str]: - address_string_list: List[str] = [] +def get_email_address_list(address_string: str) -> list[str]: + address_string_list: list[str] = [] if isinstance(address_string, str): address_string_list = re.split(r",|\s|;", address_string) return [x.strip() for x in address_string_list if x.strip()] @@ -1049,12 +1033,12 @@ def get_email_address_str(address_string: str) -> str: return address_list_str -def choicify(values: Iterable[Any]) -> List[Tuple[Any, Any]]: +def choicify(values: Iterable[Any]) -> list[tuple[Any, Any]]: """Takes an iterable and makes an iterable of tuples with it""" return [(v, v) for v in values] -def zlib_compress(data: Union[bytes, str]) -> bytes: +def zlib_compress(data: bytes | str) -> bytes: """ Compress things in a py2/3 safe fashion >>> json_str = '{"test": 1}' @@ -1065,7 +1049,7 @@ def zlib_compress(data: Union[bytes, str]) -> bytes: return zlib.compress(data) -def zlib_decompress(blob: bytes, decode: Optional[bool] = True) -> Union[bytes, str]: +def zlib_decompress(blob: bytes, decode: bool | None = True) -> bytes | str: """ Decompress things to a string in a py2/3 safe fashion >>> json_str = '{"test": 1}' @@ -1094,12 +1078,12 @@ def simple_filter_to_adhoc( } if filter_clause.get("isExtra"): result["isExtra"] = True - result["filterOptionName"] = md5_sha_from_dict(cast(Dict[Any, Any], result)) + result["filterOptionName"] = md5_sha_from_dict(cast(dict[Any, Any], result)) return result -def form_data_to_adhoc(form_data: Dict[str, Any], clause: str) -> AdhocFilterClause: +def form_data_to_adhoc(form_data: dict[str, Any], clause: str) -> AdhocFilterClause: if clause not in ("where", "having"): raise ValueError(__("Unsupported clause type: %(clause)s", clause=clause)) result: AdhocFilterClause = { @@ -1107,19 +1091,19 @@ def form_data_to_adhoc(form_data: Dict[str, Any], clause: str) -> AdhocFilterCla "expressionType": "SQL", "sqlExpression": form_data.get(clause), } - result["filterOptionName"] = md5_sha_from_dict(cast(Dict[Any, Any], result)) + result["filterOptionName"] = md5_sha_from_dict(cast(dict[Any, Any], result)) return result -def merge_extra_form_data(form_data: Dict[str, Any]) -> None: +def merge_extra_form_data(form_data: dict[str, Any]) -> None: """ Merge extra form data (appends and overrides) into the main payload and add applied time extras to the payload. """ filter_keys = ["filters", "adhoc_filters"] extra_form_data = form_data.pop("extra_form_data", {}) - append_filters: List[QueryObjectFilterClause] = extra_form_data.get("filters", None) + append_filters: list[QueryObjectFilterClause] = extra_form_data.get("filters", None) # merge append extras for key in [key for key in EXTRA_FORM_DATA_APPEND_KEYS if key not in filter_keys]: @@ -1144,9 +1128,9 @@ def merge_extra_form_data(form_data: Dict[str, Any]) -> None: if extras: form_data["extras"] = extras - adhoc_filters: List[AdhocFilterClause] = form_data.get("adhoc_filters", []) + adhoc_filters: list[AdhocFilterClause] = form_data.get("adhoc_filters", []) form_data["adhoc_filters"] = adhoc_filters - append_adhoc_filters: List[AdhocFilterClause] = extra_form_data.get( + append_adhoc_filters: list[AdhocFilterClause] = extra_form_data.get( "adhoc_filters", [] ) adhoc_filters.extend( @@ -1170,7 +1154,7 @@ def merge_extra_form_data(form_data: Dict[str, Any]) -> None: adhoc_filter["comparator"] = form_data["time_range"] -def merge_extra_filters(form_data: Dict[str, Any]) -> None: +def merge_extra_filters(form_data: dict[str, Any]) -> None: # extra_filters are temporary/contextual filters (using the legacy constructs) # that are external to the slice definition. We use those for dynamic # interactive filters like the ones emitted by the "Filter Box" visualization. @@ -1193,7 +1177,7 @@ def merge_extra_filters(form_data: Dict[str, Any]) -> None: # Grab list of existing filters 'keyed' on the column and operator - def get_filter_key(f: Dict[str, Any]) -> str: + def get_filter_key(f: dict[str, Any]) -> str: if "expressionType" in f: return "{}__{}".format(f["subject"], f["operator"]) @@ -1244,7 +1228,7 @@ def merge_extra_filters(form_data: Dict[str, Any]) -> None: del form_data["extra_filters"] -def merge_request_params(form_data: Dict[str, Any], params: Dict[str, Any]) -> None: +def merge_request_params(form_data: dict[str, Any], params: dict[str, Any]) -> None: """ Merge request parameters to the key `url_params` in form_data. Only updates or appends parameters to `form_data` that are defined in `params; pre-existing @@ -1261,7 +1245,7 @@ def merge_request_params(form_data: Dict[str, Any], params: Dict[str, Any]) -> N form_data["url_params"] = url_params -def user_label(user: User) -> Optional[str]: +def user_label(user: User) -> str | None: """Given a user ORM FAB object, returns a label""" if user: if user.first_name and user.last_name: @@ -1272,7 +1256,7 @@ def user_label(user: User) -> Optional[str]: return None -def get_example_default_schema() -> Optional[str]: +def get_example_default_schema() -> str | None: """ Return the default schema of the examples database, if any. """ @@ -1295,7 +1279,7 @@ def is_adhoc_column(column: Column) -> TypeGuard[AdhocColumn]: ) -def get_base_axis_labels(columns: Optional[List[Column]]) -> Tuple[str, ...]: +def get_base_axis_labels(columns: list[Column] | None) -> tuple[str, ...]: axis_cols = [ col for col in columns or [] @@ -1304,14 +1288,12 @@ def get_base_axis_labels(columns: Optional[List[Column]]) -> Tuple[str, ...]: return tuple(get_column_name(col) for col in axis_cols) -def get_xaxis_label(columns: Optional[List[Column]]) -> Optional[str]: +def get_xaxis_label(columns: list[Column] | None) -> str | None: labels = get_base_axis_labels(columns) return labels[0] if labels else None -def get_column_name( - column: Column, verbose_map: Optional[Dict[str, Any]] = None -) -> str: +def get_column_name(column: Column, verbose_map: dict[str, Any] | None = None) -> str: """ Extract label from column @@ -1336,9 +1318,7 @@ def get_column_name( raise ValueError("Missing label") -def get_metric_name( - metric: Metric, verbose_map: Optional[Dict[str, Any]] = None -) -> str: +def get_metric_name(metric: Metric, verbose_map: dict[str, Any] | None = None) -> str: """ Extract label from metric @@ -1374,9 +1354,9 @@ def get_metric_name( def get_column_names( - columns: Optional[Sequence[Column]], - verbose_map: Optional[Dict[str, Any]] = None, -) -> List[str]: + columns: Sequence[Column] | None, + verbose_map: dict[str, Any] | None = None, +) -> list[str]: return [ column for column in [get_column_name(column, verbose_map) for column in columns or []] @@ -1385,9 +1365,9 @@ def get_column_names( def get_metric_names( - metrics: Optional[Sequence[Metric]], - verbose_map: Optional[Dict[str, Any]] = None, -) -> List[str]: + metrics: Sequence[Metric] | None, + verbose_map: dict[str, Any] | None = None, +) -> list[str]: return [ metric for metric in [get_metric_name(metric, verbose_map) for metric in metrics or []] @@ -1396,9 +1376,9 @@ def get_metric_names( def get_first_metric_name( - metrics: Optional[Sequence[Metric]], - verbose_map: Optional[Dict[str, Any]] = None, -) -> Optional[str]: + metrics: Sequence[Metric] | None, + verbose_map: dict[str, Any] | None = None, +) -> str | None: metric_labels = get_metric_names(metrics, verbose_map) return metric_labels[0] if metric_labels else None @@ -1417,7 +1397,7 @@ def convert_legacy_filters_into_adhoc( # pylint: disable=invalid-name mapping = {"having": "having_filters", "where": "filters"} if not form_data.get("adhoc_filters"): - adhoc_filters: List[AdhocFilterClause] = [] + adhoc_filters: list[AdhocFilterClause] = [] form_data["adhoc_filters"] = adhoc_filters for clause, filters in mapping.items(): @@ -1475,17 +1455,13 @@ def split_adhoc_filters_into_base_filters( # pylint: disable=invalid-name sql_where_filters.append(sql_expression) elif clause == "HAVING": sql_having_filters.append(sql_expression) - form_data["where"] = " AND ".join( - ["({})".format(sql) for sql in sql_where_filters] - ) - form_data["having"] = " AND ".join( - ["({})".format(sql) for sql in sql_having_filters] - ) + form_data["where"] = " AND ".join([f"({sql})" for sql in sql_where_filters]) + form_data["having"] = " AND ".join([f"({sql})" for sql in sql_having_filters]) form_data["having_filters"] = simple_having_filters form_data["filters"] = simple_where_filters -def get_username() -> Optional[str]: +def get_username() -> str | None: """ Get username (if defined) associated with the current user. @@ -1498,7 +1474,7 @@ def get_username() -> Optional[str]: return None -def get_user_id() -> Optional[int]: +def get_user_id() -> int | None: """ Get the user identifier (if defined) associated with the current user. @@ -1517,7 +1493,7 @@ def get_user_id() -> Optional[int]: @contextmanager -def override_user(user: Optional[User], force: bool = True) -> Iterator[Any]: +def override_user(user: User | None, force: bool = True) -> Iterator[Any]: """ Temporarily override the current user per `flask.g` with the specified user. @@ -1583,7 +1559,7 @@ def create_ssl_cert_file(certificate: str) -> str: def time_function( func: Callable[..., FlaskResponse], *args: Any, **kwargs: Any -) -> Tuple[float, Any]: +) -> tuple[float, Any]: """ Measures the amount of time a function takes to execute in ms @@ -1603,7 +1579,7 @@ def MediumText() -> Variant: # pylint:disable=invalid-name def shortid() -> str: - return "{}".format(uuid.uuid4())[-12:] + return f"{uuid.uuid4()}"[-12:] class DatasourceName(NamedTuple): @@ -1611,7 +1587,7 @@ class DatasourceName(NamedTuple): schema: str -def get_stacktrace() -> Optional[str]: +def get_stacktrace() -> str | None: if current_app.config["SHOW_STACKTRACE"]: return traceback.format_exc() return None @@ -1649,7 +1625,7 @@ def split( yield string[i:] -def get_iterable(x: Any) -> List[Any]: +def get_iterable(x: Any) -> list[Any]: """ Get an iterable (list) representation of the object. @@ -1659,7 +1635,7 @@ def get_iterable(x: Any) -> List[Any]: return x if isinstance(x, list) else [x] -def get_form_data_token(form_data: Dict[str, Any]) -> str: +def get_form_data_token(form_data: dict[str, Any]) -> str: """ Return the token contained within form data or generate a new one. @@ -1669,7 +1645,7 @@ def get_form_data_token(form_data: Dict[str, Any]) -> str: return form_data.get("token") or "token_" + uuid.uuid4().hex[:8] -def get_column_name_from_column(column: Column) -> Optional[str]: +def get_column_name_from_column(column: Column) -> str | None: """ Extract the physical column that a column is referencing. If the column is an adhoc column, always returns `None`. @@ -1682,7 +1658,7 @@ def get_column_name_from_column(column: Column) -> Optional[str]: return column # type: ignore -def get_column_names_from_columns(columns: List[Column]) -> List[str]: +def get_column_names_from_columns(columns: list[Column]) -> list[str]: """ Extract the physical columns that a list of columns are referencing. Ignore adhoc columns @@ -1693,7 +1669,7 @@ def get_column_names_from_columns(columns: List[Column]) -> List[str]: return [col for col in map(get_column_name_from_column, columns) if col] -def get_column_name_from_metric(metric: Metric) -> Optional[str]: +def get_column_name_from_metric(metric: Metric) -> str | None: """ Extract the column that a metric is referencing. If the metric isn't a simple metric, always returns `None`. @@ -1704,11 +1680,11 @@ def get_column_name_from_metric(metric: Metric) -> Optional[str]: if is_adhoc_metric(metric): metric = cast(AdhocMetric, metric) if metric["expressionType"] == AdhocMetricExpressionType.SIMPLE: - return cast(Dict[str, Any], metric["column"])["column_name"] + return cast(dict[str, Any], metric["column"])["column_name"] return None -def get_column_names_from_metrics(metrics: List[Metric]) -> List[str]: +def get_column_names_from_metrics(metrics: list[Metric]) -> list[str]: """ Extract the columns that a list of metrics are referencing. Excludes all SQL metrics. @@ -1721,12 +1697,12 @@ def get_column_names_from_metrics(metrics: List[Metric]) -> List[str]: def extract_dataframe_dtypes( df: pd.DataFrame, - datasource: Optional[Union[BaseDatasource, Query]] = None, -) -> List[GenericDataType]: + datasource: BaseDatasource | Query | None = None, +) -> list[GenericDataType]: """Serialize pandas/numpy dtypes to generic types""" # omitting string types as those will be the default type - inferred_type_map: Dict[str, GenericDataType] = { + inferred_type_map: dict[str, GenericDataType] = { "floating": GenericDataType.NUMERIC, "integer": GenericDataType.NUMERIC, "mixed-integer-float": GenericDataType.NUMERIC, @@ -1737,7 +1713,7 @@ def extract_dataframe_dtypes( "date": GenericDataType.TEMPORAL, } - columns_by_name: Dict[str, Any] = {} + columns_by_name: dict[str, Any] = {} if datasource: for column in datasource.columns: if isinstance(column, dict): @@ -1745,7 +1721,7 @@ def extract_dataframe_dtypes( else: columns_by_name[column.column_name] = column - generic_types: List[GenericDataType] = [] + generic_types: list[GenericDataType] = [] for column in df.columns: column_object = columns_by_name.get(column) series = df[column] @@ -1767,7 +1743,7 @@ def extract_dataframe_dtypes( return generic_types -def extract_column_dtype(col: "BaseColumn") -> GenericDataType: +def extract_column_dtype(col: BaseColumn) -> GenericDataType: if col.is_temporal: return GenericDataType.TEMPORAL if col.is_numeric: @@ -1776,11 +1752,9 @@ def extract_column_dtype(col: "BaseColumn") -> GenericDataType: return GenericDataType.STRING -def indexed( - items: List[Any], key: Union[str, Callable[[Any], Any]] -) -> Dict[Any, List[Any]]: +def indexed(items: list[Any], key: str | Callable[[Any], Any]) -> dict[Any, list[Any]]: """Build an index for a list of objects""" - idx: Dict[Any, Any] = {} + idx: dict[Any, Any] = {} for item in items: key_ = getattr(item, key) if isinstance(key, str) else key(item) idx.setdefault(key_, []).append(item) @@ -1792,14 +1766,14 @@ def is_test() -> bool: def get_time_filter_status( - datasource: "BaseDatasource", - applied_time_extras: Dict[str, str], -) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]: - temporal_columns: Set[Any] = { + datasource: BaseDatasource, + applied_time_extras: dict[str, str], +) -> tuple[list[dict[str, str]], list[dict[str, str]]]: + temporal_columns: set[Any] = { col.column_name for col in datasource.columns if col.is_dttm } - applied: List[Dict[str, str]] = [] - rejected: List[Dict[str, str]] = [] + applied: list[dict[str, str]] = [] + rejected: list[dict[str, str]] = [] if time_column := applied_time_extras.get(ExtraFiltersTimeColumnType.TIME_COL): if time_column in temporal_columns: applied.append({"column": ExtraFiltersTimeColumnType.TIME_COL}) @@ -1844,14 +1818,14 @@ def format_list(items: Sequence[str], sep: str = ", ", quote: str = '"') -> str: return sep.join(f"{quote}{x.replace(quote, quote_escaped)}{quote}" for x in items) -def find_duplicates(items: Iterable[InputType]) -> List[InputType]: +def find_duplicates(items: Iterable[InputType]) -> list[InputType]: """Find duplicate items in an iterable.""" return [item for item, count in collections.Counter(items).items() if count > 1] def remove_duplicates( - items: Iterable[InputType], key: Optional[Callable[[InputType], Any]] = None -) -> List[InputType]: + items: Iterable[InputType], key: Callable[[InputType], Any] | None = None +) -> list[InputType]: """Remove duplicate items in an iterable.""" if not key: return list(dict.fromkeys(items).keys()) @@ -1868,9 +1842,9 @@ def remove_duplicates( @dataclass class DateColumn: col_label: str - timestamp_format: Optional[str] = None - offset: Optional[int] = None - time_shift: Optional[str] = None + timestamp_format: str | None = None + offset: int | None = None + time_shift: str | None = None def __hash__(self) -> int: return hash(self.col_label) @@ -1881,9 +1855,9 @@ class DateColumn: @classmethod def get_legacy_time_column( cls, - timestamp_format: Optional[str], - offset: Optional[int], - time_shift: Optional[str], + timestamp_format: str | None, + offset: int | None, + time_shift: str | None, ) -> DateColumn: return cls( timestamp_format=timestamp_format, @@ -1895,7 +1869,7 @@ class DateColumn: def normalize_dttm_col( df: pd.DataFrame, - dttm_cols: Tuple[DateColumn, ...] = tuple(), + dttm_cols: tuple[DateColumn, ...] = tuple(), ) -> None: for _col in dttm_cols: if _col.col_label not in df.columns: @@ -1925,7 +1899,7 @@ def normalize_dttm_col( df[_col.col_label] += parse_human_timedelta(_col.time_shift) -def parse_boolean_string(bool_str: Optional[str]) -> bool: +def parse_boolean_string(bool_str: str | None) -> bool: """ Convert a string representation of a true/false value into a boolean @@ -1956,7 +1930,7 @@ def parse_boolean_string(bool_str: Optional[str]) -> bool: def apply_max_row_limit( limit: int, - max_limit: Optional[int] = None, + max_limit: int | None = None, ) -> int: """ Override row limit if max global limit is defined @@ -1979,7 +1953,7 @@ def apply_max_row_limit( return max_limit -def create_zip(files: Dict[str, Any]) -> BytesIO: +def create_zip(files: dict[str, Any]) -> BytesIO: buf = BytesIO() with ZipFile(buf, "w") as bundle: for filename, contents in files.items(): @@ -1989,7 +1963,7 @@ def create_zip(files: Dict[str, Any]) -> BytesIO: return buf -def remove_extra_adhoc_filters(form_data: Dict[str, Any]) -> None: +def remove_extra_adhoc_filters(form_data: dict[str, Any]) -> None: """ Remove filters from slice data that originate from a filter box or native filter """ diff --git a/superset/utils/csv.py b/superset/utils/csv.py index a6c834b834..bab14058f2 100644 --- a/superset/utils/csv.py +++ b/superset/utils/csv.py @@ -17,7 +17,7 @@ import logging import re import urllib.request -from typing import Any, Dict, Optional +from typing import Any, Optional from urllib.error import URLError import numpy as np @@ -81,7 +81,7 @@ def df_to_escaped_csv(df: pd.DataFrame, **kwargs: Any) -> Any: def get_chart_csv_data( - chart_url: str, auth_cookies: Optional[Dict[str, str]] = None + chart_url: str, auth_cookies: Optional[dict[str, str]] = None ) -> Optional[bytes]: content = None if auth_cookies: @@ -98,7 +98,7 @@ def get_chart_csv_data( def get_chart_dataframe( - chart_url: str, auth_cookies: Optional[Dict[str, str]] = None + chart_url: str, auth_cookies: Optional[dict[str, str]] = None ) -> Optional[pd.DataFrame]: # Disable all the unnecessary-lambda violations in this function # pylint: disable=unnecessary-lambda diff --git a/superset/utils/dashboard_filter_scopes_converter.py b/superset/utils/dashboard_filter_scopes_converter.py index c0ee64370d..ce89b2a255 100644 --- a/superset/utils/dashboard_filter_scopes_converter.py +++ b/superset/utils/dashboard_filter_scopes_converter.py @@ -17,7 +17,7 @@ import json import logging from collections import defaultdict -from typing import Any, Dict, List +from typing import Any from shortid import ShortId @@ -27,11 +27,11 @@ logger = logging.getLogger(__name__) def convert_filter_scopes( - json_metadata: Dict[Any, Any], filter_boxes: List[Slice] -) -> Dict[int, Dict[str, Dict[str, Any]]]: + json_metadata: dict[Any, Any], filter_boxes: list[Slice] +) -> dict[int, dict[str, dict[str, Any]]]: filter_scopes = {} - immuned_by_id: List[int] = json_metadata.get("filter_immune_slices") or [] - immuned_by_column: Dict[str, List[int]] = defaultdict(list) + immuned_by_id: list[int] = json_metadata.get("filter_immune_slices") or [] + immuned_by_column: dict[str, list[int]] = defaultdict(list) for slice_id, columns in json_metadata.get( "filter_immune_slice_fields", {} ).items(): @@ -39,7 +39,7 @@ def convert_filter_scopes( immuned_by_column[column].append(int(slice_id)) def add_filter_scope( - filter_fields: Dict[str, Dict[str, Any]], filter_field: str, filter_id: int + filter_fields: dict[str, dict[str, Any]], filter_field: str, filter_id: int ) -> None: # in case filter field is invalid if isinstance(filter_field, str): @@ -54,7 +54,7 @@ def convert_filter_scopes( logging.info("slice [%i] has invalid field: %s", filter_id, filter_field) for filter_box in filter_boxes: - filter_fields: Dict[str, Dict[str, Any]] = {} + filter_fields: dict[str, dict[str, Any]] = {} filter_id = filter_box.id slice_params = json.loads(filter_box.params or "{}") configs = slice_params.get("filter_configs") or [] @@ -75,10 +75,10 @@ def convert_filter_scopes( def copy_filter_scopes( - old_to_new_slc_id_dict: Dict[int, int], - old_filter_scopes: Dict[int, Dict[str, Dict[str, Any]]], -) -> Dict[str, Dict[Any, Any]]: - new_filter_scopes: Dict[str, Dict[Any, Any]] = {} + old_to_new_slc_id_dict: dict[int, int], + old_filter_scopes: dict[int, dict[str, dict[str, Any]]], +) -> dict[str, dict[Any, Any]]: + new_filter_scopes: dict[str, dict[Any, Any]] = {} for filter_id, scopes in old_filter_scopes.items(): new_filter_key = old_to_new_slc_id_dict.get(int(filter_id)) if new_filter_key: @@ -93,10 +93,10 @@ def copy_filter_scopes( def convert_filter_scopes_to_native_filters( # pylint: disable=invalid-name,too-many-branches,too-many-locals,too-many-nested-blocks,too-many-statements - json_metadata: Dict[str, Any], - position_json: Dict[str, Any], - filter_boxes: List[Slice], -) -> List[Dict[str, Any]]: + json_metadata: dict[str, Any], + position_json: dict[str, Any], + filter_boxes: list[Slice], +) -> list[dict[str, Any]]: """ Convert the legacy filter scopes et al. to the native filter configuration. @@ -121,11 +121,11 @@ def convert_filter_scopes_to_native_filters( # pylint: disable=invalid-name,too filter_scopes = json_metadata.get("filter_scopes", {}) filter_box_ids = {filter_box.id for filter_box in filter_boxes} - filter_scope_by_key_and_field: Dict[str, Dict[str, Dict[str, Any]]] = defaultdict( + filter_scope_by_key_and_field: dict[str, dict[str, dict[str, Any]]] = defaultdict( dict ) - filter_by_key_and_field: Dict[str, Dict[str, Dict[str, Any]]] = defaultdict(dict) + filter_by_key_and_field: dict[str, dict[str, dict[str, Any]]] = defaultdict(dict) # Dense representation of filter scopes, falling back to chart level filter configs # if the respective filter scope is not defined at the dashboard level. @@ -150,7 +150,7 @@ def convert_filter_scopes_to_native_filters( # pylint: disable=invalid-name,too for field, filter_scope in filter_scope_by_key_and_field[key].items(): default = default_filters.get(key, {}).get(field) - fltr: Dict[str, Any] = { + fltr: dict[str, Any] = { "cascadeParentIds": [], "id": f"NATIVE_FILTER-{shortid.generate()}", "scope": { diff --git a/superset/utils/database.py b/superset/utils/database.py index 750d873d1c..70730554f3 100644 --- a/superset/utils/database.py +++ b/superset/utils/database.py @@ -17,7 +17,7 @@ from __future__ import annotations import logging -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING from flask import current_app @@ -32,7 +32,7 @@ logger = logging.getLogger(__name__) # TODO: duplicate code with DatabaseDao, below function should be moved or use dao def get_or_create_db( - database_name: str, sqlalchemy_uri: str, always_create: Optional[bool] = True + database_name: str, sqlalchemy_uri: str, always_create: bool | None = True ) -> Database: # pylint: disable=import-outside-toplevel from superset import db diff --git a/superset/utils/date_parser.py b/superset/utils/date_parser.py index 7cdc23784a..438e379a96 100644 --- a/superset/utils/date_parser.py +++ b/superset/utils/date_parser.py @@ -20,7 +20,7 @@ import re from datetime import datetime, timedelta from functools import lru_cache from time import struct_time -from typing import Dict, List, Optional, Tuple +from typing import Optional import pandas as pd import parsedatetime @@ -75,7 +75,7 @@ def parse_human_datetime(human_readable: str) -> datetime: return dttm -def normalize_time_delta(human_readable: str) -> Dict[str, int]: +def normalize_time_delta(human_readable: str) -> dict[str, int]: x_unit = r"^\s*([0-9]+)\s+(second|minute|hour|day|week|month|quarter|year)s?\s+(ago|later)*$" # pylint: disable=line-too-long,useless-suppression matched = re.match(x_unit, human_readable, re.IGNORECASE) if not matched: @@ -149,7 +149,7 @@ def get_since_until( # pylint: disable=too-many-arguments,too-many-locals,too-m time_shift: Optional[str] = None, relative_start: Optional[str] = None, relative_end: Optional[str] = None, -) -> Tuple[Optional[datetime], Optional[datetime]]: +) -> tuple[Optional[datetime], Optional[datetime]]: """Return `since` and `until` date time tuple from string representations of time_range, since, until and time_shift. @@ -227,7 +227,7 @@ def get_since_until( # pylint: disable=too-many-arguments,too-many-locals,too-m ] since_and_until_partition = [_.strip() for _ in time_range.split(separator, 1)] - since_and_until: List[Optional[str]] = [] + since_and_until: list[Optional[str]] = [] for part in since_and_until_partition: if not part: # if since or until is "", set as None diff --git a/superset/utils/decorators.py b/superset/utils/decorators.py index e77a559905..4ecd2eca98 100644 --- a/superset/utils/decorators.py +++ b/superset/utils/decorators.py @@ -17,9 +17,10 @@ from __future__ import annotations import time +from collections.abc import Iterator from contextlib import contextmanager from functools import wraps -from typing import Any, Callable, Dict, Iterator, Optional, TYPE_CHECKING, Union +from typing import Any, Callable, TYPE_CHECKING from flask import current_app, Response @@ -32,7 +33,7 @@ if TYPE_CHECKING: from superset.stats_logger import BaseStatsLogger -def statsd_gauge(metric_prefix: Optional[str] = None) -> Callable[..., Any]: +def statsd_gauge(metric_prefix: str | None = None) -> Callable[..., Any]: def decorate(f: Callable[..., Any]) -> Callable[..., Any]: """ Handle sending statsd gauge metric from any method or function @@ -83,13 +84,13 @@ def arghash(args: Any, kwargs: Any) -> int: return hash(sorted_args) -def debounce(duration: Union[float, int] = 0.1) -> Callable[..., Any]: +def debounce(duration: float | int = 0.1) -> Callable[..., Any]: """Ensure a function called with the same arguments executes only once per `duration` (default: 100ms). """ def decorate(f: Callable[..., Any]) -> Callable[..., Any]: - last: Dict[str, Any] = {"t": None, "input": None, "output": None} + last: dict[str, Any] = {"t": None, "input": None, "output": None} def wrapped(*args: Any, **kwargs: Any) -> Any: now = time.time() diff --git a/superset/utils/dict_import_export.py b/superset/utils/dict_import_export.py index 93070732e7..f3fb1bbd6c 100644 --- a/superset/utils/dict_import_export.py +++ b/superset/utils/dict_import_export.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import Any, Dict +from typing import Any from sqlalchemy.orm import Session @@ -26,7 +26,7 @@ DATABASES_KEY = "databases" logger = logging.getLogger(__name__) -def export_schema_to_dict(back_references: bool) -> Dict[str, Any]: +def export_schema_to_dict(back_references: bool) -> dict[str, Any]: """Exports the supported import/export schema to a dictionary""" databases = [ Database.export_schema(recursive=True, include_parent_ref=back_references) @@ -39,7 +39,7 @@ def export_schema_to_dict(back_references: bool) -> Dict[str, Any]: def export_to_dict( session: Session, recursive: bool, back_references: bool, include_defaults: bool -) -> Dict[str, Any]: +) -> dict[str, Any]: """Exports databases and druid clusters to a dictionary""" logger.info("Starting export") dbs = session.query(Database) diff --git a/superset/utils/encrypt.py b/superset/utils/encrypt.py index 52b784bb23..c812581ac4 100644 --- a/superset/utils/encrypt.py +++ b/superset/utils/encrypt.py @@ -16,7 +16,7 @@ # under the License. import logging from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask import Flask from flask_babel import lazy_gettext as _ @@ -31,9 +31,9 @@ class AbstractEncryptedFieldAdapter(ABC): # pylint: disable=too-few-public-meth @abstractmethod def create( self, - app_config: Optional[Dict[str, Any]], - *args: List[Any], - **kwargs: Optional[Dict[str, Any]], + app_config: Optional[dict[str, Any]], + *args: list[Any], + **kwargs: Optional[dict[str, Any]], ) -> TypeDecorator: pass @@ -43,9 +43,9 @@ class SQLAlchemyUtilsAdapter( # pylint: disable=too-few-public-methods ): def create( self, - app_config: Optional[Dict[str, Any]], - *args: List[Any], - **kwargs: Optional[Dict[str, Any]], + app_config: Optional[dict[str, Any]], + *args: list[Any], + **kwargs: Optional[dict[str, Any]], ) -> TypeDecorator: if app_config: return EncryptedType(*args, app_config["SECRET_KEY"], **kwargs) @@ -56,7 +56,7 @@ class SQLAlchemyUtilsAdapter( # pylint: disable=too-few-public-methods class EncryptedFieldFactory: def __init__(self) -> None: self._concrete_type_adapter: Optional[AbstractEncryptedFieldAdapter] = None - self._config: Optional[Dict[str, Any]] = None + self._config: Optional[dict[str, Any]] = None def init_app(self, app: Flask) -> None: self._config = app.config @@ -65,7 +65,7 @@ class EncryptedFieldFactory: ]() def create( - self, *args: List[Any], **kwargs: Optional[Dict[str, Any]] + self, *args: list[Any], **kwargs: Optional[dict[str, Any]] ) -> TypeDecorator: if self._concrete_type_adapter: return self._concrete_type_adapter.create(self._config, *args, **kwargs) @@ -81,14 +81,14 @@ class SecretsMigrator: self._previous_secret_key = previous_secret_key self._dialect: Dialect = db.engine.url.get_dialect() - def discover_encrypted_fields(self) -> Dict[str, Dict[str, EncryptedType]]: + def discover_encrypted_fields(self) -> dict[str, dict[str, EncryptedType]]: """ Iterates over SqlAlchemy's metadata, looking for EncryptedType columns along the way. Builds up a dict of table_name -> dict of col_name: enc type instance :return: """ - meta_info: Dict[str, Any] = {} + meta_info: dict[str, Any] = {} for table_name, table in self._db.metadata.tables.items(): for col_name, col in table.columns.items(): @@ -120,7 +120,7 @@ class SecretsMigrator: @staticmethod def _select_columns_from_table( - conn: Connection, column_names: List[str], table_name: str + conn: Connection, column_names: list[str], table_name: str ) -> Row: return conn.execute(f"SELECT id, {','.join(column_names)} FROM {table_name}") @@ -129,7 +129,7 @@ class SecretsMigrator: conn: Connection, row: Row, table_name: str, - columns: Dict[str, EncryptedType], + columns: dict[str, EncryptedType], ) -> None: """ Re encrypts all columns in a Row diff --git a/superset/utils/feature_flag_manager.py b/superset/utils/feature_flag_manager.py index 9874656722..ea295c776c 100644 --- a/superset/utils/feature_flag_manager.py +++ b/superset/utils/feature_flag_manager.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. from copy import deepcopy -from typing import Dict from flask import Flask @@ -25,7 +24,7 @@ class FeatureFlagManager: super().__init__() self._get_feature_flags_func = None self._is_feature_enabled_func = None - self._feature_flags: Dict[str, bool] = {} + self._feature_flags: dict[str, bool] = {} def init_app(self, app: Flask) -> None: self._get_feature_flags_func = app.config["GET_FEATURE_FLAGS_FUNC"] @@ -33,7 +32,7 @@ class FeatureFlagManager: self._feature_flags = app.config["DEFAULT_FEATURE_FLAGS"] self._feature_flags.update(app.config["FEATURE_FLAGS"]) - def get_feature_flags(self) -> Dict[str, bool]: + def get_feature_flags(self) -> dict[str, bool]: if self._get_feature_flags_func: return self._get_feature_flags_func(deepcopy(self._feature_flags)) if callable(self._is_feature_enabled_func): diff --git a/superset/utils/filters.py b/superset/utils/filters.py index 4772f49ba0..88154a40b3 100644 --- a/superset/utils/filters.py +++ b/superset/utils/filters.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Type +from typing import Any from flask_appbuilder import Model from sqlalchemy import or_ @@ -22,7 +22,7 @@ from sqlalchemy.sql.elements import BooleanClauseList def get_dataset_access_filters( - base_model: Type[Model], + base_model: type[Model], *args: Any, ) -> BooleanClauseList: # pylint: disable=import-outside-toplevel diff --git a/superset/utils/hashing.py b/superset/utils/hashing.py index 66983582ca..fff654263e 100644 --- a/superset/utils/hashing.py +++ b/superset/utils/hashing.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import hashlib -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Optional import simplejson as json @@ -25,7 +25,7 @@ def md5_sha_from_str(val: str) -> str: def md5_sha_from_dict( - obj: Dict[Any, Any], + obj: dict[Any, Any], ignore_nan: bool = False, default: Optional[Callable[[Any], Any]] = None, ) -> str: diff --git a/superset/utils/log.py b/superset/utils/log.py index f2379fe11c..5430accb43 100644 --- a/superset/utils/log.py +++ b/superset/utils/log.py @@ -22,25 +22,14 @@ import json import logging import textwrap from abc import ABC, abstractmethod +from collections.abc import Iterator from contextlib import contextmanager from datetime import datetime, timedelta -from typing import ( - Any, - Callable, - cast, - Dict, - Iterator, - Optional, - Tuple, - Type, - TYPE_CHECKING, - Union, -) +from typing import Any, Callable, cast, Literal, TYPE_CHECKING from flask import current_app, g, request from flask_appbuilder.const import API_URI_RIS_KEY from sqlalchemy.exc import SQLAlchemyError -from typing_extensions import Literal from superset.extensions import stats_logger_manager from superset.utils.core import get_user_id, LoggerLevel @@ -51,12 +40,12 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -def collect_request_payload() -> Dict[str, Any]: +def collect_request_payload() -> dict[str, Any]: """Collect log payload identifiable from request context""" if not request: return {} - payload: Dict[str, Any] = { + payload: dict[str, Any] = { "path": request.path, **request.form.to_dict(), # url search params can overwrite POST body @@ -81,7 +70,7 @@ def collect_request_payload() -> Dict[str, Any]: def get_logger_from_status( status: int, -) -> Tuple[Callable[..., None], str]: +) -> tuple[Callable[..., None], str]: """ Return logger method by status of exception. Maps logger level to status code level @@ -101,10 +90,10 @@ class AbstractEventLogger(ABC): def __call__( self, action: str, - object_ref: Optional[str] = None, + object_ref: str | None = None, log_to_statsd: bool = True, - duration: Optional[timedelta] = None, - **payload_override: Dict[str, Any], + duration: timedelta | None = None, + **payload_override: dict[str, Any], ) -> object: # pylint: disable=W0201 self.action = action @@ -130,12 +119,12 @@ class AbstractEventLogger(ABC): @abstractmethod def log( # pylint: disable=too-many-arguments self, - user_id: Optional[int], + user_id: int | None, action: str, - dashboard_id: Optional[int], - duration_ms: Optional[int], - slice_id: Optional[int], - referrer: Optional[str], + dashboard_id: int | None, + duration_ms: int | None, + slice_id: int | None, + referrer: str | None, *args: Any, **kwargs: Any, ) -> None: @@ -144,10 +133,10 @@ class AbstractEventLogger(ABC): def log_with_context( # pylint: disable=too-many-locals self, action: str, - duration: Optional[timedelta] = None, - object_ref: Optional[str] = None, + duration: timedelta | None = None, + object_ref: str | None = None, log_to_statsd: bool = True, - **payload_override: Optional[Dict[str, Any]], + **payload_override: dict[str, Any] | None, ) -> None: # pylint: disable=import-outside-toplevel from superset.views.core import get_form_data @@ -176,7 +165,7 @@ class AbstractEventLogger(ABC): if payload_override: payload.update(payload_override) - dashboard_id: Optional[int] = None + dashboard_id: int | None = None try: dashboard_id = int(payload.get("dashboard_id")) # type: ignore except (TypeError, ValueError): @@ -218,7 +207,7 @@ class AbstractEventLogger(ABC): def log_context( self, action: str, - object_ref: Optional[str] = None, + object_ref: str | None = None, log_to_statsd: bool = True, ) -> Iterator[Callable[..., None]]: """ @@ -242,9 +231,9 @@ class AbstractEventLogger(ABC): def _wrapper( self, f: Callable[..., Any], - action: Optional[Union[str, Callable[..., str]]] = None, - object_ref: Optional[Union[str, Callable[..., str], Literal[False]]] = None, - allow_extra_payload: Optional[bool] = False, + action: str | Callable[..., str] | None = None, + object_ref: str | Callable[..., str] | Literal[False] | None = None, + allow_extra_payload: bool | None = False, **wrapper_kwargs: Any, ) -> Callable[..., Any]: @functools.wraps(f) @@ -314,7 +303,7 @@ def get_event_logger_from_cfg_value(cfg_value: Any) -> AbstractEventLogger: ) ) - event_logger_type = cast(Type[Any], cfg_value) + event_logger_type = cast(type[Any], cfg_value) result = event_logger_type() # Verify that we have a valid logger impl @@ -333,12 +322,12 @@ class DBEventLogger(AbstractEventLogger): def log( # pylint: disable=too-many-arguments,too-many-locals self, - user_id: Optional[int], + user_id: int | None, action: str, - dashboard_id: Optional[int], - duration_ms: Optional[int], - slice_id: Optional[int], - referrer: Optional[str], + dashboard_id: int | None, + duration_ms: int | None, + slice_id: int | None, + referrer: str | None, *args: Any, **kwargs: Any, ) -> None: @@ -348,7 +337,7 @@ class DBEventLogger(AbstractEventLogger): records = kwargs.get("records", []) logs = [] for record in records: - json_string: Optional[str] + json_string: str | None try: json_string = json.dumps(record) except Exception: # pylint: disable=broad-except diff --git a/superset/utils/machine_auth.py b/superset/utils/machine_auth.py index 02c04abe6a..7e45fc0f31 100644 --- a/superset/utils/machine_auth.py +++ b/superset/utils/machine_auth.py @@ -19,7 +19,7 @@ from __future__ import annotations import importlib import logging -from typing import Callable, Dict, TYPE_CHECKING +from typing import Callable, TYPE_CHECKING from flask import current_app, Flask, request, Response, session from flask_login import login_user @@ -71,7 +71,7 @@ class MachineAuthProvider: return driver @staticmethod - def get_auth_cookies(user: User) -> Dict[str, str]: + def get_auth_cookies(user: User) -> dict[str, str]: # Login with the user specified to get the reports with current_app.test_request_context("/login"): login_user(user) diff --git a/superset/utils/mock_data.py b/superset/utils/mock_data.py index 4b156cc10c..7462d14540 100644 --- a/superset/utils/mock_data.py +++ b/superset/utils/mock_data.py @@ -21,8 +21,9 @@ import os import random import string import sys +from collections.abc import Iterator from datetime import date, datetime, time, timedelta -from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Type +from typing import Any, Callable, cast, Optional from uuid import uuid4 import sqlalchemy.sql.sqltypes @@ -39,17 +40,14 @@ from superset import db logger = logging.getLogger(__name__) -ColumnInfo = TypedDict( - "ColumnInfo", - { - "name": str, - "type": VisitableType, - "nullable": bool, - "default": Optional[Any], - "autoincrement": str, - "primary_key": int, - }, -) + +class ColumnInfo(TypedDict): + name: str + type: VisitableType + nullable: bool + default: Optional[Any] + autoincrement: str + primary_key: int example_column = { @@ -167,7 +165,7 @@ def get_type_generator( # pylint: disable=too-many-return-statements,too-many-b def add_data( - columns: Optional[List[ColumnInfo]], + columns: Optional[list[ColumnInfo]], num_rows: int, table_name: str, append: bool = True, @@ -212,16 +210,16 @@ def add_data( engine.execute(table.insert(), data) -def get_column_objects(columns: List[ColumnInfo]) -> List[Column]: +def get_column_objects(columns: list[ColumnInfo]) -> list[Column]: out = [] for column in columns: - kwargs = cast(Dict[str, Any], column.copy()) + kwargs = cast(dict[str, Any], column.copy()) kwargs["type_"] = kwargs.pop("type") out.append(Column(**kwargs)) return out -def generate_data(columns: List[ColumnInfo], num_rows: int) -> List[Dict[str, Any]]: +def generate_data(columns: list[ColumnInfo], num_rows: int) -> list[dict[str, Any]]: keys = [column["name"] for column in columns] return [ dict(zip(keys, row)) @@ -229,13 +227,13 @@ def generate_data(columns: List[ColumnInfo], num_rows: int) -> List[Dict[str, An ] -def generate_column_data(column: ColumnInfo, num_rows: int) -> List[Any]: +def generate_column_data(column: ColumnInfo, num_rows: int) -> list[Any]: gen = get_type_generator(column["type"]) return [gen() for _ in range(num_rows)] def add_sample_rows( - session: Session, model: Type[Model], count: int + session: Session, model: type[Model], count: int ) -> Iterator[Model]: """ Add entities of a given model. diff --git a/superset/utils/network.py b/superset/utils/network.py index 7a1aea5a71..fea3cfc6b2 100644 --- a/superset/utils/network.py +++ b/superset/utils/network.py @@ -32,10 +32,10 @@ def is_port_open(host: str, port: int) -> bool: s = socket.socket(af, socket.SOCK_STREAM) try: s.settimeout(PORT_TIMEOUT) - s.connect((sockaddr)) + s.connect(sockaddr) s.shutdown(socket.SHUT_RDWR) return True - except socket.error as _: + except OSError as _: continue finally: s.close() diff --git a/superset/utils/pandas_postprocessing/aggregate.py b/superset/utils/pandas_postprocessing/aggregate.py index a863d260c5..1116e4ec70 100644 --- a/superset/utils/pandas_postprocessing/aggregate.py +++ b/superset/utils/pandas_postprocessing/aggregate.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List +from typing import Any from pandas import DataFrame @@ -26,7 +26,7 @@ from superset.utils.pandas_postprocessing.utils import ( @validate_column_args("groupby") def aggregate( - df: DataFrame, groupby: List[str], aggregates: Dict[str, Dict[str, Any]] + df: DataFrame, groupby: list[str], aggregates: dict[str, dict[str, Any]] ) -> DataFrame: """ Apply aggregations to a DataFrame. diff --git a/superset/utils/pandas_postprocessing/boxplot.py b/superset/utils/pandas_postprocessing/boxplot.py index 399cf569fb..f9fed40e59 100644 --- a/superset/utils/pandas_postprocessing/boxplot.py +++ b/superset/utils/pandas_postprocessing/boxplot.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Optional, Union import numpy as np from flask_babel import gettext as _ @@ -27,11 +27,11 @@ from superset.utils.pandas_postprocessing.aggregate import aggregate def boxplot( df: DataFrame, - groupby: List[str], - metrics: List[str], + groupby: list[str], + metrics: list[str], whisker_type: PostProcessingBoxplotWhiskerType, percentiles: Optional[ - Union[List[Union[int, float]], Tuple[Union[int, float], Union[int, float]]] + Union[list[Union[int, float]], tuple[Union[int, float], Union[int, float]]] ] = None, ) -> DataFrame: """ @@ -102,12 +102,12 @@ def boxplot( whisker_high = np.max whisker_low = np.min - def outliers(series: Series) -> Set[float]: + def outliers(series: Series) -> set[float]: above = series[series > whisker_high(series)] below = series[series < whisker_low(series)] return above.tolist() + below.tolist() - operators: Dict[str, Callable[[Any], Any]] = { + operators: dict[str, Callable[[Any], Any]] = { "mean": np.mean, "median": np.median, "max": whisker_high, @@ -117,7 +117,7 @@ def boxplot( "count": np.ma.count, "outliers": outliers, } - aggregates: Dict[str, Dict[str, Union[str, Callable[..., Any]]]] = { + aggregates: dict[str, dict[str, Union[str, Callable[..., Any]]]] = { f"{metric}__{operator_name}": {"column": metric, "operator": operator} for operator_name, operator in operators.items() for metric in metrics diff --git a/superset/utils/pandas_postprocessing/compare.py b/superset/utils/pandas_postprocessing/compare.py index f7c8365508..b20682027f 100644 --- a/superset/utils/pandas_postprocessing/compare.py +++ b/superset/utils/pandas_postprocessing/compare.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import List, Optional +from typing import Optional import pandas as pd from flask_babel import gettext as _ @@ -29,8 +29,8 @@ from superset.utils.pandas_postprocessing.utils import validate_column_args @validate_column_args("source_columns", "compare_columns") def compare( # pylint: disable=too-many-arguments df: DataFrame, - source_columns: List[str], - compare_columns: List[str], + source_columns: list[str], + compare_columns: list[str], compare_type: PandasPostprocessingCompare, drop_original_columns: Optional[bool] = False, precision: Optional[int] = 4, diff --git a/superset/utils/pandas_postprocessing/contribution.py b/superset/utils/pandas_postprocessing/contribution.py index f8519f39a9..d383312f75 100644 --- a/superset/utils/pandas_postprocessing/contribution.py +++ b/superset/utils/pandas_postprocessing/contribution.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from decimal import Decimal -from typing import List, Optional +from typing import Optional from flask_babel import gettext as _ from pandas import DataFrame @@ -31,8 +31,8 @@ def contribution( orientation: Optional[ PostProcessingContributionOrientation ] = PostProcessingContributionOrientation.COLUMN, - columns: Optional[List[str]] = None, - rename_columns: Optional[List[str]] = None, + columns: Optional[list[str]] = None, + rename_columns: Optional[list[str]] = None, ) -> DataFrame: """ Calculate cell contribution to row/column total for numeric columns. diff --git a/superset/utils/pandas_postprocessing/cum.py b/superset/utils/pandas_postprocessing/cum.py index b94f048e5c..128fa970f5 100644 --- a/superset/utils/pandas_postprocessing/cum.py +++ b/superset/utils/pandas_postprocessing/cum.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Dict from flask_babel import gettext as _ from pandas import DataFrame @@ -31,7 +30,7 @@ from superset.utils.pandas_postprocessing.utils import ( def cum( df: DataFrame, operator: str, - columns: Dict[str, str], + columns: dict[str, str], ) -> DataFrame: """ Calculate cumulative sum/product/min/max for select columns. diff --git a/superset/utils/pandas_postprocessing/diff.py b/superset/utils/pandas_postprocessing/diff.py index 0cead2de8d..de68d39439 100644 --- a/superset/utils/pandas_postprocessing/diff.py +++ b/superset/utils/pandas_postprocessing/diff.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Dict from pandas import DataFrame @@ -28,7 +27,7 @@ from superset.utils.pandas_postprocessing.utils import ( @validate_column_args("columns") def diff( df: DataFrame, - columns: Dict[str, str], + columns: dict[str, str], periods: int = 1, axis: PandasAxis = PandasAxis.ROW, ) -> DataFrame: diff --git a/superset/utils/pandas_postprocessing/flatten.py b/superset/utils/pandas_postprocessing/flatten.py index da9954ef11..40db86db06 100644 --- a/superset/utils/pandas_postprocessing/flatten.py +++ b/superset/utils/pandas_postprocessing/flatten.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. -from collections.abc import Iterable -from typing import Any, Sequence, Union +from collections.abc import Iterable, Sequence +from typing import Any, Union import pandas as pd diff --git a/superset/utils/pandas_postprocessing/geography.py b/superset/utils/pandas_postprocessing/geography.py index 33a27c2df4..79046cb71a 100644 --- a/superset/utils/pandas_postprocessing/geography.py +++ b/superset/utils/pandas_postprocessing/geography.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Optional, Tuple +from typing import Optional import geohash as geohash_lib from flask_babel import gettext as _ @@ -95,7 +95,7 @@ def geodetic_parse( :return: DataFrame with decoded longitudes and latitudes """ - def _parse_location(location: str) -> Tuple[float, float, float]: + def _parse_location(location: str) -> tuple[float, float, float]: """ Parse a string containing a geodetic point and return latitude, longitude and altitude diff --git a/superset/utils/pandas_postprocessing/pivot.py b/superset/utils/pandas_postprocessing/pivot.py index df5fa7e37c..28e8ff380f 100644 --- a/superset/utils/pandas_postprocessing/pivot.py +++ b/superset/utils/pandas_postprocessing/pivot.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Optional +from typing import Any, Optional from flask_babel import gettext as _ from pandas import DataFrame @@ -30,9 +30,9 @@ from superset.utils.pandas_postprocessing.utils import ( @validate_column_args("index", "columns") def pivot( # pylint: disable=too-many-arguments df: DataFrame, - index: List[str], - aggregates: Dict[str, Dict[str, Any]], - columns: Optional[List[str]] = None, + index: list[str], + aggregates: dict[str, dict[str, Any]], + columns: Optional[list[str]] = None, metric_fill_value: Optional[Any] = None, column_fill_value: Optional[str] = NULL_STRING, drop_missing_columns: Optional[bool] = True, diff --git a/superset/utils/pandas_postprocessing/rename.py b/superset/utils/pandas_postprocessing/rename.py index 0e35a651a8..4bcd19782c 100644 --- a/superset/utils/pandas_postprocessing/rename.py +++ b/superset/utils/pandas_postprocessing/rename.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Dict, Optional, Union +from typing import Optional, Union import pandas as pd from flask_babel import gettext as _ @@ -27,7 +27,7 @@ from superset.utils.pandas_postprocessing.utils import validate_column_args @validate_column_args("columns") def rename( df: pd.DataFrame, - columns: Dict[str, Union[str, None]], + columns: dict[str, Union[str, None]], inplace: bool = False, level: Optional[Level] = None, ) -> pd.DataFrame: diff --git a/superset/utils/pandas_postprocessing/rolling.py b/superset/utils/pandas_postprocessing/rolling.py index 885032eb17..f93a047be9 100644 --- a/superset/utils/pandas_postprocessing/rolling.py +++ b/superset/utils/pandas_postprocessing/rolling.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, Optional, Union +from typing import Any, Optional, Union from flask_babel import gettext as _ from pandas import DataFrame @@ -31,9 +31,9 @@ from superset.utils.pandas_postprocessing.utils import ( def rolling( # pylint: disable=too-many-arguments df: DataFrame, rolling_type: str, - columns: Dict[str, str], + columns: dict[str, str], window: Optional[int] = None, - rolling_type_options: Optional[Dict[str, Any]] = None, + rolling_type_options: Optional[dict[str, Any]] = None, center: bool = False, win_type: Optional[str] = None, min_periods: Optional[int] = None, @@ -62,7 +62,7 @@ def rolling( # pylint: disable=too-many-arguments rolling_type_options = rolling_type_options or {} df_rolling = df.loc[:, columns.keys()] - kwargs: Dict[str, Union[str, int]] = {} + kwargs: dict[str, Union[str, int]] = {} if window is None: raise InvalidPostProcessingError(_("Undefined window for rolling operation")) if window == 0: diff --git a/superset/utils/pandas_postprocessing/select.py b/superset/utils/pandas_postprocessing/select.py index 59fe886d4d..c4e02508df 100644 --- a/superset/utils/pandas_postprocessing/select.py +++ b/superset/utils/pandas_postprocessing/select.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Dict, List, Optional +from typing import Optional from pandas import DataFrame @@ -24,9 +24,9 @@ from superset.utils.pandas_postprocessing.utils import validate_column_args @validate_column_args("columns", "drop", "rename") def select( df: DataFrame, - columns: Optional[List[str]] = None, - exclude: Optional[List[str]] = None, - rename: Optional[Dict[str, str]] = None, + columns: Optional[list[str]] = None, + exclude: Optional[list[str]] = None, + rename: Optional[dict[str, str]] = None, ) -> DataFrame: """ Only select a subset of columns in the original dataset. Can be useful for diff --git a/superset/utils/pandas_postprocessing/sort.py b/superset/utils/pandas_postprocessing/sort.py index 66041a7166..b6470c3546 100644 --- a/superset/utils/pandas_postprocessing/sort.py +++ b/superset/utils/pandas_postprocessing/sort.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import List, Optional, Union +from typing import Optional, Union from pandas import DataFrame @@ -26,8 +26,8 @@ from superset.utils.pandas_postprocessing.utils import validate_column_args def sort( df: DataFrame, is_sort_index: bool = False, - by: Optional[Union[List[str], str]] = None, - ascending: Union[List[bool], bool] = True, + by: Optional[Union[list[str], str]] = None, + ascending: Union[list[bool], bool] = True, ) -> DataFrame: """ Sort a DataFrame. diff --git a/superset/utils/pandas_postprocessing/utils.py b/superset/utils/pandas_postprocessing/utils.py index 2b754fbbef..37d53697cb 100644 --- a/superset/utils/pandas_postprocessing/utils.py +++ b/superset/utils/pandas_postprocessing/utils.py @@ -14,8 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from collections.abc import Sequence from functools import partial -from typing import Any, Callable, Dict, Sequence +from typing import Any, Callable import numpy as np import pandas as pd @@ -24,7 +25,7 @@ from pandas import DataFrame, NamedAgg from superset.exceptions import InvalidPostProcessingError -NUMPY_FUNCTIONS: Dict[str, Callable[..., Any]] = { +NUMPY_FUNCTIONS: dict[str, Callable[..., Any]] = { "average": np.average, "argmin": np.argmin, "argmax": np.argmax, @@ -133,8 +134,8 @@ def validate_column_args(*argnames: str) -> Callable[..., Any]: def _get_aggregate_funcs( df: DataFrame, - aggregates: Dict[str, Dict[str, Any]], -) -> Dict[str, NamedAgg]: + aggregates: dict[str, dict[str, Any]], +) -> dict[str, NamedAgg]: """ Converts a set of aggregate config objects into functions that pandas can use as aggregators. Currently only numpy aggregators are supported. @@ -143,7 +144,7 @@ def _get_aggregate_funcs( :param aggregates: Mapping from column name to aggregate config. :return: Mapping from metric name to function that takes a single input argument. """ - agg_funcs: Dict[str, NamedAgg] = {} + agg_funcs: dict[str, NamedAgg] = {} for name, agg_obj in aggregates.items(): column = agg_obj.get("column", name) if column not in df: @@ -180,7 +181,7 @@ def _get_aggregate_funcs( def _append_columns( - base_df: DataFrame, append_df: DataFrame, columns: Dict[str, str] + base_df: DataFrame, append_df: DataFrame, columns: dict[str, str] ) -> DataFrame: """ Function for adding columns from one DataFrame to another DataFrame. Calls the diff --git a/superset/utils/retries.py b/superset/utils/retries.py index d1c2947146..8a1e6b95ea 100644 --- a/superset/utils/retries.py +++ b/superset/utils/retries.py @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Callable, Dict, Generator, List, Optional, Type +from collections.abc import Generator +from typing import Any, Callable, Optional import backoff @@ -24,9 +25,9 @@ def retry_call( func: Callable[..., Any], *args: Any, strategy: Callable[..., Generator[int, None, None]] = backoff.constant, - exception: Type[Exception] = Exception, - fargs: Optional[List[Any]] = None, - fkwargs: Optional[Dict[str, Any]] = None, + exception: type[Exception] = Exception, + fargs: Optional[list[Any]] = None, + fkwargs: Optional[dict[str, Any]] = None, **kwargs: Any ) -> Any: """ diff --git a/superset/utils/screenshots.py b/superset/utils/screenshots.py index 88b97901b2..5c699e9e19 100644 --- a/superset/utils/screenshots.py +++ b/superset/utils/screenshots.py @@ -18,7 +18,7 @@ from __future__ import annotations import logging from io import BytesIO -from typing import Optional, TYPE_CHECKING, Union +from typing import TYPE_CHECKING from flask import current_app @@ -53,16 +53,16 @@ class BaseScreenshot: def __init__(self, url: str, digest: str): self.digest: str = digest self.url = url - self.screenshot: Optional[bytes] = None + self.screenshot: bytes | None = None - def driver(self, window_size: Optional[WindowSize] = None) -> WebDriverProxy: + def driver(self, window_size: WindowSize | None = None) -> WebDriverProxy: window_size = window_size or self.window_size return WebDriverProxy(self.driver_type, window_size) def cache_key( self, - window_size: Optional[Union[bool, WindowSize]] = None, - thumb_size: Optional[Union[bool, WindowSize]] = None, + window_size: bool | WindowSize | None = None, + thumb_size: bool | WindowSize | None = None, ) -> str: window_size = window_size or self.window_size thumb_size = thumb_size or self.thumb_size @@ -76,8 +76,8 @@ class BaseScreenshot: return md5_sha_from_dict(args) def get_screenshot( - self, user: User, window_size: Optional[WindowSize] = None - ) -> Optional[bytes]: + self, user: User, window_size: WindowSize | None = None + ) -> bytes | None: driver = self.driver(window_size) self.screenshot = driver.get_screenshot(self.url, self.element, user) return self.screenshot @@ -86,8 +86,8 @@ class BaseScreenshot: self, user: User = None, cache: Cache = None, - thumb_size: Optional[WindowSize] = None, - ) -> Optional[BytesIO]: + thumb_size: WindowSize | None = None, + ) -> BytesIO | None: """ Get thumbnail screenshot has BytesIO from cache or fetch @@ -95,7 +95,7 @@ class BaseScreenshot: :param cache: The cache to use :param thumb_size: Override thumbnail site """ - payload: Optional[bytes] = None + payload: bytes | None = None cache_key = self.cache_key(self.window_size, thumb_size) if cache: payload = cache.get(cache_key) @@ -112,14 +112,14 @@ class BaseScreenshot: def get_from_cache( self, cache: Cache, - window_size: Optional[WindowSize] = None, - thumb_size: Optional[WindowSize] = None, - ) -> Optional[BytesIO]: + window_size: WindowSize | None = None, + thumb_size: WindowSize | None = None, + ) -> BytesIO | None: cache_key = self.cache_key(window_size, thumb_size) return self.get_from_cache_key(cache, cache_key) @staticmethod - def get_from_cache_key(cache: Cache, cache_key: str) -> Optional[BytesIO]: + def get_from_cache_key(cache: Cache, cache_key: str) -> BytesIO | None: logger.info("Attempting to get from cache: %s", cache_key) if payload := cache.get(cache_key): return BytesIO(payload) @@ -129,11 +129,11 @@ class BaseScreenshot: def compute_and_cache( # pylint: disable=too-many-arguments self, user: User = None, - window_size: Optional[WindowSize] = None, - thumb_size: Optional[WindowSize] = None, + window_size: WindowSize | None = None, + thumb_size: WindowSize | None = None, cache: Cache = None, force: bool = True, - ) -> Optional[bytes]: + ) -> bytes | None: """ Fetches the screenshot, computes the thumbnail and caches the result @@ -178,7 +178,7 @@ class BaseScreenshot: cls, img_bytes: bytes, output: str = "png", - thumb_size: Optional[WindowSize] = None, + thumb_size: WindowSize | None = None, crop: bool = True, ) -> bytes: thumb_size = thumb_size or cls.thumb_size @@ -207,8 +207,8 @@ class ChartScreenshot(BaseScreenshot): self, url: str, digest: str, - window_size: Optional[WindowSize] = None, - thumb_size: Optional[WindowSize] = None, + window_size: WindowSize | None = None, + thumb_size: WindowSize | None = None, ): # Chart reports are in standalone="true" mode url = modify_url_query( @@ -228,8 +228,8 @@ class DashboardScreenshot(BaseScreenshot): self, url: str, digest: str, - window_size: Optional[WindowSize] = None, - thumb_size: Optional[WindowSize] = None, + window_size: WindowSize | None = None, + thumb_size: WindowSize | None = None, ): # per the element above, dashboard screenshots # should always capture in standalone diff --git a/superset/utils/ssh_tunnel.py b/superset/utils/ssh_tunnel.py index 48ada98dcc..8421350f8c 100644 --- a/superset/utils/ssh_tunnel.py +++ b/superset/utils/ssh_tunnel.py @@ -15,13 +15,13 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict +from typing import Any from superset.constants import PASSWORD_MASK from superset.databases.ssh_tunnel.models import SSHTunnel -def mask_password_info(ssh_tunnel: Dict[str, Any]) -> Dict[str, Any]: +def mask_password_info(ssh_tunnel: dict[str, Any]) -> dict[str, Any]: if ssh_tunnel.pop("password", None) is not None: ssh_tunnel["password"] = PASSWORD_MASK if ssh_tunnel.pop("private_key", None) is not None: @@ -32,8 +32,8 @@ def mask_password_info(ssh_tunnel: Dict[str, Any]) -> Dict[str, Any]: def unmask_password_info( - ssh_tunnel: Dict[str, Any], model: SSHTunnel -) -> Dict[str, Any]: + ssh_tunnel: dict[str, Any], model: SSHTunnel +) -> dict[str, Any]: if ssh_tunnel.get("password") == PASSWORD_MASK: ssh_tunnel["password"] = model.password if ssh_tunnel.get("private_key") == PASSWORD_MASK: diff --git a/superset/utils/url_map_converters.py b/superset/utils/url_map_converters.py index fbd9c800b0..11e40267b3 100644 --- a/superset/utils/url_map_converters.py +++ b/superset/utils/url_map_converters.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, List +from typing import Any from werkzeug.routing import BaseConverter, Map @@ -22,7 +22,7 @@ from superset.tags.models import ObjectTypes class RegexConverter(BaseConverter): - def __init__(self, url_map: Map, *items: List[str]) -> None: + def __init__(self, url_map: Map, *items: list[str]) -> None: super().__init__(url_map) self.regex = items[0] diff --git a/superset/utils/webdriver.py b/superset/utils/webdriver.py index 05dbee6740..c302ab8921 100644 --- a/superset/utils/webdriver.py +++ b/superset/utils/webdriver.py @@ -20,7 +20,7 @@ from __future__ import annotations import logging from enum import Enum from time import sleep -from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from flask import current_app from selenium.common.exceptions import ( @@ -37,7 +37,7 @@ from selenium.webdriver.support.ui import WebDriverWait from superset.extensions import machine_auth_provider_factory from superset.utils.retries import retry_call -WindowSize = Tuple[int, int] +WindowSize = tuple[int, int] logger = logging.getLogger(__name__) @@ -56,7 +56,7 @@ class ChartStandaloneMode(Enum): SHOW_NAV = 0 -def find_unexpected_errors(driver: WebDriver) -> List[str]: +def find_unexpected_errors(driver: WebDriver) -> list[str]: error_messages = [] try: @@ -111,7 +111,7 @@ def find_unexpected_errors(driver: WebDriver) -> List[str]: class WebDriverProxy: - def __init__(self, driver_type: str, window: Optional[WindowSize] = None): + def __init__(self, driver_type: str, window: WindowSize | None = None): self._driver_type = driver_type self._window: WindowSize = window or (800, 600) self._screenshot_locate_wait = current_app.config["SCREENSHOT_LOCATE_WAIT"] @@ -124,7 +124,7 @@ class WebDriverProxy: options = firefox.options.Options() profile = FirefoxProfile() profile.set_preference("layout.css.devPixelsPerPx", str(pixel_density)) - kwargs: Dict[Any, Any] = dict(options=options, firefox_profile=profile) + kwargs: dict[Any, Any] = dict(options=options, firefox_profile=profile) elif self._driver_type == "chrome": driver_class = chrome.webdriver.WebDriver options = chrome.options.Options() @@ -164,13 +164,11 @@ class WebDriverProxy: except Exception: # pylint: disable=broad-except pass - def get_screenshot( - self, url: str, element_name: str, user: User - ) -> Optional[bytes]: + def get_screenshot(self, url: str, element_name: str, user: User) -> bytes | None: driver = self.auth(user) driver.set_window_size(*self._window) driver.get(url) - img: Optional[bytes] = None + img: bytes | None = None selenium_headstart = current_app.config["SCREENSHOT_SELENIUM_HEADSTART"] logger.debug("Sleeping for %i seconds", selenium_headstart) sleep(selenium_headstart) diff --git a/superset/views/__init__.py b/superset/views/__init__.py index 5247f215c1..b5a21c77f0 100644 --- a/superset/views/__init__.py +++ b/superset/views/__init__.py @@ -21,8 +21,6 @@ from . import ( base, core, css_templates, - dashboard, - datasource, dynamic_plugins, health, redirects, diff --git a/superset/views/all_entities.py b/superset/views/all_entities.py index 4031d81d21..3de53be461 100644 --- a/superset/views/all_entities.py +++ b/superset/views/all_entities.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals import logging diff --git a/superset/views/base.py b/superset/views/base.py index 97d8da1d69..3a72096ac2 100644 --- a/superset/views/base.py +++ b/superset/views/base.py @@ -20,7 +20,7 @@ import logging import os import traceback from datetime import datetime -from typing import Any, Callable, cast, Dict, List, Optional, Union +from typing import Any, Callable, cast, Optional, Union import simplejson as json import yaml @@ -140,11 +140,11 @@ def get_error_msg() -> str: def json_error_response( msg: Optional[str] = None, status: int = 500, - payload: Optional[Dict[str, Any]] = None, + payload: Optional[dict[str, Any]] = None, link: Optional[str] = None, ) -> FlaskResponse: if not payload: - payload = {"error": "{}".format(msg)} + payload = {"error": f"{msg}"} if link: payload["link"] = link @@ -156,9 +156,9 @@ def json_error_response( def json_errors_response( - errors: List[SupersetError], + errors: list[SupersetError], status: int = 500, - payload: Optional[Dict[str, Any]] = None, + payload: Optional[dict[str, Any]] = None, ) -> FlaskResponse: if not payload: payload = {} @@ -182,7 +182,7 @@ def data_payload_response(payload_json: str, has_error: bool = False) -> FlaskRe def generate_download_headers( extension: str, filename: Optional[str] = None -) -> Dict[str, Any]: +) -> dict[str, Any]: filename = filename if filename else datetime.now().strftime("%Y%m%d_%H%M%S") content_disp = f"attachment; filename={filename}.{extension}" headers = {"Content-Disposition": content_disp} @@ -332,7 +332,7 @@ class BaseSupersetView(BaseView): ) -def menu_data(user: User) -> Dict[str, Any]: +def menu_data(user: User) -> dict[str, Any]: menu = appbuilder.menu.get_data() languages = {} @@ -396,7 +396,7 @@ def menu_data(user: User) -> Dict[str, Any]: @cache_manager.cache.memoize(timeout=60) -def cached_common_bootstrap_data(user: User) -> Dict[str, Any]: +def cached_common_bootstrap_data(user: User) -> dict[str, Any]: """Common data always sent to the client The function is memoized as the return value only changes when user permissions @@ -439,7 +439,7 @@ def cached_common_bootstrap_data(user: User) -> Dict[str, Any]: return bootstrap_data -def common_bootstrap_payload(user: User) -> Dict[str, Any]: +def common_bootstrap_payload(user: User) -> dict[str, Any]: return { **(cached_common_bootstrap_data(user)), "flash_messages": get_flashed_messages(with_categories=True), @@ -548,7 +548,7 @@ def show_unexpected_exception(ex: Exception) -> FlaskResponse: @superset_app.context_processor -def get_common_bootstrap_data() -> Dict[str, Any]: +def get_common_bootstrap_data() -> dict[str, Any]: def serialize_bootstrap_data() -> str: return json.dumps( {"common": common_bootstrap_payload(g.user)}, @@ -606,7 +606,7 @@ class YamlExportMixin: # pylint: disable=too-few-public-methods @action("yaml_export", __("Export to YAML"), __("Export to YAML?"), "fa-download") def yaml_export( - self, items: Union[ImportExportMixin, List[ImportExportMixin]] + self, items: Union[ImportExportMixin, list[ImportExportMixin]] ) -> FlaskResponse: if not isinstance(items, list): items = [items] @@ -663,7 +663,7 @@ class DeleteMixin: # pylint: disable=too-few-public-methods @action( "muldelete", __("Delete"), __("Delete all Really?"), "fa-trash", single=False ) - def muldelete(self: BaseView, items: List[Model]) -> FlaskResponse: + def muldelete(self: BaseView, items: list[Model]) -> FlaskResponse: if not items: abort(404) for item in items: @@ -709,7 +709,7 @@ class XlsxResponse(Response): def bind_field( - _: Any, form: DynamicForm, unbound_field: UnboundField, options: Dict[Any, Any] + _: Any, form: DynamicForm, unbound_field: UnboundField, options: dict[Any, Any] ) -> Field: """ Customize how fields are bound by stripping all whitespace. diff --git a/superset/views/base_api.py b/superset/views/base_api.py index 30d25382f3..dca7a96b1d 100644 --- a/superset/views/base_api.py +++ b/superset/views/base_api.py @@ -18,7 +18,7 @@ from __future__ import annotations import functools import logging -from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type, Union +from typing import Any, Callable, cast from flask import request, Response from flask_appbuilder import Model, ModelRestApi @@ -87,7 +87,7 @@ def requires_json(f: Callable[..., Any]) -> Callable[..., Any]: Require JSON-like formatted request to the REST API """ - def wraps(self: "BaseSupersetModelRestApi", *args: Any, **kwargs: Any) -> Response: + def wraps(self: BaseSupersetModelRestApi, *args: Any, **kwargs: Any) -> Response: if not request.is_json: raise InvalidPayloadFormatError(message="Request is not JSON") return f(self, *args, **kwargs) @@ -135,7 +135,7 @@ def statsd_metrics(f: Callable[..., Any]) -> Callable[..., Any]: class RelatedFieldFilter: # data class to specify what filter to use on a /related endpoint # pylint: disable=too-few-public-methods - def __init__(self, field_name: str, filter_class: Type[BaseFilter]): + def __init__(self, field_name: str, filter_class: type[BaseFilter]): self.field_name = field_name self.filter_class = filter_class @@ -150,7 +150,7 @@ class BaseFavoriteFilter(BaseFilter): # pylint: disable=too-few-public-methods arg_name = "" class_name = "" """ The FavStar class_name to user """ - model: Type[Union[Dashboard, Slice, SqllabQuery]] = Dashboard + model: type[Dashboard | Slice | SqllabQuery] = Dashboard """ The SQLAlchemy model """ def apply(self, query: Query, value: Any) -> Query: @@ -178,7 +178,7 @@ class BaseTagFilter(BaseFilter): # pylint: disable=too-few-public-methods arg_name = "" class_name = "" """ The Tag class_name to user """ - model: Type[Union[Dashboard, Slice, SqllabQuery, SqlaTable]] = Dashboard + model: type[Dashboard | Slice | SqllabQuery | SqlaTable] = Dashboard """ The SQLAlchemy model """ def apply(self, query: Query, value: Any) -> Query: @@ -229,7 +229,7 @@ class BaseSupersetApiMixin: ) def send_stats_metrics( - self, response: Response, key: str, time_delta: Optional[float] = None + self, response: Response, key: str, time_delta: float | None = None ) -> None: """ Helper function to handle sending statsd metrics @@ -280,7 +280,7 @@ class BaseSupersetModelRestApi(ModelRestApi, BaseSupersetApiMixin): "viz_types": "list", } - order_rel_fields: Dict[str, Tuple[str, str]] = {} + order_rel_fields: dict[str, tuple[str, str]] = {} """ Impose ordering on related fields query:: @@ -290,7 +290,7 @@ class BaseSupersetModelRestApi(ModelRestApi, BaseSupersetApiMixin): } """ - base_related_field_filters: Dict[str, BaseFilter] = {} + base_related_field_filters: dict[str, BaseFilter] = {} """ This is used to specify a base filter for related fields when they are accessed through the '/related/' endpoint. @@ -302,7 +302,7 @@ class BaseSupersetModelRestApi(ModelRestApi, BaseSupersetApiMixin): } """ - related_field_filters: Dict[str, Union[RelatedFieldFilter, str]] = {} + related_field_filters: dict[str, RelatedFieldFilter | str] = {} """ Specify a filter for related fields when they are accessed through the '/related/' endpoint. @@ -313,10 +313,10 @@ class BaseSupersetModelRestApi(ModelRestApi, BaseSupersetApiMixin): "": ) } """ - allowed_rel_fields: Set[str] = set() + allowed_rel_fields: set[str] = set() # Declare a set of allowed related fields that the `related` endpoint supports. - text_field_rel_fields: Dict[str, str] = {} + text_field_rel_fields: dict[str, str] = {} """ Declare an alternative for the human readable representation of the Model object:: @@ -325,7 +325,7 @@ class BaseSupersetModelRestApi(ModelRestApi, BaseSupersetApiMixin): } """ - extra_fields_rel_fields: Dict[str, List[str]] = {"owners": ["email", "active"]} + extra_fields_rel_fields: dict[str, list[str]] = {"owners": ["email", "active"]} """ Declare extra fields for the representation of the Model object:: @@ -334,12 +334,12 @@ class BaseSupersetModelRestApi(ModelRestApi, BaseSupersetApiMixin): } """ - allowed_distinct_fields: Set[str] = set() + allowed_distinct_fields: set[str] = set() - add_columns: List[str] - edit_columns: List[str] - list_columns: List[str] - show_columns: List[str] + add_columns: list[str] + edit_columns: list[str] + list_columns: list[str] + show_columns: list[str] def __init__(self) -> None: super().__init__() @@ -347,8 +347,8 @@ class BaseSupersetModelRestApi(ModelRestApi, BaseSupersetApiMixin): if self.apispec_parameter_schemas is None: # type: ignore self.apispec_parameter_schemas = {} self.apispec_parameter_schemas["get_related_schema"] = get_related_schema - self.openapi_spec_component_schemas: Tuple[ - Type[Schema], ... + self.openapi_spec_component_schemas: tuple[ + type[Schema], ... ] = self.openapi_spec_component_schemas + ( RelatedResponseSchema, DistincResponseSchema, @@ -409,7 +409,7 @@ class BaseSupersetModelRestApi(ModelRestApi, BaseSupersetApiMixin): def _get_extra_field_for_model( self, model: Model, column_name: str - ) -> Dict[str, str]: + ) -> dict[str, str]: ret = {} if column_name in self.extra_fields_rel_fields: model_column_names = self.extra_fields_rel_fields.get(column_name) @@ -419,8 +419,8 @@ class BaseSupersetModelRestApi(ModelRestApi, BaseSupersetApiMixin): return ret def _get_result_from_rows( - self, datamodel: SQLAInterface, rows: List[Model], column_name: str - ) -> List[Dict[str, Any]]: + self, datamodel: SQLAInterface, rows: list[Model], column_name: str + ) -> list[dict[str, Any]]: return [ { "value": datamodel.get_pk_value(row), @@ -434,8 +434,8 @@ class BaseSupersetModelRestApi(ModelRestApi, BaseSupersetApiMixin): self, datamodel: SQLAInterface, column_name: str, - ids: List[int], - result: List[Dict[str, Any]], + ids: list[int], + result: list[dict[str, Any]], ) -> None: if ids: # Filter out already present values on the result diff --git a/superset/views/base_schemas.py b/superset/views/base_schemas.py index 8f4ed7735c..2107558dc7 100644 --- a/superset/views/base_schemas.py +++ b/superset/views/base_schemas.py @@ -14,7 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Set, Union +from collections.abc import Iterable, Mapping, Sequence +from typing import Any, Optional, Union from flask import current_app, g from flask_appbuilder import Model @@ -54,7 +55,7 @@ class BaseSupersetSchema(Schema): self, data: Union[Mapping[str, Any], Iterable[Mapping[str, Any]]], many: Optional[bool] = None, - partial: Union[bool, Sequence[str], Set[str], None] = None, + partial: Union[bool, Sequence[str], set[str], None] = None, instance: Optional[Model] = None, **kwargs: Any, ) -> Any: @@ -67,7 +68,7 @@ class BaseSupersetSchema(Schema): @post_load def make_object( - self, data: Dict[Any, Any], discard: Optional[List[str]] = None + self, data: dict[Any, Any], discard: Optional[list[str]] = None ) -> Model: """ Creates a Model object from POST or PUT requests. PUT will use self.instance @@ -95,7 +96,7 @@ class BaseOwnedSchema(BaseSupersetSchema): @post_load def make_object( - self, data: Dict[str, Any], discard: Optional[List[str]] = None + self, data: dict[str, Any], discard: Optional[list[str]] = None ) -> Model: discard = discard or [] discard.append(self.owners_field_name) @@ -107,13 +108,13 @@ class BaseOwnedSchema(BaseSupersetSchema): return instance @pre_load - def pre_load(self, data: Dict[Any, Any]) -> None: + def pre_load(self, data: dict[Any, Any]) -> None: # if PUT request don't set owners to empty list if not self.instance: data[self.owners_field_name] = data.get(self.owners_field_name, []) @staticmethod - def set_owners(instance: Model, owners: List[int]) -> None: + def set_owners(instance: Model, owners: list[int]) -> None: owner_objs = [] user_id = get_user_id() if user_id and user_id not in owners: diff --git a/superset/views/core.py b/superset/views/core.py index 24bc16c310..3b63eb74d8 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -21,7 +21,7 @@ import logging import re from contextlib import closing from datetime import datetime -from typing import Any, Callable, cast, Dict, List, Optional, Union +from typing import Any, Callable, cast, Optional from urllib import parse import backoff @@ -202,7 +202,7 @@ PARAMETER_MISSING_ERR = __( "your query again." ) -SqlResults = Dict[str, Any] +SqlResults = dict[str, Any] class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods @@ -300,13 +300,11 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods datasources.add(datasource) has_access_ = all( - ( - datasource and security_manager.can_access_datasource(datasource) - for datasource in datasources - ) + datasource and security_manager.can_access_datasource(datasource) + for datasource in datasources ) if has_access_: - return redirect("/superset/dashboard/{}".format(dashboard_id)) + return redirect(f"/superset/dashboard/{dashboard_id}") if request.args.get("action") == "go": for datasource in datasources: @@ -483,7 +481,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods return data_payload_response(*viz_obj.payload_json_and_has_error(payload)) def generate_json( - self, viz_obj: BaseViz, response_type: Optional[str] = None + self, viz_obj: BaseViz, response_type: str | None = None ) -> FlaskResponse: if response_type == ChartDataResultFormat.CSV: return CsvResponse( @@ -618,7 +616,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods @check_resource_permissions(check_datasource_perms) @deprecated(eol_version="3.0") def explore_json( - self, datasource_type: Optional[str] = None, datasource_id: Optional[int] = None + self, datasource_type: str | None = None, datasource_id: int | None = None ) -> FlaskResponse: """Serves all request that GET or POST form_data @@ -631,7 +629,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods TODO: break into one endpoint for each return shape""" response_type = ChartDataResultFormat.JSON.value - responses: List[Union[ChartDataResultFormat, ChartDataResultType]] = list( + responses: list[ChartDataResultFormat | ChartDataResultType] = list( ChartDataResultFormat ) responses.extend(list(ChartDataResultType)) @@ -814,9 +812,9 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods # pylint: disable=too-many-locals,too-many-branches,too-many-statements def explore( self, - datasource_type: Optional[str] = None, - datasource_id: Optional[int] = None, - key: Optional[str] = None, + datasource_type: str | None = None, + datasource_id: int | None = None, + key: str | None = None, ) -> FlaskResponse: if request.method == "GET": return redirect(Superset.get_redirect_url()) @@ -879,7 +877,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods # fallback unknown datasource to table type datasource_type = SqlaTable.type - datasource: Optional[BaseDatasource] = None + datasource: BaseDatasource | None = None if datasource_id is not None: try: datasource = DatasourceDAO.get_datasource( @@ -965,7 +963,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods ) standalone_mode = ReservedUrlParameters.is_standalone_mode() force = request.args.get("force") in {"force", "1", "true"} - dummy_datasource_data: Dict[str, Any] = { + dummy_datasource_data: dict[str, Any] = { "type": datasource_type, "name": datasource_name, "columns": [], @@ -1058,14 +1056,14 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods @staticmethod def save_or_overwrite_slice( # pylint: disable=too-many-arguments,too-many-locals - slc: Optional[Slice], + slc: Slice | None, slice_add_perm: bool, slice_overwrite_perm: bool, slice_download_perm: bool, datasource_id: int, datasource_type: str, datasource_name: str, - query_context: Optional[str] = None, + query_context: str | None = None, ) -> FlaskResponse: """Save or overwrite a slice""" slice_name = request.args.get("slice_name") @@ -1100,7 +1098,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods flash(msg, "success") # Adding slice to a dashboard if requested - dash: Optional[Dashboard] = None + dash: Dashboard | None = None save_to_dashboard_id = request.args.get("save_to_dashboard_id") new_dashboard_name = request.args.get("new_dashboard_name") @@ -1293,7 +1291,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods dash.dashboard_title = data["dashboard_title"] dash.css = data.get("css") - old_to_new_slice_ids: Dict[int, int] = {} + old_to_new_slice_ids: dict[int, int] = {} if data["duplicate_slices"]: # Duplicating slices as well, mapping old ids to new ones for slc in original_dash.slices: @@ -1480,7 +1478,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods ) @staticmethod - def get_user_activity_access_error(user_id: int) -> Optional[FlaskResponse]: + def get_user_activity_access_error(user_id: int) -> FlaskResponse | None: try: security_manager.raise_for_user_activity_access(user_id) except SupersetSecurityException as ex: @@ -1567,7 +1565,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods if o.Dashboard.created_by: user = o.Dashboard.created_by dash["creator"] = str(user) - dash["creator_url"] = "/superset/profile/{}/".format(user.username) + dash["creator_url"] = f"/superset/profile/{user.username}/" payload.append(dash) return json_success(json.dumps(payload, default=utils.json_int_dttm_ser)) @@ -1607,7 +1605,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods @expose("/user_slices", methods=("GET",)) @expose("/user_slices//", methods=("GET",)) @deprecated(new_target="/api/v1/chart/") - def user_slices(self, user_id: Optional[int] = None) -> FlaskResponse: + def user_slices(self, user_id: int | None = None) -> FlaskResponse: """List of slices a user owns, created, modified or faved""" if not user_id: user_id = cast(int, get_user_id()) @@ -1660,7 +1658,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods @expose("/created_slices", methods=("GET",)) @expose("/created_slices//", methods=("GET",)) @deprecated(new_target="api/v1/chart/") - def created_slices(self, user_id: Optional[int] = None) -> FlaskResponse: + def created_slices(self, user_id: int | None = None) -> FlaskResponse: """List of slices created by this user""" if not user_id: user_id = cast(int, get_user_id()) @@ -1691,7 +1689,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods @expose("/fave_slices", methods=("GET",)) @expose("/fave_slices//", methods=("GET",)) @deprecated(new_target="api/v1/chart/") - def fave_slices(self, user_id: Optional[int] = None) -> FlaskResponse: + def fave_slices(self, user_id: int | None = None) -> FlaskResponse: """Favorite slices for a user""" if user_id is None: user_id = cast(int, get_user_id()) @@ -1721,7 +1719,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods if o.Slice.created_by: user = o.Slice.created_by dash["creator"] = str(user) - dash["creator_url"] = "/superset/profile/{}/".format(user.username) + dash["creator_url"] = f"/superset/profile/{user.username}/" payload.append(dash) return json_success(json.dumps(payload, default=utils.json_int_dttm_ser)) @@ -1745,7 +1743,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods table_name = request.args.get("table_name") db_name = request.args.get("db_name") extra_filters = request.args.get("extra_filters") - slices: List[Slice] = [] + slices: list[Slice] = [] if not slice_id and not (table_name and db_name): return json_error_response( @@ -1869,7 +1867,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods self, dashboard_id_or_slug: str, # pylint: disable=unused-argument add_extra_log_payload: Callable[..., None] = lambda **kwargs: None, - dashboard: Optional[Dashboard] = None, + dashboard: Dashboard | None = None, ) -> FlaskResponse: """ Server side rendering for a dashboard @@ -2112,7 +2110,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods @event_logger.log_this @deprecated(new_target="api/v1/sqllab/estimate/") def estimate_query_cost( # pylint: disable=no-self-use - self, database_id: int, schema: Optional[str] = None + self, database_id: int, schema: str | None = None ) -> FlaskResponse: mydb = db.session.query(Database).get(database_id) @@ -2135,7 +2133,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods return json_error_response(utils.error_msg_from_exception(ex)) spec = mydb.db_engine_spec - query_cost_formatters: Dict[str, Any] = app.config[ + query_cost_formatters: dict[str, Any] = app.config[ "QUERY_COST_FORMATTERS_BY_ENGINE" ] query_cost_formatter = query_cost_formatters.get( @@ -2334,14 +2332,14 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods mydb = session.query(Database).filter_by(id=database_id).one_or_none() if not mydb: return json_error_response( - "Database with id {} is missing.".format(database_id), status=400 + f"Database with id {database_id} is missing.", status=400 ) spec = mydb.db_engine_spec validators_by_engine = app.config["SQL_VALIDATORS_BY_ENGINE"] if not validators_by_engine or spec.engine not in validators_by_engine: return json_error_response( - "no SQL validator is configured for {}".format(spec.engine), status=400 + f"no SQL validator is configured for {spec.engine}", status=400 ) validator_name = validators_by_engine[spec.engine] validator = get_validator_by_name(validator_name) @@ -2403,7 +2401,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods @staticmethod def _create_sql_json_command( - execution_context: SqlJsonExecutionContext, log_params: Optional[Dict[str, Any]] + execution_context: SqlJsonExecutionContext, log_params: dict[str, Any] | None ) -> ExecuteSqlCommand: query_dao = QueryDAO() sql_json_executor = Superset._create_sql_json_executor( @@ -2556,7 +2554,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods @expose("/queries/") @expose("/queries/") @deprecated(new_target="api/v1/query/updated_since") - def queries(self, last_updated_ms: Union[float, int]) -> FlaskResponse: + def queries(self, last_updated_ms: float | int) -> FlaskResponse: """ Get the updated queries. @@ -2566,7 +2564,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods return self.queries_exec(last_updated_ms) @staticmethod - def queries_exec(last_updated_ms: Union[float, int]) -> FlaskResponse: + def queries_exec(last_updated_ms: float | int) -> FlaskResponse: stats_logger.incr("queries") if not get_user_id(): return json_error_response( @@ -2714,7 +2712,7 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods ) @staticmethod - def _get_sqllab_tabs(user_id: Optional[int]) -> Dict[str, Any]: + def _get_sqllab_tabs(user_id: int | None) -> dict[str, Any]: # send list of tab state ids tabs_state = ( db.session.query(TabState.id, TabState.label) @@ -2730,13 +2728,13 @@ class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods .first() ) - databases: Dict[int, Any] = {} + databases: dict[int, Any] = {} for database in DatabaseDAO.find_all(): databases[database.id] = { k: v for k, v in database.to_json().items() if k in DATABASE_KEYS } databases[database.id]["backend"] = database.backend - queries: Dict[str, Any] = {} + queries: dict[str, Any] = {} # These are unnecessary if sqllab backend persistence is disabled if is_feature_enabled("SQLLAB_BACKEND_PERSISTENCE"): diff --git a/superset/views/dashboard/views.py b/superset/views/dashboard/views.py index 4f12206771..71ef212f6d 100644 --- a/superset/views/dashboard/views.py +++ b/superset/views/dashboard/views.py @@ -14,9 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import builtins import json import re -from typing import Callable, List, Union +from typing import Callable, Union from flask import g, redirect, request, Response from flask_appbuilder import expose @@ -64,12 +65,13 @@ class DashboardModelView( @action("mulexport", __("Export"), __("Export dashboards?"), "fa-database") def mulexport( # pylint: disable=no-self-use - self, items: Union["DashboardModelView", List["DashboardModelView"]] + self, + items: Union["DashboardModelView", builtins.list["DashboardModelView"]], ) -> FlaskResponse: if not isinstance(items, list): items = [items] - ids = "".join("&id={}".format(d.id) for d in items) - return redirect("/dashboard/export_dashboards_form?{}".format(ids[1:])) + ids = "".join(f"&id={d.id}" for d in items) + return redirect(f"/dashboard/export_dashboards_form?{ids[1:]}") @event_logger.log_this @has_access diff --git a/superset/views/database/forms.py b/superset/views/database/forms.py index 5e2347528a..b906e5e70b 100644 --- a/superset/views/database/forms.py +++ b/superset/views/database/forms.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. """Contains the logic to create cohesive forms on the explore view""" -from typing import List from flask_appbuilder.fields import QuerySelectField from flask_appbuilder.fieldwidgets import BS3TextFieldWidget @@ -44,7 +43,7 @@ config = app.config class UploadToDatabaseForm(DynamicForm): @staticmethod - def file_allowed_dbs() -> List[Database]: + def file_allowed_dbs() -> list[Database]: file_enabled_dbs = ( db.session.query(Database).filter_by(allow_file_upload=True).all() ) diff --git a/superset/views/database/mixins.py b/superset/views/database/mixins.py index efd0b6c6eb..deb1b88f1f 100644 --- a/superset/views/database/mixins.py +++ b/superset/views/database/mixins.py @@ -227,7 +227,7 @@ class DatabaseMixin: Markup( "Cannot delete a database that has tables attached. " "Here's the list of associated tables: " - + ", ".join("{}".format(table) for table in database.tables) + + ", ".join(f"{table}" for table in database.tables) ) ) diff --git a/superset/views/database/validators.py b/superset/views/database/validators.py index 29d80611a2..2ee49c8210 100644 --- a/superset/views/database/validators.py +++ b/superset/views/database/validators.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Optional, Type +from typing import Optional from flask_babel import lazy_gettext as _ from marshmallow import ValidationError @@ -27,7 +27,7 @@ from superset.models.core import Database def sqlalchemy_uri_validator( - uri: str, exception: Type[ValidationError] = ValidationError + uri: str, exception: type[ValidationError] = ValidationError ) -> None: """ Check if a user has submitted a valid SQLAlchemy URI diff --git a/superset/views/datasource/schemas.py b/superset/views/datasource/schemas.py index b71b3defa8..5b1700708a 100644 --- a/superset/views/datasource/schemas.py +++ b/superset/views/datasource/schemas.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict +from typing import Any from marshmallow import fields, post_load, pre_load, Schema, validate from typing_extensions import TypedDict @@ -76,7 +76,7 @@ class SamplesPayloadSchema(Schema): @pre_load # pylint: disable=no-self-use, unused-argument - def handle_none(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: + def handle_none(self, data: dict[str, Any], **kwargs: Any) -> dict[str, Any]: if data is None: return {} return data diff --git a/superset/views/datasource/utils.py b/superset/views/datasource/utils.py index 42cddf4167..a4cf0c5e90 100644 --- a/superset/views/datasource/utils.py +++ b/superset/views/datasource/utils.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, Optional +from typing import Any, Optional from superset import app, db from superset.common.chart_data import ChartDataResultType @@ -27,7 +27,7 @@ from superset.utils.core import QueryStatus from superset.views.datasource.schemas import SamplesPayloadSchema -def get_limit_clause(page: Optional[int], per_page: Optional[int]) -> Dict[str, int]: +def get_limit_clause(page: Optional[int], per_page: Optional[int]) -> dict[str, int]: samples_row_limit = app.config.get("SAMPLES_ROW_LIMIT", 1000) limit = samples_row_limit offset = 0 @@ -50,7 +50,7 @@ def get_samples( # pylint: disable=too-many-arguments,too-many-locals page: int = 1, per_page: int = 1000, payload: Optional[SamplesPayloadSchema] = None, -) -> Dict[str, Any]: +) -> dict[str, Any]: datasource = DatasourceDAO.get_datasource( session=db.session, datasource_type=datasource_type, diff --git a/superset/views/log/dao.py b/superset/views/log/dao.py index 71d8a62348..87bc0817da 100644 --- a/superset/views/log/dao.py +++ b/superset/views/log/dao.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from datetime import datetime, timedelta -from typing import Any, Dict, List +from typing import Any import humanize from sqlalchemy import and_, or_ @@ -34,8 +34,8 @@ class LogDAO(BaseDAO): @staticmethod def get_recent_activity( - user_id: int, actions: List[str], distinct: bool, page: int, page_size: int - ) -> List[Dict[str, Any]]: + user_id: int, actions: list[str], distinct: bool, page: int, page_size: int + ) -> list[dict[str, Any]]: has_subject_title = or_( and_( Dashboard.dashboard_title is not None, diff --git a/superset/views/tags.py b/superset/views/tags.py index bd4f43a0d9..4f9d55aed7 100644 --- a/superset/views/tags.py +++ b/superset/views/tags.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from __future__ import absolute_import, division, print_function, unicode_literals import logging diff --git a/superset/views/users/__init__.py b/superset/views/users/__init__.py index fd9417fe5c..13a83393a9 100644 --- a/superset/views/users/__init__.py +++ b/superset/views/users/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information diff --git a/superset/views/utils.py b/superset/views/utils.py index a366ac683c..9b515edc26 100644 --- a/superset/views/utils.py +++ b/superset/views/utils.py @@ -17,7 +17,7 @@ import logging from collections import defaultdict from functools import wraps -from typing import Any, Callable, DefaultDict, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, DefaultDict, Optional, Union from urllib import parse import msgpack @@ -55,12 +55,12 @@ from superset.viz import BaseViz logger = logging.getLogger(__name__) stats_logger = app.config["STATS_LOGGER"] -REJECTED_FORM_DATA_KEYS: List[str] = [] +REJECTED_FORM_DATA_KEYS: list[str] = [] if not feature_flag_manager.is_feature_enabled("ENABLE_JAVASCRIPT_CONTROLS"): REJECTED_FORM_DATA_KEYS = ["js_tooltip", "js_onclick_href", "js_data_mutator"] -def sanitize_datasource_data(datasource_data: Dict[str, Any]) -> Dict[str, Any]: +def sanitize_datasource_data(datasource_data: dict[str, Any]) -> dict[str, Any]: if datasource_data: datasource_database = datasource_data.get("database") if datasource_database: @@ -69,7 +69,7 @@ def sanitize_datasource_data(datasource_data: Dict[str, Any]) -> Dict[str, Any]: return datasource_data -def bootstrap_user_data(user: User, include_perms: bool = False) -> Dict[str, Any]: +def bootstrap_user_data(user: User, include_perms: bool = False) -> dict[str, Any]: if user.is_anonymous: payload = {} user.roles = (security_manager.find_role("Public"),) @@ -103,7 +103,7 @@ def bootstrap_user_data(user: User, include_perms: bool = False) -> Dict[str, An def get_permissions( user: User, -) -> Tuple[Dict[str, List[Tuple[str]]], DefaultDict[str, List[str]]]: +) -> tuple[dict[str, list[tuple[str]]], DefaultDict[str, list[str]]]: if not user.roles: raise AttributeError("User object does not have roles") @@ -138,7 +138,7 @@ def get_viz( return viz_obj -def loads_request_json(request_json_data: str) -> Dict[Any, Any]: +def loads_request_json(request_json_data: str) -> dict[Any, Any]: try: return json.loads(request_json_data) except (TypeError, json.JSONDecodeError): @@ -148,9 +148,9 @@ def loads_request_json(request_json_data: str) -> Dict[Any, Any]: def get_form_data( # pylint: disable=too-many-locals slice_id: Optional[int] = None, use_slice_data: bool = False, - initial_form_data: Optional[Dict[str, Any]] = None, -) -> Tuple[Dict[str, Any], Optional[Slice]]: - form_data: Dict[str, Any] = initial_form_data or {} + initial_form_data: Optional[dict[str, Any]] = None, +) -> tuple[dict[str, Any], Optional[Slice]]: + form_data: dict[str, Any] = initial_form_data or {} if has_request_context(): # chart data API requests are JSON @@ -222,7 +222,7 @@ def get_form_data( # pylint: disable=too-many-locals return form_data, slc -def add_sqllab_custom_filters(form_data: Dict[Any, Any]) -> Any: +def add_sqllab_custom_filters(form_data: dict[Any, Any]) -> Any: """ SQLLab can include a "filters" attribute in the templateParams. The filters attribute is a list of filters to include in the @@ -244,7 +244,7 @@ def add_sqllab_custom_filters(form_data: Dict[Any, Any]) -> Any: def get_datasource_info( datasource_id: Optional[int], datasource_type: Optional[str], form_data: FormData -) -> Tuple[int, Optional[str]]: +) -> tuple[int, Optional[str]]: """ Compatibility layer for handling of datasource info @@ -277,8 +277,8 @@ def get_datasource_info( def apply_display_max_row_limit( - sql_results: Dict[str, Any], rows: Optional[int] = None -) -> Dict[str, Any]: + sql_results: dict[str, Any], rows: Optional[int] = None +) -> dict[str, Any]: """ Given a `sql_results` nested structure, applies a limit to the number of rows @@ -311,7 +311,7 @@ CONTAINER_TYPES = ["COLUMN", "GRID", "TABS", "TAB", "ROW"] def get_dashboard_extra_filters( slice_id: int, dashboard_id: int -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: session = db.session() dashboard = session.query(Dashboard).filter_by(id=dashboard_id).one_or_none() @@ -348,11 +348,11 @@ def get_dashboard_extra_filters( def build_extra_filters( # pylint: disable=too-many-locals,too-many-nested-blocks - layout: Dict[str, Dict[str, Any]], - filter_scopes: Dict[str, Dict[str, Any]], - default_filters: Dict[str, Dict[str, List[Any]]], + layout: dict[str, dict[str, Any]], + filter_scopes: dict[str, dict[str, Any]], + default_filters: dict[str, dict[str, list[Any]]], slice_id: int, -) -> List[Dict[str, Any]]: +) -> list[dict[str, Any]]: extra_filters = [] # do not apply filters if chart is not in filter's scope or chart is immune to the @@ -360,7 +360,7 @@ def build_extra_filters( # pylint: disable=too-many-locals,too-many-nested-bloc for filter_id, columns in default_filters.items(): filter_slice = db.session.query(Slice).filter_by(id=filter_id).one_or_none() - filter_configs: List[Dict[str, Any]] = [] + filter_configs: list[dict[str, Any]] = [] if filter_slice: filter_configs = ( json.loads(filter_slice.params or "{}").get("filter_configs") or [] @@ -403,7 +403,7 @@ def build_extra_filters( # pylint: disable=too-many-locals,too-many-nested-bloc def is_slice_in_container( - layout: Dict[str, Dict[str, Any]], container_id: str, slice_id: int + layout: dict[str, dict[str, Any]], container_id: str, slice_id: int ) -> bool: if container_id == "ROOT_ID": return True @@ -551,7 +551,7 @@ def check_slice_perms(_self: Any, slice_id: int) -> None: def _deserialize_results_payload( payload: Union[bytes, str], query: Query, use_msgpack: Optional[bool] = False -) -> Dict[str, Any]: +) -> dict[str, Any]: logger.debug("Deserializing from msgpack: %r", use_msgpack) if use_msgpack: with stats_timing( diff --git a/superset/viz.py b/superset/viz.py index a7b4a8952a..3bb6204524 100644 --- a/superset/viz.py +++ b/superset/viz.py @@ -30,19 +30,7 @@ import re from collections import defaultdict, OrderedDict from datetime import date, datetime, timedelta from itertools import product -from typing import ( - Any, - Callable, - cast, - Dict, - List, - Optional, - Set, - Tuple, - Type, - TYPE_CHECKING, - Union, -) +from typing import Any, Callable, cast, Optional, TYPE_CHECKING import geohash import numpy as np @@ -124,7 +112,7 @@ class BaseViz: # pylint: disable=too-many-public-methods """All visualizations derive this base class""" - viz_type: Optional[str] = None + viz_type: str | None = None verbose_name = "Base Viz" credits = "" is_timeseries = False @@ -134,8 +122,8 @@ class BaseViz: # pylint: disable=too-many-public-methods @deprecated(deprecated_in="3.0") def __init__( self, - datasource: "BaseDatasource", - form_data: Dict[str, Any], + datasource: BaseDatasource, + form_data: dict[str, Any], force: bool = False, force_cached: bool = False, ) -> None: @@ -150,25 +138,25 @@ class BaseViz: # pylint: disable=too-many-public-methods self.query = "" self.token = utils.get_form_data_token(form_data) - self.groupby: List[Column] = self.form_data.get("groupby") or [] + self.groupby: list[Column] = self.form_data.get("groupby") or [] self.time_shift = timedelta() - self.status: Optional[str] = None + self.status: str | None = None self.error_msg = "" - self.results: Optional[QueryResult] = None - self.applied_filter_columns: List[Column] = [] - self.rejected_filter_columns: List[Column] = [] - self.errors: List[Dict[str, Any]] = [] + self.results: QueryResult | None = None + self.applied_filter_columns: list[Column] = [] + self.rejected_filter_columns: list[Column] = [] + self.errors: list[dict[str, Any]] = [] self.force = force self._force_cached = force_cached - self.from_dttm: Optional[datetime] = None - self.to_dttm: Optional[datetime] = None - self._extra_chart_data: List[Tuple[str, pd.DataFrame]] = [] + self.from_dttm: datetime | None = None + self.to_dttm: datetime | None = None + self._extra_chart_data: list[tuple[str, pd.DataFrame]] = [] self.process_metrics() - self.applied_filters: List[Dict[str, str]] = [] - self.rejected_filters: List[Dict[str, str]] = [] + self.applied_filters: list[dict[str, str]] = [] + self.rejected_filters: list[dict[str, str]] = [] @property @deprecated(deprecated_in="3.0") @@ -196,8 +184,8 @@ class BaseViz: # pylint: disable=too-many-public-methods @staticmethod @deprecated(deprecated_in="3.0") def handle_js_int_overflow( - data: Dict[str, List[Dict[str, Any]]] - ) -> Dict[str, List[Dict[str, Any]]]: + data: dict[str, list[dict[str, Any]]] + ) -> dict[str, list[dict[str, Any]]]: for record in data.get("records", {}): for k, v in list(record.items()): if isinstance(v, int): @@ -259,7 +247,7 @@ class BaseViz: # pylint: disable=too-many-public-methods return df @deprecated(deprecated_in="3.0") - def get_samples(self) -> Dict[str, Any]: + def get_samples(self) -> dict[str, Any]: query_obj = self.query_obj() query_obj.update( { @@ -281,7 +269,7 @@ class BaseViz: # pylint: disable=too-many-public-methods } @deprecated(deprecated_in="3.0") - def get_df(self, query_obj: Optional[QueryObjectDict] = None) -> pd.DataFrame: + def get_df(self, query_obj: QueryObjectDict | None = None) -> pd.DataFrame: """Returns a pandas dataframe based on the query object""" if not query_obj: query_obj = self.query_obj() @@ -346,10 +334,10 @@ class BaseViz: # pylint: disable=too-many-public-methods @staticmethod @deprecated(deprecated_in="3.0") - def dedup_columns(*columns_args: Optional[List[Column]]) -> List[Column]: + def dedup_columns(*columns_args: list[Column] | None) -> list[Column]: # dedup groupby and columns while preserving order - labels: List[str] = [] - deduped_columns: List[Column] = [] + labels: list[str] = [] + deduped_columns: list[Column] = [] for columns in columns_args: for column in columns or []: label = get_column_name(column) @@ -492,7 +480,7 @@ class BaseViz: # pylint: disable=too-many-public-methods return md5_sha_from_str(json_data) @deprecated(deprecated_in="3.0") - def get_payload(self, query_obj: Optional[QueryObjectDict] = None) -> VizPayload: + def get_payload(self, query_obj: QueryObjectDict | None = None) -> VizPayload: """Returns a payload of metadata and data""" try: @@ -534,8 +522,8 @@ class BaseViz: # pylint: disable=too-many-public-methods @deprecated(deprecated_in="3.0") def get_df_payload( # pylint: disable=too-many-statements - self, query_obj: Optional[QueryObjectDict] = None, **kwargs: Any - ) -> Dict[str, Any]: + self, query_obj: QueryObjectDict | None = None, **kwargs: Any + ) -> dict[str, Any]: """Handles caching around the df payload retrieval""" if not query_obj: query_obj = self.query_obj() @@ -587,7 +575,7 @@ class BaseViz: # pylint: disable=too-many-public-methods ) + get_column_names_from_columns(query_obj.get("groupby") or []) + utils.get_column_names_from_metrics( - cast(List[Metric], query_obj.get("metrics") or []) + cast(list[Metric], query_obj.get("metrics") or []) ) if col not in self.datasource.column_names ] @@ -676,12 +664,12 @@ class BaseViz: # pylint: disable=too-many-public-methods ) @deprecated(deprecated_in="3.0") - def payload_json_and_has_error(self, payload: VizPayload) -> Tuple[str, bool]: + def payload_json_and_has_error(self, payload: VizPayload) -> tuple[str, bool]: return self.json_dumps(payload), self.has_error(payload) @property @deprecated(deprecated_in="3.0") - def data(self) -> Dict[str, Any]: + def data(self) -> dict[str, Any]: """This is the data object serialized to the js layer""" content = { "form_data": self.form_data, @@ -692,7 +680,7 @@ class BaseViz: # pylint: disable=too-many-public-methods return content @deprecated(deprecated_in="3.0") - def get_csv(self) -> Optional[str]: + def get_csv(self) -> str | None: df = self.get_df_payload()["df"] # leverage caching logic include_index = not isinstance(df.index, pd.RangeIndex) return csv.df_to_escaped_csv(df, index=include_index, **config["CSV_EXPORT"]) @@ -766,8 +754,8 @@ class TableViz(BaseViz): else QueryMode.AGGREGATE ) - columns: List[str] # output columns sans time and percent_metric column - percent_columns: List[str] = [] # percent columns that needs extra computation + columns: list[str] # output columns sans time and percent_metric column + percent_columns: list[str] = [] # percent columns that needs extra computation if self.query_mode == QueryMode.RAW: columns = get_metric_names(self.form_data.get("all_columns")) @@ -906,7 +894,7 @@ class TimeTableViz(BaseViz): return None columns = None - values: Union[List[str], str] = self.metric_labels + values: list[str] | str = self.metric_labels if self.form_data.get("groupby"): values = self.metric_labels[0] columns = get_column_names(self.form_data.get("groupby")) @@ -948,10 +936,8 @@ class PivotTableViz(BaseViz): if transpose and not columns: raise QueryObjectValidationError( _( - ( - "Please choose at least one 'Columns' field when " - "select 'Transpose Pivot' option" - ) + "Please choose at least one 'Columns' field when " + "select 'Transpose Pivot' option" ) ) if not metrics: @@ -973,8 +959,8 @@ class PivotTableViz(BaseViz): @staticmethod @deprecated(deprecated_in="3.0") def get_aggfunc( - metric: str, df: pd.DataFrame, form_data: Dict[str, Any] - ) -> Union[str, Callable[[Any], Any]]: + metric: str, df: pd.DataFrame, form_data: dict[str, Any] + ) -> str | Callable[[Any], Any]: aggfunc = form_data.get("pandas_aggfunc") or "sum" if pd.api.types.is_numeric_dtype(df[metric]): # Ensure that Pandas's sum function mimics that of SQL. @@ -985,7 +971,7 @@ class PivotTableViz(BaseViz): @staticmethod @deprecated(deprecated_in="3.0") - def _format_datetime(value: Union[pd.Timestamp, datetime, date, str]) -> str: + def _format_datetime(value: pd.Timestamp | datetime | date | str) -> str: """ Format a timestamp in such a way that the viz will be able to apply the correct formatting in the frontend. @@ -994,7 +980,7 @@ class PivotTableViz(BaseViz): :return: formatted timestamp if it is a valid timestamp, otherwise the original value """ - tstamp: Optional[pd.Timestamp] = None + tstamp: pd.Timestamp | None = None if isinstance(value, pd.Timestamp): tstamp = value if isinstance(value, (date, datetime)): @@ -1018,7 +1004,7 @@ class PivotTableViz(BaseViz): del df[DTTM_ALIAS] metrics = [utils.get_metric_name(m) for m in self.form_data["metrics"]] - aggfuncs: Dict[str, Union[str, Callable[[Any], Any]]] = {} + aggfuncs: dict[str, str | Callable[[Any], Any]] = {} for metric in metrics: aggfuncs[metric] = self.get_aggfunc(metric, df, self.form_data) @@ -1088,7 +1074,7 @@ class TreemapViz(BaseViz): return query_obj @deprecated(deprecated_in="3.0") - def _nest(self, metric: str, df: pd.DataFrame) -> List[Dict[str, Any]]: + def _nest(self, metric: str, df: pd.DataFrame) -> list[dict[str, Any]]: nlevels = df.index.nlevels if nlevels == 1: result = [{"name": n, "value": v} for n, v in zip(df.index, df[metric])] @@ -1200,7 +1186,7 @@ class NVD3Viz(BaseViz): """Base class for all nvd3 vizs""" credits = 'NVD3.org' - viz_type: Optional[str] = None + viz_type: str | None = None verbose_name = "Base NVD3 Viz" is_timeseries = False @@ -1249,7 +1235,7 @@ class BubbleViz(NVD3Viz): df["shape"] = "circle" df["group"] = df[[get_column_name(self.series)]] # type: ignore - series: Dict[Any, List[Any]] = defaultdict(list) + series: dict[Any, list[Any]] = defaultdict(list) for row in df.to_dict(orient="records"): series[row["group"]].append(row) chart_data = [] @@ -1357,7 +1343,7 @@ class NVD3TimeSeriesViz(NVD3Viz): verbose_name = _("Time Series - Line Chart") sort_series = False is_timeseries = True - pivot_fill_value: Optional[int] = None + pivot_fill_value: int | None = None @deprecated(deprecated_in="3.0") def query_obj(self) -> QueryObjectDict: @@ -1376,7 +1362,7 @@ class NVD3TimeSeriesViz(NVD3Viz): @deprecated(deprecated_in="3.0") def to_series( # pylint: disable=too-many-branches self, df: pd.DataFrame, classed: str = "", title_suffix: str = "" - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: cols = [] for col in df.columns: if col == "": @@ -1393,7 +1379,7 @@ class NVD3TimeSeriesViz(NVD3Viz): ys = series[name] if df[name].dtype.kind not in "biufc": continue - series_title: Union[List[str], str, Tuple[str, ...]] + series_title: list[str] | str | tuple[str, ...] if isinstance(name, list): series_title = [str(title) for title in name] elif isinstance(name, tuple): @@ -1510,7 +1496,7 @@ class NVD3TimeSeriesViz(NVD3Viz): dttm_series = df2[DTTM_ALIAS] + delta df2 = df2.drop(DTTM_ALIAS, axis=1) df2 = pd.concat([dttm_series, df2], axis=1) - label = "{} offset".format(option) + label = f"{option} offset" df2 = self.process_data(df2) self._extra_chart_data.append((label, df2)) @@ -1524,9 +1510,7 @@ class NVD3TimeSeriesViz(NVD3Viz): for i, (label, df2) in enumerate(self._extra_chart_data): chart_data.extend( - self.to_series( - df2, classed="time-shift-{}".format(i), title_suffix=label - ) + self.to_series(df2, classed=f"time-shift-{i}", title_suffix=label) ) else: chart_data = [] @@ -1547,16 +1531,14 @@ class NVD3TimeSeriesViz(NVD3Viz): diff = df / df2 else: raise QueryObjectValidationError( - "Invalid `comparison_type`: {0}".format(comparison_type) + f"Invalid `comparison_type`: {comparison_type}" ) # remove leading/trailing NaNs from the time shift difference diff = diff[diff.first_valid_index() : diff.last_valid_index()] chart_data.extend( - self.to_series( - diff, classed="time-shift-{}".format(i), title_suffix=label - ) + self.to_series(diff, classed=f"time-shift-{i}", title_suffix=label) ) if not self.sort_series: @@ -1670,7 +1652,7 @@ class NVD3DualLineViz(NVD3Viz): return query_obj @deprecated(deprecated_in="3.0") - def to_series(self, df: pd.DataFrame, classed: str = "") -> List[Dict[str, Any]]: + def to_series(self, df: pd.DataFrame, classed: str = "") -> list[dict[str, Any]]: cols = [] for col in df.columns: if col == "": @@ -1823,7 +1805,7 @@ class HistogramViz(BaseViz): return query_obj @deprecated(deprecated_in="3.0") - def labelify(self, keys: Union[List[str], str], column: str) -> str: + def labelify(self, keys: list[str] | str, column: str) -> str: if isinstance(keys, str): keys = [keys] # removing undesirable characters @@ -2033,17 +2015,17 @@ class SankeyViz(BaseViz): df["target"] = df["target"].astype(str) recs = df.to_dict(orient="records") - hierarchy: Dict[str, Set[str]] = defaultdict(set) + hierarchy: dict[str, set[str]] = defaultdict(set) for row in recs: hierarchy[row["source"]].add(row["target"]) @deprecated(deprecated_in="3.0") - def find_cycle(graph: Dict[str, Set[str]]) -> Optional[Tuple[str, str]]: + def find_cycle(graph: dict[str, set[str]]) -> tuple[str, str] | None: """Whether there's a cycle in a directed graph""" path = set() @deprecated(deprecated_in="3.0") - def visit(vertex: str) -> Optional[Tuple[str, str]]: + def visit(vertex: str) -> tuple[str, str] | None: path.add(vertex) for neighbour in graph.get(vertex, ()): if neighbour in path or visit(neighbour): @@ -2214,7 +2196,7 @@ class FilterBoxViz(BaseViz): """A multi filter, multi-choice filter box to make dashboards interactive""" - query_context_factory: Optional[QueryContextFactory] = None + query_context_factory: QueryContextFactory | None = None viz_type = "filter_box" verbose_name = _("Filters") is_timeseries = False @@ -2581,20 +2563,20 @@ class BaseDeckGLViz(BaseViz): is_timeseries = False credits = 'deck.gl' - spatial_control_keys: List[str] = [] + spatial_control_keys: list[str] = [] @deprecated(deprecated_in="3.0") - def get_metrics(self) -> List[str]: + def get_metrics(self) -> list[str]: # pylint: disable=attribute-defined-outside-init self.metric = self.form_data.get("size") return [self.metric] if self.metric else [] @deprecated(deprecated_in="3.0") - def process_spatial_query_obj(self, key: str, group_by: List[str]) -> None: + def process_spatial_query_obj(self, key: str, group_by: list[str]) -> None: group_by.extend(self.get_spatial_columns(key)) @deprecated(deprecated_in="3.0") - def get_spatial_columns(self, key: str) -> List[str]: + def get_spatial_columns(self, key: str) -> list[str]: spatial = self.form_data.get(key) if spatial is None: raise ValueError(_("Bad spatial key")) @@ -2611,7 +2593,7 @@ class BaseDeckGLViz(BaseViz): @staticmethod @deprecated(deprecated_in="3.0") - def parse_coordinates(latlog: Any) -> Optional[Tuple[float, float]]: + def parse_coordinates(latlog: Any) -> tuple[float, float] | None: if not latlog: return None try: @@ -2624,7 +2606,7 @@ class BaseDeckGLViz(BaseViz): @staticmethod @deprecated(deprecated_in="3.0") - def reverse_geohash_decode(geohash_code: str) -> Tuple[str, str]: + def reverse_geohash_decode(geohash_code: str) -> tuple[str, str]: lat, lng = geohash.decode(geohash_code) return (lng, lat) @@ -2692,7 +2674,7 @@ class BaseDeckGLViz(BaseViz): self.add_null_filters() query_obj = super().query_obj() - group_by: List[str] = [] + group_by: list[str] = [] for key in self.spatial_control_keys: self.process_spatial_query_obj(key, group_by) @@ -2720,7 +2702,7 @@ class BaseDeckGLViz(BaseViz): return query_obj @deprecated(deprecated_in="3.0") - def get_js_columns(self, data: Dict[str, Any]) -> Dict[str, Any]: + def get_js_columns(self, data: dict[str, Any]) -> dict[str, Any]: cols = self.form_data.get("js_columns") or [] return {col: data.get(col) for col in cols} @@ -2748,7 +2730,7 @@ class BaseDeckGLViz(BaseViz): } @deprecated(deprecated_in="3.0") - def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]: + def get_properties(self, data: dict[str, Any]) -> dict[str, Any]: raise NotImplementedError() @@ -2774,7 +2756,7 @@ class DeckScatterViz(BaseDeckGLViz): return super().query_obj() @deprecated(deprecated_in="3.0") - def get_metrics(self) -> List[str]: + def get_metrics(self) -> list[str]: # pylint: disable=attribute-defined-outside-init self.metric = None if self.point_radius_fixed.get("type") == "metric": @@ -2783,7 +2765,7 @@ class DeckScatterViz(BaseDeckGLViz): return [] @deprecated(deprecated_in="3.0") - def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]: + def get_properties(self, data: dict[str, Any]) -> dict[str, Any]: return { "metric": data.get(self.metric_label) if self.metric_label else None, "radius": self.fixed_value @@ -2825,7 +2807,7 @@ class DeckScreengrid(BaseDeckGLViz): return super().query_obj() @deprecated(deprecated_in="3.0") - def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]: + def get_properties(self, data: dict[str, Any]) -> dict[str, Any]: return { "position": data.get("spatial"), "weight": (data.get(self.metric_label) if self.metric_label else None) or 1, @@ -2849,7 +2831,7 @@ class DeckGrid(BaseDeckGLViz): spatial_control_keys = ["spatial"] @deprecated(deprecated_in="3.0") - def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]: + def get_properties(self, data: dict[str, Any]) -> dict[str, Any]: return { "position": data.get("spatial"), "weight": (data.get(self.metric_label) if self.metric_label else None) or 1, @@ -2864,7 +2846,7 @@ class DeckGrid(BaseDeckGLViz): @deprecated(deprecated_in="3.0") -def geohash_to_json(geohash_code: str) -> List[List[float]]: +def geohash_to_json(geohash_code: str) -> list[list[float]]: bbox = geohash.bbox(geohash_code) return [ [bbox.get("w"), bbox.get("n")], @@ -2907,7 +2889,7 @@ class DeckPathViz(BaseDeckGLViz): return query_obj @deprecated(deprecated_in="3.0") - def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]: + def get_properties(self, data: dict[str, Any]) -> dict[str, Any]: line_type = self.form_data["line_type"] deser = self.deser_map[line_type] line_column = self.form_data["line_column"] @@ -2946,14 +2928,14 @@ class DeckPolygon(DeckPathViz): return super().query_obj() @deprecated(deprecated_in="3.0") - def get_metrics(self) -> List[str]: + def get_metrics(self) -> list[str]: metrics = [self.form_data.get("metric")] if self.elevation.get("type") == "metric": metrics.append(self.elevation.get("value")) return [metric for metric in metrics if metric] @deprecated(deprecated_in="3.0") - def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]: + def get_properties(self, data: dict[str, Any]) -> dict[str, Any]: super().get_properties(data) elevation = self.form_data["point_radius_fixed"]["value"] type_ = self.form_data["point_radius_fixed"]["type"] @@ -2974,7 +2956,7 @@ class DeckHex(BaseDeckGLViz): spatial_control_keys = ["spatial"] @deprecated(deprecated_in="3.0") - def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]: + def get_properties(self, data: dict[str, Any]) -> dict[str, Any]: return { "position": data.get("spatial"), "weight": (data.get(self.metric_label) if self.metric_label else None) or 1, @@ -2996,7 +2978,7 @@ class DeckHeatmap(BaseDeckGLViz): verbose_name = _("Deck.gl - Heatmap") spatial_control_keys = ["spatial"] - def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]: + def get_properties(self, data: dict[str, Any]) -> dict[str, Any]: return { "position": data.get("spatial"), "weight": (data.get(self.metric_label) if self.metric_label else None) or 1, @@ -3025,7 +3007,7 @@ class DeckGeoJson(BaseDeckGLViz): return query_obj @deprecated(deprecated_in="3.0") - def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]: + def get_properties(self, data: dict[str, Any]) -> dict[str, Any]: geojson = data[get_column_name(self.form_data["geojson"])] return json.loads(geojson) @@ -3047,7 +3029,7 @@ class DeckArc(BaseDeckGLViz): return super().query_obj() @deprecated(deprecated_in="3.0") - def get_properties(self, data: Dict[str, Any]) -> Dict[str, Any]: + def get_properties(self, data: dict[str, Any]) -> dict[str, Any]: dim = self.form_data.get("dimension") return { "sourcePosition": data.get("start_spatial"), @@ -3153,7 +3135,7 @@ class PairedTTestViz(BaseViz): else: cols.append(col) df.columns = cols - data: Dict[str, List[Dict[str, Any]]] = {} + data: dict[str, list[dict[str, Any]]] = {} series = df.to_dict("series") for name_set in df.columns: # If no groups are defined, nameSet will be the metric name @@ -3188,7 +3170,7 @@ class RoseViz(NVD3TimeSeriesViz): return None data = super().get_data(df) - result: Dict[str, List[Dict[str, str]]] = {} + result: dict[str, list[dict[str, str]]] = {} for datum in data: key = datum["key"] for val in datum["values"]: @@ -3227,8 +3209,8 @@ class PartitionViz(NVD3TimeSeriesViz): @staticmethod @deprecated(deprecated_in="3.0") def levels_for( - time_op: str, groups: List[str], df: pd.DataFrame - ) -> Dict[int, pd.Series]: + time_op: str, groups: list[str], df: pd.DataFrame + ) -> dict[int, pd.Series]: """ Compute the partition at each `level` from the dataframe. """ @@ -3245,8 +3227,8 @@ class PartitionViz(NVD3TimeSeriesViz): @staticmethod @deprecated(deprecated_in="3.0") def levels_for_diff( - time_op: str, groups: List[str], df: pd.DataFrame - ) -> Dict[int, pd.DataFrame]: + time_op: str, groups: list[str], df: pd.DataFrame + ) -> dict[int, pd.DataFrame]: # Obtain a unique list of the time grains times = list(set(df[DTTM_ALIAS])) times.sort() @@ -3282,8 +3264,8 @@ class PartitionViz(NVD3TimeSeriesViz): @deprecated(deprecated_in="3.0") def levels_for_time( - self, groups: List[str], df: pd.DataFrame - ) -> Dict[int, VizData]: + self, groups: list[str], df: pd.DataFrame + ) -> dict[int, VizData]: procs = {} for i in range(0, len(groups) + 1): self.form_data["groupby"] = groups[:i] @@ -3295,11 +3277,11 @@ class PartitionViz(NVD3TimeSeriesViz): @deprecated(deprecated_in="3.0") def nest_values( self, - levels: Dict[int, pd.DataFrame], + levels: dict[int, pd.DataFrame], level: int = 0, - metric: Optional[str] = None, - dims: Optional[List[str]] = None, - ) -> List[Dict[str, Any]]: + metric: str | None = None, + dims: list[str] | None = None, + ) -> list[dict[str, Any]]: """ Nest values at each level on the back-end with access and setting, instead of summing from the bottom. @@ -3340,11 +3322,11 @@ class PartitionViz(NVD3TimeSeriesViz): @deprecated(deprecated_in="3.0") def nest_procs( self, - procs: Dict[int, pd.DataFrame], + procs: dict[int, pd.DataFrame], level: int = -1, - dims: Optional[Tuple[str, ...]] = None, + dims: tuple[str, ...] | None = None, time: Any = None, - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: if dims is None: dims = () if level == -1: @@ -3395,7 +3377,7 @@ class PartitionViz(NVD3TimeSeriesViz): @deprecated(deprecated_in="3.0") -def get_subclasses(cls: Type[BaseViz]) -> Set[Type[BaseViz]]: +def get_subclasses(cls: type[BaseViz]) -> set[type[BaseViz]]: return set(cls.__subclasses__()).union( [sc for c in cls.__subclasses__() for sc in get_subclasses(c)] ) diff --git a/tests/common/logger_utils.py b/tests/common/logger_utils.py index 98471342b7..8cb443cac8 100644 --- a/tests/common/logger_utils.py +++ b/tests/common/logger_utils.py @@ -29,7 +29,7 @@ from inspect import ( Signature, ) from logging import Logger -from typing import Any, Callable, cast, Optional, Type, Union +from typing import Any, Callable, cast, Union _DEFAULT_ENTER_MSG_PREFIX = "enter to " _DEFAULT_ENTER_MSG_SUFFIX = "" @@ -48,11 +48,11 @@ empty_and_none = {Signature.empty, "None"} Function = Callable[..., Any] -Decorated = Union[Type[Any], Function] +Decorated = Union[type[Any], Function] def log( - decorated: Optional[Decorated] = None, + decorated: Decorated | None = None, *, prefix_enter_msg: str = _DEFAULT_ENTER_MSG_PREFIX, suffix_enter_msg: str = _DEFAULT_ENTER_MSG_SUFFIX, @@ -85,11 +85,11 @@ def _make_decorator( def decorator(decorated: Decorated): decorated_logger = _get_logger(decorated) - def decorator_class(clazz: Type[Any]) -> Type[Any]: + def decorator_class(clazz: type[Any]) -> type[Any]: _decorate_class_members_with_logs(clazz) return clazz - def _decorate_class_members_with_logs(clazz: Type[Any]) -> None: + def _decorate_class_members_with_logs(clazz: type[Any]) -> None: members = getmembers( clazz, predicate=lambda val: ismethod(val) or isfunction(val) ) @@ -160,7 +160,7 @@ def _make_decorator( return _wrapper_func if isclass(decorated): - return decorator_class(cast(Type[Any], decorated)) + return decorator_class(cast(type[Any], decorated)) return decorator_func(cast(Function, decorated)) return decorator diff --git a/tests/common/query_context_generator.py b/tests/common/query_context_generator.py index 15b013dc84..32b4063974 100644 --- a/tests/common/query_context_generator.py +++ b/tests/common/query_context_generator.py @@ -16,7 +16,7 @@ # under the License. import copy import dataclasses -from typing import Any, Dict, List, Optional +from typing import Any, Optional from superset.common.chart_data import ChartDataResultType from superset.utils.core import AnnotationType, DTTM_ALIAS @@ -42,7 +42,7 @@ query_birth_names = { "where": "", } -QUERY_OBJECTS: Dict[str, Dict[str, object]] = { +QUERY_OBJECTS: dict[str, dict[str, object]] = { "birth_names": query_birth_names, # `:suffix` are overrides only "birth_names:include_time": { @@ -205,7 +205,7 @@ def get_query_object( query_name: str, add_postprocessing_operations: bool, add_time_offsets: bool, -) -> Dict[str, Any]: +) -> dict[str, Any]: if query_name not in QUERY_OBJECTS: raise Exception(f"QueryObject fixture not defined for datasource: {query_name}") obj = QUERY_OBJECTS[query_name] @@ -227,7 +227,7 @@ def get_query_object( return query_object -def _get_postprocessing_operation(query_name: str) -> List[Dict[str, Any]]: +def _get_postprocessing_operation(query_name: str) -> list[dict[str, Any]]: if query_name not in QUERY_OBJECTS: raise Exception( f"Post-processing fixture not defined for datasource: {query_name}" @@ -250,8 +250,8 @@ class QueryContextGenerator: add_time_offsets: bool = False, table_id=1, table_type="table", - form_data: Optional[Dict[str, Any]] = None, - ) -> Dict[str, Any]: + form_data: Optional[dict[str, Any]] = None, + ) -> dict[str, Any]: form_data = form_data or {} table_name = query_name.split(":")[0] table = self.get_table(table_name, table_id, table_type) diff --git a/tests/example_data/data_generator/base_generator.py b/tests/example_data/data_generator/base_generator.py index 023b929091..38ab2e5413 100644 --- a/tests/example_data/data_generator/base_generator.py +++ b/tests/example_data/data_generator/base_generator.py @@ -15,10 +15,11 @@ # specific language governing permissions and limitations # under the License. from abc import ABC, abstractmethod -from typing import Any, Dict, Iterable +from collections.abc import Iterable +from typing import Any class ExampleDataGenerator(ABC): @abstractmethod - def generate(self) -> Iterable[Dict[Any, Any]]: + def generate(self) -> Iterable[dict[Any, Any]]: ... diff --git a/tests/example_data/data_generator/birth_names/birth_names_generator.py b/tests/example_data/data_generator/birth_names/birth_names_generator.py index 2b68abbd4f..a8e8c45e28 100644 --- a/tests/example_data/data_generator/birth_names/birth_names_generator.py +++ b/tests/example_data/data_generator/birth_names/birth_names_generator.py @@ -16,9 +16,10 @@ # under the License. from __future__ import annotations +from collections.abc import Iterable from datetime import datetime from random import choice, randint -from typing import Any, Dict, Iterable, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from tests.consts.birth_names import ( BOY, @@ -58,7 +59,7 @@ class BirthNamesGenerator(ExampleDataGenerator): self._until_not_include_year = start_year + years_amount self._rows_per_year = rows_per_year - def generate(self) -> Iterable[Dict[Any, Any]]: + def generate(self) -> Iterable[dict[Any, Any]]: for year in range(self._start_year, self._until_not_include_year): ds = self._make_year(year) for _ in range(self._rows_per_year): @@ -67,7 +68,7 @@ class BirthNamesGenerator(ExampleDataGenerator): def _make_year(self, year: int): return datetime(year, 1, 1, 0, 0, 0) - def generate_row(self, dt: datetime) -> Dict[Any, Any]: + def generate_row(self, dt: datetime) -> dict[Any, Any]: gender = choice([BOY, GIRL]) num = randint(1, 100000) return { diff --git a/tests/example_data/data_loading/data_definitions/types.py b/tests/example_data/data_loading/data_definitions/types.py index e393019e01..a1ed104348 100644 --- a/tests/example_data/data_loading/data_definitions/types.py +++ b/tests/example_data/data_loading/data_definitions/types.py @@ -24,8 +24,9 @@ # specific language governing permissions and limitations # under the License. from abc import ABC, abstractmethod +from collections.abc import Iterable from dataclasses import dataclass -from typing import Any, Dict, Iterable, Optional +from typing import Any, Optional from sqlalchemy.types import TypeEngine @@ -33,14 +34,14 @@ from sqlalchemy.types import TypeEngine @dataclass class TableMetaData: table_name: str - types: Optional[Dict[str, TypeEngine]] + types: Optional[dict[str, TypeEngine]] @dataclass class Table: table_name: str table_metadata: TableMetaData - data: Iterable[Dict[Any, Any]] + data: Iterable[dict[Any, Any]] class TableMetaDataFactory(ABC): @@ -48,6 +49,6 @@ class TableMetaDataFactory(ABC): def make(self) -> TableMetaData: ... - def make_table(self, data: Iterable[Dict[Any, Any]]) -> Table: + def make_table(self, data: Iterable[dict[Any, Any]]) -> Table: metadata = self.make() return Table(metadata.table_name, metadata, data) diff --git a/tests/example_data/data_loading/pandas/pandas_data_loader.py b/tests/example_data/data_loading/pandas/pandas_data_loader.py index 7f41602054..49dcf3b2db 100644 --- a/tests/example_data/data_loading/pandas/pandas_data_loader.py +++ b/tests/example_data/data_loading/pandas/pandas_data_loader.py @@ -17,7 +17,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Dict, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING from pandas import DataFrame from sqlalchemy.inspection import inspect @@ -63,10 +63,10 @@ class PandasDataLoader(DataLoader): schema=self._detect_schema_name(), ) - def _detect_schema_name(self) -> Optional[str]: + def _detect_schema_name(self) -> str | None: return inspect(self._db_engine).default_schema_name - def _take_data_types(self, table: Table) -> Optional[Dict[str, str]]: + def _take_data_types(self, table: Table) -> dict[str, str] | None: if metadata_table := table.table_metadata: types = metadata_table.types if types: diff --git a/tests/example_data/data_loading/pandas/pands_data_loading_conf.py b/tests/example_data/data_loading/pandas/pands_data_loading_conf.py index 1c43adc931..8de12b39ef 100644 --- a/tests/example_data/data_loading/pandas/pands_data_loading_conf.py +++ b/tests/example_data/data_loading/pandas/pands_data_loading_conf.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from typing import Any, Dict +from typing import Any default_pandas_data_loader_config = { "if_exists": "replace", @@ -54,7 +54,7 @@ class PandasLoaderConfigurations: self.support_datetime_type = support_datetime_type @classmethod - def make_from_dict(cls, _dict: Dict[str, Any]) -> PandasLoaderConfigurations: + def make_from_dict(cls, _dict: dict[str, Any]) -> PandasLoaderConfigurations: copy_dict = default_pandas_data_loader_config.copy() copy_dict.update(_dict) return PandasLoaderConfigurations(**copy_dict) # type: ignore diff --git a/tests/example_data/data_loading/pandas/table_df_convertor.py b/tests/example_data/data_loading/pandas/table_df_convertor.py index e801c8464e..aad1077ce5 100644 --- a/tests/example_data/data_loading/pandas/table_df_convertor.py +++ b/tests/example_data/data_loading/pandas/table_df_convertor.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from typing import Optional, TYPE_CHECKING +from typing import TYPE_CHECKING from pandas import DataFrame @@ -30,10 +30,10 @@ if TYPE_CHECKING: @log class TableToDfConvertorImpl(TableToDfConvertor): convert_datetime_to_str: bool - _time_format: Optional[str] + _time_format: str | None def __init__( - self, convert_ds_to_datetime: bool, time_format: Optional[str] = None + self, convert_ds_to_datetime: bool, time_format: str | None = None ) -> None: self.convert_datetime_to_str = convert_ds_to_datetime self._time_format = time_format diff --git a/tests/integration_tests/access_tests.py b/tests/integration_tests/access_tests.py index 38fd105240..79fdff6346 100644 --- a/tests/integration_tests/access_tests.py +++ b/tests/integration_tests/access_tests.py @@ -516,7 +516,7 @@ class TestRequestAccess(SupersetTestCase): ) self.assertEqual( access_request3.roles_with_datasource, - "
    • {}
    ".format(approve_link_3), + f"
    • {approve_link_3}
    ", ) # cleanup diff --git a/tests/integration_tests/advanced_data_type/api_tests.py b/tests/integration_tests/advanced_data_type/api_tests.py index 5bfe308e16..e865069462 100644 --- a/tests/integration_tests/advanced_data_type/api_tests.py +++ b/tests/integration_tests/advanced_data_type/api_tests.py @@ -24,7 +24,7 @@ from superset.utils.core import get_example_default_schema from tests.integration_tests.utils.get_dashboards import get_dashboards_ids from unittest import mock from sqlalchemy import Column -from typing import Any, List +from typing import Any from superset.advanced_data_type.types import ( AdvancedDataType, AdvancedDataTypeRequest, @@ -52,7 +52,7 @@ def translation_func(req: AdvancedDataTypeRequest) -> AdvancedDataTypeResponse: return target_resp -def translate_filter_func(col: Column, op: FilterOperator, values: List[Any]): +def translate_filter_func(col: Column, op: FilterOperator, values: list[Any]): pass diff --git a/tests/integration_tests/base_tests.py b/tests/integration_tests/base_tests.py index f70f0f63bd..fec66f88d2 100644 --- a/tests/integration_tests/base_tests.py +++ b/tests/integration_tests/base_tests.py @@ -20,7 +20,7 @@ from datetime import datetime import imp import json from contextlib import contextmanager -from typing import Any, Dict, Union, List, Optional +from typing import Any, Union, Optional from unittest.mock import Mock, patch, MagicMock import pandas as pd @@ -67,12 +67,12 @@ def get_resp( else: resp = client.get(url, follow_redirects=follow_redirects) if raise_on_error and resp.status_code > 400: - raise Exception("http request failed with code {}".format(resp.status_code)) + raise Exception(f"http request failed with code {resp.status_code}") return resp.data.decode("utf-8") def post_assert_metric( - client: Any, uri: str, data: Dict[str, Any], func_name: str + client: Any, uri: str, data: dict[str, Any], func_name: str ) -> Response: """ Simple client post with an extra assertion for statsd metrics @@ -121,7 +121,7 @@ class SupersetTestCase(TestCase): @staticmethod def create_user_with_roles( - username: str, roles: List[str], should_create_roles: bool = False + username: str, roles: list[str], should_create_roles: bool = False ): user_to_create = security_manager.find_user(username) if not user_to_create: @@ -485,12 +485,12 @@ class SupersetTestCase(TestCase): return rv def post_assert_metric( - self, uri: str, data: Dict[str, Any], func_name: str + self, uri: str, data: dict[str, Any], func_name: str ) -> Response: return post_assert_metric(self.client, uri, data, func_name) def put_assert_metric( - self, uri: str, data: Dict[str, Any], func_name: str + self, uri: str, data: dict[str, Any], func_name: str ) -> Response: """ Simple client put with an extra assertion for statsd metrics diff --git a/tests/integration_tests/cachekeys/api_tests.py b/tests/integration_tests/cachekeys/api_tests.py index d3552bfc8d..c867ce7f51 100644 --- a/tests/integration_tests/cachekeys/api_tests.py +++ b/tests/integration_tests/cachekeys/api_tests.py @@ -16,7 +16,7 @@ # under the License. # isort:skip_file """Unit tests for Superset""" -from typing import Dict, Any +from typing import Any import pytest @@ -31,7 +31,7 @@ from tests.integration_tests.base_tests import ( @pytest.fixture def invalidate(test_client, login_as_admin): - def _invalidate(params: Dict[str, Any]): + def _invalidate(params: dict[str, Any]): return post_assert_metric( test_client, "api/v1/cachekey/invalidate", params, "invalidate" ) diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py index 99b3275281..fa09e56675 100644 --- a/tests/integration_tests/charts/api_tests.py +++ b/tests/integration_tests/charts/api_tests.py @@ -17,7 +17,6 @@ # isort:skip_file """Unit tests for Superset""" import json -import logging from io import BytesIO from zipfile import is_zipfile, ZipFile diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py index b02ccb5b96..f9e6b5e3b1 100644 --- a/tests/integration_tests/charts/data/api_tests.py +++ b/tests/integration_tests/charts/data/api_tests.py @@ -21,7 +21,7 @@ import unittest import copy from datetime import datetime from io import BytesIO -from typing import Any, Dict, Optional, List +from typing import Any, Optional from unittest import mock from zipfile import ZipFile @@ -740,11 +740,11 @@ class TestPostChartDataApi(BaseTestChartDataApi): data = rv.json["result"][0]["data"] - unique_names = set(row["name"] for row in data) + unique_names = {row["name"] for row in data} self.maxDiff = None self.assertEqual(len(unique_names), SERIES_LIMIT) self.assertEqual( - set(column for column in data[0].keys()), {"state", "name", "sum__num"} + {column for column in data[0].keys()}, {"state", "name", "sum__num"} ) @pytest.mark.usefixtures( @@ -1124,7 +1124,7 @@ class TestGetChartDataApi(BaseTestChartDataApi): @pytest.fixture() -def physical_query_context(physical_dataset) -> Dict[str, Any]: +def physical_query_context(physical_dataset) -> dict[str, Any]: return { "datasource": { "type": physical_dataset.type, @@ -1218,7 +1218,7 @@ def test_data_cache_default_timeout( def test_chart_cache_timeout( - load_energy_table_with_slice: List[Slice], + load_energy_table_with_slice: list[Slice], test_client, login_as_admin, physical_query_context, diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index 0ea5bb5106..28da7b7913 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -19,7 +19,7 @@ from __future__ import annotations import contextlib import functools import os -from typing import Any, Callable, Dict, Optional, TYPE_CHECKING +from typing import Any, Callable, TYPE_CHECKING from unittest.mock import patch import pytest @@ -55,7 +55,7 @@ def test_client(app_context: AppContext): @pytest.fixture -def login_as(test_client: "FlaskClient[Any]"): +def login_as(test_client: FlaskClient[Any]): """Fixture with app context and logged in admin user.""" def _login_as(username: str, password: str = "general"): @@ -160,7 +160,7 @@ def drop_from_schema(engine: Engine, schema_name: str): @pytest.fixture(scope="session") def example_db_provider() -> Callable[[], Database]: # type: ignore class _example_db_provider: - _db: Optional[Database] = None + _db: Database | None = None def __call__(self) -> Database: with app.app_context(): @@ -257,7 +257,7 @@ def with_feature_flags(**mock_feature_flags): return decorate -def with_config(override_config: Dict[str, Any]): +def with_config(override_config: dict[str, Any]): """ Use this decorator to mock specific config keys. diff --git a/tests/integration_tests/core_tests.py b/tests/integration_tests/core_tests.py index 2e9e287620..f0c72b0680 100644 --- a/tests/integration_tests/core_tests.py +++ b/tests/integration_tests/core_tests.py @@ -23,7 +23,6 @@ import html import io import json import logging -from typing import Dict, List from urllib.parse import quote import superset.utils.database @@ -37,7 +36,6 @@ from sqlalchemy import Table import pytest import pytz import random -import re import unittest from unittest import mock @@ -496,7 +494,7 @@ class TestCore(SupersetTestCase): assert response.headers["Content-Type"] == "application/json" response_body = json.loads(response.data.decode("utf-8")) expected_body = {"error": "Could not load database driver: broken"} - assert response_body == expected_body, "%s != %s" % ( + assert response_body == expected_body, "{} != {}".format( response_body, expected_body, ) @@ -515,7 +513,7 @@ class TestCore(SupersetTestCase): assert response.headers["Content-Type"] == "application/json" response_body = json.loads(response.data.decode("utf-8")) expected_body = {"error": "Could not load database driver: mssql+pymssql"} - assert response_body == expected_body, "%s != %s" % ( + assert response_body == expected_body, "{} != {}".format( response_body, expected_body, ) @@ -563,7 +561,7 @@ class TestCore(SupersetTestCase): self.login(username=username) database = superset.utils.database.get_example_database() sqlalchemy_uri_decrypted = database.sqlalchemy_uri_decrypted - url = "databaseview/edit/{}".format(database.id) + url = f"databaseview/edit/{database.id}" data = {k: database.__getattribute__(k) for k in DatabaseView.add_columns} data["sqlalchemy_uri"] = database.safe_sqlalchemy_uri() self.client.post(url, data=data) @@ -582,7 +580,7 @@ class TestCore(SupersetTestCase): def test_warm_up_cache(self): self.login() slc = self.get_slice("Girls", db.session) - data = self.get_json_resp("/superset/warm_up_cache?slice_id={}".format(slc.id)) + data = self.get_json_resp(f"/superset/warm_up_cache?slice_id={slc.id}") self.assertEqual( data, [{"slice_id": slc.id, "viz_error": None, "viz_status": "success"}] ) @@ -609,7 +607,7 @@ class TestCore(SupersetTestCase): store_cache_keys = app.config["STORE_CACHE_KEYS_IN_METADATA_DB"] app.config["STORE_CACHE_KEYS_IN_METADATA_DB"] = True girls_slice = self.get_slice("Girls", db.session) - self.get_json_resp("/superset/warm_up_cache?slice_id={}".format(girls_slice.id)) + self.get_json_resp(f"/superset/warm_up_cache?slice_id={girls_slice.id}") ck = db.session.query(CacheKey).order_by(CacheKey.id.desc()).first() assert ck.datasource_uid == f"{girls_slice.table.id}__table" app.config["STORE_CACHE_KEYS_IN_METADATA_DB"] = store_cache_keys @@ -650,7 +648,7 @@ class TestCore(SupersetTestCase): kv_value = kv.value self.assertEqual(json.loads(value), json.loads(kv_value)) - resp = self.client.get("/kv/{}/".format(kv.id)) + resp = self.client.get(f"/kv/{kv.id}/") self.assertEqual(resp.status_code, 200) self.assertEqual(json.loads(value), json.loads(resp.data.decode("utf-8"))) @@ -662,7 +660,7 @@ class TestCore(SupersetTestCase): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_csv_endpoint(self): self.login() - client_id = "{}".format(random.getrandbits(64))[:10] + client_id = f"{random.getrandbits(64)}"[:10] get_name_sql = """ SELECT name FROM birth_names @@ -676,17 +674,17 @@ class TestCore(SupersetTestCase): WHERE name = '{name}' LIMIT 1 """ - client_id = "{}".format(random.getrandbits(64))[:10] + client_id = f"{random.getrandbits(64)}"[:10] self.run_sql(sql, client_id, raise_on_error=True) - resp = self.get_resp("/superset/csv/{}".format(client_id)) + resp = self.get_resp(f"/superset/csv/{client_id}") data = csv.reader(io.StringIO(resp)) expected_data = csv.reader(io.StringIO(f"name\n{name}\n")) - client_id = "{}".format(random.getrandbits(64))[:10] + client_id = f"{random.getrandbits(64)}"[:10] self.run_sql(sql, client_id, raise_on_error=True) - resp = self.get_resp("/superset/csv/{}".format(client_id)) + resp = self.get_resp(f"/superset/csv/{client_id}") data = csv.reader(io.StringIO(resp)) expected_data = csv.reader(io.StringIO(f"name\n{name}\n")) @@ -704,7 +702,7 @@ class TestCore(SupersetTestCase): def test_required_params_in_sql_json(self): self.login() - client_id = "{}".format(random.getrandbits(64))[:10] + client_id = f"{random.getrandbits(64)}"[:10] data = {"client_id": client_id} rv = self.client.post( @@ -876,12 +874,12 @@ class TestCore(SupersetTestCase): self.get_resp(slc.slice_url) self.assertEqual(1, qry.count()) - def create_sample_csvfile(self, filename: str, content: List[str]) -> None: + def create_sample_csvfile(self, filename: str, content: list[str]) -> None: with open(filename, "w+") as test_file: for l in content: test_file.write(f"{l}\n") - def create_sample_excelfile(self, filename: str, content: Dict[str, str]) -> None: + def create_sample_excelfile(self, filename: str, content: dict[str, str]) -> None: pd.DataFrame(content).to_excel(filename) def enable_csv_upload(self, database: models.Database) -> None: diff --git a/tests/integration_tests/csv_upload_tests.py b/tests/integration_tests/csv_upload_tests.py index 91a76f97cf..9bc204ff06 100644 --- a/tests/integration_tests/csv_upload_tests.py +++ b/tests/integration_tests/csv_upload_tests.py @@ -20,7 +20,7 @@ import json import logging import os import shutil -from typing import Dict, Optional, Union +from typing import Optional, Union from unittest import mock @@ -132,7 +132,7 @@ def get_upload_db(): def upload_csv( filename: str, table_name: str, - extra: Optional[Dict[str, str]] = None, + extra: Optional[dict[str, str]] = None, dtype: Union[str, None] = None, ): csv_upload_db_id = get_upload_db().id @@ -155,7 +155,7 @@ def upload_csv( def upload_excel( - filename: str, table_name: str, extra: Optional[Dict[str, str]] = None + filename: str, table_name: str, extra: Optional[dict[str, str]] = None ): excel_upload_db_id = get_upload_db().id form_data = { @@ -175,7 +175,7 @@ def upload_excel( def upload_columnar( - filename: str, table_name: str, extra: Optional[Dict[str, str]] = None + filename: str, table_name: str, extra: Optional[dict[str, str]] = None ): columnar_upload_db_id = get_upload_db().id form_data = { @@ -218,7 +218,7 @@ def mock_upload_to_s3(filename: str, upload_prefix: str, table: Table) -> str: def escaped_double_quotes(text): - return f"\"{text}\"" + return rf"\"{text}\"" def escaped_parquet(text): diff --git a/tests/integration_tests/dashboard_tests.py b/tests/integration_tests/dashboard_tests.py index d54151db83..669bc93693 100644 --- a/tests/integration_tests/dashboard_tests.py +++ b/tests/integration_tests/dashboard_tests.py @@ -115,7 +115,7 @@ class TestDashboard(SupersetTestCase): def get_mock_positions(self, dash): positions = {"DASHBOARD_VERSION_KEY": "v2"} for i, slc in enumerate(dash.slices): - id = "DASHBOARD_CHART_TYPE-{}".format(i) + id = f"DASHBOARD_CHART_TYPE-{i}" d = { "type": "CHART", "id": id, @@ -167,7 +167,7 @@ class TestDashboard(SupersetTestCase): # set a further modified_time for unit test "last_modified_time": datetime.now().timestamp() + 1000, } - url = "/superset/save_dash/{}/".format(dash.id) + url = f"/superset/save_dash/{dash.id}/" resp = self.get_resp(url, data=dict(data=json.dumps(data))) self.assertIn("SUCCESS", resp) @@ -189,7 +189,7 @@ class TestDashboard(SupersetTestCase): "last_modified_time": datetime.now().timestamp() + 1000, } - url = "/superset/save_dash/{}/".format(dash.id) + url = f"/superset/save_dash/{dash.id}/" resp = self.get_resp(url, data=dict(data=json.dumps(data))) self.assertIn("SUCCESS", resp) @@ -217,7 +217,7 @@ class TestDashboard(SupersetTestCase): "last_modified_time": datetime.now().timestamp() + 1000, } - url = "/superset/save_dash/{}/".format(dash.id) + url = f"/superset/save_dash/{dash.id}/" resp = self.get_resp(url, data=dict(data=json.dumps(data))) self.assertIn("SUCCESS", resp) @@ -239,7 +239,7 @@ class TestDashboard(SupersetTestCase): # set a further modified_time for unit test "last_modified_time": datetime.now().timestamp() + 1000, } - url = "/superset/save_dash/{}/".format(dash.id) + url = f"/superset/save_dash/{dash.id}/" self.get_resp(url, data=dict(data=json.dumps(data))) updatedDash = db.session.query(Dashboard).filter_by(slug="births").first() self.assertEqual(updatedDash.dashboard_title, "new title") @@ -264,7 +264,7 @@ class TestDashboard(SupersetTestCase): # set a further modified_time for unit test "last_modified_time": datetime.now().timestamp() + 1000, } - url = "/superset/save_dash/{}/".format(dash.id) + url = f"/superset/save_dash/{dash.id}/" self.get_resp(url, data=dict(data=json.dumps(data))) updatedDash = db.session.query(Dashboard).filter_by(slug="births").first() self.assertIn("color_namespace", updatedDash.json_metadata) @@ -301,13 +301,13 @@ class TestDashboard(SupersetTestCase): # Save changes to Births dashboard and retrieve updated dash dash_id = dash.id - url = "/superset/save_dash/{}/".format(dash_id) + url = f"/superset/save_dash/{dash_id}/" self.client.post(url, data=dict(data=json.dumps(data))) dash = db.session.query(Dashboard).filter_by(id=dash_id).first() orig_json_data = dash.data # Verify that copy matches original - url = "/superset/copy_dash/{}/".format(dash_id) + url = f"/superset/copy_dash/{dash_id}/" resp = self.get_json_resp(url, data=dict(data=json.dumps(data))) self.assertEqual(resp["dashboard_title"], "Copy Of Births") self.assertEqual(resp["position_json"], orig_json_data["position_json"]) @@ -334,7 +334,7 @@ class TestDashboard(SupersetTestCase): data = { "slice_ids": [new_slice.data["slice_id"], existing_slice.data["slice_id"]] } - url = "/superset/add_slices/{}/".format(dash.id) + url = f"/superset/add_slices/{dash.id}/" resp = self.client.post(url, data=dict(data=json.dumps(data))) assert "SLICES ADDED" in resp.data.decode("utf-8") @@ -375,7 +375,7 @@ class TestDashboard(SupersetTestCase): # save dash dash_id = dash.id - url = "/superset/save_dash/{}/".format(dash_id) + url = f"/superset/save_dash/{dash_id}/" self.client.post(url, data=dict(data=json.dumps(data))) dash = db.session.query(Dashboard).filter_by(id=dash_id).first() diff --git a/tests/integration_tests/dashboard_utils.py b/tests/integration_tests/dashboard_utils.py index bea724dafc..6c3d000051 100644 --- a/tests/integration_tests/dashboard_utils.py +++ b/tests/integration_tests/dashboard_utils.py @@ -17,7 +17,7 @@ """Utils to provide dashboards for tests""" import json -from typing import Any, Dict, List, Optional +from typing import Optional from pandas import DataFrame @@ -65,7 +65,7 @@ def create_table_metadata( def create_slice( - title: str, viz_type: str, table: SqlaTable, slices_dict: Dict[str, str] + title: str, viz_type: str, table: SqlaTable, slices_dict: dict[str, str] ) -> Slice: return Slice( slice_name=title, @@ -77,7 +77,7 @@ def create_slice( def create_dashboard( - slug: str, title: str, position: str, slices: List[Slice] + slug: str, title: str, position: str, slices: list[Slice] ) -> Dashboard: dash = db.session.query(Dashboard).filter_by(slug=slug).one_or_none() if dash: diff --git a/tests/integration_tests/dashboards/api_tests.py b/tests/integration_tests/dashboards/api_tests.py index f57afd95a6..49a6bbecbc 100644 --- a/tests/integration_tests/dashboards/api_tests.py +++ b/tests/integration_tests/dashboards/api_tests.py @@ -19,7 +19,7 @@ import json from io import BytesIO from time import sleep -from typing import List, Optional +from typing import Optional from unittest.mock import ANY, patch from zipfile import is_zipfile, ZipFile @@ -66,7 +66,7 @@ DASHBOARDS_FIXTURE_COUNT = 10 class TestDashboardApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixin): resource_name = "dashboard" - dashboards: List[Dashboard] = [] + dashboards: list[Dashboard] = [] dashboard_data = { "dashboard_title": "title1_changed", "slug": "slug1_changed", @@ -80,10 +80,10 @@ class TestDashboardApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixi self, dashboard_title: str, slug: Optional[str], - owners: List[int], - roles: List[int] = [], + owners: list[int], + roles: list[int] = [], created_by=None, - slices: Optional[List[Slice]] = None, + slices: Optional[list[Slice]] = None, position_json: str = "", css: str = "", json_metadata: str = "", @@ -211,9 +211,9 @@ class TestDashboardApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixi self.assertEqual(response.status_code, 200) data = json.loads(response.data.decode("utf-8")) dashboard = Dashboard.get("world_health") - expected_dataset_ids = set([s.datasource_id for s in dashboard.slices]) + expected_dataset_ids = {s.datasource_id for s in dashboard.slices} result = data["result"] - actual_dataset_ids = set([dataset["id"] for dataset in result]) + actual_dataset_ids = {dataset["id"] for dataset in result} self.assertEqual(actual_dataset_ids, expected_dataset_ids) expected_values = [0, 1] if backend() == "presto" else [0, 1, 2] self.assertEqual(result[0]["column_types"], expected_values) @@ -927,7 +927,7 @@ class TestDashboardApi(SupersetTestCase, ApiOwnersTestCaseMixin, InsertChartMixi buf = BytesIO() with ZipFile(buf, "w") as bundle: with bundle.open("sql/dump.sql", "w") as fp: - fp.write("CREATE TABLE foo (bar INT)".encode()) + fp.write(b"CREATE TABLE foo (bar INT)") buf.seek(0) return buf diff --git a/tests/integration_tests/dashboards/base_case.py b/tests/integration_tests/dashboards/base_case.py index a0a1ff630f..db85cd6409 100644 --- a/tests/integration_tests/dashboards/base_case.py +++ b/tests/integration_tests/dashboards/base_case.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import json -from typing import Any, Dict, Union +from typing import Any, Union import prison from flask import Response @@ -49,13 +49,13 @@ class DashboardTestCase(SupersetTestCase): return self.client.get(DASHBOARDS_API_URL) def save_dashboard_via_view( - self, dashboard_id: Union[str, int], dashboard_data: Dict[str, Any] + self, dashboard_id: Union[str, int], dashboard_data: dict[str, Any] ) -> Response: save_dash_url = SAVE_DASHBOARD_URL_FORMAT.format(dashboard_id) return self.get_resp(save_dash_url, data=dict(data=json.dumps(dashboard_data))) def save_dashboard( - self, dashboard_id: Union[str, int], dashboard_data: Dict[str, Any] + self, dashboard_id: Union[str, int], dashboard_data: dict[str, Any] ) -> Response: return self.save_dashboard_via_view(dashboard_id, dashboard_data) diff --git a/tests/integration_tests/dashboards/dashboard_test_utils.py b/tests/integration_tests/dashboards/dashboard_test_utils.py index df2687fba9..ee8001cdba 100644 --- a/tests/integration_tests/dashboards/dashboard_test_utils.py +++ b/tests/integration_tests/dashboards/dashboard_test_utils.py @@ -17,7 +17,7 @@ import logging import random import string -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Optional from sqlalchemy import func @@ -32,10 +32,10 @@ logger = logging.getLogger(__name__) session = appbuilder.get_session -def get_mock_positions(dashboard: Dashboard) -> Dict[str, Any]: +def get_mock_positions(dashboard: Dashboard) -> dict[str, Any]: positions = {"DASHBOARD_VERSION_KEY": "v2"} for i, slc in enumerate(dashboard.slices): - id_ = "DASHBOARD_CHART_TYPE-{}".format(i) + id_ = f"DASHBOARD_CHART_TYPE-{i}" position_data: Any = { "type": "CHART", "id": id_, @@ -48,7 +48,7 @@ def get_mock_positions(dashboard: Dashboard) -> Dict[str, Any]: def build_save_dash_parts( dashboard_slug: Optional[str] = None, dashboard_to_edit: Optional[Dashboard] = None -) -> Tuple[Dashboard, Dict[str, Any], Dict[str, Any]]: +) -> tuple[Dashboard, dict[str, Any], dict[str, Any]]: if not dashboard_to_edit: dashboard_slug = ( dashboard_slug if dashboard_slug else DEFAULT_DASHBOARD_SLUG_TO_TEST @@ -68,7 +68,7 @@ def build_save_dash_parts( return dashboard_to_edit, data_before_change, data_after_change -def get_all_dashboards() -> List[Dashboard]: +def get_all_dashboards() -> list[Dashboard]: return db.session.query(Dashboard).all() diff --git a/tests/integration_tests/dashboards/filter_sets/conftest.py b/tests/integration_tests/dashboards/filter_sets/conftest.py index b7a28273b0..b19e929f9d 100644 --- a/tests/integration_tests/dashboards/filter_sets/conftest.py +++ b/tests/integration_tests/dashboards/filter_sets/conftest.py @@ -17,7 +17,8 @@ from __future__ import annotations import json -from typing import Any, Dict, Generator, List, TYPE_CHECKING +from collections.abc import Generator +from typing import Any, TYPE_CHECKING import pytest @@ -67,7 +68,7 @@ security_manager: BaseSecurityManager = sm @pytest.fixture(autouse=True, scope="module") -def test_users() -> Generator[Dict[str, int], None, None]: +def test_users() -> Generator[dict[str, int], None, None]: usernames = [ ADMIN_USERNAME_FOR_TEST, DASHBOARD_OWNER_USERNAME, @@ -82,16 +83,16 @@ def test_users() -> Generator[Dict[str, int], None, None]: delete_users(usernames_to_ids) -def delete_users(usernames_to_ids: Dict[str, int]) -> None: +def delete_users(usernames_to_ids: dict[str, int]) -> None: for username in usernames_to_ids.keys(): db.session.delete(security_manager.find_user(username)) db.session.commit() def create_test_users( - admin_role: Role, filter_set_role: Role, usernames: List[str] -) -> Dict[str, int]: - users: List[User] = [] + admin_role: Role, filter_set_role: Role, usernames: list[str] +) -> dict[str, int]: + users: list[User] = [] for username in usernames: user = build_user(username, filter_set_role, admin_role) users.append(user) @@ -108,7 +109,7 @@ def build_user(username: str, filter_set_role: Role, admin_role: Role) -> User: if not user: user = security_manager.find_user(username) if user is None: - raise Exception("Failed to build the user {}".format(username)) + raise Exception(f"Failed to build the user {username}") return user @@ -118,7 +119,7 @@ def build_filter_set_role() -> Role: all_datasource_view_name: ViewMenu = security_manager.find_view_menu( "all_datasource_access" ) - pvms: List[PermissionView] = security_manager.find_permissions_view_menu( + pvms: list[PermissionView] = security_manager.find_permissions_view_menu( filterset_view_name ) + security_manager.find_permissions_view_menu(all_datasource_view_name) for pvm in pvms: @@ -167,8 +168,8 @@ def dashboard_id(dashboard: Dashboard) -> Generator[int, None, None]: @pytest.fixture def filtersets( - dashboard_id: int, test_users: Dict[str, int], dumped_valid_json_metadata: str -) -> Generator[Dict[str, List[FilterSet]], None, None]: + dashboard_id: int, test_users: dict[str, int], dumped_valid_json_metadata: str +) -> Generator[dict[str, list[FilterSet]], None, None]: first_filter_set = FilterSet( name="filter_set_1_of_" + str(dashboard_id), dashboard_id=dashboard_id, @@ -216,17 +217,17 @@ def filtersets( @pytest.fixture -def filterset_id(filtersets: Dict[str, List[FilterSet]]) -> int: +def filterset_id(filtersets: dict[str, list[FilterSet]]) -> int: return filtersets["Dashboard"][0].id @pytest.fixture -def valid_json_metadata() -> Dict[str, Any]: +def valid_json_metadata() -> dict[str, Any]: return {"nativeFilters": {}} @pytest.fixture -def dumped_valid_json_metadata(valid_json_metadata: Dict[str, Any]) -> str: +def dumped_valid_json_metadata(valid_json_metadata: dict[str, Any]) -> str: return json.dumps(valid_json_metadata) @@ -238,7 +239,7 @@ def exists_user_id() -> int: @pytest.fixture def valid_filter_set_data_for_create( dashboard_id: int, dumped_valid_json_metadata: str, exists_user_id: int -) -> Dict[str, Any]: +) -> dict[str, Any]: name = "test_filter_set_of_dashboard_" + str(dashboard_id) return { NAME_FIELD: name, @@ -252,7 +253,7 @@ def valid_filter_set_data_for_create( @pytest.fixture def valid_filter_set_data_for_update( dashboard_id: int, dumped_valid_json_metadata: str, exists_user_id: int -) -> Dict[str, Any]: +) -> dict[str, Any]: name = "name_changed_test_filter_set_of_dashboard_" + str(dashboard_id) return { NAME_FIELD: name, @@ -273,13 +274,13 @@ def not_exists_user_id() -> int: @pytest.fixture() def dashboard_based_filter_set_dict( - filtersets: Dict[str, List[FilterSet]] -) -> Dict[str, Any]: + filtersets: dict[str, list[FilterSet]] +) -> dict[str, Any]: return filtersets["Dashboard"][0].to_dict() @pytest.fixture() def user_based_filter_set_dict( - filtersets: Dict[str, List[FilterSet]] -) -> Dict[str, Any]: + filtersets: dict[str, list[FilterSet]] +) -> dict[str, Any]: return filtersets[FILTER_SET_OWNER_USERNAME][0].to_dict() diff --git a/tests/integration_tests/dashboards/filter_sets/create_api_tests.py b/tests/integration_tests/dashboards/filter_sets/create_api_tests.py index b5d1919dd4..9891266101 100644 --- a/tests/integration_tests/dashboards/filter_sets/create_api_tests.py +++ b/tests/integration_tests/dashboards/filter_sets/create_api_tests.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from typing import Any, Dict +from typing import Any from flask.testing import FlaskClient @@ -42,11 +42,11 @@ from tests.integration_tests.dashboards.filter_sets.utils import ( from tests.integration_tests.test_app import login -def assert_filterset_was_not_created(filter_set_data: Dict[str, Any]) -> None: +def assert_filterset_was_not_created(filter_set_data: dict[str, Any]) -> None: assert get_filter_set_by_name(str(filter_set_data["name"])) is None -def assert_filterset_was_created(filter_set_data: Dict[str, Any]) -> None: +def assert_filterset_was_created(filter_set_data: dict[str, Any]) -> None: assert get_filter_set_by_name(filter_set_data["name"]) is not None @@ -54,7 +54,7 @@ class TestCreateFilterSetsApi: def test_with_extra_field__400( self, dashboard_id: int, - valid_filter_set_data_for_create: Dict[str, Any], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -74,7 +74,7 @@ class TestCreateFilterSetsApi: def test_with_id_field__400( self, dashboard_id: int, - valid_filter_set_data_for_create: Dict[str, Any], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -94,7 +94,7 @@ class TestCreateFilterSetsApi: def test_with_dashboard_not_exists__404( self, not_exists_dashboard_id: int, - valid_filter_set_data_for_create: Dict[str, Any], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # act @@ -110,7 +110,7 @@ class TestCreateFilterSetsApi: def test_without_name__400( self, dashboard_id: int, - valid_filter_set_data_for_create: Dict[str, Any], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -129,7 +129,7 @@ class TestCreateFilterSetsApi: def test_with_none_name__400( self, dashboard_id: int, - valid_filter_set_data_for_create: Dict[str, Any], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -148,7 +148,7 @@ class TestCreateFilterSetsApi: def test_with_int_as_name__400( self, dashboard_id: int, - valid_filter_set_data_for_create: Dict[str, Any], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -167,7 +167,7 @@ class TestCreateFilterSetsApi: def test_without_description__201( self, dashboard_id: int, - valid_filter_set_data_for_create: Dict[str, Any], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -186,7 +186,7 @@ class TestCreateFilterSetsApi: def test_with_none_description__201( self, dashboard_id: int, - valid_filter_set_data_for_create: Dict[str, Any], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -205,7 +205,7 @@ class TestCreateFilterSetsApi: def test_with_int_as_description__400( self, dashboard_id: int, - valid_filter_set_data_for_create: Dict[str, Any], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -224,7 +224,7 @@ class TestCreateFilterSetsApi: def test_without_json_metadata__400( self, dashboard_id: int, - valid_filter_set_data_for_create: Dict[str, Any], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -243,7 +243,7 @@ class TestCreateFilterSetsApi: def test_with_invalid_json_metadata__400( self, dashboard_id: int, - valid_filter_set_data_for_create: Dict[str, Any], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -262,7 +262,7 @@ class TestCreateFilterSetsApi: def test_without_owner_type__400( self, dashboard_id: int, - valid_filter_set_data_for_create: Dict[str, Any], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -281,7 +281,7 @@ class TestCreateFilterSetsApi: def test_with_invalid_owner_type__400( self, dashboard_id: int, - valid_filter_set_data_for_create: Dict[str, Any], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -300,7 +300,7 @@ class TestCreateFilterSetsApi: def test_without_owner_id_when_owner_type_is_user__400( self, dashboard_id: int, - valid_filter_set_data_for_create: Dict[str, Any], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -320,7 +320,7 @@ class TestCreateFilterSetsApi: def test_without_owner_id_when_owner_type_is_dashboard__201( self, dashboard_id: int, - valid_filter_set_data_for_create: Dict[str, Any], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -340,7 +340,7 @@ class TestCreateFilterSetsApi: def test_with_not_exists_owner__400( self, dashboard_id: int, - valid_filter_set_data_for_create: Dict[str, Any], + valid_filter_set_data_for_create: dict[str, Any], not_exists_user_id: int, client: FlaskClient[Any], ): @@ -361,8 +361,8 @@ class TestCreateFilterSetsApi: def test_when_caller_is_admin_and_owner_is_admin__201( self, dashboard_id: int, - test_users: Dict[str, int], - valid_filter_set_data_for_create: Dict[str, Any], + test_users: dict[str, int], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -384,8 +384,8 @@ class TestCreateFilterSetsApi: def test_when_caller_is_admin_and_owner_is_dashboard_owner__201( self, dashboard_id: int, - test_users: Dict[str, int], - valid_filter_set_data_for_create: Dict[str, Any], + test_users: dict[str, int], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -407,8 +407,8 @@ class TestCreateFilterSetsApi: def test_when_caller_is_admin_and_owner_is_regular_user__201( self, dashboard_id: int, - test_users: Dict[str, int], - valid_filter_set_data_for_create: Dict[str, Any], + test_users: dict[str, int], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -430,8 +430,8 @@ class TestCreateFilterSetsApi: def test_when_caller_is_admin_and_owner_type_is_dashboard__201( self, dashboard_id: int, - test_users: Dict[str, int], - valid_filter_set_data_for_create: Dict[str, Any], + test_users: dict[str, int], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -451,8 +451,8 @@ class TestCreateFilterSetsApi: def test_when_caller_is_dashboard_owner_and_owner_is_admin__201( self, dashboard_id: int, - test_users: Dict[str, int], - valid_filter_set_data_for_create: Dict[str, Any], + test_users: dict[str, int], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -474,8 +474,8 @@ class TestCreateFilterSetsApi: def test_when_caller_is_dashboard_owner_and_owner_is_dashboard_owner__201( self, dashboard_id: int, - test_users: Dict[str, int], - valid_filter_set_data_for_create: Dict[str, Any], + test_users: dict[str, int], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -497,8 +497,8 @@ class TestCreateFilterSetsApi: def test_when_caller_is_dashboard_owner_and_owner_is_regular_user__201( self, dashboard_id: int, - test_users: Dict[str, int], - valid_filter_set_data_for_create: Dict[str, Any], + test_users: dict[str, int], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -520,8 +520,8 @@ class TestCreateFilterSetsApi: def test_when_caller_is_dashboard_owner_and_owner_type_is_dashboard__201( self, dashboard_id: int, - test_users: Dict[str, int], - valid_filter_set_data_for_create: Dict[str, Any], + test_users: dict[str, int], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -541,8 +541,8 @@ class TestCreateFilterSetsApi: def test_when_caller_is_regular_user_and_owner_is_admin__201( self, dashboard_id: int, - test_users: Dict[str, int], - valid_filter_set_data_for_create: Dict[str, Any], + test_users: dict[str, int], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -564,8 +564,8 @@ class TestCreateFilterSetsApi: def test_when_caller_is_regular_user_and_owner_is_dashboard_owner__201( self, dashboard_id: int, - test_users: Dict[str, int], - valid_filter_set_data_for_create: Dict[str, Any], + test_users: dict[str, int], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -587,8 +587,8 @@ class TestCreateFilterSetsApi: def test_when_caller_is_regular_user_and_owner_is_regular_user__201( self, dashboard_id: int, - test_users: Dict[str, int], - valid_filter_set_data_for_create: Dict[str, Any], + test_users: dict[str, int], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -610,8 +610,8 @@ class TestCreateFilterSetsApi: def test_when_caller_is_regular_user_and_owner_type_is_dashboard__403( self, dashboard_id: int, - test_users: Dict[str, int], - valid_filter_set_data_for_create: Dict[str, Any], + test_users: dict[str, int], + valid_filter_set_data_for_create: dict[str, Any], client: FlaskClient[Any], ): # arrange diff --git a/tests/integration_tests/dashboards/filter_sets/delete_api_tests.py b/tests/integration_tests/dashboards/filter_sets/delete_api_tests.py index 7011cb5781..41d7ea59f7 100644 --- a/tests/integration_tests/dashboards/filter_sets/delete_api_tests.py +++ b/tests/integration_tests/dashboards/filter_sets/delete_api_tests.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from typing import Any, Dict, List, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from tests.integration_tests.dashboards.filter_sets.consts import ( DASHBOARD_OWNER_USERNAME, @@ -36,11 +36,11 @@ if TYPE_CHECKING: from superset.models.filter_set import FilterSet -def assert_filterset_was_not_deleted(filter_set_dict: Dict[str, Any]) -> None: +def assert_filterset_was_not_deleted(filter_set_dict: dict[str, Any]) -> None: assert get_filter_set_by_name(filter_set_dict["name"]) is not None -def assert_filterset_deleted(filter_set_dict: Dict[str, Any]) -> None: +def assert_filterset_deleted(filter_set_dict: dict[str, Any]) -> None: assert get_filter_set_by_name(filter_set_dict["name"]) is None @@ -48,7 +48,7 @@ class TestDeleteFilterSet: def test_with_dashboard_exists_filterset_not_exists__200( self, dashboard_id: int, - filtersets: Dict[str, List[FilterSet]], + filtersets: dict[str, list[FilterSet]], client: FlaskClient[Any], ): # arrange @@ -62,7 +62,7 @@ class TestDeleteFilterSet: def test_with_dashboard_not_exists_filterset_not_exists__404( self, not_exists_dashboard_id: int, - filtersets: Dict[str, List[FilterSet]], + filtersets: dict[str, list[FilterSet]], client: FlaskClient[Any], ): # arrange @@ -78,7 +78,7 @@ class TestDeleteFilterSet: def test_with_dashboard_not_exists_filterset_exists__404( self, not_exists_dashboard_id: int, - dashboard_based_filter_set_dict: Dict[str, Any], + dashboard_based_filter_set_dict: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -94,9 +94,9 @@ class TestDeleteFilterSet: def test_when_caller_is_admin_and_owner_type_is_user__200( self, - test_users: Dict[str, int], - user_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + test_users: dict[str, int], + user_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -110,9 +110,9 @@ class TestDeleteFilterSet: def test_when_caller_is_admin_and_owner_type_is_dashboard__200( self, - test_users: Dict[str, int], - dashboard_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + test_users: dict[str, int], + dashboard_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -126,9 +126,9 @@ class TestDeleteFilterSet: def test_when_caller_is_dashboard_owner_and_owner_is_other_user_403( self, - test_users: Dict[str, int], - user_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + test_users: dict[str, int], + user_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -143,9 +143,9 @@ class TestDeleteFilterSet: def test_when_caller_is_dashboard_owner_and_owner_type_is_dashboard__200( self, - test_users: Dict[str, int], - dashboard_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + test_users: dict[str, int], + dashboard_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -160,9 +160,9 @@ class TestDeleteFilterSet: def test_when_caller_is_filterset_owner__200( self, - test_users: Dict[str, int], - user_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + test_users: dict[str, int], + user_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -177,9 +177,9 @@ class TestDeleteFilterSet: def test_when_caller_is_regular_user_and_owner_type_is_user__403( self, - test_users: Dict[str, int], - user_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + test_users: dict[str, int], + user_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -194,9 +194,9 @@ class TestDeleteFilterSet: def test_when_caller_is_regular_user_and_owner_type_is_dashboard__403( self, - test_users: Dict[str, int], - dashboard_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + test_users: dict[str, int], + dashboard_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange diff --git a/tests/integration_tests/dashboards/filter_sets/get_api_tests.py b/tests/integration_tests/dashboards/filter_sets/get_api_tests.py index ad40d0e33c..71c985310d 100644 --- a/tests/integration_tests/dashboards/filter_sets/get_api_tests.py +++ b/tests/integration_tests/dashboards/filter_sets/get_api_tests.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from typing import Any, Dict, List, Set, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from tests.integration_tests.dashboards.filter_sets.consts import ( DASHBOARD_OWNER_USERNAME, @@ -66,12 +66,12 @@ class TestGetFilterSetsApi: def test_when_caller_admin__200( self, dashboard_id: int, - filtersets: Dict[str, List[FilterSet]], + filtersets: dict[str, list[FilterSet]], client: FlaskClient[Any], ): # arrange login(client, "admin") - expected_ids: Set[int] = collect_all_ids(filtersets) + expected_ids: set[int] = collect_all_ids(filtersets) # act response = call_get_filter_sets(client, dashboard_id) @@ -83,7 +83,7 @@ class TestGetFilterSetsApi: def test_when_caller_dashboard_owner__200( self, dashboard_id: int, - filtersets: Dict[str, List[FilterSet]], + filtersets: dict[str, list[FilterSet]], client: FlaskClient[Any], ): # arrange @@ -100,7 +100,7 @@ class TestGetFilterSetsApi: def test_when_caller_filterset_owner__200( self, dashboard_id: int, - filtersets: Dict[str, List[FilterSet]], + filtersets: dict[str, list[FilterSet]], client: FlaskClient[Any], ): # arrange @@ -117,12 +117,12 @@ class TestGetFilterSetsApi: def test_when_caller_regular_user__200( self, dashboard_id: int, - filtersets: Dict[str, List[int]], + filtersets: dict[str, list[int]], client: FlaskClient[Any], ): # arrange login(client, REGULAR_USER) - expected_ids: Set[int] = set() + expected_ids: set[int] = set() # act response = call_get_filter_sets(client, dashboard_id) diff --git a/tests/integration_tests/dashboards/filter_sets/update_api_tests.py b/tests/integration_tests/dashboards/filter_sets/update_api_tests.py index 07db98f617..a6e895a460 100644 --- a/tests/integration_tests/dashboards/filter_sets/update_api_tests.py +++ b/tests/integration_tests/dashboards/filter_sets/update_api_tests.py @@ -17,7 +17,7 @@ from __future__ import annotations import json -from typing import Any, Dict, List, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from superset.dashboards.filter_sets.consts import ( DESCRIPTION_FIELD, @@ -45,8 +45,8 @@ if TYPE_CHECKING: def merge_two_filter_set_dict( - first: Dict[Any, Any], second: Dict[Any, Any] -) -> Dict[Any, Any]: + first: dict[Any, Any], second: dict[Any, Any] +) -> dict[Any, Any]: for d in [first, second]: if JSON_METADATA_FIELD in d: if PARAMS_PROPERTY not in d: @@ -55,12 +55,12 @@ def merge_two_filter_set_dict( return {**first, **second} -def assert_filterset_was_not_updated(filter_set_dict: Dict[str, Any]) -> None: +def assert_filterset_was_not_updated(filter_set_dict: dict[str, Any]) -> None: assert filter_set_dict == get_filter_set_by_name(filter_set_dict["name"]).to_dict() def assert_filterset_updated( - filter_set_dict_before: Dict[str, Any], data_updated: Dict[str, Any] + filter_set_dict_before: dict[str, Any], data_updated: dict[str, Any] ) -> None: expected_data = merge_two_filter_set_dict(filter_set_dict_before, data_updated) assert expected_data == get_filter_set_by_name(expected_data["name"]).to_dict() @@ -70,7 +70,7 @@ class TestUpdateFilterSet: def test_with_dashboard_exists_filterset_not_exists__404( self, dashboard_id: int, - filtersets: Dict[str, List[FilterSet]], + filtersets: dict[str, list[FilterSet]], client: FlaskClient[Any], ): # arrange @@ -86,7 +86,7 @@ class TestUpdateFilterSet: def test_with_dashboard_not_exists_filterset_not_exists__404( self, not_exists_dashboard_id: int, - filtersets: Dict[str, List[FilterSet]], + filtersets: dict[str, list[FilterSet]], client: FlaskClient[Any], ): # arrange @@ -102,7 +102,7 @@ class TestUpdateFilterSet: def test_with_dashboard_not_exists_filterset_exists__404( self, not_exists_dashboard_id: int, - dashboard_based_filter_set_dict: Dict[str, Any], + dashboard_based_filter_set_dict: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -118,8 +118,8 @@ class TestUpdateFilterSet: def test_with_extra_field__400( self, - dashboard_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + dashboard_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -138,8 +138,8 @@ class TestUpdateFilterSet: def test_with_id_field__400( self, - dashboard_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + dashboard_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -158,8 +158,8 @@ class TestUpdateFilterSet: def test_with_none_name__400( self, - dashboard_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + dashboard_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -177,8 +177,8 @@ class TestUpdateFilterSet: def test_with_int_as_name__400( self, - dashboard_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + dashboard_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -196,8 +196,8 @@ class TestUpdateFilterSet: def test_without_name__200( self, - dashboard_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + dashboard_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -217,8 +217,8 @@ class TestUpdateFilterSet: def test_with_none_description__400( self, - dashboard_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + dashboard_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -236,8 +236,8 @@ class TestUpdateFilterSet: def test_with_int_as_description__400( self, - dashboard_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + dashboard_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -255,8 +255,8 @@ class TestUpdateFilterSet: def test_without_description__200( self, - dashboard_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + dashboard_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -276,8 +276,8 @@ class TestUpdateFilterSet: def test_with_invalid_json_metadata__400( self, - dashboard_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + dashboard_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -295,9 +295,9 @@ class TestUpdateFilterSet: def test_with_json_metadata__200( self, - dashboard_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], - valid_json_metadata: Dict[Any, Any], + dashboard_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], + valid_json_metadata: dict[Any, Any], client: FlaskClient[Any], ): # arrange @@ -320,8 +320,8 @@ class TestUpdateFilterSet: def test_with_invalid_owner_type__400( self, - dashboard_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + dashboard_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -339,8 +339,8 @@ class TestUpdateFilterSet: def test_with_user_owner_type__400( self, - dashboard_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + dashboard_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -358,8 +358,8 @@ class TestUpdateFilterSet: def test_with_dashboard_owner_type__200( self, - user_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + user_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -382,9 +382,9 @@ class TestUpdateFilterSet: def test_when_caller_is_admin_and_owner_type_is_user__200( self, - test_users: Dict[str, int], - user_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + test_users: dict[str, int], + user_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -402,9 +402,9 @@ class TestUpdateFilterSet: def test_when_caller_is_admin_and_owner_type_is_dashboard__200( self, - test_users: Dict[str, int], - dashboard_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + test_users: dict[str, int], + dashboard_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -422,9 +422,9 @@ class TestUpdateFilterSet: def test_when_caller_is_dashboard_owner_and_owner_is_other_user_403( self, - test_users: Dict[str, int], - user_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + test_users: dict[str, int], + user_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -441,9 +441,9 @@ class TestUpdateFilterSet: def test_when_caller_is_dashboard_owner_and_owner_type_is_dashboard__200( self, - test_users: Dict[str, int], - dashboard_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + test_users: dict[str, int], + dashboard_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -462,9 +462,9 @@ class TestUpdateFilterSet: def test_when_caller_is_filterset_owner__200( self, - test_users: Dict[str, int], - user_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + test_users: dict[str, int], + user_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -483,9 +483,9 @@ class TestUpdateFilterSet: def test_when_caller_is_regular_user_and_owner_type_is_user__403( self, - test_users: Dict[str, int], - user_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + test_users: dict[str, int], + user_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange @@ -502,9 +502,9 @@ class TestUpdateFilterSet: def test_when_caller_is_regular_user_and_owner_type_is_dashboard__403( self, - test_users: Dict[str, int], - dashboard_based_filter_set_dict: Dict[str, Any], - valid_filter_set_data_for_update: Dict[str, Any], + test_users: dict[str, int], + dashboard_based_filter_set_dict: dict[str, Any], + valid_filter_set_data_for_update: dict[str, Any], client: FlaskClient[Any], ): # arrange diff --git a/tests/integration_tests/dashboards/filter_sets/utils.py b/tests/integration_tests/dashboards/filter_sets/utils.py index a63e4164d8..d728bf6fc3 100644 --- a/tests/integration_tests/dashboards/filter_sets/utils.py +++ b/tests/integration_tests/dashboards/filter_sets/utils.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Union +from typing import Any, TYPE_CHECKING from superset.models.filter_set import FilterSet from tests.integration_tests.dashboards.filter_sets.consts import FILTER_SET_URI @@ -28,7 +28,7 @@ if TYPE_CHECKING: def call_create_filter_set( - client: FlaskClient[Any], dashboard_id: int, data: Dict[str, Any] + client: FlaskClient[Any], dashboard_id: int, data: dict[str, Any] ) -> Response: uri = FILTER_SET_URI.format(dashboard_id=dashboard_id) return client.post(uri, json=data) @@ -41,8 +41,8 @@ def call_get_filter_sets(client: FlaskClient[Any], dashboard_id: int) -> Respons def call_delete_filter_set( client: FlaskClient[Any], - filter_set_dict_to_update: Dict[str, Any], - dashboard_id: Optional[int] = None, + filter_set_dict_to_update: dict[str, Any], + dashboard_id: int | None = None, ) -> Response: dashboard_id = ( dashboard_id @@ -58,9 +58,9 @@ def call_delete_filter_set( def call_update_filter_set( client: FlaskClient[Any], - filter_set_dict_to_update: Dict[str, Any], - data: Dict[str, Any], - dashboard_id: Optional[int] = None, + filter_set_dict_to_update: dict[str, Any], + data: dict[str, Any], + dashboard_id: int | None = None, ) -> Response: dashboard_id = ( dashboard_id @@ -90,12 +90,12 @@ def get_filter_set_by_dashboard_id(dashboard_id: int) -> FilterSet: def collect_all_ids( - filtersets: Union[Dict[str, List[FilterSet]], List[FilterSet]] -) -> Set[int]: + filtersets: dict[str, list[FilterSet]] | list[FilterSet] +) -> set[int]: if isinstance(filtersets, dict): - filtersets_lists: List[List[FilterSet]] = list(filtersets.values()) - ids: Set[int] = set() - lst: List[FilterSet] + filtersets_lists: list[list[FilterSet]] = list(filtersets.values()) + ids: set[int] = set() + lst: list[FilterSet] for lst in filtersets_lists: ids.update(set(map(lambda fs: fs.id, lst))) return ids diff --git a/tests/integration_tests/dashboards/permalink/api_tests.py b/tests/integration_tests/dashboards/permalink/api_tests.py index b20112334d..3c560a4469 100644 --- a/tests/integration_tests/dashboards/permalink/api_tests.py +++ b/tests/integration_tests/dashboards/permalink/api_tests.py @@ -14,8 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import json -from typing import Iterator +from collections.abc import Iterator from unittest.mock import patch from uuid import uuid3 diff --git a/tests/integration_tests/dashboards/security/base_case.py b/tests/integration_tests/dashboards/security/base_case.py index bbb5fad831..e60fa96d44 100644 --- a/tests/integration_tests/dashboards/security/base_case.py +++ b/tests/integration_tests/dashboards/security/base_case.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import List, Optional +from typing import Optional import pytest from flask import escape, Response @@ -37,8 +37,8 @@ class BaseTestDashboardSecurity(DashboardTestCase): self, response: Response, expected_counts: int, - expected_dashboards: Optional[List[Dashboard]] = None, - not_expected_dashboards: Optional[List[Dashboard]] = None, + expected_dashboards: Optional[list[Dashboard]] = None, + not_expected_dashboards: Optional[list[Dashboard]] = None, ) -> None: self.assert200(response) response_data = response.json diff --git a/tests/integration_tests/dashboards/superset_factory_util.py b/tests/integration_tests/dashboards/superset_factory_util.py index b160a56a33..88495b03b4 100644 --- a/tests/integration_tests/dashboards/superset_factory_util.py +++ b/tests/integration_tests/dashboards/superset_factory_util.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging -from typing import List, Optional +from typing import Optional from flask_appbuilder import Model from flask_appbuilder.security.sqla.models import User @@ -50,8 +50,8 @@ def create_dashboard_to_db( dashboard_title: Optional[str] = None, slug: Optional[str] = None, published: bool = False, - owners: Optional[List[User]] = None, - slices: Optional[List[Slice]] = None, + owners: Optional[list[User]] = None, + slices: Optional[list[Slice]] = None, css: str = "", json_metadata: str = "", position_json: str = "", @@ -76,8 +76,8 @@ def create_dashboard( dashboard_title: Optional[str] = None, slug: Optional[str] = None, published: bool = False, - owners: Optional[List[User]] = None, - slices: Optional[List[Slice]] = None, + owners: Optional[list[User]] = None, + slices: Optional[list[Slice]] = None, css: str = "", json_metadata: str = "", position_json: str = "", @@ -107,7 +107,7 @@ def insert_model(dashboard: Model) -> None: def create_slice_to_db( name: Optional[str] = None, datasource_id: Optional[int] = None, - owners: Optional[List[User]] = None, + owners: Optional[list[User]] = None, ) -> Slice: slice_ = create_slice(datasource_id, name=name, owners=owners) insert_model(slice_) @@ -119,7 +119,7 @@ def create_slice( datasource_id: Optional[int] = None, datasource: Optional[SqlaTable] = None, name: Optional[str] = None, - owners: Optional[List[User]] = None, + owners: Optional[list[User]] = None, ) -> Slice: name = name if name is not None else random_str() owners = owners if owners is not None else [] @@ -149,7 +149,7 @@ def create_slice( def create_datasource_table_to_db( name: Optional[str] = None, db_id: Optional[int] = None, - owners: Optional[List[User]] = None, + owners: Optional[list[User]] = None, ) -> SqlaTable: sqltable = create_datasource_table(name, db_id, owners=owners) insert_model(sqltable) @@ -161,7 +161,7 @@ def create_datasource_table( name: Optional[str] = None, db_id: Optional[int] = None, database: Optional[Database] = None, - owners: Optional[List[User]] = None, + owners: Optional[list[User]] = None, ) -> SqlaTable: name = name if name is not None else random_str() owners = owners if owners is not None else [] @@ -192,7 +192,7 @@ def delete_all_inserted_objects() -> None: def delete_all_inserted_dashboards(): try: - dashboards_to_delete: List[Dashboard] = ( + dashboards_to_delete: list[Dashboard] = ( session.query(Dashboard) .filter(Dashboard.id.in_(inserted_dashboards_ids)) .all() @@ -241,7 +241,7 @@ def delete_dashboard_slices_associations(dashboard: Dashboard) -> None: def delete_all_inserted_slices(): try: - slices_to_delete: List[Slice] = ( + slices_to_delete: list[Slice] = ( session.query(Slice).filter(Slice.id.in_(inserted_slices_ids)).all() ) for slice in slices_to_delete: @@ -272,7 +272,7 @@ def delete_slice_users_associations(slice_: Slice) -> None: def delete_all_inserted_tables(): try: - tables_to_delete: List[SqlaTable] = ( + tables_to_delete: list[SqlaTable] = ( session.query(SqlaTable) .filter(SqlaTable.id.in_(inserted_sqltables_ids)) .all() @@ -307,7 +307,7 @@ def delete_table_users_associations(table: SqlaTable) -> None: def delete_all_inserted_dbs(): try: - dbs_to_delete: List[Database] = ( + dbs_to_delete: list[Database] = ( session.query(Database) .filter(Database.id.in_(inserted_databases_ids)) .all() diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index a30a951884..6fa1288067 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -23,7 +23,6 @@ from io import BytesIO from unittest import mock from unittest.mock import patch, MagicMock from zipfile import is_zipfile, ZipFile -from operator import itemgetter import prison import pytest @@ -52,7 +51,6 @@ from tests.integration_tests.fixtures.birth_names_dashboard import ( load_birth_names_dashboard_with_slices, load_birth_names_data, ) -from tests.integration_tests.fixtures.certificates import ssl_certificate from tests.integration_tests.fixtures.energy_dashboard import ( load_energy_table_with_slice, load_energy_table_data, @@ -1805,10 +1803,10 @@ class TestDatabaseApi(SupersetTestCase): schemas = [ s[0] for s in database.get_all_table_names_in_schema(schema_name) ] - self.assertEquals(response["count"], len(schemas)) + self.assertEqual(response["count"], len(schemas)) for option in response["result"]: - self.assertEquals(option["extra"], None) - self.assertEquals(option["type"], "table") + self.assertEqual(option["extra"], None) + self.assertEqual(option["type"], "table") self.assertTrue(option["value"] in schemas) def test_database_tables_not_found(self): diff --git a/tests/integration_tests/databases/commands_tests.py b/tests/integration_tests/databases/commands_tests.py index 553fae4fbf..b47d3d89fe 100644 --- a/tests/integration_tests/databases/commands_tests.py +++ b/tests/integration_tests/databases/commands_tests.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from unittest import mock, skip +from unittest import skip from unittest.mock import patch import pytest diff --git a/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py b/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py index 86c280b9bb..64bc0d8572 100644 --- a/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py +++ b/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from unittest import mock, skip +from unittest import mock from unittest.mock import patch import pytest diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 2c358d7114..6c99efd358 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -18,7 +18,7 @@ import json import unittest from io import BytesIO -from typing import List, Optional +from typing import Optional from unittest.mock import ANY, patch from zipfile import is_zipfile, ZipFile @@ -68,7 +68,7 @@ class TestDatasetApi(SupersetTestCase): @staticmethod def insert_dataset( table_name: str, - owners: List[int], + owners: list[int], database: Database, sql: Optional[str] = None, schema: Optional[str] = None, @@ -94,7 +94,7 @@ class TestDatasetApi(SupersetTestCase): "ab_permission", [self.get_user("admin").id], get_main_database() ) - def get_fixture_datasets(self) -> List[SqlaTable]: + def get_fixture_datasets(self) -> list[SqlaTable]: return ( db.session.query(SqlaTable) .options(joinedload(SqlaTable.database)) @@ -102,7 +102,7 @@ class TestDatasetApi(SupersetTestCase): .all() ) - def get_fixture_virtual_datasets(self) -> List[SqlaTable]: + def get_fixture_virtual_datasets(self) -> list[SqlaTable]: return ( db.session.query(SqlaTable) .filter(SqlaTable.table_name.in_(self.fixture_virtual_table_names)) @@ -410,13 +410,11 @@ class TestDatasetApi(SupersetTestCase): ) all_datasets = db.session.query(SqlaTable).all() schema_values = sorted( - set( - [ - dataset.schema - for dataset in all_datasets - if dataset.schema is not None - ] - ) + { + dataset.schema + for dataset in all_datasets + if dataset.schema is not None + } ) expected_response = { "count": len(schema_values), diff --git a/tests/integration_tests/datasets/commands_tests.py b/tests/integration_tests/datasets/commands_tests.py index 8753b1d273..953c34059f 100644 --- a/tests/integration_tests/datasets/commands_tests.py +++ b/tests/integration_tests/datasets/commands_tests.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from operator import itemgetter -from typing import Any, List +from typing import Any from unittest.mock import patch import pytest @@ -312,7 +312,7 @@ class TestImportDatasetsCommand(SupersetTestCase): assert len(dataset.metrics) == 2 assert dataset.main_dttm_col == "ds" assert dataset.filter_select_enabled - assert set(col.column_name for col in dataset.columns) == { + assert {col.column_name for col in dataset.columns} == { "num_california", "ds", "state", @@ -526,7 +526,7 @@ class TestImportDatasetsCommand(SupersetTestCase): db.session.commit() -def _get_table_from_list_by_name(name: str, tables: List[Any]): +def _get_table_from_list_by_name(name: str, tables: list[Any]): for table in tables: if table.table_name == name: return table diff --git a/tests/integration_tests/db_engine_specs/base_tests.py b/tests/integration_tests/db_engine_specs/base_tests.py index e20ea35ae4..2d4f72c4f4 100644 --- a/tests/integration_tests/db_engine_specs/base_tests.py +++ b/tests/integration_tests/db_engine_specs/base_tests.py @@ -15,8 +15,6 @@ # specific language governing permissions and limitations # under the License. # isort:skip_file -from datetime import datetime -from typing import Tuple, Type from tests.integration_tests.test_app import app from tests.integration_tests.base_tests import SupersetTestCase diff --git a/tests/integration_tests/db_engine_specs/bigquery_tests.py b/tests/integration_tests/db_engine_specs/bigquery_tests.py index 2f4f1c70cc..c4f04584fa 100644 --- a/tests/integration_tests/db_engine_specs/bigquery_tests.py +++ b/tests/integration_tests/db_engine_specs/bigquery_tests.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import sys import unittest.mock as mock import pytest @@ -95,7 +94,7 @@ class TestBigQueryDbEngineSpec(TestDbEngineSpec): """ # Mock a google.cloud.bigquery.table.Row - class Row(object): + class Row: def __init__(self, value): self._value = value diff --git a/tests/integration_tests/db_engine_specs/hive_tests.py b/tests/integration_tests/db_engine_specs/hive_tests.py index b63f64ab03..341b494927 100644 --- a/tests/integration_tests/db_engine_specs/hive_tests.py +++ b/tests/integration_tests/db_engine_specs/hive_tests.py @@ -15,9 +15,7 @@ # specific language governing permissions and limitations # under the License. # isort:skip_file -from datetime import datetime from unittest import mock -from typing import List import pytest import pandas as pd @@ -377,7 +375,7 @@ def test_where_latest_partition_no_columns_no_values(mock_method): def test__latest_partition_from_df(): - def is_correct_result(data: List, result: List) -> bool: + def is_correct_result(data: list, result: list) -> bool: df = pd.DataFrame({"partition": data}) return HiveEngineSpec._latest_partition_from_df(df) == result diff --git a/tests/integration_tests/dict_import_export_tests.py b/tests/integration_tests/dict_import_export_tests.py index de0aa83262..6018e59a92 100644 --- a/tests/integration_tests/dict_import_export_tests.py +++ b/tests/integration_tests/dict_import_export_tests.py @@ -61,7 +61,7 @@ class TestDictImportExport(SupersetTestCase): self, name, schema=None, id=0, cols_names=[], cols_uuids=None, metric_names=[] ): database_name = "main" - name = "{0}{1}".format(NAME_PREFIX, name) + name = f"{NAME_PREFIX}{name}" params = {DBREF: id, "database_name": database_name} if cols_uuids is None: @@ -100,12 +100,12 @@ class TestDictImportExport(SupersetTestCase): self.assertEqual(len(expected_ds.metrics), len(actual_ds.metrics)) self.assertEqual(len(expected_ds.columns), len(actual_ds.columns)) self.assertEqual( - set([c.column_name for c in expected_ds.columns]), - set([c.column_name for c in actual_ds.columns]), + {c.column_name for c in expected_ds.columns}, + {c.column_name for c in actual_ds.columns}, ) self.assertEqual( - set([m.metric_name for m in expected_ds.metrics]), - set([m.metric_name for m in actual_ds.metrics]), + {m.metric_name for m in expected_ds.metrics}, + {m.metric_name for m in actual_ds.metrics}, ) def assert_datasource_equals(self, expected_ds, actual_ds): @@ -114,12 +114,12 @@ class TestDictImportExport(SupersetTestCase): self.assertEqual(len(expected_ds.metrics), len(actual_ds.metrics)) self.assertEqual(len(expected_ds.columns), len(actual_ds.columns)) self.assertEqual( - set([c.column_name for c in expected_ds.columns]), - set([c.column_name for c in actual_ds.columns]), + {c.column_name for c in expected_ds.columns}, + {c.column_name for c in actual_ds.columns}, ) self.assertEqual( - set([m.metric_name for m in expected_ds.metrics]), - set([m.metric_name for m in actual_ds.metrics]), + {m.metric_name for m in expected_ds.metrics}, + {m.metric_name for m in actual_ds.metrics}, ) def test_import_table_no_metadata(self): diff --git a/tests/integration_tests/email_tests.py b/tests/integration_tests/email_tests.py index 381b8cda1b..7c7cc16830 100644 --- a/tests/integration_tests/email_tests.py +++ b/tests/integration_tests/email_tests.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information diff --git a/tests/integration_tests/event_logger_tests.py b/tests/integration_tests/event_logger_tests.py index fa965ebd7d..3b20f6a918 100644 --- a/tests/integration_tests/event_logger_tests.py +++ b/tests/integration_tests/event_logger_tests.py @@ -17,8 +17,8 @@ import logging import time import unittest -from datetime import datetime, timedelta -from typing import Any, Callable, cast, Dict, Iterator, Optional, Type, Union +from datetime import timedelta +from typing import Any, Optional from unittest.mock import patch from flask import current_app diff --git a/tests/integration_tests/explore/permalink/api_tests.py b/tests/integration_tests/explore/permalink/api_tests.py index b9b1bfd0fb..81be2f0de8 100644 --- a/tests/integration_tests/explore/permalink/api_tests.py +++ b/tests/integration_tests/explore/permalink/api_tests.py @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. import json -from typing import Any, Dict, Iterator +from collections.abc import Iterator +from typing import Any from uuid import uuid3 import pytest @@ -43,7 +44,7 @@ def chart(app_context, load_world_bank_dashboard_with_slices) -> Slice: @pytest.fixture -def form_data(chart) -> Dict[str, Any]: +def form_data(chart) -> dict[str, Any]: datasource = f"{chart.datasource.id}__{chart.datasource.type}" return { "chart_id": chart.id, @@ -68,7 +69,7 @@ def permalink_salt() -> Iterator[str]: def test_post( - form_data: Dict[str, Any], permalink_salt: str, test_client, login_as_admin + form_data: dict[str, Any], permalink_salt: str, test_client, login_as_admin ): resp = test_client.post(f"api/v1/explore/permalink", json={"formData": form_data}) assert resp.status_code == 201 @@ -125,7 +126,7 @@ def test_post_invalid_schema(test_client, login_as_admin) -> None: def test_get( - form_data: Dict[str, Any], permalink_salt: str, test_client, login_as_admin + form_data: dict[str, Any], permalink_salt: str, test_client, login_as_admin ) -> None: resp = test_client.post(f"api/v1/explore/permalink", json={"formData": form_data}) data = json.loads(resp.data.decode("utf-8")) diff --git a/tests/integration_tests/explore/permalink/commands_tests.py b/tests/integration_tests/explore/permalink/commands_tests.py index 63ed02cd7b..eace978d78 100644 --- a/tests/integration_tests/explore/permalink/commands_tests.py +++ b/tests/integration_tests/explore/permalink/commands_tests.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -import json from unittest.mock import patch import pytest diff --git a/tests/integration_tests/fixtures/birth_names_dashboard.py b/tests/integration_tests/fixtures/birth_names_dashboard.py index be680a720d..d9a4a5d9e0 100644 --- a/tests/integration_tests/fixtures/birth_names_dashboard.py +++ b/tests/integration_tests/fixtures/birth_names_dashboard.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Callable, List, Optional +from typing import Callable, Optional import pytest @@ -93,7 +93,7 @@ def _create_table( return table -def _cleanup(dash_id: int, slice_ids: List[int]) -> None: +def _cleanup(dash_id: int, slice_ids: list[int]) -> None: schema = get_example_default_schema() for datasource in db.session.query(SqlaTable).filter_by( table_name="birth_names", schema=schema diff --git a/tests/integration_tests/fixtures/datasource.py b/tests/integration_tests/fixtures/datasource.py index f394d68a0e..279b67eda0 100644 --- a/tests/integration_tests/fixtures/datasource.py +++ b/tests/integration_tests/fixtures/datasource.py @@ -15,7 +15,8 @@ # specific language governing permissions and limitations # under the License. """Fixtures for test_datasource.py""" -from typing import Any, Dict, Generator +from collections.abc import Generator +from typing import Any import pytest from sqlalchemy import Column, create_engine, Date, Integer, MetaData, String, Table @@ -31,7 +32,7 @@ from superset.utils.database import get_example_database from tests.integration_tests.test_app import app -def get_datasource_post() -> Dict[str, Any]: +def get_datasource_post() -> dict[str, Any]: schema = get_example_default_schema() return { diff --git a/tests/integration_tests/fixtures/energy_dashboard.py b/tests/integration_tests/fixtures/energy_dashboard.py index effe59a755..8b597bf3be 100644 --- a/tests/integration_tests/fixtures/energy_dashboard.py +++ b/tests/integration_tests/fixtures/energy_dashboard.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. import random -from typing import Dict, List, Set import pandas as pd import pytest @@ -29,7 +28,7 @@ from superset.utils.database import get_example_database from tests.integration_tests.dashboard_utils import create_slice, create_table_metadata from tests.integration_tests.test_app import app -misc_dash_slices: Set[str] = set() +misc_dash_slices: set[str] = set() ENERGY_USAGE_TBL_NAME = "energy_usage" @@ -70,7 +69,7 @@ def _get_dataframe(): return pd.DataFrame.from_dict(data) -def _create_energy_table() -> List[Slice]: +def _create_energy_table() -> list[Slice]: table = create_table_metadata( table_name=ENERGY_USAGE_TBL_NAME, database=get_example_database(), @@ -100,7 +99,7 @@ def _create_energy_table() -> List[Slice]: def _create_and_commit_energy_slice( - table: SqlaTable, title: str, viz_type: str, param: Dict[str, str] + table: SqlaTable, title: str, viz_type: str, param: dict[str, str] ): slice = create_slice(title, viz_type, table, param) existing_slice = ( diff --git a/tests/integration_tests/fixtures/importexport.py b/tests/integration_tests/fixtures/importexport.py index d5c898eba2..5fddb071e2 100644 --- a/tests/integration_tests/fixtures/importexport.py +++ b/tests/integration_tests/fixtures/importexport.py @@ -14,10 +14,10 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List +from typing import Any # example V0 import/export format -dataset_ui_export: List[Dict[str, Any]] = [ +dataset_ui_export: list[dict[str, Any]] = [ { "columns": [ { @@ -48,7 +48,7 @@ dataset_ui_export: List[Dict[str, Any]] = [ } ] -dataset_cli_export: Dict[str, Any] = { +dataset_cli_export: dict[str, Any] = { "databases": [ { "allow_run_async": True, @@ -59,7 +59,7 @@ dataset_cli_export: Dict[str, Any] = { ] } -dashboard_export: Dict[str, Any] = { +dashboard_export: dict[str, Any] = { "dashboards": [ { "__Dashboard__": { @@ -318,35 +318,35 @@ dashboard_export: Dict[str, Any] = { } # example V1 import/export format -database_metadata_config: Dict[str, Any] = { +database_metadata_config: dict[str, Any] = { "version": "1.0.0", "type": "Database", "timestamp": "2020-11-04T21:27:44.423819+00:00", } -dataset_metadata_config: Dict[str, Any] = { +dataset_metadata_config: dict[str, Any] = { "version": "1.0.0", "type": "SqlaTable", "timestamp": "2020-11-04T21:27:44.423819+00:00", } -chart_metadata_config: Dict[str, Any] = { +chart_metadata_config: dict[str, Any] = { "version": "1.0.0", "type": "Slice", "timestamp": "2020-11-04T21:27:44.423819+00:00", } -dashboard_metadata_config: Dict[str, Any] = { +dashboard_metadata_config: dict[str, Any] = { "version": "1.0.0", "type": "Dashboard", "timestamp": "2020-11-04T21:27:44.423819+00:00", } -saved_queries_metadata_config: Dict[str, Any] = { +saved_queries_metadata_config: dict[str, Any] = { "version": "1.0.0", "type": "SavedQuery", "timestamp": "2021-03-30T20:37:54.791187+00:00", } -database_config: Dict[str, Any] = { +database_config: dict[str, Any] = { "allow_csv_upload": True, "allow_ctas": True, "allow_cvas": True, @@ -361,7 +361,7 @@ database_config: Dict[str, Any] = { "version": "1.0.0", } -database_with_ssh_tunnel_config_private_key: Dict[str, Any] = { +database_with_ssh_tunnel_config_private_key: dict[str, Any] = { "allow_csv_upload": True, "allow_ctas": True, "allow_cvas": True, @@ -383,7 +383,7 @@ database_with_ssh_tunnel_config_private_key: Dict[str, Any] = { "version": "1.0.0", } -database_with_ssh_tunnel_config_password: Dict[str, Any] = { +database_with_ssh_tunnel_config_password: dict[str, Any] = { "allow_csv_upload": True, "allow_ctas": True, "allow_cvas": True, @@ -404,7 +404,7 @@ database_with_ssh_tunnel_config_password: Dict[str, Any] = { "version": "1.0.0", } -database_with_ssh_tunnel_config_no_credentials: Dict[str, Any] = { +database_with_ssh_tunnel_config_no_credentials: dict[str, Any] = { "allow_csv_upload": True, "allow_ctas": True, "allow_cvas": True, @@ -424,7 +424,7 @@ database_with_ssh_tunnel_config_no_credentials: Dict[str, Any] = { "version": "1.0.0", } -database_with_ssh_tunnel_config_mix_credentials: Dict[str, Any] = { +database_with_ssh_tunnel_config_mix_credentials: dict[str, Any] = { "allow_csv_upload": True, "allow_ctas": True, "allow_cvas": True, @@ -446,7 +446,7 @@ database_with_ssh_tunnel_config_mix_credentials: Dict[str, Any] = { "version": "1.0.0", } -database_with_ssh_tunnel_config_private_pass_only: Dict[str, Any] = { +database_with_ssh_tunnel_config_private_pass_only: dict[str, Any] = { "allow_csv_upload": True, "allow_ctas": True, "allow_cvas": True, @@ -468,7 +468,7 @@ database_with_ssh_tunnel_config_private_pass_only: Dict[str, Any] = { } -dataset_config: Dict[str, Any] = { +dataset_config: dict[str, Any] = { "table_name": "imported_dataset", "main_dttm_col": None, "description": "This is a dataset that was exported", @@ -513,7 +513,7 @@ dataset_config: Dict[str, Any] = { "database_uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89", } -chart_config: Dict[str, Any] = { +chart_config: dict[str, Any] = { "slice_name": "Deck Path", "viz_type": "deck_path", "params": { @@ -557,7 +557,7 @@ chart_config: Dict[str, Any] = { "dataset_uuid": "10808100-158b-42c4-842e-f32b99d88dfb", } -dashboard_config: Dict[str, Any] = { +dashboard_config: dict[str, Any] = { "dashboard_title": "Test dash", "description": None, "css": "", diff --git a/tests/integration_tests/fixtures/query_context.py b/tests/integration_tests/fixtures/query_context.py index 00a3036e01..9efa589ba8 100644 --- a/tests/integration_tests/fixtures/query_context.py +++ b/tests/integration_tests/fixtures/query_context.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, Optional +from typing import Any, Optional from tests.common.query_context_generator import QueryContextGenerator from tests.integration_tests.base_tests import SupersetTestCase @@ -29,8 +29,8 @@ def get_query_context( query_name: str, add_postprocessing_operations: bool = False, add_time_offsets: bool = False, - form_data: Optional[Dict[str, Any]] = None, -) -> Dict[str, Any]: + form_data: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: """ Create a request payload for retrieving a QueryContext object via the `api/v1/chart/data` endpoint. By default returns a payload corresponding to one diff --git a/tests/integration_tests/fixtures/world_bank_dashboard.py b/tests/integration_tests/fixtures/world_bank_dashboard.py index 561bbe10b2..18ceba9af2 100644 --- a/tests/integration_tests/fixtures/world_bank_dashboard.py +++ b/tests/integration_tests/fixtures/world_bank_dashboard.py @@ -17,7 +17,7 @@ import json import string from random import choice, randint, random, uniform -from typing import Any, Dict, List +from typing import Any import pandas as pd import pytest @@ -94,7 +94,7 @@ def create_dashboard_for_loaded_data(): return dash_id_to_delete, slices_ids_to_delete -def _create_world_bank_slices(table: SqlaTable) -> List[Slice]: +def _create_world_bank_slices(table: SqlaTable) -> list[Slice]: from superset.examples.world_bank import create_slices slices = create_slices(table) @@ -102,7 +102,7 @@ def _create_world_bank_slices(table: SqlaTable) -> List[Slice]: return slices -def _commit_slices(slices: List[Slice]): +def _commit_slices(slices: list[Slice]): for slice in slices: o = db.session.query(Slice).filter_by(slice_name=slice.slice_name).one_or_none() if o: @@ -128,7 +128,7 @@ def _create_world_bank_dashboard(table: SqlaTable) -> Dashboard: return dash -def _cleanup(dash_id: int, slices_ids: List[int]) -> None: +def _cleanup(dash_id: int, slices_ids: list[int]) -> None: dash = db.session.query(Dashboard).filter_by(id=dash_id).first() db.session.delete(dash) for slice_id in slices_ids: @@ -148,7 +148,7 @@ def _get_dataframe(database: Database) -> DataFrame: return df -def _get_world_bank_data() -> List[Dict[Any, Any]]: +def _get_world_bank_data() -> list[dict[Any, Any]]: data = [] for _ in range(100): data.append( diff --git a/tests/integration_tests/import_export_tests.py b/tests/integration_tests/import_export_tests.py index 5bbc985a36..d44745377f 100644 --- a/tests/integration_tests/import_export_tests.py +++ b/tests/integration_tests/import_export_tests.py @@ -115,7 +115,7 @@ class TestImportExport(SupersetTestCase): dashboard_title=title, slices=slcs, position_json='{"size_y": 2, "size_x": 2}', - slug="{}_imported".format(title.lower()), + slug=f"{title.lower()}_imported", json_metadata=json.dumps(json_metadata), ) @@ -160,12 +160,12 @@ class TestImportExport(SupersetTestCase): self.assertEqual(len(expected_ds.metrics), len(actual_ds.metrics)) self.assertEqual(len(expected_ds.columns), len(actual_ds.columns)) self.assertEqual( - set([c.column_name for c in expected_ds.columns]), - set([c.column_name for c in actual_ds.columns]), + {c.column_name for c in expected_ds.columns}, + {c.column_name for c in actual_ds.columns}, ) self.assertEqual( - set([m.metric_name for m in expected_ds.metrics]), - set([m.metric_name for m in actual_ds.metrics]), + {m.metric_name for m in expected_ds.metrics}, + {m.metric_name for m in actual_ds.metrics}, ) def assert_datasource_equals(self, expected_ds, actual_ds): @@ -174,12 +174,12 @@ class TestImportExport(SupersetTestCase): self.assertEqual(len(expected_ds.metrics), len(actual_ds.metrics)) self.assertEqual(len(expected_ds.columns), len(actual_ds.columns)) self.assertEqual( - set([c.column_name for c in expected_ds.columns]), - set([c.column_name for c in actual_ds.columns]), + {c.column_name for c in expected_ds.columns}, + {c.column_name for c in actual_ds.columns}, ) self.assertEqual( - set([m.metric_name for m in expected_ds.metrics]), - set([m.metric_name for m in actual_ds.metrics]), + {m.metric_name for m in expected_ds.metrics}, + {m.metric_name for m in actual_ds.metrics}, ) def assert_slice_equals(self, expected_slc, actual_slc): @@ -404,8 +404,8 @@ class TestImportExport(SupersetTestCase): { "remote_id": 10003, "expanded_slices": { - "{}".format(e_slc.id): True, - "{}".format(b_slc.id): False, + f"{e_slc.id}": True, + f"{b_slc.id}": False, }, # mocked filter_scope metadata "filter_scopes": { @@ -437,8 +437,8 @@ class TestImportExport(SupersetTestCase): } }, "expanded_slices": { - "{}".format(i_e_slc.id): True, - "{}".format(i_b_slc.id): False, + f"{i_e_slc.id}": True, + f"{i_b_slc.id}": False, }, } self.assertEqual( diff --git a/tests/integration_tests/insert_chart_mixin.py b/tests/integration_tests/insert_chart_mixin.py index da05d0c49d..722e387a54 100644 --- a/tests/integration_tests/insert_chart_mixin.py +++ b/tests/integration_tests/insert_chart_mixin.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import List, Optional +from typing import Optional from superset import db, security_manager from superset.connectors.sqla.models import SqlaTable @@ -29,7 +29,7 @@ class InsertChartMixin: def insert_chart( self, slice_name: str, - owners: List[int], + owners: list[int], datasource_id: int, created_by=None, datasource_type: str = "table", diff --git a/tests/integration_tests/key_value/commands/fixtures.py b/tests/integration_tests/key_value/commands/fixtures.py index 66aea8a4ed..ac33d003e0 100644 --- a/tests/integration_tests/key_value/commands/fixtures.py +++ b/tests/integration_tests/key_value/commands/fixtures.py @@ -18,7 +18,8 @@ from __future__ import annotations import json -from typing import Generator, TYPE_CHECKING +from collections.abc import Generator +from typing import TYPE_CHECKING from uuid import UUID import pytest diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py index d5684b1b62..c4bc7aa89b 100644 --- a/tests/integration_tests/model_tests.py +++ b/tests/integration_tests/model_tests.py @@ -417,14 +417,14 @@ class TestSqlaTableModel(SupersetTestCase): assert str(sqla_literal.compile()) == "ds" sqla_literal = ds_col.get_timestamp_expression("P1D") - compiled = "{}".format(sqla_literal.compile()) + compiled = f"{sqla_literal.compile()}" if tbl.database.backend == "mysql": assert compiled == "DATE(ds)" prev_ds_expr = ds_col.expression ds_col.expression = "DATE_ADD(ds, 1)" sqla_literal = ds_col.get_timestamp_expression("P1D") - compiled = "{}".format(sqla_literal.compile()) + compiled = f"{sqla_literal.compile()}" if tbl.database.backend == "mysql": assert compiled == "DATE(DATE_ADD(ds, 1))" ds_col.expression = prev_ds_expr @@ -437,20 +437,20 @@ class TestSqlaTableModel(SupersetTestCase): ds_col.expression = None ds_col.python_date_format = "epoch_s" sqla_literal = ds_col.get_timestamp_expression(None) - compiled = "{}".format(sqla_literal.compile()) + compiled = f"{sqla_literal.compile()}" if tbl.database.backend == "mysql": self.assertEqual(compiled, "from_unixtime(ds)") ds_col.python_date_format = "epoch_s" sqla_literal = ds_col.get_timestamp_expression("P1D") - compiled = "{}".format(sqla_literal.compile()) + compiled = f"{sqla_literal.compile()}" if tbl.database.backend == "mysql": self.assertEqual(compiled, "DATE(from_unixtime(ds))") prev_ds_expr = ds_col.expression ds_col.expression = "DATE_ADD(ds, 1)" sqla_literal = ds_col.get_timestamp_expression("P1D") - compiled = "{}".format(sqla_literal.compile()) + compiled = f"{sqla_literal.compile()}" if tbl.database.backend == "mysql": self.assertEqual(compiled, "DATE(from_unixtime(DATE_ADD(ds, 1)))") ds_col.expression = prev_ds_expr diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py index 5e5beae345..7a3d4e4a1e 100644 --- a/tests/integration_tests/query_context_tests.py +++ b/tests/integration_tests/query_context_tests.py @@ -16,7 +16,7 @@ # under the License. import re import time -from typing import Any, Dict +from typing import Any import numpy as np import pandas as pd @@ -49,7 +49,7 @@ from tests.integration_tests.fixtures.birth_names_dashboard import ( from tests.integration_tests.fixtures.query_context import get_query_context -def get_sql_text(payload: Dict[str, Any]) -> str: +def get_sql_text(payload: dict[str, Any]) -> str: payload["result_type"] = ChartDataResultType.QUERY.value query_context = ChartDataQueryContextSchema().load(payload) responses = query_context.get_payload() diff --git a/tests/integration_tests/reports/alert_tests.py b/tests/integration_tests/reports/alert_tests.py index 32cc2dcefb..4920a96283 100644 --- a/tests/integration_tests/reports/alert_tests.py +++ b/tests/integration_tests/reports/alert_tests.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=invalid-name, unused-argument, import-outside-toplevel from contextlib import nullcontext -from typing import List, Optional, Tuple, Union +from typing import Optional, Union import pandas as pd import pytest @@ -56,10 +56,10 @@ from tests.integration_tests.test_app import app ], ) def test_execute_query_as_report_executor( - owner_names: List[str], + owner_names: list[str], creator_name: Optional[str], - config: List[ExecutorType], - expected_result: Union[Tuple[ExecutorType, str], Exception], + config: list[ExecutorType], + expected_result: Union[tuple[ExecutorType, str], Exception], mocker: MockFixture, app_context: None, get_user, diff --git a/tests/integration_tests/reports/commands_tests.py b/tests/integration_tests/reports/commands_tests.py index a81bc6fa66..db80079d77 100644 --- a/tests/integration_tests/reports/commands_tests.py +++ b/tests/integration_tests/reports/commands_tests.py @@ -17,7 +17,7 @@ import json from contextlib import contextmanager from datetime import datetime, timedelta, timezone -from typing import List, Optional +from typing import Optional from unittest.mock import call, Mock, patch from uuid import uuid4 @@ -100,7 +100,7 @@ pytestmark = pytest.mark.usefixtures( ) -def get_target_from_report_schedule(report_schedule: ReportSchedule) -> List[str]: +def get_target_from_report_schedule(report_schedule: ReportSchedule) -> list[str]: return [ json.loads(recipient.recipient_config_json)["target"] for recipient in report_schedule.recipients @@ -1976,9 +1976,7 @@ def test__send_with_client_errors(notification_mock, logger_mock): assert excinfo.errisinstance(SupersetException) logger_mock.warning.assert_called_with( - ( - "SupersetError(message='', error_type=, level=, extra=None)" - ) + "SupersetError(message='', error_type=, level=, extra=None)" ) @@ -2021,7 +2019,5 @@ def test__send_with_server_errors(notification_mock, logger_mock): assert excinfo.errisinstance(SupersetException) # it logs the error logger_mock.warning.assert_called_with( - ( - "SupersetError(message='', error_type=, level=, extra=None)" - ) + "SupersetError(message='', error_type=, level=, extra=None)" ) diff --git a/tests/integration_tests/reports/scheduler_tests.py b/tests/integration_tests/reports/scheduler_tests.py index 4b8968592b..3284ee9772 100644 --- a/tests/integration_tests/reports/scheduler_tests.py +++ b/tests/integration_tests/reports/scheduler_tests.py @@ -16,7 +16,6 @@ # under the License. from random import randint -from typing import List from unittest.mock import patch import pytest @@ -32,7 +31,7 @@ from tests.integration_tests.test_app import app @pytest.fixture -def owners(get_user) -> List[User]: +def owners(get_user) -> list[User]: return [get_user("admin")] diff --git a/tests/integration_tests/reports/utils.py b/tests/integration_tests/reports/utils.py index 3801beb1a3..7672c5c940 100644 --- a/tests/integration_tests/reports/utils.py +++ b/tests/integration_tests/reports/utils.py @@ -17,7 +17,7 @@ import json from contextlib import contextmanager -from typing import Any, Dict, List, Optional +from typing import Any, Optional from uuid import uuid4 from flask_appbuilder.security.sqla.models import User @@ -49,7 +49,7 @@ def insert_report_schedule( type: str, name: str, crontab: str, - owners: List[User], + owners: list[User], timezone: Optional[str] = None, sql: Optional[str] = None, description: Optional[str] = None, @@ -61,10 +61,10 @@ def insert_report_schedule( log_retention: Optional[int] = None, last_state: Optional[ReportState] = None, grace_period: Optional[int] = None, - recipients: Optional[List[ReportRecipients]] = None, + recipients: Optional[list[ReportRecipients]] = None, report_format: Optional[ReportDataFormat] = None, - logs: Optional[List[ReportExecutionLog]] = None, - extra: Optional[Dict[Any, Any]] = None, + logs: Optional[list[ReportExecutionLog]] = None, + extra: Optional[dict[Any, Any]] = None, force_screenshot: bool = False, ) -> ReportSchedule: owners = owners or [] @@ -113,9 +113,9 @@ def create_report_notification( grace_period: Optional[int] = None, report_format: Optional[ReportDataFormat] = None, name: Optional[str] = None, - extra: Optional[Dict[str, Any]] = None, + extra: Optional[dict[str, Any]] = None, force_screenshot: bool = False, - owners: Optional[List[User]] = None, + owners: Optional[list[User]] = None, ) -> ReportSchedule: if not owners: owners = [ diff --git a/tests/integration_tests/security/migrate_roles_tests.py b/tests/integration_tests/security/migrate_roles_tests.py index a541f00952..ae89fea068 100644 --- a/tests/integration_tests/security/migrate_roles_tests.py +++ b/tests/integration_tests/security/migrate_roles_tests.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. """Unit tests for alerting in Superset""" -import json import logging from contextlib import contextmanager from unittest.mock import patch diff --git a/tests/integration_tests/security/row_level_security_tests.py b/tests/integration_tests/security/row_level_security_tests.py index 51aa76ee27..2a28089c3e 100644 --- a/tests/integration_tests/security/row_level_security_tests.py +++ b/tests/integration_tests/security/row_level_security_tests.py @@ -16,7 +16,7 @@ # under the License. # isort:skip_file import re -from typing import Any, Dict, List, Optional +from typing import Any, Optional from unittest import mock import pytest @@ -55,7 +55,7 @@ class TestRowLevelSecurity(SupersetTestCase): """ rls_entry = None - query_obj: Dict[str, Any] = dict( + query_obj: dict[str, Any] = dict( groupby=[], metrics=None, filter=[], @@ -542,8 +542,8 @@ class TestRowLevelSecurityWithRelatedAPI(SupersetTestCase): db_tables = db.session.query(SqlaTable).all() - db_table_names = set([t.name for t in db_tables]) - received_tables = set([table["text"] for table in result]) + db_table_names = {t.name for t in db_tables} + received_tables = {table["text"] for table in result} assert data["count"] == len(db_tables) assert len(result) == len(db_tables) @@ -558,8 +558,8 @@ class TestRowLevelSecurityWithRelatedAPI(SupersetTestCase): data = json.loads(rv.data.decode("utf-8")) result = data["result"] - db_role_names = set([r.name for r in security_manager.get_all_roles()]) - received_roles = set([role["text"] for role in result]) + db_role_names = {r.name for r in security_manager.get_all_roles()} + received_roles = {role["text"] for role in result} assert data["count"] == len(db_role_names) assert len(result) == len(db_role_names) @@ -580,7 +580,7 @@ class TestRowLevelSecurityWithRelatedAPI(SupersetTestCase): self.assertEqual(rv.status_code, 200) data = json.loads(rv.data.decode("utf-8")) result = data["result"] - received_tables = set([table["text"].split(".")[-1] for table in result]) + received_tables = {table["text"].split(".")[-1] for table in result} assert data["count"] == 1 assert len(result) == 1 @@ -615,7 +615,7 @@ RLS_GENDER_REGEX = re.compile(r"AND \(gender = 'girl'\)") EMBEDDED_SUPERSET=True, ) class GuestTokenRowLevelSecurityTests(SupersetTestCase): - query_obj: Dict[str, Any] = dict( + query_obj: dict[str, Any] = dict( groupby=[], metrics=None, filter=[], @@ -633,7 +633,7 @@ class GuestTokenRowLevelSecurityTests(SupersetTestCase): "clause": "name = 'Alice'", } - def guest_user_with_rls(self, rules: Optional[List[Any]] = None) -> GuestUser: + def guest_user_with_rls(self, rules: Optional[list[Any]] = None) -> GuestUser: if rules is None: rules = [self.default_rls_rule()] return security_manager.get_guest_user_from_token( diff --git a/tests/integration_tests/sql_lab/api_tests.py b/tests/integration_tests/sql_lab/api_tests.py index a57d24c3e4..89aefdfd09 100644 --- a/tests/integration_tests/sql_lab/api_tests.py +++ b/tests/integration_tests/sql_lab/api_tests.py @@ -112,7 +112,7 @@ class TestSqlLabApi(SupersetTestCase): @mock.patch("superset.sqllab.commands.results.results_backend_use_msgpack", False) def test_execute_required_params(self): self.login() - client_id = "{}".format(random.getrandbits(64))[:10] + client_id = f"{random.getrandbits(64)}"[:10] data = {"client_id": client_id} rv = self.client.post( @@ -157,7 +157,7 @@ class TestSqlLabApi(SupersetTestCase): core.results_backend.get.return_value = {} self.login() - client_id = "{}".format(random.getrandbits(64))[:10] + client_id = f"{random.getrandbits(64)}"[:10] data = {"sql": "SELECT 1", "database_id": 1, "client_id": client_id} rv = self.client.post( diff --git a/tests/integration_tests/sql_lab/commands_tests.py b/tests/integration_tests/sql_lab/commands_tests.py index 3d505ee2f5..d76924a8fb 100644 --- a/tests/integration_tests/sql_lab/commands_tests.py +++ b/tests/integration_tests/sql_lab/commands_tests.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from unittest import mock, skip +from unittest import mock from unittest.mock import Mock, patch import pandas as pd diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py index 4003913516..854a0c9be0 100644 --- a/tests/integration_tests/sqla_models_tests.py +++ b/tests/integration_tests/sqla_models_tests.py @@ -17,7 +17,8 @@ # isort:skip_file import re from datetime import datetime -from typing import Any, Dict, List, NamedTuple, Optional, Pattern, Tuple, Union +from typing import Any, NamedTuple, Optional, Union +from re import Pattern from unittest.mock import patch import pytest @@ -50,7 +51,7 @@ from tests.integration_tests.test_app import app from .base_tests import SupersetTestCase from .conftest import only_postgresql -VIRTUAL_TABLE_INT_TYPES: Dict[str, Pattern[str]] = { +VIRTUAL_TABLE_INT_TYPES: dict[str, Pattern[str]] = { "hive": re.compile(r"^INT_TYPE$"), "mysql": re.compile("^LONGLONG$"), "postgresql": re.compile(r"^INTEGER$"), @@ -58,7 +59,7 @@ VIRTUAL_TABLE_INT_TYPES: Dict[str, Pattern[str]] = { "sqlite": re.compile(r"^INT$"), } -VIRTUAL_TABLE_STRING_TYPES: Dict[str, Pattern[str]] = { +VIRTUAL_TABLE_STRING_TYPES: dict[str, Pattern[str]] = { "hive": re.compile(r"^STRING_TYPE$"), "mysql": re.compile(r"^VAR_STRING$"), "postgresql": re.compile(r"^STRING$"), @@ -70,8 +71,8 @@ VIRTUAL_TABLE_STRING_TYPES: Dict[str, Pattern[str]] = { class FilterTestCase(NamedTuple): column: str operator: str - value: Union[float, int, List[Any], str] - expected: Union[str, List[str]] + value: Union[float, int, list[Any], str] + expected: Union[str, list[str]] class TestDatabaseModel(SupersetTestCase): @@ -101,7 +102,7 @@ class TestDatabaseModel(SupersetTestCase): assert col.is_temporal is True def test_db_column_types(self): - test_cases: Dict[str, GenericDataType] = { + test_cases: dict[str, GenericDataType] = { # string "CHAR": GenericDataType.STRING, "VARCHAR": GenericDataType.STRING, @@ -291,7 +292,7 @@ class TestDatabaseModel(SupersetTestCase): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_where_operators(self): - filters: Tuple[FilterTestCase, ...] = ( + filters: tuple[FilterTestCase, ...] = ( FilterTestCase("num", FilterOperator.IS_NULL, "", "IS NULL"), FilterTestCase("num", FilterOperator.IS_NOT_NULL, "", "IS NOT NULL"), # Some db backends translate true/false to 1/0 @@ -493,7 +494,7 @@ class TestDatabaseModel(SupersetTestCase): "mycase", "expr", } - cols: Dict[str, TableColumn] = {col.column_name: col for col in table.columns} + cols: dict[str, TableColumn] = {col.column_name: col for col in table.columns} # assert that the type for intcol has been updated (asserting CI types) backend = table.database.backend assert VIRTUAL_TABLE_INT_TYPES[backend].match(cols["intcol"].type) @@ -802,7 +803,7 @@ def test__normalize_prequery_result_type( result: Any, ) -> None: def _convert_dttm( - target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None + target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None ) -> Optional[str]: if target_type.upper() == "TIMESTAMP": return f"""TIME_PARSE('{dttm.isoformat(timespec="seconds")}')""" diff --git a/tests/integration_tests/sqllab_tests.py b/tests/integration_tests/sqllab_tests.py index 16cc16d264..e9892b1d36 100644 --- a/tests/integration_tests/sqllab_tests.py +++ b/tests/integration_tests/sqllab_tests.py @@ -359,7 +359,7 @@ class TestSqlLab(SupersetTestCase): db.session.commit() data = self.get_json_resp( - "/superset/queries/{}".format(float(datetime_to_epoch(now)) - 1000) + f"/superset/queries/{float(datetime_to_epoch(now)) - 1000}" ) self.assertEqual(1, len(data)) @@ -391,13 +391,13 @@ class TestSqlLab(SupersetTestCase): # Test search queries on user Id user_id = security_manager.find_user("admin").id - data = self.get_json_resp("/superset/search_queries?user_id={}".format(user_id)) + data = self.get_json_resp(f"/superset/search_queries?user_id={user_id}") self.assertEqual(2, len(data)) user_ids = {k["userId"] for k in data} - self.assertEqual(set([user_id]), user_ids) + self.assertEqual({user_id}, user_ids) user_id = security_manager.find_user("gamma_sqllab").id - resp = self.get_resp("/superset/search_queries?user_id={}".format(user_id)) + resp = self.get_resp(f"/superset/search_queries?user_id={user_id}") data = json.loads(resp) self.assertEqual(1, len(data)) self.assertEqual(data[0]["userId"], user_id) @@ -451,7 +451,7 @@ class TestSqlLab(SupersetTestCase): self.assertEqual(1, len(data)) user_ids = {k["userId"] for k in data} - self.assertEqual(set([user_id]), user_ids) + self.assertEqual({user_id}, user_ids) def test_alias_duplicate(self): self.run_sql( @@ -593,7 +593,7 @@ class TestSqlLab(SupersetTestCase): self.assertEqual(len(data["data"]), test_limit) data = self.run_sql( - "SELECT * FROM birth_names LIMIT {}".format(test_limit), + f"SELECT * FROM birth_names LIMIT {test_limit}", client_id="sql_limit_3", query_limit=test_limit + 1, ) @@ -601,7 +601,7 @@ class TestSqlLab(SupersetTestCase): self.assertEqual(data["query"]["limitingFactor"], LimitingFactor.QUERY) data = self.run_sql( - "SELECT * FROM birth_names LIMIT {}".format(test_limit + 1), + f"SELECT * FROM birth_names LIMIT {test_limit + 1}", client_id="sql_limit_4", query_limit=test_limit, ) @@ -609,7 +609,7 @@ class TestSqlLab(SupersetTestCase): self.assertEqual(data["query"]["limitingFactor"], LimitingFactor.DROPDOWN) data = self.run_sql( - "SELECT * FROM birth_names LIMIT {}".format(test_limit), + f"SELECT * FROM birth_names LIMIT {test_limit}", client_id="sql_limit_5", query_limit=test_limit, ) diff --git a/tests/integration_tests/strategy_tests.py b/tests/integration_tests/strategy_tests.py index e54ae865e3..f6d664c649 100644 --- a/tests/integration_tests/strategy_tests.py +++ b/tests/integration_tests/strategy_tests.py @@ -16,8 +16,6 @@ # under the License. # isort:skip_file """Unit tests for Superset cache warmup""" -import datetime -import json from unittest.mock import MagicMock from tests.integration_tests.fixtures.birth_names_dashboard import ( load_birth_names_dashboard_with_slices, diff --git a/tests/integration_tests/superset_test_config.py b/tests/integration_tests/superset_test_config.py index c3f9b350f8..77e007a2dd 100644 --- a/tests/integration_tests/superset_test_config.py +++ b/tests/integration_tests/superset_test_config.py @@ -130,7 +130,7 @@ ALERT_REPORTS_WORKING_TIME_OUT_KILL = True ALERT_REPORTS_QUERY_EXECUTION_MAX_TRIES = 3 -class CeleryConfig(object): +class CeleryConfig: BROKER_URL = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_CELERY_DB}" CELERY_IMPORTS = ("superset.sql_lab",) CELERY_RESULT_BACKEND = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_RESULTS_DB}" diff --git a/tests/integration_tests/superset_test_config_sqllab_backend_persist_off.py b/tests/integration_tests/superset_test_config_sqllab_backend_persist_off.py index 9f6dd2ead1..31d14ef71b 100644 --- a/tests/integration_tests/superset_test_config_sqllab_backend_persist_off.py +++ b/tests/integration_tests/superset_test_config_sqllab_backend_persist_off.py @@ -16,8 +16,6 @@ # under the License. # flake8: noqa # type: ignore -import os -from copy import copy from .superset_test_config import * diff --git a/tests/integration_tests/superset_test_config_thumbnails.py b/tests/integration_tests/superset_test_config_thumbnails.py index 9f621efabb..5bd02e7b0f 100644 --- a/tests/integration_tests/superset_test_config_thumbnails.py +++ b/tests/integration_tests/superset_test_config_thumbnails.py @@ -61,7 +61,7 @@ REDIS_CELERY_DB = os.environ.get("REDIS_CELERY_DB", 2) REDIS_RESULTS_DB = os.environ.get("REDIS_RESULTS_DB", 3) -class CeleryConfig(object): +class CeleryConfig: BROKER_URL = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_CELERY_DB}" CELERY_IMPORTS = ("superset.sql_lab", "superset.tasks.thumbnails") CELERY_ANNOTATIONS = {"sql_lab.add": {"rate_limit": "10/s"}} diff --git a/tests/integration_tests/tagging_tests.py b/tests/integration_tests/tagging_tests.py index 71fb7e4e4e..72ba577d9f 100644 --- a/tests/integration_tests/tagging_tests.py +++ b/tests/integration_tests/tagging_tests.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from unittest import mock import pytest diff --git a/tests/integration_tests/tags/api_tests.py b/tests/integration_tests/tags/api_tests.py index 7bf21da4fc..b047388a68 100644 --- a/tests/integration_tests/tags/api_tests.py +++ b/tests/integration_tests/tags/api_tests.py @@ -16,10 +16,7 @@ # under the License. # isort:skip_file """Unit tests for Superset""" -from datetime import datetime, timedelta import json -import random -import string import pytest import prison diff --git a/tests/integration_tests/tags/commands_tests.py b/tests/integration_tests/tags/commands_tests.py index 8f44d2ebda..cd5a024840 100644 --- a/tests/integration_tests/tags/commands_tests.py +++ b/tests/integration_tests/tags/commands_tests.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. import itertools -import json from unittest.mock import MagicMock, patch import pytest diff --git a/tests/integration_tests/tags/dao_tests.py b/tests/integration_tests/tags/dao_tests.py index f46abaa723..49b22d260b 100644 --- a/tests/integration_tests/tags/dao_tests.py +++ b/tests/integration_tests/tags/dao_tests.py @@ -15,10 +15,7 @@ # specific language governing permissions and limitations # under the License. # isort:skip_file -import copy -import json from operator import and_ -import time from unittest.mock import patch import pytest from superset.dao.exceptions import DAOCreateFailedError, DAOException diff --git a/tests/integration_tests/thumbnails_tests.py b/tests/integration_tests/thumbnails_tests.py index 228da6de79..eb2be859ba 100644 --- a/tests/integration_tests/thumbnails_tests.py +++ b/tests/integration_tests/thumbnails_tests.py @@ -20,7 +20,6 @@ import json import urllib.request from io import BytesIO -from typing import Tuple from unittest import skipUnless from unittest.mock import ANY, call, MagicMock, patch @@ -203,7 +202,7 @@ class TestThumbnails(SupersetTestCase): digest_return_value = "foo_bar" digest_hash = "5c7d96a3dd7a87850a2ef34087565a6e" - def _get_id_and_thumbnail_url(self, url: str) -> Tuple[int, str]: + def _get_id_and_thumbnail_url(self, url: str) -> tuple[int, str]: rv = self.client.get(url) resp = json.loads(rv.data.decode("utf-8")) obj = resp["result"][0] diff --git a/tests/integration_tests/users/__init__.py b/tests/integration_tests/users/__init__.py index fd9417fe5c..13a83393a9 100644 --- a/tests/integration_tests/users/__init__.py +++ b/tests/integration_tests/users/__init__.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information diff --git a/tests/integration_tests/utils/csv_tests.py b/tests/integration_tests/utils/csv_tests.py index e514efb1d2..38c1dd51ac 100644 --- a/tests/integration_tests/utils/csv_tests.py +++ b/tests/integration_tests/utils/csv_tests.py @@ -43,13 +43,13 @@ def test_escape_value(): assert result == "'=value" result = csv.escape_value("|value") - assert result == "'\|value" + assert result == r"'\|value" result = csv.escape_value("%value") assert result == "'%value" result = csv.escape_value("=cmd|' /C calc'!A0") - assert result == "'=cmd\|' /C calc'!A0" + assert result == r"'=cmd\|' /C calc'!A0" result = csv.escape_value('""=10+2') assert result == '\'""=10+2' @@ -74,7 +74,7 @@ def test_df_to_escaped_csv(): assert escaped_csv_rows == [ ["col_a", "'=func()"], - ["-10", "'=cmd\|' /C calc'!A0"], + ["-10", r"'=cmd\|' /C calc'!A0"], ["a", "'=b"], # pandas seems to be removing the leading "" ["' =a", "b"], ] diff --git a/tests/integration_tests/utils/encrypt_tests.py b/tests/integration_tests/utils/encrypt_tests.py index 2199783529..45fd291ee8 100644 --- a/tests/integration_tests/utils/encrypt_tests.py +++ b/tests/integration_tests/utils/encrypt_tests.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, List, Optional +from typing import Any, Optional from sqlalchemy import String, TypeDecorator from sqlalchemy_utils import EncryptedType @@ -28,9 +28,9 @@ from tests.integration_tests.base_tests import SupersetTestCase class CustomEncFieldAdapter(AbstractEncryptedFieldAdapter): def create( self, - app_config: Optional[Dict[str, Any]], - *args: List[Any], - **kwargs: Optional[Dict[str, Any]] + app_config: Optional[dict[str, Any]], + *args: list[Any], + **kwargs: Optional[dict[str, Any]] ) -> TypeDecorator: if app_config: return StringEncryptedType(*args, app_config["SECRET_KEY"], **kwargs) diff --git a/tests/integration_tests/utils/get_dashboards.py b/tests/integration_tests/utils/get_dashboards.py index 03260fb94d..7012bf08a0 100644 --- a/tests/integration_tests/utils/get_dashboards.py +++ b/tests/integration_tests/utils/get_dashboards.py @@ -14,14 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import List from flask_appbuilder import SQLA from superset.models.dashboard import Dashboard -def get_dashboards_ids(db: SQLA, dashboard_slugs: List[str]) -> List[int]: +def get_dashboards_ids(db: SQLA, dashboard_slugs: list[str]) -> list[int]: result = ( db.session.query(Dashboard.id).filter(Dashboard.slug.in_(dashboard_slugs)).all() ) diff --git a/tests/integration_tests/utils/public_interfaces_test.py b/tests/integration_tests/utils/public_interfaces_test.py index 7b5d671246..af67bb6ca3 100644 --- a/tests/integration_tests/utils/public_interfaces_test.py +++ b/tests/integration_tests/utils/public_interfaces_test.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Callable, Dict +from typing import Any, Callable import pytest @@ -23,7 +23,7 @@ from superset.utils.public_interfaces import compute_hash, get_warning_message # These are public interfaces exposed by Superset. Make sure # to only change the interfaces and update the hashes in new # major versions of Superset. -hashes: Dict[Callable[..., Any], str] = {} +hashes: dict[Callable[..., Any], str] = {} @pytest.mark.parametrize("interface,expected_hash", list(hashes.items())) diff --git a/tests/integration_tests/utils_tests.py b/tests/integration_tests/utils_tests.py index 2db008fdb7..b4d750c8d0 100644 --- a/tests/integration_tests/utils_tests.py +++ b/tests/integration_tests/utils_tests.py @@ -21,7 +21,7 @@ from decimal import Decimal import json import os import re -from typing import Any, Tuple, List, Optional +from typing import Any, Optional from unittest.mock import Mock, patch from superset.databases.commands.exceptions import DatabaseInvalidError @@ -121,12 +121,12 @@ class TestUtils(SupersetTestCase): assert isinstance(base_json_conv(np.int64(1)), int) assert isinstance(base_json_conv(np.array([1, 2, 3])), list) assert base_json_conv(np.array(None)) is None - assert isinstance(base_json_conv(set([1])), list) + assert isinstance(base_json_conv({1}), list) assert isinstance(base_json_conv(Decimal("1.0")), float) assert isinstance(base_json_conv(uuid.uuid4()), str) assert isinstance(base_json_conv(time()), str) assert isinstance(base_json_conv(timedelta(0)), str) - assert isinstance(base_json_conv(bytes()), str) + assert isinstance(base_json_conv(b""), str) assert base_json_conv(bytes("", encoding="utf-16")) == "[bytes]" with pytest.raises(TypeError): @@ -1054,7 +1054,7 @@ class TestUtils(SupersetTestCase): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_extract_dataframe_dtypes(self): slc = self.get_slice("Girls", db.session) - cols: Tuple[Tuple[str, GenericDataType, List[Any]], ...] = ( + cols: tuple[tuple[str, GenericDataType, list[Any]], ...] = ( ("dt", GenericDataType.TEMPORAL, [date(2021, 2, 4), date(2021, 2, 4)]), ( "dttm", diff --git a/tests/integration_tests/viz_tests.py b/tests/integration_tests/viz_tests.py index 137e2a474c..30d4d1e183 100644 --- a/tests/integration_tests/viz_tests.py +++ b/tests/integration_tests/viz_tests.py @@ -19,7 +19,6 @@ from datetime import date, datetime, timezone import logging from math import nan from unittest.mock import Mock, patch -from typing import Any, Dict, List, Set import numpy as np import pandas as pd @@ -1009,7 +1008,7 @@ class TestTimeSeriesTableViz(SupersetTestCase): test_viz = viz.TimeTableViz(datasource, form_data) data = test_viz.get_data(df) # Check method correctly transforms data - self.assertEqual(set(["count", "sum__A"]), set(data["columns"])) + self.assertEqual({"count", "sum__A"}, set(data["columns"])) time_format = "%Y-%m-%d %H:%M:%S" expected = { t1.strftime(time_format): {"sum__A": 15, "count": 6}, @@ -1030,7 +1029,7 @@ class TestTimeSeriesTableViz(SupersetTestCase): test_viz = viz.TimeTableViz(datasource, form_data) data = test_viz.get_data(df) # Check method correctly transforms data - self.assertEqual(set(["a1", "a2", "a3"]), set(data["columns"])) + self.assertEqual({"a1", "a2", "a3"}, set(data["columns"])) time_format = "%Y-%m-%d %H:%M:%S" expected = { t1.strftime(time_format): {"a1": 15, "a2": 20, "a3": 25}, diff --git a/tests/unit_tests/charts/dao/dao_tests.py b/tests/unit_tests/charts/dao/dao_tests.py index 72ae9dbba7..b1d5cc6488 100644 --- a/tests/unit_tests/charts/dao/dao_tests.py +++ b/tests/unit_tests/charts/dao/dao_tests.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Iterator +from collections.abc import Iterator import pytest from sqlalchemy.orm.session import Session diff --git a/tests/unit_tests/charts/test_post_processing.py b/tests/unit_tests/charts/test_post_processing.py index 84496bf1cf..b7cdda6e68 100644 --- a/tests/unit_tests/charts/test_post_processing.py +++ b/tests/unit_tests/charts/test_post_processing.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -import json import pandas as pd import pytest diff --git a/tests/unit_tests/common/test_query_object_factory.py b/tests/unit_tests/common/test_query_object_factory.py index 4fd906f648..02304828dc 100644 --- a/tests/unit_tests/common/test_query_object_factory.py +++ b/tests/unit_tests/common/test_query_object_factory.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, Optional +from typing import Any, Optional from unittest.mock import Mock, patch from pytest import fixture, mark @@ -23,7 +23,7 @@ from superset.common.query_object_factory import QueryObjectFactory from tests.common.query_context_generator import QueryContextGenerator -def create_app_config() -> Dict[str, Any]: +def create_app_config() -> dict[str, Any]: return { "ROW_LIMIT": 5000, "DEFAULT_RELATIVE_START_TIME": "today", @@ -34,7 +34,7 @@ def create_app_config() -> Dict[str, Any]: @fixture -def app_config() -> Dict[str, Any]: +def app_config() -> dict[str, Any]: return create_app_config().copy() @@ -58,7 +58,7 @@ def apply_max_row_limit(limit: int, max_limit: Optional[int] = None) -> int: @fixture def query_object_factory( - app_config: Dict[str, Any], connector_registry: Mock, session_factory: Mock + app_config: dict[str, Any], connector_registry: Mock, session_factory: Mock ) -> QueryObjectFactory: import superset.common.query_object_factory as mod @@ -67,7 +67,7 @@ def query_object_factory( @fixture -def raw_query_context() -> Dict[str, Any]: +def raw_query_context() -> dict[str, Any]: return QueryContextGenerator().generate("birth_names") @@ -75,7 +75,7 @@ class TestQueryObjectFactory: def test_query_context_limit_and_offset_defaults( self, query_object_factory: QueryObjectFactory, - raw_query_context: Dict[str, Any], + raw_query_context: dict[str, Any], ): raw_query_object = raw_query_context["queries"][0] raw_query_object.pop("row_limit", None) @@ -89,7 +89,7 @@ class TestQueryObjectFactory: def test_query_context_limit( self, query_object_factory: QueryObjectFactory, - raw_query_context: Dict[str, Any], + raw_query_context: dict[str, Any], ): raw_query_object = raw_query_context["queries"][0] raw_query_object["row_limit"] = 100 @@ -104,7 +104,7 @@ class TestQueryObjectFactory: def test_query_context_null_post_processing_op( self, query_object_factory: QueryObjectFactory, - raw_query_context: Dict[str, Any], + raw_query_context: dict[str, Any], ): raw_query_object = raw_query_context["queries"][0] raw_query_object["post_processing"] = [None] diff --git a/tests/unit_tests/config_test.py b/tests/unit_tests/config_test.py index 021193a6cd..4a62f26e6f 100644 --- a/tests/unit_tests/config_test.py +++ b/tests/unit_tests/config_test.py @@ -17,7 +17,7 @@ # pylint: disable=import-outside-toplevel, unused-argument, redefined-outer-name, invalid-name from functools import partial -from typing import Any, Dict, TYPE_CHECKING +from typing import Any, TYPE_CHECKING import pytest from pytest_mock import MockerFixture @@ -44,7 +44,7 @@ FULL_DTTM_DEFAULTS_EXAMPLE = { } -def apply_dttm_defaults(table: "SqlaTable", dttm_defaults: Dict[str, Any]) -> None: +def apply_dttm_defaults(table: "SqlaTable", dttm_defaults: dict[str, Any]) -> None: """Applies dttm defaults to the table, mutates in place.""" for dbcol in table.columns: # Set is_dttm is column is listed in dttm_columns. diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py index 6740a8b6e2..6a4f1e550c 100644 --- a/tests/unit_tests/conftest.py +++ b/tests/unit_tests/conftest.py @@ -19,7 +19,8 @@ import importlib import os import unittest.mock -from typing import Any, Callable, Iterator +from collections.abc import Iterator +from typing import Any, Callable import pytest from _pytest.fixtures import SubRequest diff --git a/tests/unit_tests/dao/queries_test.py b/tests/unit_tests/dao/queries_test.py index 62eeff3106..d0ab3ec8a5 100644 --- a/tests/unit_tests/dao/queries_test.py +++ b/tests/unit_tests/dao/queries_test.py @@ -14,9 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import json from datetime import datetime, timedelta -from typing import Any, Iterator +from typing import Any import pytest from pytest_mock import MockFixture diff --git a/tests/unit_tests/dashboards/commands/importers/v1/utils_test.py b/tests/unit_tests/dashboards/commands/importers/v1/utils_test.py index 0392acb315..60a659159a 100644 --- a/tests/unit_tests/dashboards/commands/importers/v1/utils_test.py +++ b/tests/unit_tests/dashboards/commands/importers/v1/utils_test.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=import-outside-toplevel, unused-argument -from typing import Any, Dict +from typing import Any def test_update_id_refs_immune_missing( # pylint: disable=invalid-name @@ -59,7 +59,7 @@ def test_update_id_refs_immune_missing( # pylint: disable=invalid-name }, } chart_ids = {"uuid1": 1, "uuid2": 2} - dataset_info: Dict[str, Dict[str, Any]] = {} # not used + dataset_info: dict[str, dict[str, Any]] = {} # not used fixed = update_id_refs(config, chart_ids, dataset_info) assert fixed == { @@ -103,7 +103,7 @@ def test_update_native_filter_config_scope_excluded(): }, } chart_ids = {"uuid1": 1, "uuid2": 2} - dataset_info: Dict[str, Dict[str, Any]] = {} # not used + dataset_info: dict[str, dict[str, Any]] = {} # not used fixed = update_id_refs(config, chart_ids, dataset_info) assert fixed == { diff --git a/tests/unit_tests/dashboards/dao_tests.py b/tests/unit_tests/dashboards/dao_tests.py index a8f93e7513..c94d2ab157 100644 --- a/tests/unit_tests/dashboards/dao_tests.py +++ b/tests/unit_tests/dashboards/dao_tests.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Iterator +from collections.abc import Iterator import pytest from sqlalchemy.orm.session import Session diff --git a/tests/unit_tests/databases/dao/dao_tests.py b/tests/unit_tests/databases/dao/dao_tests.py index 47db402670..f085cb53c7 100644 --- a/tests/unit_tests/databases/dao/dao_tests.py +++ b/tests/unit_tests/databases/dao/dao_tests.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Iterator +from collections.abc import Iterator import pytest from sqlalchemy.orm.session import Session diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py index 2a5738ebd3..fbad104c1d 100644 --- a/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py +++ b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import Iterator import pytest from sqlalchemy.orm.session import Session diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py index b5adf765fa..de0b70db9c 100644 --- a/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py +++ b/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Iterator +from collections.abc import Iterator import pytest from pytest_mock import MockFixture diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py index 58f90054cc..544cf3434a 100644 --- a/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py +++ b/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Iterator +from collections.abc import Iterator import pytest from sqlalchemy.orm.session import Session diff --git a/tests/unit_tests/databases/ssh_tunnel/dao_tests.py b/tests/unit_tests/databases/ssh_tunnel/dao_tests.py index ae5b6e9bd3..27f9c3b8ad 100644 --- a/tests/unit_tests/databases/ssh_tunnel/dao_tests.py +++ b/tests/unit_tests/databases/ssh_tunnel/dao_tests.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. -from typing import Iterator import pytest from sqlalchemy.orm.session import Session diff --git a/tests/unit_tests/datasets/commands/importers/v1/import_test.py b/tests/unit_tests/datasets/commands/importers/v1/import_test.py index 839374425b..9e8690e6e3 100644 --- a/tests/unit_tests/datasets/commands/importers/v1/import_test.py +++ b/tests/unit_tests/datasets/commands/importers/v1/import_test.py @@ -20,7 +20,7 @@ import copy import json import re import uuid -from typing import Any, Dict +from typing import Any from unittest.mock import Mock, patch import pytest @@ -296,7 +296,7 @@ def test_import_column_extra_is_string(mocker: MockFixture, session: Session) -> session.flush() dataset_uuid = uuid.uuid4() - yaml_config: Dict[str, Any] = { + yaml_config: dict[str, Any] = { "version": "1.0.0", "table_name": "my_table", "main_dttm_col": "ds", @@ -388,7 +388,7 @@ def test_import_column_allowed_data_url( session.flush() dataset_uuid = uuid.uuid4() - yaml_config: Dict[str, Any] = { + yaml_config: dict[str, Any] = { "version": "1.0.0", "table_name": "my_table", "main_dttm_col": "ds", diff --git a/tests/unit_tests/datasets/conftest.py b/tests/unit_tests/datasets/conftest.py index 8d217ae27a..8bef6945a6 100644 --- a/tests/unit_tests/datasets/conftest.py +++ b/tests/unit_tests/datasets/conftest.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, TYPE_CHECKING +from typing import Any, TYPE_CHECKING import pytest @@ -23,7 +23,7 @@ if TYPE_CHECKING: @pytest.fixture -def columns_default() -> Dict[str, Any]: +def columns_default() -> dict[str, Any]: """Default props for new columns""" return { "changed_by": 1, @@ -49,7 +49,7 @@ def columns_default() -> Dict[str, Any]: @pytest.fixture -def sample_columns() -> Dict["TableColumn", Dict[str, Any]]: +def sample_columns() -> dict["TableColumn", dict[str, Any]]: from superset.connectors.sqla.models import TableColumn return { @@ -93,7 +93,7 @@ def sample_columns() -> Dict["TableColumn", Dict[str, Any]]: @pytest.fixture -def sample_metrics() -> Dict["SqlMetric", Dict[str, Any]]: +def sample_metrics() -> dict["SqlMetric", dict[str, Any]]: from superset.connectors.sqla.models import SqlMetric return { diff --git a/tests/unit_tests/datasets/dao/dao_tests.py b/tests/unit_tests/datasets/dao/dao_tests.py index 350425d08e..4eb43cd9de 100644 --- a/tests/unit_tests/datasets/dao/dao_tests.py +++ b/tests/unit_tests/datasets/dao/dao_tests.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Iterator +from collections.abc import Iterator import pytest from sqlalchemy.orm.session import Session diff --git a/tests/unit_tests/datasource/dao_tests.py b/tests/unit_tests/datasource/dao_tests.py index 16334066d7..99a4850301 100644 --- a/tests/unit_tests/datasource/dao_tests.py +++ b/tests/unit_tests/datasource/dao_tests.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Iterator +from collections.abc import Iterator import pytest from sqlalchemy.orm.session import Session diff --git a/tests/unit_tests/db_engine_specs/test_athena.py b/tests/unit_tests/db_engine_specs/test_athena.py index 51ec6656aa..f0811a3e14 100644 --- a/tests/unit_tests/db_engine_specs/test_athena.py +++ b/tests/unit_tests/db_engine_specs/test_athena.py @@ -81,7 +81,7 @@ def test_get_text_clause_with_colon() -> None: from superset.db_engine_specs.athena import AthenaEngineSpec query = ( - "SELECT foo FROM tbl WHERE " "abc >= TIMESTAMP '2021-11-26T00\:00\:00.000000'" + "SELECT foo FROM tbl WHERE " r"abc >= TIMESTAMP '2021-11-26T00\:00\:00.000000'" ) text_clause = AthenaEngineSpec.get_text_clause(query) assert text_clause.text == query diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py index 868a6bbdc3..33083f0399 100644 --- a/tests/unit_tests/db_engine_specs/test_base.py +++ b/tests/unit_tests/db_engine_specs/test_base.py @@ -17,7 +17,7 @@ # pylint: disable=unused-argument, import-outside-toplevel, protected-access from textwrap import dedent -from typing import Any, Dict, Optional, Type +from typing import Any, Optional import pytest from sqlalchemy import types @@ -130,8 +130,8 @@ def test_cte_query_parsing(original: types.TypeEngine, expected: str) -> None: ) def test_get_column_spec( native_type: str, - sqla_type: Type[types.TypeEngine], - attrs: Optional[Dict[str, Any]], + sqla_type: type[types.TypeEngine], + attrs: Optional[dict[str, Any]], generic_type: GenericDataType, is_dttm: bool, ) -> None: diff --git a/tests/unit_tests/db_engine_specs/test_clickhouse.py b/tests/unit_tests/db_engine_specs/test_clickhouse.py index 0c437bc009..6dfeddaf37 100644 --- a/tests/unit_tests/db_engine_specs/test_clickhouse.py +++ b/tests/unit_tests/db_engine_specs/test_clickhouse.py @@ -16,7 +16,7 @@ # under the License. from datetime import datetime -from typing import Any, Dict, Optional, Type +from typing import Any, Optional from unittest.mock import Mock import pytest @@ -189,8 +189,8 @@ def test_connect_convert_dttm( ) def test_connect_get_column_spec( native_type: str, - sqla_type: Type[TypeEngine], - attrs: Optional[Dict[str, Any]], + sqla_type: type[TypeEngine], + attrs: Optional[dict[str, Any]], generic_type: GenericDataType, is_dttm: bool, ) -> None: diff --git a/tests/unit_tests/db_engine_specs/test_elasticsearch.py b/tests/unit_tests/db_engine_specs/test_elasticsearch.py index de55c63426..0c15977669 100644 --- a/tests/unit_tests/db_engine_specs/test_elasticsearch.py +++ b/tests/unit_tests/db_engine_specs/test_elasticsearch.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from datetime import datetime -from typing import Any, Dict, Optional +from typing import Any, Optional from unittest.mock import MagicMock import pytest @@ -49,7 +49,7 @@ from tests.unit_tests.fixtures.common import dttm ) def test_elasticsearch_convert_dttm( target_type: str, - db_extra: Optional[Dict[str, Any]], + db_extra: Optional[dict[str, Any]], expected_result: Optional[str], dttm: datetime, ) -> None: diff --git a/tests/unit_tests/db_engine_specs/test_mssql.py b/tests/unit_tests/db_engine_specs/test_mssql.py index acd35a4ecf..673f4817be 100644 --- a/tests/unit_tests/db_engine_specs/test_mssql.py +++ b/tests/unit_tests/db_engine_specs/test_mssql.py @@ -17,7 +17,7 @@ import unittest.mock as mock from datetime import datetime from textwrap import dedent -from typing import Any, Dict, Optional, Type +from typing import Any, Optional import pytest from sqlalchemy import column, table @@ -50,8 +50,8 @@ from tests.unit_tests.fixtures.common import dttm ) def test_get_column_spec( native_type: str, - sqla_type: Type[TypeEngine], - attrs: Optional[Dict[str, Any]], + sqla_type: type[TypeEngine], + attrs: Optional[dict[str, Any]], generic_type: GenericDataType, is_dttm: bool, ) -> None: diff --git a/tests/unit_tests/db_engine_specs/test_mysql.py b/tests/unit_tests/db_engine_specs/test_mysql.py index 07ce6838fc..89abf2321d 100644 --- a/tests/unit_tests/db_engine_specs/test_mysql.py +++ b/tests/unit_tests/db_engine_specs/test_mysql.py @@ -16,7 +16,7 @@ # under the License. from datetime import datetime -from typing import Any, Dict, Optional, Tuple, Type +from typing import Any, Optional from unittest.mock import Mock, patch import pytest @@ -71,8 +71,8 @@ from tests.unit_tests.fixtures.common import dttm ) def test_get_column_spec( native_type: str, - sqla_type: Type[types.TypeEngine], - attrs: Optional[Dict[str, Any]], + sqla_type: type[types.TypeEngine], + attrs: Optional[dict[str, Any]], generic_type: GenericDataType, is_dttm: bool, ) -> None: @@ -166,7 +166,7 @@ def test_validate_database_uri(sqlalchemy_uri: str, error: bool) -> None: ], ) def test_adjust_engine_params( - sqlalchemy_uri: str, connect_args: Dict[str, Any], returns: Dict[str, Any] + sqlalchemy_uri: str, connect_args: dict[str, Any], returns: dict[str, Any] ) -> None: from superset.db_engine_specs.mysql import MySQLEngineSpec diff --git a/tests/unit_tests/db_engine_specs/test_ocient.py b/tests/unit_tests/db_engine_specs/test_ocient.py index af9fd2ad16..a58f31d242 100644 --- a/tests/unit_tests/db_engine_specs/test_ocient.py +++ b/tests/unit_tests/db_engine_specs/test_ocient.py @@ -17,7 +17,7 @@ # pylint: disable=import-outside-toplevel -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable import pytest @@ -33,7 +33,7 @@ def ocient_is_installed() -> bool: # (msg,expected) -MARSHALED_OCIENT_ERRORS: List[Tuple[str, SupersetError]] = [ +MARSHALED_OCIENT_ERRORS: list[tuple[str, SupersetError]] = [ ( "The referenced user does not exist (User 'mj' not found)", SupersetError( @@ -224,7 +224,7 @@ def test_connection_errors(msg: str, expected: SupersetError) -> None: def _generate_gis_type_sanitization_test_cases() -> ( - List[Tuple[str, int, Any, Dict[str, Any]]] + list[tuple[str, int, Any, dict[str, Any]]] ): if not ocient_is_installed(): return [] diff --git a/tests/unit_tests/db_engine_specs/test_postgres.py b/tests/unit_tests/db_engine_specs/test_postgres.py index fef8647959..145d398898 100644 --- a/tests/unit_tests/db_engine_specs/test_postgres.py +++ b/tests/unit_tests/db_engine_specs/test_postgres.py @@ -16,7 +16,7 @@ # under the License. from datetime import datetime -from typing import Any, Dict, Optional, Type +from typing import Any, Optional import pytest from sqlalchemy import types @@ -82,8 +82,8 @@ def test_convert_dttm( ) def test_get_column_spec( native_type: str, - sqla_type: Type[types.TypeEngine], - attrs: Optional[Dict[str, Any]], + sqla_type: type[types.TypeEngine], + attrs: Optional[dict[str, Any]], generic_type: GenericDataType, is_dttm: bool, ) -> None: diff --git a/tests/unit_tests/db_engine_specs/test_presto.py b/tests/unit_tests/db_engine_specs/test_presto.py index df2ed58c37..7739361cf3 100644 --- a/tests/unit_tests/db_engine_specs/test_presto.py +++ b/tests/unit_tests/db_engine_specs/test_presto.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. from datetime import datetime -from typing import Any, Dict, Optional, Type +from typing import Any, Optional from unittest import mock import pytest @@ -77,8 +77,8 @@ def test_convert_dttm( ) def test_get_column_spec( native_type: str, - sqla_type: Type[types.TypeEngine], - attrs: Optional[Dict[str, Any]], + sqla_type: type[types.TypeEngine], + attrs: Optional[dict[str, Any]], generic_type: GenericDataType, is_dttm: bool, ) -> None: diff --git a/tests/unit_tests/db_engine_specs/test_starrocks.py b/tests/unit_tests/db_engine_specs/test_starrocks.py index 7812a16830..ac246e3d5b 100644 --- a/tests/unit_tests/db_engine_specs/test_starrocks.py +++ b/tests/unit_tests/db_engine_specs/test_starrocks.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Any, Dict, Optional, Type +from typing import Any, Optional import pytest from sqlalchemy import types @@ -45,8 +45,8 @@ from tests.unit_tests.db_engine_specs.utils import assert_column_spec ) def test_get_column_spec( native_type: str, - sqla_type: Type[types.TypeEngine], - attrs: Optional[Dict[str, Any]], + sqla_type: type[types.TypeEngine], + attrs: Optional[dict[str, Any]], generic_type: GenericDataType, is_dttm: bool, ) -> None: @@ -74,9 +74,9 @@ def test_get_column_spec( ) def test_adjust_engine_params( sqlalchemy_uri: str, - connect_args: Dict[str, Any], + connect_args: dict[str, Any], return_schema: str, - return_connect_args: Dict[str, Any], + return_connect_args: dict[str, Any], ) -> None: from superset.db_engine_specs.starrocks import StarRocksEngineSpec diff --git a/tests/unit_tests/db_engine_specs/test_trino.py b/tests/unit_tests/db_engine_specs/test_trino.py index 0ea296a075..963953d18b 100644 --- a/tests/unit_tests/db_engine_specs/test_trino.py +++ b/tests/unit_tests/db_engine_specs/test_trino.py @@ -17,7 +17,7 @@ # pylint: disable=unused-argument, import-outside-toplevel, protected-access import json from datetime import datetime -from typing import Any, Dict, Optional, Type +from typing import Any, Optional from unittest.mock import Mock, patch import pandas as pd @@ -57,7 +57,7 @@ from tests.unit_tests.fixtures.common import dttm ), ], ) -def test_get_extra_params(extra: Dict[str, Any], expected: Dict[str, Any]) -> None: +def test_get_extra_params(extra: dict[str, Any], expected: dict[str, Any]) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec database = Mock() @@ -95,7 +95,7 @@ def test_auth_basic(mock_auth: Mock) -> None: {"auth_method": "basic", "auth_params": auth_params} ) - params: Dict[str, Any] = {} + params: dict[str, Any] = {} TrinoEngineSpec.update_params_from_encrypted_extra(database, params) connect_args = params.setdefault("connect_args", {}) assert connect_args.get("http_scheme") == "https" @@ -117,7 +117,7 @@ def test_auth_kerberos(mock_auth: Mock) -> None: {"auth_method": "kerberos", "auth_params": auth_params} ) - params: Dict[str, Any] = {} + params: dict[str, Any] = {} TrinoEngineSpec.update_params_from_encrypted_extra(database, params) connect_args = params.setdefault("connect_args", {}) assert connect_args.get("http_scheme") == "https" @@ -134,7 +134,7 @@ def test_auth_certificate(mock_auth: Mock) -> None: {"auth_method": "certificate", "auth_params": auth_params} ) - params: Dict[str, Any] = {} + params: dict[str, Any] = {} TrinoEngineSpec.update_params_from_encrypted_extra(database, params) connect_args = params.setdefault("connect_args", {}) assert connect_args.get("http_scheme") == "https" @@ -152,7 +152,7 @@ def test_auth_jwt(mock_auth: Mock) -> None: {"auth_method": "jwt", "auth_params": auth_params} ) - params: Dict[str, Any] = {} + params: dict[str, Any] = {} TrinoEngineSpec.update_params_from_encrypted_extra(database, params) connect_args = params.setdefault("connect_args", {}) assert connect_args.get("http_scheme") == "https" @@ -176,7 +176,7 @@ def test_auth_custom_auth() -> None: {"trino": {"custom_auth": auth_class}}, clear=True, ): - params: Dict[str, Any] = {} + params: dict[str, Any] = {} TrinoEngineSpec.update_params_from_encrypted_extra(database, params) connect_args = params.setdefault("connect_args", {}) @@ -243,8 +243,8 @@ def test_auth_custom_auth_denied() -> None: ) def test_get_column_spec( native_type: str, - sqla_type: Type[types.TypeEngine], - attrs: Optional[Dict[str, Any]], + sqla_type: type[types.TypeEngine], + attrs: Optional[dict[str, Any]], generic_type: GenericDataType, is_dttm: bool, ) -> None: @@ -324,8 +324,8 @@ def test_cancel_query_failed(engine_mock: Mock) -> None: ], ) def test_prepare_cancel_query( - initial_extra: Dict[str, Any], - final_extra: Dict[str, Any], + initial_extra: dict[str, Any], + final_extra: dict[str, Any], mocker: MockerFixture, ) -> None: from superset.db_engine_specs.trino import TrinoEngineSpec diff --git a/tests/unit_tests/db_engine_specs/utils.py b/tests/unit_tests/db_engine_specs/utils.py index 13ae7a34d2..774ca3eaf2 100644 --- a/tests/unit_tests/db_engine_specs/utils.py +++ b/tests/unit_tests/db_engine_specs/utils.py @@ -17,7 +17,7 @@ from __future__ import annotations from datetime import datetime -from typing import Any, Dict, Optional, Type, TYPE_CHECKING +from typing import Any, TYPE_CHECKING from sqlalchemy import types @@ -28,11 +28,11 @@ if TYPE_CHECKING: def assert_convert_dttm( - db_engine_spec: Type[BaseEngineSpec], + db_engine_spec: type[BaseEngineSpec], target_type: str, - expected_result: Optional[str], + expected_result: str | None, dttm: datetime, - db_extra: Optional[Dict[str, Any]] = None, + db_extra: dict[str, Any] | None = None, ) -> None: for target in ( target_type, @@ -50,10 +50,10 @@ def assert_convert_dttm( def assert_column_spec( - db_engine_spec: Type[BaseEngineSpec], + db_engine_spec: type[BaseEngineSpec], native_type: str, - sqla_type: Type[types.TypeEngine], - attrs: Optional[Dict[str, Any]], + sqla_type: type[types.TypeEngine], + attrs: dict[str, Any] | None, generic_type: GenericDataType, is_dttm: bool, ) -> None: diff --git a/tests/unit_tests/extensions/ssh_test.py b/tests/unit_tests/extensions/ssh_test.py index 0e997729d9..4538d71969 100644 --- a/tests/unit_tests/extensions/ssh_test.py +++ b/tests/unit_tests/extensions/ssh_test.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any from unittest.mock import Mock, patch import pytest diff --git a/tests/unit_tests/fixtures/assets_configs.py b/tests/unit_tests/fixtures/assets_configs.py index 73bc5921ec..bda84c1335 100644 --- a/tests/unit_tests/fixtures/assets_configs.py +++ b/tests/unit_tests/fixtures/assets_configs.py @@ -14,9 +14,9 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict +from typing import Any -databases_config: Dict[str, Any] = { +databases_config: dict[str, Any] = { "databases/examples.yaml": { "database_name": "examples", "sqlalchemy_uri": "sqlite:///test.db", @@ -32,7 +32,7 @@ databases_config: Dict[str, Any] = { "allow_csv_upload": False, }, } -datasets_config: Dict[str, Any] = { +datasets_config: dict[str, Any] = { "datasets/examples/video_game_sales.yaml": { "table_name": "video_game_sales", "main_dttm_col": None, @@ -80,7 +80,7 @@ datasets_config: Dict[str, Any] = { "database_uuid": "a2dc77af-e654-49bb-b321-40f6b559a1ee", }, } -charts_config_1: Dict[str, Any] = { +charts_config_1: dict[str, Any] = { "charts/Games_per_Genre_over_time_95.yaml": { "slice_name": "Games per Genre over time", "viz_type": "line", @@ -100,7 +100,7 @@ charts_config_1: Dict[str, Any] = { "dataset_uuid": "53d47c0c-c03d-47f0-b9ac-81225f808283", }, } -dashboards_config_1: Dict[str, Any] = { +dashboards_config_1: dict[str, Any] = { "dashboards/Video_Game_Sales_11.yaml": { "dashboard_title": "Video Game Sales", "description": None, @@ -182,7 +182,7 @@ dashboards_config_1: Dict[str, Any] = { }, } -charts_config_2: Dict[str, Any] = { +charts_config_2: dict[str, Any] = { "charts/Games_per_Genre_131.yaml": { "slice_name": "Games per Genre", "viz_type": "treemap", @@ -193,7 +193,7 @@ charts_config_2: Dict[str, Any] = { "dataset_uuid": "53d47c0c-c03d-47f0-b9ac-81225f808283", }, } -dashboards_config_2: Dict[str, Any] = { +dashboards_config_2: dict[str, Any] = { "dashboards/Video_Game_Sales_11.yaml": { "dashboard_title": "Video Game Sales", "description": None, diff --git a/tests/unit_tests/fixtures/datasets.py b/tests/unit_tests/fixtures/datasets.py index 5d5466a5e8..7bddae0b81 100644 --- a/tests/unit_tests/fixtures/datasets.py +++ b/tests/unit_tests/fixtures/datasets.py @@ -14,11 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from typing import Any, Dict +from typing import Any from unittest.mock import Mock -def get_column_mock(params: Dict[str, Any]) -> Mock: +def get_column_mock(params: dict[str, Any]) -> Mock: mock = Mock() mock.id = params["id"] mock.column_name = params["column_name"] @@ -32,7 +32,7 @@ def get_column_mock(params: Dict[str, Any]) -> Mock: return mock -def get_metric_mock(params: Dict[str, Any]) -> Mock: +def get_metric_mock(params: dict[str, Any]) -> Mock: mock = Mock() mock.id = params["id"] mock.metric_name = params["metric_name"] diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py index bf8f589913..d37296447a 100644 --- a/tests/unit_tests/models/core_test.py +++ b/tests/unit_tests/models/core_test.py @@ -18,7 +18,7 @@ # pylint: disable=import-outside-toplevel import json from datetime import datetime -from typing import List, Optional +from typing import Optional import pytest from pytest_mock import MockFixture @@ -54,7 +54,7 @@ def test_get_metrics(mocker: MockFixture) -> None: inspector: Inspector, table_name: str, schema: Optional[str], - ) -> List[MetricType]: + ) -> list[MetricType]: return [ { "expression": "COUNT(DISTINCT user_id)", diff --git a/tests/unit_tests/pandas_postprocessing/utils.py b/tests/unit_tests/pandas_postprocessing/utils.py index 07366b1577..fa9fa30d36 100644 --- a/tests/unit_tests/pandas_postprocessing/utils.py +++ b/tests/unit_tests/pandas_postprocessing/utils.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import math -from typing import Any, List, Optional +from typing import Any, Optional from pandas import Series @@ -26,7 +26,7 @@ AGGREGATES_MULTIPLE = { } -def series_to_list(series: Series) -> List[Any]: +def series_to_list(series: Series) -> list[Any]: """ Converts a `Series` to a regular list, and replaces non-numeric values to Nones. @@ -43,8 +43,8 @@ def series_to_list(series: Series) -> List[Any]: def round_floats( - floats: List[Optional[float]], precision: int -) -> List[Optional[float]]: + floats: list[Optional[float]], precision: int +) -> list[Optional[float]]: """ Round list of floats to certain precision diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index cfe6e213b2..e00dc3166e 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -16,8 +16,7 @@ # under the License. # pylint: disable=invalid-name, redefined-outer-name, unused-argument, protected-access, too-many-lines -import unittest -from typing import Optional, Set +from typing import Optional import pytest import sqlparse @@ -40,7 +39,7 @@ from superset.sql_parse import ( ) -def extract_tables(query: str) -> Set[Table]: +def extract_tables(query: str) -> set[Table]: """ Helper function to extract tables referenced in a query. """ diff --git a/tests/unit_tests/tasks/test_cron_util.py b/tests/unit_tests/tasks/test_cron_util.py index 282dc99860..5bc22273f5 100644 --- a/tests/unit_tests/tasks/test_cron_util.py +++ b/tests/unit_tests/tasks/test_cron_util.py @@ -14,8 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from datetime import datetime -from typing import List import pytest import pytz @@ -49,7 +47,7 @@ from superset.tasks.cron_util import cron_schedule_window ], ) def test_cron_schedule_window_los_angeles( - current_dttm: str, cron: str, expected: List[FakeDatetime] + current_dttm: str, cron: str, expected: list[FakeDatetime] ) -> None: """ Reports scheduler: Test cron schedule window for "America/Los_Angeles" @@ -86,7 +84,7 @@ def test_cron_schedule_window_los_angeles( ], ) def test_cron_schedule_window_invalid_timezone( - current_dttm: str, cron: str, expected: List[FakeDatetime] + current_dttm: str, cron: str, expected: list[FakeDatetime] ) -> None: """ Reports scheduler: Test cron schedule window for "invalid timezone" @@ -124,7 +122,7 @@ def test_cron_schedule_window_invalid_timezone( ], ) def test_cron_schedule_window_new_york( - current_dttm: str, cron: str, expected: List[FakeDatetime] + current_dttm: str, cron: str, expected: list[FakeDatetime] ) -> None: """ Reports scheduler: Test cron schedule window for "America/New_York" @@ -161,7 +159,7 @@ def test_cron_schedule_window_new_york( ], ) def test_cron_schedule_window_chicago( - current_dttm: str, cron: str, expected: List[FakeDatetime] + current_dttm: str, cron: str, expected: list[FakeDatetime] ) -> None: """ Reports scheduler: Test cron schedule window for "America/Chicago" @@ -198,7 +196,7 @@ def test_cron_schedule_window_chicago( ], ) def test_cron_schedule_window_chicago_daylight( - current_dttm: str, cron: str, expected: List[FakeDatetime] + current_dttm: str, cron: str, expected: list[FakeDatetime] ) -> None: """ Reports scheduler: Test cron schedule window for "America/Chicago" diff --git a/tests/unit_tests/tasks/test_utils.py b/tests/unit_tests/tasks/test_utils.py index 7854717201..b3fbfca8a2 100644 --- a/tests/unit_tests/tasks/test_utils.py +++ b/tests/unit_tests/tasks/test_utils.py @@ -18,7 +18,7 @@ from contextlib import nullcontext from dataclasses import dataclass from enum import Enum -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Optional, Union import pytest from flask_appbuilder.security.sqla.models import User @@ -31,8 +31,8 @@ SELENIUM_USERNAME = "admin" def _get_users( - params: Optional[Union[int, List[int]]] -) -> Optional[Union[User, List[User]]]: + params: Optional[Union[int, list[int]]] +) -> Optional[Union[User, list[User]]]: if params is None: return None if isinstance(params, int): @@ -42,7 +42,7 @@ def _get_users( @dataclass class ModelConfig: - owners: List[int] + owners: list[int] creator: Optional[int] = None modifier: Optional[int] = None @@ -268,18 +268,18 @@ class ModelType(int, Enum): ) def test_get_executor( model_type: ModelType, - executor_types: List[ExecutorType], + executor_types: list[ExecutorType], model_config: ModelConfig, current_user: Optional[int], - expected_result: Tuple[int, ExecutorNotFoundError], + expected_result: tuple[int, ExecutorNotFoundError], ) -> None: from superset.models.dashboard import Dashboard from superset.models.slice import Slice from superset.reports.models import ReportSchedule from superset.tasks.utils import get_executor - model: Type[Union[Dashboard, ReportSchedule, Slice]] - model_kwargs: Dict[str, Any] = {} + model: type[Union[Dashboard, ReportSchedule, Slice]] + model_kwargs: dict[str, Any] = {} if model_type == ModelType.REPORT_SCHEDULE: model = ReportSchedule model_kwargs = { diff --git a/tests/unit_tests/thumbnails/test_digest.py b/tests/unit_tests/thumbnails/test_digest.py index 04f244e629..68bd7a58f7 100644 --- a/tests/unit_tests/thumbnails/test_digest.py +++ b/tests/unit_tests/thumbnails/test_digest.py @@ -17,7 +17,7 @@ from __future__ import annotations from contextlib import nullcontext -from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union +from typing import Any, TYPE_CHECKING from unittest.mock import patch import pytest @@ -31,7 +31,7 @@ if TYPE_CHECKING: from superset.models.dashboard import Dashboard from superset.models.slice import Slice -_DEFAULT_DASHBOARD_KWARGS: Dict[str, Any] = { +_DEFAULT_DASHBOARD_KWARGS: dict[str, Any] = { "id": 1, "dashboard_title": "My Title", "slices": [{"id": 1, "slice_name": "My Chart"}], @@ -150,11 +150,11 @@ def CUSTOM_CHART_FUNC( ], ) def test_dashboard_digest( - dashboard_overrides: Optional[Dict[str, Any]], - execute_as: List[ExecutorType], + dashboard_overrides: dict[str, Any] | None, + execute_as: list[ExecutorType], has_current_user: bool, use_custom_digest: bool, - expected_result: Union[str, Exception], + expected_result: str | Exception, ) -> None: from superset import app from superset.models.dashboard import Dashboard @@ -167,7 +167,7 @@ def test_dashboard_digest( } slices = [Slice(**slice_kwargs) for slice_kwargs in kwargs.pop("slices")] dashboard = Dashboard(**kwargs, slices=slices) - user: Optional[User] = None + user: User | None = None if has_current_user: user = User(id=1, username="1") func = CUSTOM_DASHBOARD_FUNC if use_custom_digest else None @@ -222,11 +222,11 @@ def test_dashboard_digest( ], ) def test_chart_digest( - chart_overrides: Optional[Dict[str, Any]], - execute_as: List[ExecutorType], + chart_overrides: dict[str, Any] | None, + execute_as: list[ExecutorType], has_current_user: bool, use_custom_digest: bool, - expected_result: Union[str, Exception], + expected_result: str | Exception, ) -> None: from superset import app from superset.models.slice import Slice @@ -237,7 +237,7 @@ def test_chart_digest( **(chart_overrides or {}), } chart = Slice(**kwargs) - user: Optional[User] = None + user: User | None = None if has_current_user: user = User(id=1, username="1") func = CUSTOM_CHART_FUNC if use_custom_digest else None diff --git a/tests/unit_tests/utils/cache_test.py b/tests/unit_tests/utils/cache_test.py index 53650e1d20..bd6179957e 100644 --- a/tests/unit_tests/utils/cache_test.py +++ b/tests/unit_tests/utils/cache_test.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information diff --git a/tests/unit_tests/utils/date_parser_tests.py b/tests/unit_tests/utils/date_parser_tests.py index fb0fe07d29..a2ec20901a 100644 --- a/tests/unit_tests/utils/date_parser_tests.py +++ b/tests/unit_tests/utils/date_parser_tests.py @@ -16,7 +16,7 @@ # under the License. import re from datetime import date, datetime, timedelta -from typing import Optional, Tuple +from typing import Optional from unittest.mock import Mock, patch import pytest @@ -74,8 +74,8 @@ def mock_parse_human_datetime(s: str) -> Optional[datetime]: @patch("superset.utils.date_parser.parse_human_datetime", mock_parse_human_datetime) def test_get_since_until() -> None: - result: Tuple[Optional[datetime], Optional[datetime]] - expected: Tuple[Optional[datetime], Optional[datetime]] + result: tuple[Optional[datetime], Optional[datetime]] + expected: tuple[Optional[datetime], Optional[datetime]] result = get_since_until() expected = None, datetime(2016, 11, 7) diff --git a/tests/unit_tests/utils/test_core.py b/tests/unit_tests/utils/test_core.py index 3636983156..996bd1948f 100644 --- a/tests/unit_tests/utils/test_core.py +++ b/tests/unit_tests/utils/test_core.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -16,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import os -from typing import Any, Dict, Optional +from typing import Any, Optional import pytest @@ -86,7 +85,7 @@ EXTRA_FILTER: QueryObjectFilterClause = { ], ) def test_remove_extra_adhoc_filters( - original: Dict[str, Any], expected: Dict[str, Any] + original: dict[str, Any], expected: dict[str, Any] ) -> None: remove_extra_adhoc_filters(original) assert expected == original diff --git a/tests/unit_tests/utils/test_file.py b/tests/unit_tests/utils/test_file.py index de20402e5c..a2168a7d92 100644 --- a/tests/unit_tests/utils/test_file.py +++ b/tests/unit_tests/utils/test_file.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information diff --git a/tests/unit_tests/utils/urls_tests.py b/tests/unit_tests/utils/urls_tests.py index 208d6caea4..287f346c3d 100644 --- a/tests/unit_tests/utils/urls_tests.py +++ b/tests/unit_tests/utils/urls_tests.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information