chore(pre-commit): Add pyupgrade and pycln hooks (#24197)

This commit is contained in:
John Bodley 2023-06-01 12:01:10 -07:00 committed by GitHub
parent 7d7ce63970
commit a4d5d7c6b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
448 changed files with 3084 additions and 3305 deletions

View File

@ -15,14 +15,28 @@
# limitations under the License.
#
repos:
- repo: https://github.com/MarcoGorelli/auto-walrus
rev: v0.2.2
hooks:
- id: auto-walrus
- repo: https://github.com/asottile/pyupgrade
rev: v3.4.0
hooks:
- id: pyupgrade
args:
- --py39-plus
- repo: https://github.com/hadialqattan/pycln
rev: v2.1.2
hooks:
- id: pycln
args:
- --disable-all-dunder-policy
- --exclude=superset/config.py
- --extend-exclude=tests/integration_tests/superset_test_config.*.py
- repo: https://github.com/PyCQA/isort
rev: 5.12.0
hooks:
- id: isort
- repo: https://github.com/MarcoGorelli/auto-walrus
rev: v0.2.2
hooks:
- id: auto-walrus
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.3.0
hooks:

View File

@ -17,8 +17,9 @@ import csv as lib_csv
import os
import re
import sys
from collections.abc import Iterator
from dataclasses import dataclass
from typing import Any, Dict, Iterator, List, Optional, Union
from typing import Any, Optional, Union
import click
from click.core import Context
@ -67,15 +68,15 @@ class GitChangeLog:
def __init__(
self,
version: str,
logs: List[GitLog],
logs: list[GitLog],
access_token: Optional[str] = None,
risk: Optional[bool] = False,
) -> None:
self._version = version
self._logs = logs
self._pr_logs_with_details: Dict[int, Dict[str, Any]] = {}
self._github_login_cache: Dict[str, Optional[str]] = {}
self._github_prs: Dict[int, Any] = {}
self._pr_logs_with_details: dict[int, dict[str, Any]] = {}
self._github_login_cache: dict[str, Optional[str]] = {}
self._github_prs: dict[int, Any] = {}
self._wait = 10
github_token = access_token or os.environ.get("GITHUB_TOKEN")
self._github = Github(github_token)
@ -126,7 +127,7 @@ class GitChangeLog:
"superset/migrations/versions/" in file.filename for file in commit.files
)
def _get_pull_request_details(self, git_log: GitLog) -> Dict[str, Any]:
def _get_pull_request_details(self, git_log: GitLog) -> dict[str, Any]:
pr_number = git_log.pr_number
if pr_number:
detail = self._pr_logs_with_details.get(pr_number)
@ -156,7 +157,7 @@ class GitChangeLog:
return detail
def _is_risk_pull_request(self, labels: List[Any]) -> bool:
def _is_risk_pull_request(self, labels: list[Any]) -> bool:
for label in labels:
risk_label = re.match(SUPERSET_RISKY_LABELS, label.name)
if risk_label is not None:
@ -174,8 +175,8 @@ class GitChangeLog:
def _parse_change_log(
self,
changelog: Dict[str, str],
pr_info: Dict[str, str],
changelog: dict[str, str],
pr_info: dict[str, str],
github_login: str,
) -> None:
formatted_pr = (
@ -227,7 +228,7 @@ class GitChangeLog:
result += f"**{key}** {changelog[key]}\n"
return result
def __iter__(self) -> Iterator[Dict[str, Any]]:
def __iter__(self) -> Iterator[dict[str, Any]]:
for log in self._logs:
yield {
"pr_number": log.pr_number,
@ -250,20 +251,20 @@ class GitLogs:
def __init__(self, git_ref: str) -> None:
self._git_ref = git_ref
self._logs: List[GitLog] = []
self._logs: list[GitLog] = []
@property
def git_ref(self) -> str:
return self._git_ref
@property
def logs(self) -> List[GitLog]:
def logs(self) -> list[GitLog]:
return self._logs
def fetch(self) -> None:
self._logs = list(map(self._parse_log, self._git_logs()))[::-1]
def diff(self, git_logs: "GitLogs") -> List[GitLog]:
def diff(self, git_logs: "GitLogs") -> list[GitLog]:
return [log for log in git_logs.logs if log not in self._logs]
def __repr__(self) -> str:
@ -284,7 +285,7 @@ class GitLogs:
print(f"Could not checkout {git_ref}")
sys.exit(1)
def _git_logs(self) -> List[str]:
def _git_logs(self) -> list[str]:
# let's get current git ref so we can revert it back
current_git_ref = self._git_get_current_head()
self._git_checkout(self._git_ref)

View File

@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Any, Dict, List
from typing import Any
from click.core import Context
@ -34,7 +34,7 @@ PROJECT_MODULE = "superset"
PROJECT_DESCRIPTION = "Apache Superset is a modern, enterprise-ready business intelligence web application"
def string_comma_to_list(message: str) -> List[str]:
def string_comma_to_list(message: str) -> list[str]:
if not message:
return []
return [element.strip() for element in message.split(",")]
@ -52,7 +52,7 @@ def render_template(template_file: str, **kwargs: Any) -> str:
return template.render(kwargs)
class BaseParameters(object):
class BaseParameters:
def __init__(
self,
version: str,
@ -60,7 +60,7 @@ class BaseParameters(object):
) -> None:
self.version = version
self.version_rc = version_rc
self.template_arguments: Dict[str, Any] = {}
self.template_arguments: dict[str, Any] = {}
def __repr__(self) -> str:
return f"Apache Credentials: {self.version}/{self.version_rc}"

View File

@ -22,7 +22,6 @@
#
import logging
import os
from datetime import timedelta
from typing import Optional
from cachelib.file import FileSystemCache
@ -42,7 +41,7 @@ def get_env_variable(var_name: str, default: Optional[str] = None) -> str:
error_msg = "The environment variable {} was missing, abort...".format(
var_name
)
raise EnvironmentError(error_msg)
raise OSError(error_msg)
DATABASE_DIALECT = get_env_variable("DATABASE_DIALECT")
@ -53,7 +52,7 @@ DATABASE_PORT = get_env_variable("DATABASE_PORT")
DATABASE_DB = get_env_variable("DATABASE_DB")
# The SQLAlchemy connection string.
SQLALCHEMY_DATABASE_URI = "%s://%s:%s@%s:%s/%s" % (
SQLALCHEMY_DATABASE_URI = "{}://{}:{}@{}:{}/{}".format(
DATABASE_DIALECT,
DATABASE_USER,
DATABASE_PASSWORD,
@ -80,7 +79,7 @@ CACHE_CONFIG = {
DATA_CACHE_CONFIG = CACHE_CONFIG
class CeleryConfig(object):
class CeleryConfig:
broker_url = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_CELERY_DB}"
imports = ("superset.sql_lab",)
result_backend = f"redis://{REDIS_HOST}:{REDIS_PORT}/{REDIS_RESULTS_DB}"

View File

@ -23,7 +23,7 @@ from graphlib import TopologicalSorter
from inspect import getsource
from pathlib import Path
from types import ModuleType
from typing import Any, Dict, List, Set, Type
from typing import Any
import click
from flask import current_app
@ -48,12 +48,10 @@ def import_migration_script(filepath: Path) -> ModuleType:
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module) # type: ignore
return module
raise Exception(
"No module spec found in location: `{path}`".format(path=str(filepath))
)
raise Exception(f"No module spec found in location: `{str(filepath)}`")
def extract_modified_tables(module: ModuleType) -> Set[str]:
def extract_modified_tables(module: ModuleType) -> set[str]:
"""
Extract the tables being modified by a migration script.
@ -62,7 +60,7 @@ def extract_modified_tables(module: ModuleType) -> Set[str]:
actually traversing the AST.
"""
tables: Set[str] = set()
tables: set[str] = set()
for function in {"upgrade", "downgrade"}:
source = getsource(getattr(module, function))
tables.update(re.findall(r'alter_table\(\s*"(\w+?)"\s*\)', source, re.DOTALL))
@ -72,11 +70,11 @@ def extract_modified_tables(module: ModuleType) -> Set[str]:
return tables
def find_models(module: ModuleType) -> List[Type[Model]]:
def find_models(module: ModuleType) -> list[type[Model]]:
"""
Find all models in a migration script.
"""
models: List[Type[Model]] = []
models: list[type[Model]] = []
tables = extract_modified_tables(module)
# add models defined explicitly in the migration script
@ -123,7 +121,7 @@ def find_models(module: ModuleType) -> List[Type[Model]]:
sorter: TopologicalSorter[Any] = TopologicalSorter()
for model in models:
inspector = inspect(model)
dependent_tables: List[str] = []
dependent_tables: list[str] = []
for column in inspector.columns.values():
for foreign_key in column.foreign_keys:
if foreign_key.column.table.name != model.__tablename__:
@ -174,7 +172,7 @@ def main(
print("\nIdentifying models used in the migration:")
models = find_models(module)
model_rows: Dict[Type[Model], int] = {}
model_rows: dict[type[Model], int] = {}
for model in models:
rows = session.query(model).count()
print(f"- {model.__name__} ({rows} rows in table {model.__tablename__})")
@ -182,7 +180,7 @@ def main(
session.close()
print("Benchmarking migration")
results: Dict[str, float] = {}
results: dict[str, float] = {}
start = time.time()
upgrade(revision=revision)
duration = time.time() - start
@ -190,14 +188,14 @@ def main(
print(f"Migration on current DB took: {duration:.2f} seconds")
min_entities = 10
new_models: Dict[Type[Model], List[Model]] = defaultdict(list)
new_models: dict[type[Model], list[Model]] = defaultdict(list)
while min_entities <= limit:
downgrade(revision=down_revision)
print(f"Running with at least {min_entities} entities of each model")
for model in models:
missing = min_entities - model_rows[model]
if missing > 0:
entities: List[Model] = []
entities: list[Model] = []
print(f"- Adding {missing} entities to the {model.__name__} model")
bar = ChargingBar("Processing", max=missing)
try:

View File

@ -33,13 +33,13 @@ Example:
./cancel_github_workflows.py 1024 --include-last
"""
import os
from typing import Any, Dict, Iterable, Iterator, List, Optional, Union
from collections.abc import Iterable, Iterator
from typing import Any, Literal, Optional, Union
import click
import requests
from click.exceptions import ClickException
from dateutil import parser
from typing_extensions import Literal
github_token = os.environ.get("GITHUB_TOKEN")
github_repo = os.environ.get("GITHUB_REPOSITORY", "apache/superset")
@ -47,7 +47,7 @@ github_repo = os.environ.get("GITHUB_REPOSITORY", "apache/superset")
def request(
method: Literal["GET", "POST", "DELETE", "PUT"], endpoint: str, **kwargs: Any
) -> Dict[str, Any]:
) -> dict[str, Any]:
resp = requests.request(
method,
f"https://api.github.com/{endpoint.lstrip('/')}",
@ -61,8 +61,8 @@ def request(
def list_runs(
repo: str,
params: Optional[Dict[str, str]] = None,
) -> Iterator[Dict[str, Any]]:
params: Optional[dict[str, str]] = None,
) -> Iterator[dict[str, Any]]:
"""List all github workflow runs.
Returns:
An iterator that will iterate through all pages of matching runs."""
@ -77,16 +77,15 @@ def list_runs(
params={**params, "per_page": 100, "page": page},
)
total_count = result["total_count"]
for item in result["workflow_runs"]:
yield item
yield from result["workflow_runs"]
page += 1
def cancel_run(repo: str, run_id: Union[str, int]) -> Dict[str, Any]:
def cancel_run(repo: str, run_id: Union[str, int]) -> dict[str, Any]:
return request("POST", f"/repos/{repo}/actions/runs/{run_id}/cancel")
def get_pull_request(repo: str, pull_number: Union[str, int]) -> Dict[str, Any]:
def get_pull_request(repo: str, pull_number: Union[str, int]) -> dict[str, Any]:
return request("GET", f"/repos/{repo}/pulls/{pull_number}")
@ -96,7 +95,7 @@ def get_runs(
user: Optional[str] = None,
statuses: Iterable[str] = ("queued", "in_progress"),
events: Iterable[str] = ("pull_request", "push"),
) -> List[Dict[str, Any]]:
) -> list[dict[str, Any]]:
"""Get workflow runs associated with the given branch"""
return [
item
@ -108,7 +107,7 @@ def get_runs(
]
def print_commit(commit: Dict[str, Any], branch: str) -> None:
def print_commit(commit: dict[str, Any], branch: str) -> None:
"""Print out commit message for verification"""
indented_message = " \n".join(commit["message"].split("\n"))
date_str = (
@ -155,7 +154,7 @@ Date: {date_str}
def cancel_github_workflows(
branch_or_pull: Optional[str],
repo: str,
event: List[str],
event: list[str],
include_last: bool,
include_running: bool,
) -> None:

View File

@ -24,7 +24,7 @@ def cleanup_permissions() -> None:
pvms = security_manager.get_session.query(
security_manager.permissionview_model
).all()
print("# of permission view menus is: {}".format(len(pvms)))
print(f"# of permission view menus is: {len(pvms)}")
pvms_dict = defaultdict(list)
for pvm in pvms:
pvms_dict[(pvm.permission, pvm.view_menu)].append(pvm)
@ -43,7 +43,7 @@ def cleanup_permissions() -> None:
pvms = security_manager.get_session.query(
security_manager.permissionview_model
).all()
print("Stage 1: # of permission view menus is: {}".format(len(pvms)))
print(f"Stage 1: # of permission view menus is: {len(pvms)}")
# 2. Clean up None permissions or view menus
pvms = security_manager.get_session.query(
@ -57,7 +57,7 @@ def cleanup_permissions() -> None:
pvms = security_manager.get_session.query(
security_manager.permissionview_model
).all()
print("Stage 2: # of permission view menus is: {}".format(len(pvms)))
print(f"Stage 2: # of permission view menus is: {len(pvms)}")
# 3. Delete empty permission view menus from roles
roles = security_manager.get_session.query(security_manager.role_model).all()

View File

@ -14,21 +14,19 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import io
import json
import os
import subprocess
import sys
from setuptools import find_packages, setup
BASE_DIR = os.path.abspath(os.path.dirname(__file__))
PACKAGE_JSON = os.path.join(BASE_DIR, "superset-frontend", "package.json")
with open(PACKAGE_JSON, "r") as package_file:
with open(PACKAGE_JSON) as package_file:
version_string = json.load(package_file)["version"]
with io.open("README.md", "r", encoding="utf-8") as f:
with open("README.md", encoding="utf-8") as f:
long_description = f.read()

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import ipaddress
from typing import Any, List
from typing import Any
from sqlalchemy import Column
@ -77,7 +77,7 @@ def cidr_func(req: AdvancedDataTypeRequest) -> AdvancedDataTypeResponse:
# Make this return a single clause
def cidr_translate_filter_func(
col: Column, operator: FilterOperator, values: List[Any]
col: Column, operator: FilterOperator, values: list[Any]
) -> Any:
"""
Convert a passed in column, FilterOperator and

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import itertools
from typing import Any, Dict, List
from typing import Any
from sqlalchemy import Column
@ -26,7 +26,7 @@ from superset.advanced_data_type.types import (
)
from superset.utils.core import FilterOperator, FilterStringOperators
port_conversion_dict: Dict[str, List[int]] = {
port_conversion_dict: dict[str, list[int]] = {
"http": [80],
"ssh": [22],
"https": [443],
@ -100,7 +100,7 @@ def port_translation_func(req: AdvancedDataTypeRequest) -> AdvancedDataTypeRespo
def port_translate_filter_func(
col: Column, operator: FilterOperator, values: List[Any]
col: Column, operator: FilterOperator, values: list[Any]
) -> Any:
"""
Convert a passed in column, FilterOperator

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from dataclasses import dataclass
from typing import Any, Callable, List, Optional, TypedDict, Union
from typing import Any, Callable, Optional, TypedDict, Union
from sqlalchemy import Column
from sqlalchemy.sql.expression import BinaryExpression
@ -30,7 +30,7 @@ class AdvancedDataTypeRequest(TypedDict):
"""
advanced_data_type: str
values: List[
values: list[
Union[FilterValues, None]
] # unparsed value (usually text when passed from text box)
@ -41,9 +41,9 @@ class AdvancedDataTypeResponse(TypedDict, total=False):
"""
error_message: Optional[str]
values: List[Any] # parsed value (can be any value)
values: list[Any] # parsed value (can be any value)
display_value: str # The string representation of the parsed values
valid_filter_operators: List[FilterStringOperators]
valid_filter_operators: list[FilterStringOperators]
@dataclass
@ -54,6 +54,6 @@ class AdvancedDataType:
verbose_name: str
description: str
valid_data_types: List[str]
valid_data_types: list[str]
translate_type: Callable[[AdvancedDataTypeRequest], AdvancedDataTypeResponse]
translate_filter: Callable[[Column, FilterOperator, Any], BinaryExpression]

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict
from typing import Any
from flask import request, Response
from flask_appbuilder.api import expose, permission_name, protect, rison, safe
@ -127,7 +127,7 @@ class AnnotationRestApi(BaseSupersetModelRestApi):
@staticmethod
def _apply_layered_relation_to_rison( # pylint: disable=invalid-name
layer_id: int, rison_parameters: Dict[str, Any]
layer_id: int, rison_parameters: dict[str, Any]
) -> None:
if "filters" not in rison_parameters:
rison_parameters["filters"] = []

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import List, Optional
from typing import Optional
from superset.annotation_layers.annotations.commands.exceptions import (
AnnotationBulkDeleteFailedError,
@ -30,9 +30,9 @@ logger = logging.getLogger(__name__)
class BulkDeleteAnnotationCommand(BaseCommand):
def __init__(self, model_ids: List[int]):
def __init__(self, model_ids: list[int]):
self._model_ids = model_ids
self._models: Optional[List[Annotation]] = None
self._models: Optional[list[Annotation]] = None
def run(self) -> None:
self.validate()

View File

@ -16,7 +16,7 @@
# under the License.
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
class CreateAnnotationCommand(BaseCommand):
def __init__(self, data: Dict[str, Any]):
def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
def run(self) -> Model:
@ -50,7 +50,7 @@ class CreateAnnotationCommand(BaseCommand):
return annotation
def validate(self) -> None:
exceptions: List[ValidationError] = []
exceptions: list[ValidationError] = []
layer_id: Optional[int] = self._properties.get("layer")
start_dttm: Optional[datetime] = self._properties.get("start_dttm")
end_dttm: Optional[datetime] = self._properties.get("end_dttm")

View File

@ -16,7 +16,7 @@
# under the License.
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
@ -39,7 +39,7 @@ logger = logging.getLogger(__name__)
class UpdateAnnotationCommand(BaseCommand):
def __init__(self, model_id: int, data: Dict[str, Any]):
def __init__(self, model_id: int, data: dict[str, Any]):
self._model_id = model_id
self._properties = data.copy()
self._model: Optional[Annotation] = None
@ -54,7 +54,7 @@ class UpdateAnnotationCommand(BaseCommand):
return annotation
def validate(self) -> None:
exceptions: List[ValidationError] = []
exceptions: list[ValidationError] = []
layer_id: Optional[int] = self._properties.get("layer")
short_descr: str = self._properties.get("short_descr", "")

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import List, Optional
from typing import Optional
from sqlalchemy.exc import SQLAlchemyError
@ -31,7 +31,7 @@ class AnnotationDAO(BaseDAO):
model_cls = Annotation
@staticmethod
def bulk_delete(models: Optional[List[Annotation]], commit: bool = True) -> None:
def bulk_delete(models: Optional[list[Annotation]], commit: bool = True) -> None:
item_ids = [model.id for model in models] if models else []
try:
db.session.query(Annotation).filter(Annotation.id.in_(item_ids)).delete(

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import List, Optional
from typing import Optional
from superset.annotation_layers.commands.exceptions import (
AnnotationLayerBulkDeleteFailedError,
@ -31,9 +31,9 @@ logger = logging.getLogger(__name__)
class BulkDeleteAnnotationLayerCommand(BaseCommand):
def __init__(self, model_ids: List[int]):
def __init__(self, model_ids: list[int]):
self._model_ids = model_ids
self._models: Optional[List[AnnotationLayer]] = None
self._models: Optional[list[AnnotationLayer]] = None
def run(self) -> None:
self.validate()

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict, List
from typing import Any
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
@ -33,7 +33,7 @@ logger = logging.getLogger(__name__)
class CreateAnnotationLayerCommand(BaseCommand):
def __init__(self, data: Dict[str, Any]):
def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
def run(self) -> Model:
@ -46,7 +46,7 @@ class CreateAnnotationLayerCommand(BaseCommand):
return annotation_layer
def validate(self) -> None:
exceptions: List[ValidationError] = []
exceptions: list[ValidationError] = []
name = self._properties.get("name", "")

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
@ -35,7 +35,7 @@ logger = logging.getLogger(__name__)
class UpdateAnnotationLayerCommand(BaseCommand):
def __init__(self, model_id: int, data: Dict[str, Any]):
def __init__(self, model_id: int, data: dict[str, Any]):
self._model_id = model_id
self._properties = data.copy()
self._model: Optional[AnnotationLayer] = None
@ -50,7 +50,7 @@ class UpdateAnnotationLayerCommand(BaseCommand):
return annotation_layer
def validate(self) -> None:
exceptions: List[ValidationError] = []
exceptions: list[ValidationError] = []
name = self._properties.get("name", "")
self._model = AnnotationLayerDAO.find_by_id(self._model_id)

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import List, Optional, Union
from typing import Optional, Union
from sqlalchemy.exc import SQLAlchemyError
@ -32,7 +32,7 @@ class AnnotationLayerDAO(BaseDAO):
@staticmethod
def bulk_delete(
models: Optional[List[AnnotationLayer]], commit: bool = True
models: Optional[list[AnnotationLayer]], commit: bool = True
) -> None:
item_ids = [model.id for model in models] if models else []
try:
@ -46,7 +46,7 @@ class AnnotationLayerDAO(BaseDAO):
raise DAODeleteFailedError() from ex
@staticmethod
def has_annotations(model_id: Union[int, List[int]]) -> bool:
def has_annotations(model_id: Union[int, list[int]]) -> bool:
if isinstance(model_id, list):
return (
db.session.query(AnnotationLayer)

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import List, Optional
from typing import Optional
from flask_babel import lazy_gettext as _
@ -37,9 +37,9 @@ logger = logging.getLogger(__name__)
class BulkDeleteChartCommand(BaseCommand):
def __init__(self, model_ids: List[int]):
def __init__(self, model_ids: list[int]):
self._model_ids = model_ids
self._models: Optional[List[Slice]] = None
self._models: Optional[list[Slice]] = None
def run(self) -> None:
self.validate()

View File

@ -16,7 +16,7 @@
# under the License.
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from flask import g
from flask_appbuilder.models.sqla import Model
@ -37,7 +37,7 @@ logger = logging.getLogger(__name__)
class CreateChartCommand(CreateMixin, BaseCommand):
def __init__(self, data: Dict[str, Any]):
def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
def run(self) -> Model:
@ -56,7 +56,7 @@ class CreateChartCommand(CreateMixin, BaseCommand):
datasource_type = self._properties["datasource_type"]
datasource_id = self._properties["datasource_id"]
dashboard_ids = self._properties.get("dashboards", [])
owner_ids: Optional[List[int]] = self._properties.get("owners")
owner_ids: Optional[list[int]] = self._properties.get("owners")
# Validate/Populate datasource
try:

View File

@ -18,7 +18,7 @@
import json
import logging
from typing import Iterator, Tuple
from collections.abc import Iterator
import yaml
@ -42,7 +42,7 @@ class ExportChartsCommand(ExportModelsCommand):
not_found = ChartNotFoundError
@staticmethod
def _export(model: Slice, export_related: bool = True) -> Iterator[Tuple[str, str]]:
def _export(model: Slice, export_related: bool = True) -> Iterator[tuple[str, str]]:
file_name = get_filename(model.slice_name, model.id)
file_path = f"charts/{file_name}.yaml"

View File

@ -16,7 +16,7 @@
# under the License.
import logging
from typing import Any, Dict
from typing import Any
from marshmallow.exceptions import ValidationError
@ -40,7 +40,7 @@ class ImportChartsCommand(BaseCommand):
until it finds one that matches.
"""
def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any):
self.contents = contents
self.args = args
self.kwargs = kwargs

View File

@ -15,8 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import json
from typing import Any, Dict, Set
from typing import Any
from marshmallow import Schema
from sqlalchemy.orm import Session
@ -40,7 +39,7 @@ class ImportChartsCommand(ImportModelsCommand):
dao = ChartDAO
model_name = "chart"
prefix = "charts/"
schemas: Dict[str, Schema] = {
schemas: dict[str, Schema] = {
"charts/": ImportV1ChartSchema(),
"datasets/": ImportV1DatasetSchema(),
"databases/": ImportV1DatabaseSchema(),
@ -49,29 +48,29 @@ class ImportChartsCommand(ImportModelsCommand):
@staticmethod
def _import(
session: Session, configs: Dict[str, Any], overwrite: bool = False
session: Session, configs: dict[str, Any], overwrite: bool = False
) -> None:
# discover datasets associated with charts
dataset_uuids: Set[str] = set()
dataset_uuids: set[str] = set()
for file_name, config in configs.items():
if file_name.startswith("charts/"):
dataset_uuids.add(config["dataset_uuid"])
# discover databases associated with datasets
database_uuids: Set[str] = set()
database_uuids: set[str] = set()
for file_name, config in configs.items():
if file_name.startswith("datasets/") and config["uuid"] in dataset_uuids:
database_uuids.add(config["database_uuid"])
# import related databases
database_ids: Dict[str, int] = {}
database_ids: dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/") and config["uuid"] in database_uuids:
database = import_database(session, config, overwrite=False)
database_ids[str(database.uuid)] = database.id
# import datasets with the correct parent ref
datasets: Dict[str, SqlaTable] = {}
datasets: dict[str, SqlaTable] = {}
for file_name, config in configs.items():
if (
file_name.startswith("datasets/")

View File

@ -16,7 +16,7 @@
# under the License.
import json
from typing import Any, Dict
from typing import Any
from flask import g
from sqlalchemy.orm import Session
@ -28,7 +28,7 @@ from superset.models.slice import Slice
def import_chart(
session: Session,
config: Dict[str, Any],
config: dict[str, Any],
overwrite: bool = False,
ignore_permissions: bool = False,
) -> Slice:

View File

@ -16,7 +16,7 @@
# under the License.
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from flask import g
from flask_appbuilder.models.sqla import Model
@ -42,14 +42,14 @@ from superset.models.slice import Slice
logger = logging.getLogger(__name__)
def is_query_context_update(properties: Dict[str, Any]) -> bool:
def is_query_context_update(properties: dict[str, Any]) -> bool:
return set(properties) == {"query_context", "query_context_generation"} and bool(
properties.get("query_context_generation")
)
class UpdateChartCommand(UpdateMixin, BaseCommand):
def __init__(self, model_id: int, data: Dict[str, Any]):
def __init__(self, model_id: int, data: dict[str, Any]):
self._model_id = model_id
self._properties = data.copy()
self._model: Optional[Slice] = None
@ -67,9 +67,9 @@ class UpdateChartCommand(UpdateMixin, BaseCommand):
return chart
def validate(self) -> None:
exceptions: List[ValidationError] = []
exceptions: list[ValidationError] = []
dashboard_ids = self._properties.get("dashboards")
owner_ids: Optional[List[int]] = self._properties.get("owners")
owner_ids: Optional[list[int]] = self._properties.get("owners")
# Validate if datasource_id is provided datasource_type is required
datasource_id = self._properties.get("datasource_id")

View File

@ -17,7 +17,7 @@
# pylint: disable=arguments-renamed
import logging
from datetime import datetime
from typing import List, Optional, TYPE_CHECKING
from typing import Optional, TYPE_CHECKING
from sqlalchemy.exc import SQLAlchemyError
@ -39,7 +39,7 @@ class ChartDAO(BaseDAO):
base_filter = ChartFilter
@staticmethod
def bulk_delete(models: Optional[List[Slice]], commit: bool = True) -> None:
def bulk_delete(models: Optional[list[Slice]], commit: bool = True) -> None:
item_ids = [model.id for model in models] if models else []
# bulk delete, first delete related data
if models:
@ -71,7 +71,7 @@ class ChartDAO(BaseDAO):
db.session.commit()
@staticmethod
def favorited_ids(charts: List[Slice]) -> List[FavStar]:
def favorited_ids(charts: list[Slice]) -> list[FavStar]:
ids = [chart.id for chart in charts]
return [
star.obj_id

View File

@ -18,7 +18,7 @@ from __future__ import annotations
import json
import logging
from typing import Any, Dict, Optional, TYPE_CHECKING, Union
from typing import Any, TYPE_CHECKING
import simplejson
from flask import current_app, g, make_response, request, Response
@ -315,7 +315,7 @@ class ChartDataRestApi(ChartRestApi):
return self._get_data_response(command, True)
def _run_async(
self, form_data: Dict[str, Any], command: ChartDataCommand
self, form_data: dict[str, Any], command: ChartDataCommand
) -> Response:
"""
Execute command as an async query.
@ -344,9 +344,9 @@ class ChartDataRestApi(ChartRestApi):
def _send_chart_response(
self,
result: Dict[Any, Any],
form_data: Optional[Dict[str, Any]] = None,
datasource: Optional[Union[BaseDatasource, Query]] = None,
result: dict[Any, Any],
form_data: dict[str, Any] | None = None,
datasource: BaseDatasource | Query | None = None,
) -> Response:
result_type = result["query_context"].result_type
result_format = result["query_context"].result_format
@ -408,8 +408,8 @@ class ChartDataRestApi(ChartRestApi):
self,
command: ChartDataCommand,
force_cached: bool = False,
form_data: Optional[Dict[str, Any]] = None,
datasource: Optional[Union[BaseDatasource, Query]] = None,
form_data: dict[str, Any] | None = None,
datasource: BaseDatasource | Query | None = None,
) -> Response:
try:
result = command.run(force_cached=force_cached)
@ -421,12 +421,12 @@ class ChartDataRestApi(ChartRestApi):
return self._send_chart_response(result, form_data, datasource)
# pylint: disable=invalid-name, no-self-use
def _load_query_context_form_from_cache(self, cache_key: str) -> Dict[str, Any]:
def _load_query_context_form_from_cache(self, cache_key: str) -> dict[str, Any]:
return QueryContextCacheLoader.load(cache_key)
# pylint: disable=no-self-use
def _create_query_context_from_form(
self, form_data: Dict[str, Any]
self, form_data: dict[str, Any]
) -> QueryContext:
try:
return ChartDataQueryContextSchema().load(form_data)

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict, Optional
from typing import Any, Optional
from flask import Request
@ -32,7 +32,7 @@ class CreateAsyncChartDataJobCommand:
jwt_data = async_query_manager.parse_jwt_from_request(request)
self._async_channel_id = jwt_data["channel"]
def run(self, form_data: Dict[str, Any], user_id: Optional[int]) -> Dict[str, Any]:
def run(self, form_data: dict[str, Any], user_id: Optional[int]) -> dict[str, Any]:
job_metadata = async_query_manager.init_job(self._async_channel_id, user_id)
load_chart_data_into_cache.delay(job_metadata, form_data)
return job_metadata

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict
from typing import Any
from flask_babel import lazy_gettext as _
@ -36,7 +36,7 @@ class ChartDataCommand(BaseCommand):
def __init__(self, query_context: QueryContext):
self._query_context = query_context
def run(self, **kwargs: Any) -> Dict[str, Any]:
def run(self, **kwargs: Any) -> dict[str, Any]:
# caching is handled in query_context.get_df_payload
# (also evals `force` property)
cache_query_context = kwargs.get("cache", False)

View File

@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict
from typing import Any
from superset import cache
from superset.charts.commands.exceptions import ChartDataCacheLoadError
@ -22,7 +22,7 @@ from superset.charts.commands.exceptions import ChartDataCacheLoadError
class QueryContextCacheLoader: # pylint: disable=too-few-public-methods
@staticmethod
def load(cache_key: str) -> Dict[str, Any]:
def load(cache_key: str) -> dict[str, Any]:
cache_value = cache.get(cache_key)
if not cache_value:
raise ChartDataCacheLoadError("Cached data not found")

View File

@ -27,7 +27,7 @@ for these chart types.
"""
from io import StringIO
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
from typing import Any, Optional, TYPE_CHECKING, Union
import pandas as pd
from flask_babel import gettext as __
@ -45,14 +45,14 @@ if TYPE_CHECKING:
from superset.models.sql_lab import Query
def get_column_key(label: Tuple[str, ...], metrics: List[str]) -> Tuple[Any, ...]:
def get_column_key(label: tuple[str, ...], metrics: list[str]) -> tuple[Any, ...]:
"""
Sort columns when combining metrics.
MultiIndex labels have the metric name as the last element in the
tuple. We want to sort these according to the list of passed metrics.
"""
parts: List[Any] = list(label)
parts: list[Any] = list(label)
metric = parts[-1]
parts[-1] = metrics.index(metric)
return tuple(parts)
@ -60,9 +60,9 @@ def get_column_key(label: Tuple[str, ...], metrics: List[str]) -> Tuple[Any, ...
def pivot_df( # pylint: disable=too-many-locals, too-many-arguments, too-many-statements, too-many-branches
df: pd.DataFrame,
rows: List[str],
columns: List[str],
metrics: List[str],
rows: list[str],
columns: list[str],
metrics: list[str],
aggfunc: str = "Sum",
transpose_pivot: bool = False,
combine_metrics: bool = False,
@ -194,7 +194,7 @@ def list_unique_values(series: pd.Series) -> str:
"""
List unique values in a series.
"""
return ", ".join(set(str(v) for v in pd.Series.unique(series)))
return ", ".join({str(v) for v in pd.Series.unique(series)})
pivot_v2_aggfunc_map = {
@ -223,7 +223,7 @@ pivot_v2_aggfunc_map = {
def pivot_table_v2(
df: pd.DataFrame,
form_data: Dict[str, Any],
form_data: dict[str, Any],
datasource: Optional[Union["BaseDatasource", "Query"]] = None,
) -> pd.DataFrame:
"""
@ -249,7 +249,7 @@ def pivot_table_v2(
def pivot_table(
df: pd.DataFrame,
form_data: Dict[str, Any],
form_data: dict[str, Any],
datasource: Optional[Union["BaseDatasource", "Query"]] = None,
) -> pd.DataFrame:
"""
@ -285,7 +285,7 @@ def pivot_table(
def table(
df: pd.DataFrame,
form_data: Dict[str, Any],
form_data: dict[str, Any],
datasource: Optional[ # pylint: disable=unused-argument
Union["BaseDatasource", "Query"]
] = None,
@ -315,10 +315,10 @@ post_processors = {
def apply_post_process(
result: Dict[Any, Any],
form_data: Optional[Dict[str, Any]] = None,
result: dict[Any, Any],
form_data: Optional[dict[str, Any]] = None,
datasource: Optional[Union["BaseDatasource", "Query"]] = None,
) -> Dict[Any, Any]:
) -> dict[Any, Any]:
form_data = form_data or {}
viz_type = form_data.get("viz_type")

View File

@ -18,7 +18,7 @@
from __future__ import annotations
import inspect
from typing import Any, Dict, Optional, TYPE_CHECKING
from typing import Any, TYPE_CHECKING
from flask_babel import gettext as _
from marshmallow import EXCLUDE, fields, post_load, Schema, validate
@ -1383,7 +1383,7 @@ class ChartDataQueryObjectSchema(Schema):
class ChartDataQueryContextSchema(Schema):
query_context_factory: Optional[QueryContextFactory] = None
query_context_factory: QueryContextFactory | None = None
datasource = fields.Nested(ChartDataDatasourceSchema)
queries = fields.List(fields.Nested(ChartDataQueryObjectSchema))
custom_cache_timeout = fields.Integer(
@ -1407,7 +1407,7 @@ class ChartDataQueryContextSchema(Schema):
# pylint: disable=unused-argument
@post_load
def make_query_context(self, data: Dict[str, Any], **kwargs: Any) -> QueryContext:
def make_query_context(self, data: dict[str, Any], **kwargs: Any) -> QueryContext:
query_context = self.get_query_context_factory().create(**data)
return query_context

View File

@ -18,7 +18,7 @@ import logging
import sys
from datetime import datetime
from pathlib import Path
from typing import List, Optional
from typing import Optional
from zipfile import is_zipfile, ZipFile
import click
@ -309,7 +309,7 @@ else:
from superset.dashboards.commands.importers.v0 import ImportDashboardsCommand
path_object = Path(path)
files: List[Path] = []
files: list[Path] = []
if path_object.is_file():
files.append(path_object)
elif path_object.exists() and not recursive:
@ -363,7 +363,7 @@ else:
sync_metrics = "metrics" in sync_array
path_object = Path(path)
files: List[Path] = []
files: list[Path] = []
if path_object.is_file():
files.append(path_object)
elif path_object.exists() and not recursive:

View File

@ -18,7 +18,7 @@
import importlib
import logging
import pkgutil
from typing import Any, Dict
from typing import Any
import click
from colorama import Fore, Style
@ -40,7 +40,7 @@ def superset() -> None:
"""This is a management script for the Superset application."""
@app.shell_context_processor
def make_shell_context() -> Dict[str, Any]:
def make_shell_context() -> dict[str, Any]:
return dict(app=app, db=db)
@ -79,5 +79,5 @@ def version(verbose: bool) -> None:
)
print(Fore.BLUE + "-=" * 15)
if verbose:
print("[DB] : " + "{}".format(db.engine))
print("[DB] : " + f"{db.engine}")
print(Style.RESET_ALL)

View File

@ -17,7 +17,6 @@
import json
from copy import deepcopy
from textwrap import dedent
from typing import Set, Tuple
import click
from click_option_group import optgroup, RequiredMutuallyExclusiveOptionGroup
@ -102,7 +101,7 @@ def native_filters() -> None:
)
def upgrade(
all_: bool, # pylint: disable=unused-argument
dashboard_ids: Tuple[int, ...],
dashboard_ids: tuple[int, ...],
) -> None:
"""
Upgrade legacy filter-box charts to native dashboard filters.
@ -251,7 +250,7 @@ def upgrade(
)
def downgrade(
all_: bool, # pylint: disable=unused-argument
dashboard_ids: Tuple[int, ...],
dashboard_ids: tuple[int, ...],
) -> None:
"""
Downgrade native dashboard filters to legacy filter-box charts (where applicable).
@ -347,7 +346,7 @@ def downgrade(
)
def cleanup(
all_: bool, # pylint: disable=unused-argument
dashboard_ids: Tuple[int, ...],
dashboard_ids: tuple[int, ...],
) -> None:
"""
Cleanup obsolete legacy filter-box charts and interim metadata.
@ -355,7 +354,7 @@ def cleanup(
Note this operation is irreversible.
"""
slice_ids: Set[int] = set()
slice_ids: set[int] = set()
# Cleanup the dashboard which contains legacy fields used for downgrading.
for dashboard in (

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Type, Union
from typing import Union
import click
from celery.utils.abstract import CallableTask
@ -75,7 +75,7 @@ def compute_thumbnails(
def compute_generic_thumbnail(
friendly_type: str,
model_cls: Union[Type[Dashboard], Type[Slice]],
model_cls: Union[type[Dashboard], type[Slice]],
model_id: int,
compute_func: CallableTask,
) -> None:

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from abc import ABC, abstractmethod
from typing import Any, List, Optional
from typing import Any, Optional
from flask_appbuilder.security.sqla.models import User
@ -45,7 +45,7 @@ class BaseCommand(ABC):
class CreateMixin: # pylint: disable=too-few-public-methods
@staticmethod
def populate_owners(owner_ids: Optional[List[int]] = None) -> List[User]:
def populate_owners(owner_ids: Optional[list[int]] = None) -> list[User]:
"""
Populate list of owners, defaulting to the current user if `owner_ids` is
undefined or empty. If current user is missing in `owner_ids`, current user
@ -60,7 +60,7 @@ class CreateMixin: # pylint: disable=too-few-public-methods
class UpdateMixin: # pylint: disable=too-few-public-methods
@staticmethod
def populate_owners(owner_ids: Optional[List[int]] = None) -> List[User]:
def populate_owners(owner_ids: Optional[list[int]] = None) -> list[User]:
"""
Populate list of owners. If current user is missing in `owner_ids`, current user
is added unless belonging to the Admin role.

View File

@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from flask_babel import lazy_gettext as _
from marshmallow import ValidationError
@ -59,7 +59,7 @@ class CommandInvalidError(CommandException):
def __init__(
self,
message: str = "",
exceptions: Optional[List[ValidationError]] = None,
exceptions: Optional[list[ValidationError]] = None,
) -> None:
self._exceptions = exceptions or []
super().__init__(message)
@ -67,14 +67,14 @@ class CommandInvalidError(CommandException):
def append(self, exception: ValidationError) -> None:
self._exceptions.append(exception)
def extend(self, exceptions: List[ValidationError]) -> None:
def extend(self, exceptions: list[ValidationError]) -> None:
self._exceptions.extend(exceptions)
def get_list_classnames(self) -> List[str]:
def get_list_classnames(self) -> list[str]:
return list(sorted({ex.__class__.__name__ for ex in self._exceptions}))
def normalized_messages(self) -> Dict[Any, Any]:
errors: Dict[Any, Any] = {}
def normalized_messages(self) -> dict[Any, Any]:
errors: dict[Any, Any] = {}
for exception in self._exceptions:
errors.update(exception.normalized_messages())
return errors

View File

@ -15,8 +15,8 @@
# specific language governing permissions and limitations
# under the License.
from collections.abc import Iterator
from datetime import datetime, timezone
from typing import Iterator, Tuple
import yaml
@ -36,7 +36,7 @@ class ExportAssetsCommand(BaseCommand):
Command that exports all databases, datasets, charts, dashboards and saved queries.
"""
def run(self) -> Iterator[Tuple[str, str]]:
def run(self) -> Iterator[tuple[str, str]]:
metadata = {
"version": EXPORT_VERSION,
"type": "assets",

View File

@ -15,8 +15,8 @@
# specific language governing permissions and limitations
# under the License.
from collections.abc import Iterator
from datetime import datetime, timezone
from typing import Iterator, List, Tuple, Type
import yaml
from flask_appbuilder import Model
@ -30,21 +30,21 @@ METADATA_FILE_NAME = "metadata.yaml"
class ExportModelsCommand(BaseCommand):
dao: Type[BaseDAO] = BaseDAO
not_found: Type[CommandException] = CommandException
dao: type[BaseDAO] = BaseDAO
not_found: type[CommandException] = CommandException
def __init__(self, model_ids: List[int], export_related: bool = True):
def __init__(self, model_ids: list[int], export_related: bool = True):
self.model_ids = model_ids
self.export_related = export_related
# this will be set when calling validate()
self._models: List[Model] = []
self._models: list[Model] = []
@staticmethod
def _export(model: Model, export_related: bool = True) -> Iterator[Tuple[str, str]]:
def _export(model: Model, export_related: bool = True) -> Iterator[tuple[str, str]]:
raise NotImplementedError("Subclasses MUST implement _export")
def run(self) -> Iterator[Tuple[str, str]]:
def run(self) -> Iterator[tuple[str, str]]:
self.validate()
metadata = {

View File

@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, List, Optional, Set
from typing import Any, Optional
from marshmallow import Schema, validate
from marshmallow.exceptions import ValidationError
@ -40,33 +40,33 @@ class ImportModelsCommand(BaseCommand):
dao = BaseDAO
model_name = "model"
prefix = ""
schemas: Dict[str, Schema] = {}
schemas: dict[str, Schema] = {}
import_error = CommandException
# pylint: disable=unused-argument
def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any):
self.contents = contents
self.passwords: Dict[str, str] = kwargs.get("passwords") or {}
self.ssh_tunnel_passwords: Dict[str, str] = (
self.passwords: dict[str, str] = kwargs.get("passwords") or {}
self.ssh_tunnel_passwords: dict[str, str] = (
kwargs.get("ssh_tunnel_passwords") or {}
)
self.ssh_tunnel_private_keys: Dict[str, str] = (
self.ssh_tunnel_private_keys: dict[str, str] = (
kwargs.get("ssh_tunnel_private_keys") or {}
)
self.ssh_tunnel_priv_key_passwords: Dict[str, str] = (
self.ssh_tunnel_priv_key_passwords: dict[str, str] = (
kwargs.get("ssh_tunnel_priv_key_passwords") or {}
)
self.overwrite: bool = kwargs.get("overwrite", False)
self._configs: Dict[str, Any] = {}
self._configs: dict[str, Any] = {}
@staticmethod
def _import(
session: Session, configs: Dict[str, Any], overwrite: bool = False
session: Session, configs: dict[str, Any], overwrite: bool = False
) -> None:
raise NotImplementedError("Subclasses MUST implement _import")
@classmethod
def _get_uuids(cls) -> Set[str]:
def _get_uuids(cls) -> set[str]:
return {str(model.uuid) for model in db.session.query(cls.dao.model_cls).all()}
def run(self) -> None:
@ -84,11 +84,11 @@ class ImportModelsCommand(BaseCommand):
raise self.import_error() from ex
def validate(self) -> None:
exceptions: List[ValidationError] = []
exceptions: list[ValidationError] = []
# verify that the metadata file is present and valid
try:
metadata: Optional[Dict[str, str]] = load_metadata(self.contents)
metadata: Optional[dict[str, str]] = load_metadata(self.contents)
except ValidationError as exc:
exceptions.append(exc)
metadata = None
@ -114,7 +114,7 @@ class ImportModelsCommand(BaseCommand):
)
def _prevent_overwrite_existing_model( # pylint: disable=invalid-name
self, exceptions: List[ValidationError]
self, exceptions: list[ValidationError]
) -> None:
"""check if the object exists and shouldn't be overwritten"""
if not self.overwrite:

View File

@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from marshmallow import Schema
from marshmallow.exceptions import ValidationError
@ -56,7 +56,7 @@ class ImportAssetsCommand(BaseCommand):
and will overwrite everything.
"""
schemas: Dict[str, Schema] = {
schemas: dict[str, Schema] = {
"charts/": ImportV1ChartSchema(),
"dashboards/": ImportV1DashboardSchema(),
"datasets/": ImportV1DatasetSchema(),
@ -65,24 +65,24 @@ class ImportAssetsCommand(BaseCommand):
}
# pylint: disable=unused-argument
def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any):
self.contents = contents
self.passwords: Dict[str, str] = kwargs.get("passwords") or {}
self.ssh_tunnel_passwords: Dict[str, str] = (
self.passwords: dict[str, str] = kwargs.get("passwords") or {}
self.ssh_tunnel_passwords: dict[str, str] = (
kwargs.get("ssh_tunnel_passwords") or {}
)
self.ssh_tunnel_private_keys: Dict[str, str] = (
self.ssh_tunnel_private_keys: dict[str, str] = (
kwargs.get("ssh_tunnel_private_keys") or {}
)
self.ssh_tunnel_priv_key_passwords: Dict[str, str] = (
self.ssh_tunnel_priv_key_passwords: dict[str, str] = (
kwargs.get("ssh_tunnel_priv_key_passwords") or {}
)
self._configs: Dict[str, Any] = {}
self._configs: dict[str, Any] = {}
@staticmethod
def _import(session: Session, configs: Dict[str, Any]) -> None:
def _import(session: Session, configs: dict[str, Any]) -> None:
# import databases first
database_ids: Dict[str, int] = {}
database_ids: dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/"):
database = import_database(session, config, overwrite=True)
@ -95,7 +95,7 @@ class ImportAssetsCommand(BaseCommand):
import_saved_query(session, config, overwrite=True)
# import datasets
dataset_info: Dict[str, Dict[str, Any]] = {}
dataset_info: dict[str, dict[str, Any]] = {}
for file_name, config in configs.items():
if file_name.startswith("datasets/"):
config["database_id"] = database_ids[config["database_uuid"]]
@ -107,7 +107,7 @@ class ImportAssetsCommand(BaseCommand):
}
# import charts
chart_ids: Dict[str, int] = {}
chart_ids: dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("charts/"):
config.update(dataset_info[config["dataset_uuid"]])
@ -121,7 +121,7 @@ class ImportAssetsCommand(BaseCommand):
dashboard = import_dashboard(session, config, overwrite=True)
# set ref in the dashboard_slices table
dashboard_chart_ids: List[Dict[str, int]] = []
dashboard_chart_ids: list[dict[str, int]] = []
for uuid in find_chart_uuids(config["position"]):
if uuid not in chart_ids:
break
@ -151,11 +151,11 @@ class ImportAssetsCommand(BaseCommand):
raise ImportFailedError() from ex
def validate(self) -> None:
exceptions: List[ValidationError] = []
exceptions: list[ValidationError] = []
# verify that the metadata file is present and valid
try:
metadata: Optional[Dict[str, str]] = load_metadata(self.contents)
metadata: Optional[dict[str, str]] = load_metadata(self.contents)
except ValidationError as exc:
exceptions.append(exc)
metadata = None

View File

@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, List, Set, Tuple
from typing import Any
from marshmallow import Schema
from sqlalchemy.orm import Session
@ -52,7 +52,7 @@ class ImportExamplesCommand(ImportModelsCommand):
dao = BaseDAO
model_name = "model"
schemas: Dict[str, Schema] = {
schemas: dict[str, Schema] = {
"charts/": ImportV1ChartSchema(),
"dashboards/": ImportV1DashboardSchema(),
"datasets/": ImportV1DatasetSchema(),
@ -60,7 +60,7 @@ class ImportExamplesCommand(ImportModelsCommand):
}
import_error = CommandException
def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any):
super().__init__(contents, *args, **kwargs)
self.force_data = kwargs.get("force_data", False)
@ -81,7 +81,7 @@ class ImportExamplesCommand(ImportModelsCommand):
raise self.import_error() from ex
@classmethod
def _get_uuids(cls) -> Set[str]:
def _get_uuids(cls) -> set[str]:
# pylint: disable=protected-access
return (
ImportDatabasesCommand._get_uuids()
@ -93,12 +93,12 @@ class ImportExamplesCommand(ImportModelsCommand):
@staticmethod
def _import( # pylint: disable=arguments-differ, too-many-locals, too-many-branches
session: Session,
configs: Dict[str, Any],
configs: dict[str, Any],
overwrite: bool = False,
force_data: bool = False,
) -> None:
# import databases
database_ids: Dict[str, int] = {}
database_ids: dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/"):
database = import_database(
@ -114,7 +114,7 @@ class ImportExamplesCommand(ImportModelsCommand):
# database was created before its UUID was frozen, so it has a random UUID.
# We need to determine its ID so we can point the dataset to it.
examples_db = get_example_database()
dataset_info: Dict[str, Dict[str, Any]] = {}
dataset_info: dict[str, dict[str, Any]] = {}
for file_name, config in configs.items():
if file_name.startswith("datasets/"):
# find the ID of the corresponding database
@ -153,7 +153,7 @@ class ImportExamplesCommand(ImportModelsCommand):
}
# import charts
chart_ids: Dict[str, int] = {}
chart_ids: dict[str, int] = {}
for file_name, config in configs.items():
if (
file_name.startswith("charts/")
@ -175,7 +175,7 @@ class ImportExamplesCommand(ImportModelsCommand):
).fetchall()
# import dashboards
dashboard_chart_ids: List[Tuple[int, int]] = []
dashboard_chart_ids: list[tuple[int, int]] = []
for file_name, config in configs.items():
if file_name.startswith("dashboards/"):
try:

View File

@ -15,7 +15,7 @@
import logging
from pathlib import Path, PurePosixPath
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from zipfile import ZipFile
import yaml
@ -46,7 +46,7 @@ class MetadataSchema(Schema):
timestamp = fields.DateTime()
def load_yaml(file_name: str, content: str) -> Dict[str, Any]:
def load_yaml(file_name: str, content: str) -> dict[str, Any]:
"""Try to load a YAML file"""
try:
return yaml.safe_load(content)
@ -55,7 +55,7 @@ def load_yaml(file_name: str, content: str) -> Dict[str, Any]:
raise ValidationError({file_name: "Not a valid YAML file"}) from ex
def load_metadata(contents: Dict[str, str]) -> Dict[str, str]:
def load_metadata(contents: dict[str, str]) -> dict[str, str]:
"""Apply validation and load a metadata file"""
if METADATA_FILE_NAME not in contents:
# if the contents have no METADATA_FILE_NAME this is probably
@ -80,9 +80,9 @@ def load_metadata(contents: Dict[str, str]) -> Dict[str, str]:
def validate_metadata_type(
metadata: Optional[Dict[str, str]],
metadata: Optional[dict[str, str]],
type_: str,
exceptions: List[ValidationError],
exceptions: list[ValidationError],
) -> None:
"""Validate that the type declared in METADATA_FILE_NAME is correct"""
if metadata and "type" in metadata:
@ -96,35 +96,35 @@ def validate_metadata_type(
# pylint: disable=too-many-locals,too-many-arguments
def load_configs(
contents: Dict[str, str],
schemas: Dict[str, Schema],
passwords: Dict[str, str],
exceptions: List[ValidationError],
ssh_tunnel_passwords: Dict[str, str],
ssh_tunnel_private_keys: Dict[str, str],
ssh_tunnel_priv_key_passwords: Dict[str, str],
) -> Dict[str, Any]:
configs: Dict[str, Any] = {}
contents: dict[str, str],
schemas: dict[str, Schema],
passwords: dict[str, str],
exceptions: list[ValidationError],
ssh_tunnel_passwords: dict[str, str],
ssh_tunnel_private_keys: dict[str, str],
ssh_tunnel_priv_key_passwords: dict[str, str],
) -> dict[str, Any]:
configs: dict[str, Any] = {}
# load existing databases so we can apply the password validation
db_passwords: Dict[str, str] = {
db_passwords: dict[str, str] = {
str(uuid): password
for uuid, password in db.session.query(Database.uuid, Database.password).all()
}
# load existing ssh_tunnels so we can apply the password validation
db_ssh_tunnel_passwords: Dict[str, str] = {
db_ssh_tunnel_passwords: dict[str, str] = {
str(uuid): password
for uuid, password in db.session.query(SSHTunnel.uuid, SSHTunnel.password).all()
}
# load existing ssh_tunnels so we can apply the private_key validation
db_ssh_tunnel_private_keys: Dict[str, str] = {
db_ssh_tunnel_private_keys: dict[str, str] = {
str(uuid): private_key
for uuid, private_key in db.session.query(
SSHTunnel.uuid, SSHTunnel.private_key
).all()
}
# load existing ssh_tunnels so we can apply the private_key_password validation
db_ssh_tunnel_priv_key_passws: Dict[str, str] = {
db_ssh_tunnel_priv_key_passws: dict[str, str] = {
str(uuid): private_key_password
for uuid, private_key_password in db.session.query(
SSHTunnel.uuid, SSHTunnel.private_key_password
@ -206,7 +206,7 @@ def is_valid_config(file_name: str) -> bool:
return True
def get_contents_from_bundle(bundle: ZipFile) -> Dict[str, str]:
def get_contents_from_bundle(bundle: ZipFile) -> dict[str, str]:
return {
remove_root(file_name): bundle.read(file_name).decode()
for file_name in bundle.namelist()

View File

@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations
from typing import List, Optional, TYPE_CHECKING
from typing import TYPE_CHECKING
from flask import g
from flask_appbuilder.security.sqla.models import Role, User
@ -37,9 +37,9 @@ if TYPE_CHECKING:
def populate_owners(
owner_ids: Optional[List[int]],
owner_ids: list[int] | None,
default_to_user: bool,
) -> List[User]:
) -> list[User]:
"""
Helper function for commands, will fetch all users from owners id's
@ -63,13 +63,13 @@ def populate_owners(
return owners
def populate_roles(role_ids: Optional[List[int]] = None) -> List[Role]:
def populate_roles(role_ids: list[int] | None = None) -> list[Role]:
"""
Helper function for commands, will fetch all roles from roles id's
:raises RolesNotFoundValidationError: If a role in the input list is not found
:param role_ids: A List of roles by id's
"""
roles: List[Role] = []
roles: list[Role] = []
if role_ids:
roles = security_manager.find_roles_by_id(role_ids)
if len(roles) != len(role_ids):

View File

@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
from enum import Enum
from typing import Set
class ChartDataResultFormat(str, Enum):
@ -28,7 +27,7 @@ class ChartDataResultFormat(str, Enum):
XLSX = "xlsx"
@classmethod
def table_like(cls) -> Set["ChartDataResultFormat"]:
def table_like(cls) -> set["ChartDataResultFormat"]:
return {cls.CSV} | {cls.XLSX}

View File

@ -17,7 +17,7 @@
from __future__ import annotations
import copy
from typing import Any, Callable, Dict, Optional, TYPE_CHECKING
from typing import Any, Callable, TYPE_CHECKING
from flask_babel import _
@ -49,7 +49,7 @@ def _get_datasource(
def _get_columns(
query_context: QueryContext, query_obj: QueryObject, _: bool
) -> Dict[str, Any]:
) -> dict[str, Any]:
datasource = _get_datasource(query_context, query_obj)
return {
"data": [
@ -65,7 +65,7 @@ def _get_columns(
def _get_timegrains(
query_context: QueryContext, query_obj: QueryObject, _: bool
) -> Dict[str, Any]:
) -> dict[str, Any]:
datasource = _get_datasource(query_context, query_obj)
return {
"data": [
@ -83,7 +83,7 @@ def _get_query(
query_context: QueryContext,
query_obj: QueryObject,
_: bool,
) -> Dict[str, Any]:
) -> dict[str, Any]:
datasource = _get_datasource(query_context, query_obj)
result = {"language": datasource.query_language}
try:
@ -96,8 +96,8 @@ def _get_query(
def _get_full(
query_context: QueryContext,
query_obj: QueryObject,
force_cached: Optional[bool] = False,
) -> Dict[str, Any]:
force_cached: bool | None = False,
) -> dict[str, Any]:
datasource = _get_datasource(query_context, query_obj)
result_type = query_obj.result_type or query_context.result_type
payload = query_context.get_df_payload(query_obj, force_cached=force_cached)
@ -141,7 +141,7 @@ def _get_full(
def _get_samples(
query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False
) -> Dict[str, Any]:
) -> dict[str, Any]:
datasource = _get_datasource(query_context, query_obj)
query_obj = copy.copy(query_obj)
query_obj.is_timeseries = False
@ -162,7 +162,7 @@ def _get_samples(
def _get_drill_detail(
query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False
) -> Dict[str, Any]:
) -> dict[str, Any]:
# todo(yongjie): Remove this function,
# when determining whether samples should be applied to the time filter.
datasource = _get_datasource(query_context, query_obj)
@ -183,13 +183,13 @@ def _get_drill_detail(
def _get_results(
query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False
) -> Dict[str, Any]:
) -> dict[str, Any]:
payload = _get_full(query_context, query_obj, force_cached)
return payload
_result_type_functions: Dict[
ChartDataResultType, Callable[[QueryContext, QueryObject, bool], Dict[str, Any]]
_result_type_functions: dict[
ChartDataResultType, Callable[[QueryContext, QueryObject, bool], dict[str, Any]]
] = {
ChartDataResultType.COLUMNS: _get_columns,
ChartDataResultType.TIMEGRAINS: _get_timegrains,
@ -210,7 +210,7 @@ def get_query_results(
query_context: QueryContext,
query_obj: QueryObject,
force_cached: bool,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Return result payload for a chart data request.

View File

@ -17,7 +17,7 @@
from __future__ import annotations
import logging
from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Union
from typing import Any, ClassVar, TYPE_CHECKING
import pandas as pd
@ -47,15 +47,15 @@ class QueryContext:
enforce_numerical_metrics: ClassVar[bool] = True
datasource: BaseDatasource
slice_: Optional[Slice] = None
queries: List[QueryObject]
form_data: Optional[Dict[str, Any]]
slice_: Slice | None = None
queries: list[QueryObject]
form_data: dict[str, Any] | None
result_type: ChartDataResultType
result_format: ChartDataResultFormat
force: bool
custom_cache_timeout: Optional[int]
custom_cache_timeout: int | None
cache_values: Dict[str, Any]
cache_values: dict[str, Any]
_processor: QueryContextProcessor
@ -65,14 +65,14 @@ class QueryContext:
self,
*,
datasource: BaseDatasource,
queries: List[QueryObject],
slice_: Optional[Slice],
form_data: Optional[Dict[str, Any]],
queries: list[QueryObject],
slice_: Slice | None,
form_data: dict[str, Any] | None,
result_type: ChartDataResultType,
result_format: ChartDataResultFormat,
force: bool = False,
custom_cache_timeout: Optional[int] = None,
cache_values: Dict[str, Any],
custom_cache_timeout: int | None = None,
cache_values: dict[str, Any],
) -> None:
self.datasource = datasource
self.slice_ = slice_
@ -88,18 +88,18 @@ class QueryContext:
def get_data(
self,
df: pd.DataFrame,
) -> Union[str, List[Dict[str, Any]]]:
) -> str | list[dict[str, Any]]:
return self._processor.get_data(df)
def get_payload(
self,
cache_query_context: Optional[bool] = False,
cache_query_context: bool | None = False,
force_cached: bool = False,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Returns the query results with both metadata and data"""
return self._processor.get_payload(cache_query_context, force_cached)
def get_cache_timeout(self) -> Optional[int]:
def get_cache_timeout(self) -> int | None:
if self.custom_cache_timeout is not None:
return self.custom_cache_timeout
if self.slice_ and self.slice_.cache_timeout is not None:
@ -110,14 +110,14 @@ class QueryContext:
return self.datasource.database.cache_timeout
return None
def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> Optional[str]:
def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> str | None:
return self._processor.query_cache_key(query_obj, **kwargs)
def get_df_payload(
self,
query_obj: QueryObject,
force_cached: Optional[bool] = False,
) -> Dict[str, Any]:
force_cached: bool | None = False,
) -> dict[str, Any]:
return self._processor.get_df_payload(
query_obj=query_obj,
force_cached=force_cached,

View File

@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations
from typing import Any, Dict, List, Optional, TYPE_CHECKING
from typing import Any, TYPE_CHECKING
from superset import app, db
from superset.charts.dao import ChartDAO
@ -48,12 +48,12 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
self,
*,
datasource: DatasourceDict,
queries: List[Dict[str, Any]],
form_data: Optional[Dict[str, Any]] = None,
result_type: Optional[ChartDataResultType] = None,
result_format: Optional[ChartDataResultFormat] = None,
queries: list[dict[str, Any]],
form_data: dict[str, Any] | None = None,
result_type: ChartDataResultType | None = None,
result_format: ChartDataResultFormat | None = None,
force: bool = False,
custom_cache_timeout: Optional[int] = None,
custom_cache_timeout: int | None = None,
) -> QueryContext:
datasource_model_instance = None
if datasource:
@ -101,13 +101,13 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
datasource_id=int(datasource["id"]),
)
def _get_slice(self, slice_id: Any) -> Optional[Slice]:
def _get_slice(self, slice_id: Any) -> Slice | None:
return ChartDAO.find_by_id(slice_id)
def _process_query_object(
self,
datasource: BaseDatasource,
form_data: Optional[Dict[str, Any]],
form_data: dict[str, Any] | None,
query_object: QueryObject,
) -> QueryObject:
self._apply_granularity(query_object, form_data, datasource)
@ -117,7 +117,7 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
def _apply_granularity(
self,
query_object: QueryObject,
form_data: Optional[Dict[str, Any]],
form_data: dict[str, Any] | None,
datasource: BaseDatasource,
) -> None:
temporal_columns = {

View File

@ -19,7 +19,7 @@ from __future__ import annotations
import copy
import logging
import re
from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Union
from typing import Any, ClassVar, TYPE_CHECKING
import numpy as np
import pandas as pd
@ -77,8 +77,8 @@ logger = logging.getLogger(__name__)
class CachedTimeOffset(TypedDict):
df: pd.DataFrame
queries: List[str]
cache_keys: List[Optional[str]]
queries: list[str]
cache_keys: list[str | None]
class QueryContextProcessor:
@ -102,8 +102,8 @@ class QueryContextProcessor:
enforce_numerical_metrics: ClassVar[bool] = True
def get_df_payload(
self, query_obj: QueryObject, force_cached: Optional[bool] = False
) -> Dict[str, Any]:
self, query_obj: QueryObject, force_cached: bool | None = False
) -> dict[str, Any]:
"""Handles caching around the df payload retrieval"""
cache_key = self.query_cache_key(query_obj)
timeout = self.get_cache_timeout()
@ -181,7 +181,7 @@ class QueryContextProcessor:
"label_map": label_map,
}
def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> Optional[str]:
def query_cache_key(self, query_obj: QueryObject, **kwargs: Any) -> str | None:
"""
Returns a QueryObject cache key for objects in self.queries
"""
@ -248,8 +248,8 @@ class QueryContextProcessor:
def normalize_df(self, df: pd.DataFrame, query_object: QueryObject) -> pd.DataFrame:
# todo: should support "python_date_format" and "get_column" in each datasource
def _get_timestamp_format(
source: BaseDatasource, column: Optional[str]
) -> Optional[str]:
source: BaseDatasource, column: str | None
) -> str | None:
column_obj = source.get_column(column)
if (
column_obj
@ -315,9 +315,9 @@ class QueryContextProcessor:
query_context = self._query_context
# ensure query_object is immutable
query_object_clone = copy.copy(query_object)
queries: List[str] = []
cache_keys: List[Optional[str]] = []
rv_dfs: List[pd.DataFrame] = [df]
queries: list[str] = []
cache_keys: list[str | None] = []
rv_dfs: list[pd.DataFrame] = [df]
time_offsets = query_object.time_offsets
outer_from_dttm, outer_to_dttm = get_since_until_from_query_object(query_object)
@ -449,7 +449,7 @@ class QueryContextProcessor:
rv_df = pd.concat(rv_dfs, axis=1, copy=False) if time_offsets else df
return CachedTimeOffset(df=rv_df, queries=queries, cache_keys=cache_keys)
def get_data(self, df: pd.DataFrame) -> Union[str, List[Dict[str, Any]]]:
def get_data(self, df: pd.DataFrame) -> str | list[dict[str, Any]]:
if self._query_context.result_format in ChartDataResultFormat.table_like():
include_index = not isinstance(df.index, pd.RangeIndex)
columns = list(df.columns)
@ -470,9 +470,9 @@ class QueryContextProcessor:
def get_payload(
self,
cache_query_context: Optional[bool] = False,
cache_query_context: bool | None = False,
force_cached: bool = False,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""Returns the query results with both metadata and data"""
# Get all the payloads from the QueryObjects
@ -522,13 +522,13 @@ class QueryContextProcessor:
return generate_cache_key(cache_dict, key_prefix)
def get_annotation_data(self, query_obj: QueryObject) -> Dict[str, Any]:
def get_annotation_data(self, query_obj: QueryObject) -> dict[str, Any]:
"""
:param query_context:
:param query_obj:
:return:
"""
annotation_data: Dict[str, Any] = self.get_native_annotation_data(query_obj)
annotation_data: dict[str, Any] = self.get_native_annotation_data(query_obj)
for annotation_layer in [
layer
for layer in query_obj.annotation_layers
@ -541,7 +541,7 @@ class QueryContextProcessor:
return annotation_data
@staticmethod
def get_native_annotation_data(query_obj: QueryObject) -> Dict[str, Any]:
def get_native_annotation_data(query_obj: QueryObject) -> dict[str, Any]:
annotation_data = {}
annotation_layers = [
layer
@ -576,8 +576,8 @@ class QueryContextProcessor:
@staticmethod
def get_viz_annotation_data(
annotation_layer: Dict[str, Any], force: bool
) -> Dict[str, Any]:
annotation_layer: dict[str, Any], force: bool
) -> dict[str, Any]:
chart = ChartDAO.find_by_id(annotation_layer["value"])
if not chart:
raise QueryObjectValidationError(_("The chart does not exist"))

View File

@ -21,7 +21,7 @@ import json
import logging
from datetime import datetime
from pprint import pformat
from typing import Any, Dict, List, NamedTuple, Optional, TYPE_CHECKING
from typing import Any, NamedTuple, TYPE_CHECKING
from flask import g
from flask_babel import gettext as _
@ -81,58 +81,58 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
and druid. The query objects are constructed on the client.
"""
annotation_layers: List[Dict[str, Any]]
applied_time_extras: Dict[str, str]
annotation_layers: list[dict[str, Any]]
applied_time_extras: dict[str, str]
apply_fetch_values_predicate: bool
columns: List[Column]
datasource: Optional[BaseDatasource]
extras: Dict[str, Any]
filter: List[QueryObjectFilterClause]
from_dttm: Optional[datetime]
granularity: Optional[str]
inner_from_dttm: Optional[datetime]
inner_to_dttm: Optional[datetime]
columns: list[Column]
datasource: BaseDatasource | None
extras: dict[str, Any]
filter: list[QueryObjectFilterClause]
from_dttm: datetime | None
granularity: str | None
inner_from_dttm: datetime | None
inner_to_dttm: datetime | None
is_rowcount: bool
is_timeseries: bool
metrics: Optional[List[Metric]]
metrics: list[Metric] | None
order_desc: bool
orderby: List[OrderBy]
post_processing: List[Dict[str, Any]]
result_type: Optional[ChartDataResultType]
row_limit: Optional[int]
orderby: list[OrderBy]
post_processing: list[dict[str, Any]]
result_type: ChartDataResultType | None
row_limit: int | None
row_offset: int
series_columns: List[Column]
series_columns: list[Column]
series_limit: int
series_limit_metric: Optional[Metric]
time_offsets: List[str]
time_shift: Optional[str]
time_range: Optional[str]
to_dttm: Optional[datetime]
series_limit_metric: Metric | None
time_offsets: list[str]
time_shift: str | None
time_range: str | None
to_dttm: datetime | None
def __init__( # pylint: disable=too-many-locals
self,
*,
annotation_layers: Optional[List[Dict[str, Any]]] = None,
applied_time_extras: Optional[Dict[str, str]] = None,
annotation_layers: list[dict[str, Any]] | None = None,
applied_time_extras: dict[str, str] | None = None,
apply_fetch_values_predicate: bool = False,
columns: Optional[List[Column]] = None,
datasource: Optional[BaseDatasource] = None,
extras: Optional[Dict[str, Any]] = None,
filters: Optional[List[QueryObjectFilterClause]] = None,
granularity: Optional[str] = None,
columns: list[Column] | None = None,
datasource: BaseDatasource | None = None,
extras: dict[str, Any] | None = None,
filters: list[QueryObjectFilterClause] | None = None,
granularity: str | None = None,
is_rowcount: bool = False,
is_timeseries: Optional[bool] = None,
metrics: Optional[List[Metric]] = None,
is_timeseries: bool | None = None,
metrics: list[Metric] | None = None,
order_desc: bool = True,
orderby: Optional[List[OrderBy]] = None,
post_processing: Optional[List[Optional[Dict[str, Any]]]] = None,
row_limit: Optional[int],
row_offset: Optional[int] = None,
series_columns: Optional[List[Column]] = None,
orderby: list[OrderBy] | None = None,
post_processing: list[dict[str, Any] | None] | None = None,
row_limit: int | None,
row_offset: int | None = None,
series_columns: list[Column] | None = None,
series_limit: int = 0,
series_limit_metric: Optional[Metric] = None,
time_range: Optional[str] = None,
time_shift: Optional[str] = None,
series_limit_metric: Metric | None = None,
time_range: str | None = None,
time_shift: str | None = None,
**kwargs: Any,
):
self._set_annotation_layers(annotation_layers)
@ -166,7 +166,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
self._move_deprecated_extra_fields(kwargs)
def _set_annotation_layers(
self, annotation_layers: Optional[List[Dict[str, Any]]]
self, annotation_layers: list[dict[str, Any]] | None
) -> None:
self.annotation_layers = [
layer
@ -175,14 +175,14 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
if layer["annotationType"] != "FORMULA"
]
def _set_is_timeseries(self, is_timeseries: Optional[bool]) -> None:
def _set_is_timeseries(self, is_timeseries: bool | None) -> None:
# is_timeseries is True if time column is in either columns or groupby
# (both are dimensions)
self.is_timeseries = (
is_timeseries if is_timeseries is not None else DTTM_ALIAS in self.columns
)
def _set_metrics(self, metrics: Optional[List[Metric]] = None) -> None:
def _set_metrics(self, metrics: list[Metric] | None = None) -> None:
# Support metric reference/definition in the format of
# 1. 'metric_name' - name of predefined metric
# 2. { label: 'label_name' } - legacy format for a predefined metric
@ -195,16 +195,16 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
]
def _set_post_processing(
self, post_processing: Optional[List[Optional[Dict[str, Any]]]]
self, post_processing: list[dict[str, Any] | None] | None
) -> None:
post_processing = post_processing or []
self.post_processing = [post_proc for post_proc in post_processing if post_proc]
def _init_series_columns(
self,
series_columns: Optional[List[Column]],
metrics: Optional[List[Metric]],
is_timeseries: Optional[bool],
series_columns: list[Column] | None,
metrics: list[Metric] | None,
is_timeseries: bool | None,
) -> None:
if series_columns:
self.series_columns = series_columns
@ -213,7 +213,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
else:
self.series_columns = []
def _rename_deprecated_fields(self, kwargs: Dict[str, Any]) -> None:
def _rename_deprecated_fields(self, kwargs: dict[str, Any]) -> None:
# rename deprecated fields
for field in DEPRECATED_FIELDS:
if field.old_name in kwargs:
@ -233,7 +233,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
)
setattr(self, field.new_name, value)
def _move_deprecated_extra_fields(self, kwargs: Dict[str, Any]) -> None:
def _move_deprecated_extra_fields(self, kwargs: dict[str, Any]) -> None:
# move deprecated extras fields to extras
for field in DEPRECATED_EXTRAS_FIELDS:
if field.old_name in kwargs:
@ -256,19 +256,19 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
self.extras[field.new_name] = value
@property
def metric_names(self) -> List[str]:
def metric_names(self) -> list[str]:
"""Return metrics names (labels), coerce adhoc metrics to strings."""
return get_metric_names(self.metrics or [])
@property
def column_names(self) -> List[str]:
def column_names(self) -> list[str]:
"""Return column names (labels). Gives priority to groupbys if both groupbys
and metrics are non-empty, otherwise returns column labels."""
return get_column_names(self.columns)
def validate(
self, raise_exceptions: Optional[bool] = True
) -> Optional[QueryObjectValidationError]:
self, raise_exceptions: bool | None = True
) -> QueryObjectValidationError | None:
"""Validate query object"""
try:
self._validate_there_are_no_missing_series()
@ -314,7 +314,7 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
)
)
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
query_object_dict = {
"apply_fetch_values_predicate": self.apply_fetch_values_predicate,
"columns": self.columns,

View File

@ -16,7 +16,7 @@
# under the License.
from __future__ import annotations
from typing import Any, Dict, Optional, TYPE_CHECKING
from typing import Any, TYPE_CHECKING
from superset.common.chart_data import ChartDataResultType
from superset.common.query_object import QueryObject
@ -31,13 +31,13 @@ if TYPE_CHECKING:
class QueryObjectFactory: # pylint: disable=too-few-public-methods
_config: Dict[str, Any]
_config: dict[str, Any]
_datasource_dao: DatasourceDAO
_session_maker: sessionmaker
def __init__(
self,
app_configurations: Dict[str, Any],
app_configurations: dict[str, Any],
_datasource_dao: DatasourceDAO,
session_maker: sessionmaker,
):
@ -48,11 +48,11 @@ class QueryObjectFactory: # pylint: disable=too-few-public-methods
def create( # pylint: disable=too-many-arguments
self,
parent_result_type: ChartDataResultType,
datasource: Optional[DatasourceDict] = None,
extras: Optional[Dict[str, Any]] = None,
row_limit: Optional[int] = None,
time_range: Optional[str] = None,
time_shift: Optional[str] = None,
datasource: DatasourceDict | None = None,
extras: dict[str, Any] | None = None,
row_limit: int | None = None,
time_range: str | None = None,
time_shift: str | None = None,
**kwargs: Any,
) -> QueryObject:
datasource_model_instance = None
@ -84,13 +84,13 @@ class QueryObjectFactory: # pylint: disable=too-few-public-methods
def _process_extras( # pylint: disable=no-self-use
self,
extras: Optional[Dict[str, Any]],
) -> Dict[str, Any]:
extras: dict[str, Any] | None,
) -> dict[str, Any]:
extras = extras or {}
return extras
def _process_row_limit(
self, row_limit: Optional[int], result_type: ChartDataResultType
self, row_limit: int | None, result_type: ChartDataResultType
) -> int:
default_row_limit = (
self._config["SAMPLES_ROW_LIMIT"]

View File

@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, List
from typing import Any
from sqlalchemy import MetaData
from sqlalchemy.exc import IntegrityError
@ -25,7 +25,7 @@ from superset.tags.models import ObjectTypes, TagTypes
def add_types_to_charts(
metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
) -> None:
slices = metadata.tables["slices"]
@ -57,7 +57,7 @@ def add_types_to_charts(
def add_types_to_dashboards(
metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
) -> None:
dashboard_table = metadata.tables["dashboards"]
@ -89,7 +89,7 @@ def add_types_to_dashboards(
def add_types_to_saved_queries(
metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
) -> None:
saved_query = metadata.tables["saved_query"]
@ -121,7 +121,7 @@ def add_types_to_saved_queries(
def add_types_to_datasets(
metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
) -> None:
tables = metadata.tables["tables"]
@ -237,7 +237,7 @@ def add_types(metadata: MetaData) -> None:
def add_owners_to_charts(
metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
) -> None:
slices = metadata.tables["slices"]
@ -273,7 +273,7 @@ def add_owners_to_charts(
def add_owners_to_dashboards(
metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
) -> None:
dashboard_table = metadata.tables["dashboards"]
@ -309,7 +309,7 @@ def add_owners_to_dashboards(
def add_owners_to_saved_queries(
metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
) -> None:
saved_query = metadata.tables["saved_query"]
@ -345,7 +345,7 @@ def add_owners_to_saved_queries(
def add_owners_to_datasets(
metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str]
metadata: MetaData, tag: Any, tagged_object: Any, columns: list[str]
) -> None:
tables = metadata.tables["tables"]

View File

@ -17,7 +17,7 @@
from __future__ import annotations
import datetime
from typing import Any, List, TYPE_CHECKING
from typing import Any, TYPE_CHECKING
import numpy as np
import pandas as pd
@ -29,7 +29,7 @@ if TYPE_CHECKING:
def left_join_df(
left_df: pd.DataFrame,
right_df: pd.DataFrame,
join_keys: List[str],
join_keys: list[str],
) -> pd.DataFrame:
df = left_df.set_index(join_keys).join(right_df.set_index(join_keys))
df.reset_index(inplace=True)

View File

@ -17,7 +17,7 @@
from __future__ import annotations
import logging
from typing import Any, Dict, List, Optional
from typing import Any
from flask_caching import Cache
from pandas import DataFrame
@ -37,7 +37,7 @@ config = app.config
stats_logger: BaseStatsLogger = config["STATS_LOGGER"]
logger = logging.getLogger(__name__)
_cache: Dict[CacheRegion, Cache] = {
_cache: dict[CacheRegion, Cache] = {
CacheRegion.DEFAULT: cache_manager.cache,
CacheRegion.DATA: cache_manager.data_cache,
}
@ -53,17 +53,17 @@ class QueryCacheManager:
self,
df: DataFrame = DataFrame(),
query: str = "",
annotation_data: Optional[Dict[str, Any]] = None,
applied_template_filters: Optional[List[str]] = None,
applied_filter_columns: Optional[List[Column]] = None,
rejected_filter_columns: Optional[List[Column]] = None,
status: Optional[str] = None,
error_message: Optional[str] = None,
annotation_data: dict[str, Any] | None = None,
applied_template_filters: list[str] | None = None,
applied_filter_columns: list[Column] | None = None,
rejected_filter_columns: list[Column] | None = None,
status: str | None = None,
error_message: str | None = None,
is_loaded: bool = False,
stacktrace: Optional[str] = None,
is_cached: Optional[bool] = None,
cache_dttm: Optional[str] = None,
cache_value: Optional[Dict[str, Any]] = None,
stacktrace: str | None = None,
is_cached: bool | None = None,
cache_dttm: str | None = None,
cache_value: dict[str, Any] | None = None,
) -> None:
self.df = df
self.query = query
@ -85,10 +85,10 @@ class QueryCacheManager:
self,
key: str,
query_result: QueryResult,
annotation_data: Optional[Dict[str, Any]] = None,
force_query: Optional[bool] = False,
timeout: Optional[int] = None,
datasource_uid: Optional[str] = None,
annotation_data: dict[str, Any] | None = None,
force_query: bool | None = False,
timeout: int | None = None,
datasource_uid: str | None = None,
region: CacheRegion = CacheRegion.DEFAULT,
) -> None:
"""
@ -136,11 +136,11 @@ class QueryCacheManager:
@classmethod
def get(
cls,
key: Optional[str],
key: str | None,
region: CacheRegion = CacheRegion.DEFAULT,
force_query: Optional[bool] = False,
force_cached: Optional[bool] = False,
) -> "QueryCacheManager":
force_query: bool | None = False,
force_cached: bool | None = False,
) -> QueryCacheManager:
"""
Initialize QueryCacheManager by query-cache key
"""
@ -190,10 +190,10 @@ class QueryCacheManager:
@staticmethod
def set(
key: Optional[str],
value: Dict[str, Any],
timeout: Optional[int] = None,
datasource_uid: Optional[str] = None,
key: str | None,
value: dict[str, Any],
timeout: int | None = None,
datasource_uid: str | None = None,
region: CacheRegion = CacheRegion.DEFAULT,
) -> None:
"""
@ -204,7 +204,7 @@ class QueryCacheManager:
@staticmethod
def delete(
key: Optional[str],
key: str | None,
region: CacheRegion = CacheRegion.DEFAULT,
) -> None:
if key:
@ -212,7 +212,7 @@ class QueryCacheManager:
@staticmethod
def has(
key: Optional[str],
key: str | None,
region: CacheRegion = CacheRegion.DEFAULT,
) -> bool:
return bool(_cache[region].get(key)) if key else False

View File

@ -17,7 +17,7 @@
from __future__ import annotations
from datetime import datetime
from typing import Any, cast, Dict, Optional, Tuple
from typing import Any, cast
from superset import app
from superset.common.query_object import QueryObject
@ -26,10 +26,10 @@ from superset.utils.date_parser import get_since_until
def get_since_until_from_time_range(
time_range: Optional[str] = None,
time_shift: Optional[str] = None,
extras: Optional[Dict[str, Any]] = None,
) -> Tuple[Optional[datetime], Optional[datetime]]:
time_range: str | None = None,
time_shift: str | None = None,
extras: dict[str, Any] | None = None,
) -> tuple[datetime | None, datetime | None]:
return get_since_until(
relative_start=(extras or {}).get(
"relative_start", app.config["DEFAULT_RELATIVE_START_TIME"]
@ -45,7 +45,7 @@ def get_since_until_from_time_range(
# pylint: disable=invalid-name
def get_since_until_from_query_object(
query_object: QueryObject,
) -> Tuple[Optional[datetime], Optional[datetime]]:
) -> tuple[datetime | None, datetime | None]:
"""
this function will return since and until by tuple if
1) the time_range is in the query object.

View File

@ -33,20 +33,7 @@ import sys
from collections import OrderedDict
from datetime import timedelta
from email.mime.multipart import MIMEMultipart
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Set,
Tuple,
Type,
TYPE_CHECKING,
TypedDict,
Union,
)
from typing import Any, Callable, Literal, TYPE_CHECKING, TypedDict
import pkg_resources
from cachelib.base import BaseCache
@ -114,17 +101,17 @@ PACKAGE_JSON_FILE = pkg_resources.resource_filename(
FAVICONS = [{"href": "/static/assets/images/favicon.png"}]
def _try_json_readversion(filepath: str) -> Optional[str]:
def _try_json_readversion(filepath: str) -> str | None:
try:
with open(filepath, "r") as f:
with open(filepath) as f:
return json.load(f).get("version")
except Exception: # pylint: disable=broad-except
return None
def _try_json_readsha(filepath: str, length: int) -> Optional[str]:
def _try_json_readsha(filepath: str, length: int) -> str | None:
try:
with open(filepath, "r") as f:
with open(filepath) as f:
return json.load(f).get("GIT_SHA")[:length]
except Exception: # pylint: disable=broad-except
return None
@ -275,7 +262,7 @@ ENABLE_PROXY_FIX = False
PROXY_FIX_CONFIG = {"x_for": 1, "x_proto": 1, "x_host": 1, "x_port": 1, "x_prefix": 1}
# Configuration for scheduling queries from SQL Lab.
SCHEDULED_QUERIES: Dict[str, Any] = {}
SCHEDULED_QUERIES: dict[str, Any] = {}
# ------------------------------
# GLOBALS FOR APP Builder
@ -294,7 +281,7 @@ LOGO_TARGET_PATH = None
LOGO_TOOLTIP = ""
# Specify any text that should appear to the right of the logo
LOGO_RIGHT_TEXT: Union[Callable[[], str], str] = ""
LOGO_RIGHT_TEXT: Callable[[], str] | str = ""
# Enables SWAGGER UI for superset openapi spec
# ex: http://localhost:8080/swagger/v1
@ -347,7 +334,7 @@ AUTH_TYPE = AUTH_DB
# Grant public role the same set of permissions as for a selected builtin role.
# This is useful if one wants to enable anonymous users to view
# dashboards. Explicit grant on specific datasets is still required.
PUBLIC_ROLE_LIKE: Optional[str] = None
PUBLIC_ROLE_LIKE: str | None = None
# ---------------------------------------------------
# Babel config for translations
@ -390,8 +377,8 @@ LANGUAGES = {}
class D3Format(TypedDict, total=False):
decimal: str
thousands: str
grouping: List[int]
currency: List[str]
grouping: list[int]
currency: list[str]
D3_FORMAT: D3Format = {}
@ -404,7 +391,7 @@ D3_FORMAT: D3Format = {}
# For example, DEFAULT_FEATURE_FLAGS = { 'FOO': True, 'BAR': False } here
# and FEATURE_FLAGS = { 'BAR': True, 'BAZ': True } in superset_config.py
# will result in combined feature flags of { 'FOO': True, 'BAR': True, 'BAZ': True }
DEFAULT_FEATURE_FLAGS: Dict[str, bool] = {
DEFAULT_FEATURE_FLAGS: dict[str, bool] = {
# Experimental feature introducing a client (browser) cache
"CLIENT_CACHE": False, # deprecated
"DISABLE_DATASET_SOURCE_EDIT": False, # deprecated
@ -527,7 +514,7 @@ DEFAULT_FEATURE_FLAGS.update(
)
# This is merely a default.
FEATURE_FLAGS: Dict[str, bool] = {}
FEATURE_FLAGS: dict[str, bool] = {}
# A function that receives a dict of all feature flags
# (DEFAULT_FEATURE_FLAGS merged with FEATURE_FLAGS)
@ -543,7 +530,7 @@ FEATURE_FLAGS: Dict[str, bool] = {}
# if hasattr(g, "user") and g.user.is_active:
# feature_flags_dict['some_feature'] = g.user and g.user.get_id() == 5
# return feature_flags_dict
GET_FEATURE_FLAGS_FUNC: Optional[Callable[[Dict[str, bool]], Dict[str, bool]]] = None
GET_FEATURE_FLAGS_FUNC: Callable[[dict[str, bool]], dict[str, bool]] | None = None
# A function that receives a feature flag name and an optional default value.
# Has a similar utility to GET_FEATURE_FLAGS_FUNC but it's useful to not force the
# evaluation of all feature flags when just evaluating a single one.
@ -551,7 +538,7 @@ GET_FEATURE_FLAGS_FUNC: Optional[Callable[[Dict[str, bool]], Dict[str, bool]]] =
# Note that the default `get_feature_flags` will evaluate each feature with this
# callable when the config key is set, so don't use both GET_FEATURE_FLAGS_FUNC
# and IS_FEATURE_ENABLED_FUNC in conjunction.
IS_FEATURE_ENABLED_FUNC: Optional[Callable[[str, Optional[bool]], bool]] = None
IS_FEATURE_ENABLED_FUNC: Callable[[str, bool | None], bool] | None = None
# A function that expands/overrides the frontend `bootstrap_data.common` object.
# Can be used to implement custom frontend functionality,
# or dynamically change certain configs.
@ -563,7 +550,7 @@ IS_FEATURE_ENABLED_FUNC: Optional[Callable[[str, Optional[bool]], bool]] = None
# Takes as a parameter the common bootstrap payload before transformations.
# Returns a dict containing data that should be added or overridden to the payload.
COMMON_BOOTSTRAP_OVERRIDES_FUNC: Callable[
[Dict[str, Any]], Dict[str, Any]
[dict[str, Any]], dict[str, Any]
] = lambda data: {} # default: empty dict
# EXTRA_CATEGORICAL_COLOR_SCHEMES is used for adding custom categorical color schemes
@ -580,7 +567,7 @@ COMMON_BOOTSTRAP_OVERRIDES_FUNC: Callable[
# }]
# This is merely a default
EXTRA_CATEGORICAL_COLOR_SCHEMES: List[Dict[str, Any]] = []
EXTRA_CATEGORICAL_COLOR_SCHEMES: list[dict[str, Any]] = []
# THEME_OVERRIDES is used for adding custom theme to superset
# example code for "My theme" custom scheme
@ -599,7 +586,7 @@ EXTRA_CATEGORICAL_COLOR_SCHEMES: List[Dict[str, Any]] = []
# }
# }
THEME_OVERRIDES: Dict[str, Any] = {}
THEME_OVERRIDES: dict[str, Any] = {}
# EXTRA_SEQUENTIAL_COLOR_SCHEMES is used for adding custom sequential color schemes
# EXTRA_SEQUENTIAL_COLOR_SCHEMES = [
@ -615,7 +602,7 @@ THEME_OVERRIDES: Dict[str, Any] = {}
# }]
# This is merely a default
EXTRA_SEQUENTIAL_COLOR_SCHEMES: List[Dict[str, Any]] = []
EXTRA_SEQUENTIAL_COLOR_SCHEMES: list[dict[str, Any]] = []
# ---------------------------------------------------
# Thumbnail config (behind feature flag)
@ -626,7 +613,7 @@ EXTRA_SEQUENTIAL_COLOR_SCHEMES: List[Dict[str, Any]] = []
# `superset.tasks.types.ExecutorType` for a full list of executor options.
# To always use a fixed user account, use the following configuration:
# THUMBNAIL_EXECUTE_AS = [ExecutorType.SELENIUM]
THUMBNAIL_SELENIUM_USER: Optional[str] = "admin"
THUMBNAIL_SELENIUM_USER: str | None = "admin"
THUMBNAIL_EXECUTE_AS = [ExecutorType.CURRENT_USER, ExecutorType.SELENIUM]
# By default, thumbnail digests are calculated based on various parameters in the
@ -639,10 +626,10 @@ THUMBNAIL_EXECUTE_AS = [ExecutorType.CURRENT_USER, ExecutorType.SELENIUM]
# `THUMBNAIL_EXECUTE_AS`; the executor is only equal to the currently logged in
# user if the executor type is equal to `ExecutorType.CURRENT_USER`)
# and return the final digest string:
THUMBNAIL_DASHBOARD_DIGEST_FUNC: Optional[
THUMBNAIL_DASHBOARD_DIGEST_FUNC: None | (
Callable[[Dashboard, ExecutorType, str], str]
] = None
THUMBNAIL_CHART_DIGEST_FUNC: Optional[Callable[[Slice, ExecutorType, str], str]] = None
) = None
THUMBNAIL_CHART_DIGEST_FUNC: Callable[[Slice, ExecutorType, str], str] | None = None
THUMBNAIL_CACHE_CONFIG: CacheConfig = {
"CACHE_TYPE": "NullCache",
@ -714,7 +701,7 @@ STORE_CACHE_KEYS_IN_METADATA_DB = False
# CORS Options
ENABLE_CORS = False
CORS_OPTIONS: Dict[Any, Any] = {}
CORS_OPTIONS: dict[Any, Any] = {}
# Sanitizes the HTML content used in markdowns to allow its rendering in a safe manner.
# Disabling this option is not recommended for security reasons. If you wish to allow
@ -736,7 +723,7 @@ HTML_SANITIZATION = True
# }
# }
# Be careful when extending the default schema to avoid XSS attacks.
HTML_SANITIZATION_SCHEMA_EXTENSIONS: Dict[str, Any] = {}
HTML_SANITIZATION_SCHEMA_EXTENSIONS: dict[str, Any] = {}
# Chrome allows up to 6 open connections per domain at a time. When there are more
# than 6 slices in dashboard, a lot of time fetch requests are queued up and wait for
@ -768,13 +755,13 @@ EXCEL_EXPORT = {"encoding": "utf-8"}
# time grains in superset/db_engine_specs/base.py).
# For example: to disable 1 second time grain:
# TIME_GRAIN_DENYLIST = ['PT1S']
TIME_GRAIN_DENYLIST: List[str] = []
TIME_GRAIN_DENYLIST: list[str] = []
# Additional time grains to be supported using similar definitions as in
# superset/db_engine_specs/base.py.
# For example: To add a new 2 second time grain:
# TIME_GRAIN_ADDONS = {'PT2S': '2 second'}
TIME_GRAIN_ADDONS: Dict[str, str] = {}
TIME_GRAIN_ADDONS: dict[str, str] = {}
# Implementation of additional time grains per engine.
# The column to be truncated is denoted `{col}` in the expression.
@ -784,7 +771,7 @@ TIME_GRAIN_ADDONS: Dict[str, str] = {}
# 'PT2S': 'toDateTime(intDiv(toUInt32(toDateTime({col})), 2)*2)'
# }
# }
TIME_GRAIN_ADDON_EXPRESSIONS: Dict[str, Dict[str, str]] = {}
TIME_GRAIN_ADDON_EXPRESSIONS: dict[str, dict[str, str]] = {}
# ---------------------------------------------------
# List of viz_types not allowed in your environment
@ -792,7 +779,7 @@ TIME_GRAIN_ADDON_EXPRESSIONS: Dict[str, Dict[str, str]] = {}
# VIZ_TYPE_DENYLIST = ['pivot_table', 'treemap']
# ---------------------------------------------------
VIZ_TYPE_DENYLIST: List[str] = []
VIZ_TYPE_DENYLIST: list[str] = []
# --------------------------------------------------
# Modules, datasources and middleware to be registered
@ -802,8 +789,8 @@ DEFAULT_MODULE_DS_MAP = OrderedDict(
("superset.connectors.sqla.models", ["SqlaTable"]),
]
)
ADDITIONAL_MODULE_DS_MAP: Dict[str, List[str]] = {}
ADDITIONAL_MIDDLEWARE: List[Callable[..., Any]] = []
ADDITIONAL_MODULE_DS_MAP: dict[str, list[str]] = {}
ADDITIONAL_MIDDLEWARE: list[Callable[..., Any]] = []
# 1) https://docs.python-guide.org/writing/logging/
# 2) https://docs.python.org/2/library/logging.config.html
@ -925,9 +912,9 @@ CELERY_CONFIG = CeleryConfig # pylint: disable=invalid-name
# within the app
# OVERRIDE_HTTP_HEADERS: sets override values for HTTP headers. These values will
# override anything set within the app
DEFAULT_HTTP_HEADERS: Dict[str, Any] = {}
OVERRIDE_HTTP_HEADERS: Dict[str, Any] = {}
HTTP_HEADERS: Dict[str, Any] = {}
DEFAULT_HTTP_HEADERS: dict[str, Any] = {}
OVERRIDE_HTTP_HEADERS: dict[str, Any] = {}
HTTP_HEADERS: dict[str, Any] = {}
# The db id here results in selecting this one as a default in SQL Lab
DEFAULT_DB_ID = None
@ -974,8 +961,8 @@ SQLLAB_QUERY_COST_ESTIMATE_TIMEOUT = int(timedelta(seconds=10).total_seconds())
# return out
#
# QUERY_COST_FORMATTERS_BY_ENGINE: {"postgresql": postgres_query_cost_formatter}
QUERY_COST_FORMATTERS_BY_ENGINE: Dict[
str, Callable[[List[Dict[str, Any]]], List[Dict[str, Any]]]
QUERY_COST_FORMATTERS_BY_ENGINE: dict[
str, Callable[[list[dict[str, Any]]], list[dict[str, Any]]]
] = {}
# Flag that controls if limit should be enforced on the CTA (create table as queries).
@ -1000,13 +987,13 @@ SQLLAB_CTAS_NO_LIMIT = False
# else:
# return f'tmp_{schema}'
# Function accepts database object, user object, schema name and sql that will be run.
SQLLAB_CTAS_SCHEMA_NAME_FUNC: Optional[
SQLLAB_CTAS_SCHEMA_NAME_FUNC: None | (
Callable[[Database, models.User, str, str], str]
] = None
) = None
# If enabled, it can be used to store the results of long-running queries
# in SQL Lab by using the "Run Async" button/feature
RESULTS_BACKEND: Optional[BaseCache] = None
RESULTS_BACKEND: BaseCache | None = None
# Use PyArrow and MessagePack for async query results serialization,
# rather than JSON. This feature requires additional testing from the
@ -1028,7 +1015,7 @@ CSV_TO_HIVE_UPLOAD_DIRECTORY = "EXTERNAL_HIVE_TABLES/"
def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name
database: Database,
user: models.User, # pylint: disable=unused-argument
schema: Optional[str],
schema: str | None,
) -> str:
# Note the final empty path enforces a trailing slash.
return os.path.join(
@ -1038,14 +1025,14 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name
# The namespace within hive where the tables created from
# uploading CSVs will be stored.
UPLOADED_CSV_HIVE_NAMESPACE: Optional[str] = None
UPLOADED_CSV_HIVE_NAMESPACE: str | None = None
# Function that computes the allowed schemas for the CSV uploads.
# Allowed schemas will be a union of schemas_allowed_for_file_upload
# db configuration and a result of this function.
# mypy doesn't catch that if case ensures list content being always str
ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[[Database, models.User], List[str]] = (
ALLOWED_USER_CSV_SCHEMA_FUNC: Callable[[Database, models.User], list[str]] = (
lambda database, user: [UPLOADED_CSV_HIVE_NAMESPACE]
if UPLOADED_CSV_HIVE_NAMESPACE
else []
@ -1062,7 +1049,7 @@ CSV_DEFAULT_NA_NAMES = list(STR_NA_VALUES)
# It's important to make sure that the objects exposed (as well as objects attached
# to those objets) are harmless. We recommend only exposing simple/pure functions that
# return native types.
JINJA_CONTEXT_ADDONS: Dict[str, Callable[..., Any]] = {}
JINJA_CONTEXT_ADDONS: dict[str, Callable[..., Any]] = {}
# A dictionary of macro template processors (by engine) that gets merged into global
# template processors. The existing template processors get updated with this
@ -1070,7 +1057,7 @@ JINJA_CONTEXT_ADDONS: Dict[str, Callable[..., Any]] = {}
# dictionary. The customized addons don't necessarily need to use Jinja templating
# language. This allows you to define custom logic to process templates on a per-engine
# basis. Example value = `{"presto": CustomPrestoTemplateProcessor}`
CUSTOM_TEMPLATE_PROCESSORS: Dict[str, Type[BaseTemplateProcessor]] = {}
CUSTOM_TEMPLATE_PROCESSORS: dict[str, type[BaseTemplateProcessor]] = {}
# Roles that are controlled by the API / Superset and should not be changes
# by humans.
@ -1125,7 +1112,7 @@ PERMISSION_INSTRUCTIONS_LINK = ""
# Integrate external Blueprints to the app by passing them to your
# configuration. These blueprints will get integrated in the app
BLUEPRINTS: List[Blueprint] = []
BLUEPRINTS: list[Blueprint] = []
# Provide a callable that receives a tracking_url and returns another
# URL. This is used to translate internal Hadoop job tracker URL
@ -1142,7 +1129,7 @@ TRACKING_URL_TRANSFORMER = lambda url: url
# customize the polling time of each engine
DB_POLL_INTERVAL_SECONDS: Dict[str, int] = {}
DB_POLL_INTERVAL_SECONDS: dict[str, int] = {}
# Interval between consecutive polls when using Presto Engine
# See here: https://github.com/dropbox/PyHive/blob/8eb0aeab8ca300f3024655419b93dad926c1a351/pyhive/presto.py#L93 # pylint: disable=line-too-long,useless-suppression
@ -1159,7 +1146,7 @@ PRESTO_POLL_INTERVAL = int(timedelta(seconds=1).total_seconds())
# "another_auth_method": auth_method,
# },
# }
ALLOWED_EXTRA_AUTHENTICATIONS: Dict[str, Dict[str, Callable[..., Any]]] = {}
ALLOWED_EXTRA_AUTHENTICATIONS: dict[str, dict[str, Callable[..., Any]]] = {}
# The id of a template dashboard that should be copied to every new user
DASHBOARD_TEMPLATE_ID = None
@ -1224,14 +1211,14 @@ def EMAIL_HEADER_MUTATOR( # pylint: disable=invalid-name,unused-argument
# Owners, filters for created_by, etc.
# The users can also be excluded by overriding the get_exclude_users_from_lists method
# in security manager
EXCLUDE_USERS_FROM_LISTS: Optional[List[str]] = None
EXCLUDE_USERS_FROM_LISTS: list[str] | None = None
# For database connections, this dictionary will remove engines from the available
# list/dropdown if you do not want these dbs to show as available.
# The available list is generated by driver installed, and some engines have multiple
# drivers.
# e.g., DBS_AVAILABLE_DENYLIST: Dict[str, Set[str]] = {"databricks": {"pyhive", "pyodbc"}}
DBS_AVAILABLE_DENYLIST: Dict[str, Set[str]] = {}
DBS_AVAILABLE_DENYLIST: dict[str, set[str]] = {}
# This auth provider is used by background (offline) tasks that need to access
# protected resources. Can be overridden by end users in order to support
@ -1261,7 +1248,7 @@ ALERT_REPORTS_WORKING_TIME_OUT_KILL = True
# ExecutorType.OWNER,
# ExecutorType.SELENIUM,
# ]
ALERT_REPORTS_EXECUTE_AS: List[ExecutorType] = [ExecutorType.OWNER]
ALERT_REPORTS_EXECUTE_AS: list[ExecutorType] = [ExecutorType.OWNER]
# if ALERT_REPORTS_WORKING_TIME_OUT_KILL is True, set a celery hard timeout
# Equal to working timeout + ALERT_REPORTS_WORKING_TIME_OUT_LAG
ALERT_REPORTS_WORKING_TIME_OUT_LAG = int(timedelta(seconds=10).total_seconds())
@ -1286,7 +1273,7 @@ EMAIL_REPORTS_SUBJECT_PREFIX = "[Report] "
EMAIL_REPORTS_CTA = "Explore in Superset"
# Slack API token for the superset reports, either string or callable
SLACK_API_TOKEN: Optional[Union[Callable[[], str], str]] = None
SLACK_API_TOKEN: Callable[[], str] | str | None = None
SLACK_PROXY = None
# The webdriver to use for generating reports. Use one of the following
@ -1310,7 +1297,7 @@ WEBDRIVER_WINDOW = {
WEBDRIVER_AUTH_FUNC = None
# Any config options to be passed as-is to the webdriver
WEBDRIVER_CONFIGURATION: Dict[Any, Any] = {"service_log_path": "/dev/null"}
WEBDRIVER_CONFIGURATION: dict[Any, Any] = {"service_log_path": "/dev/null"}
# Additional args to be passed as arguments to the config object
# Note: If using Chrome, you'll want to add the "--marionette" arg.
@ -1353,7 +1340,7 @@ SQL_VALIDATORS_BY_ENGINE = {
# displayed prominently in the "Add Database" dialog. You should
# use the "engine_name" attribute of the corresponding DB engine spec
# in `superset/db_engine_specs/`.
PREFERRED_DATABASES: List[str] = [
PREFERRED_DATABASES: list[str] = [
"PostgreSQL",
"Presto",
"MySQL",
@ -1386,7 +1373,7 @@ TALISMAN_CONFIG = {
#
SESSION_COOKIE_HTTPONLY = True # Prevent cookie from being read by frontend JS?
SESSION_COOKIE_SECURE = False # Prevent cookie from being transmitted over non-tls?
SESSION_COOKIE_SAMESITE: Optional[Literal["None", "Lax", "Strict"]] = "Lax"
SESSION_COOKIE_SAMESITE: Literal["None", "Lax", "Strict"] | None = "Lax"
# Accepts None, "basic" and "strong", more details on: https://flask-login.readthedocs.io/en/latest/#session-protection
SESSION_PROTECTION = "strong"
@ -1418,7 +1405,7 @@ DATASET_IMPORT_ALLOWED_DATA_URLS = [r".*"]
# Path used to store SSL certificates that are generated when using custom certs.
# Defaults to temporary directory.
# Example: SSL_CERT_PATH = "/certs"
SSL_CERT_PATH: Optional[str] = None
SSL_CERT_PATH: str | None = None
# SQLA table mutator, every time we fetch the metadata for a certain table
# (superset.connectors.sqla.models.SqlaTable), we call this hook
@ -1443,9 +1430,9 @@ GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT = 1000
GLOBAL_ASYNC_QUERIES_REDIS_STREAM_LIMIT_FIREHOSE = 1000000
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_NAME = "async-token"
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SECURE = False
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SAMESITE: Optional[
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_SAMESITE: None | (
Literal["None", "Lax", "Strict"]
] = None
) = None
GLOBAL_ASYNC_QUERIES_JWT_COOKIE_DOMAIN = None
GLOBAL_ASYNC_QUERIES_JWT_SECRET = "test-secret-change-me"
GLOBAL_ASYNC_QUERIES_TRANSPORT = "polling"
@ -1461,7 +1448,7 @@ GUEST_TOKEN_JWT_ALGO = "HS256"
GUEST_TOKEN_HEADER_NAME = "X-GuestToken"
GUEST_TOKEN_JWT_EXP_SECONDS = 300 # 5 minutes
# Guest token audience for the embedded superset, either string or callable
GUEST_TOKEN_JWT_AUDIENCE: Optional[Union[Callable[[], str], str]] = None
GUEST_TOKEN_JWT_AUDIENCE: Callable[[], str] | str | None = None
# A SQL dataset health check. Note if enabled it is strongly advised that the callable
# be memoized to aid with performance, i.e.,
@ -1492,7 +1479,7 @@ GUEST_TOKEN_JWT_AUDIENCE: Optional[Union[Callable[[], str], str]] = None
# cache_manager.cache.delete_memoized(func)
# cache_manager.cache.set(name, code, timeout=0)
#
DATASET_HEALTH_CHECK: Optional[Callable[["SqlaTable"], str]] = None
DATASET_HEALTH_CHECK: Callable[[SqlaTable], str] | None = None
# Do not show user info or profile in the menu
MENU_HIDE_USER_INFO = False
@ -1502,7 +1489,7 @@ MENU_HIDE_USER_INFO = False
ENABLE_BROAD_ACTIVITY_ACCESS = True
# the advanced data type key should correspond to that set in the column metadata
ADVANCED_DATA_TYPES: Dict[str, AdvancedDataType] = {
ADVANCED_DATA_TYPES: dict[str, AdvancedDataType] = {
"internet_address": internet_address,
"port": internet_port,
}
@ -1514,9 +1501,9 @@ ADVANCED_DATA_TYPES: Dict[str, AdvancedDataType] = {
# "Xyz",
# [{"col": 'created_by', "opr": 'rel_o_m', "value": 10}],
# )
WELCOME_PAGE_LAST_TAB: Union[
Literal["examples", "all"], Tuple[str, List[Dict[str, Any]]]
] = "all"
WELCOME_PAGE_LAST_TAB: (
Literal["examples", "all"] | tuple[str, list[dict[str, Any]]]
) = "all"
# Configuration for environment tag shown on the navbar. Setting 'text' to '' will hide the tag.
# 'color' can either be a hex color code, or a dot-indexed theme color (e.g. error.base)

View File

@ -16,21 +16,12 @@
# under the License.
from __future__ import annotations
import builtins
import json
from collections.abc import Hashable
from datetime import datetime
from enum import Enum
from typing import (
Any,
Dict,
Hashable,
List,
Optional,
Set,
Tuple,
Type,
TYPE_CHECKING,
Union,
)
from typing import Any, TYPE_CHECKING
from flask_appbuilder.security.sqla.models import User
from flask_babel import gettext as __
@ -89,23 +80,23 @@ class BaseDatasource(
# ---------------------------------------------------------------
# class attributes to define when deriving BaseDatasource
# ---------------------------------------------------------------
__tablename__: Optional[str] = None # {connector_name}_datasource
baselink: Optional[str] = None # url portion pointing to ModelView endpoint
__tablename__: str | None = None # {connector_name}_datasource
baselink: str | None = None # url portion pointing to ModelView endpoint
@property
def column_class(self) -> Type["BaseColumn"]:
def column_class(self) -> type[BaseColumn]:
# link to derivative of BaseColumn
raise NotImplementedError()
@property
def metric_class(self) -> Type["BaseMetric"]:
def metric_class(self) -> type[BaseMetric]:
# link to derivative of BaseMetric
raise NotImplementedError()
owner_class: Optional[User] = None
owner_class: User | None = None
# Used to do code highlighting when displaying the query in the UI
query_language: Optional[str] = None
query_language: str | None = None
# Only some datasources support Row Level Security
is_rls_supported: bool = False
@ -131,9 +122,9 @@ class BaseDatasource(
is_managed_externally = Column(Boolean, nullable=False, default=False)
external_url = Column(Text, nullable=True)
sql: Optional[str] = None
owners: List[User]
update_from_object_fields: List[str]
sql: str | None = None
owners: list[User]
update_from_object_fields: list[str]
extra_import_fields = ["is_managed_externally", "external_url"]
@ -142,7 +133,7 @@ class BaseDatasource(
return DatasourceKind.VIRTUAL if self.sql else DatasourceKind.PHYSICAL
@property
def owners_data(self) -> List[Dict[str, Any]]:
def owners_data(self) -> list[dict[str, Any]]:
return [
{
"first_name": o.first_name,
@ -167,8 +158,8 @@ class BaseDatasource(
),
)
columns: List["BaseColumn"] = []
metrics: List["BaseMetric"] = []
columns: list[BaseColumn] = []
metrics: list[BaseMetric] = []
@property
def type(self) -> str:
@ -180,11 +171,11 @@ class BaseDatasource(
return f"{self.id}__{self.type}"
@property
def column_names(self) -> List[str]:
def column_names(self) -> list[str]:
return sorted([c.column_name for c in self.columns], key=lambda x: x or "")
@property
def columns_types(self) -> Dict[str, str]:
def columns_types(self) -> dict[str, str]:
return {c.column_name: c.type for c in self.columns}
@property
@ -196,26 +187,26 @@ class BaseDatasource(
raise NotImplementedError()
@property
def connection(self) -> Optional[str]:
def connection(self) -> str | None:
"""String representing the context of the Datasource"""
return None
@property
def schema(self) -> Optional[str]:
def schema(self) -> str | None:
"""String representing the schema of the Datasource (if it applies)"""
return None
@property
def filterable_column_names(self) -> List[str]:
def filterable_column_names(self) -> list[str]:
return sorted([c.column_name for c in self.columns if c.filterable])
@property
def dttm_cols(self) -> List[str]:
def dttm_cols(self) -> list[str]:
return []
@property
def url(self) -> str:
return "/{}/edit/{}".format(self.baselink, self.id)
return f"/{self.baselink}/edit/{self.id}"
@property
def explore_url(self) -> str:
@ -224,10 +215,10 @@ class BaseDatasource(
return f"/explore/?datasource_type={self.type}&datasource_id={self.id}"
@property
def column_formats(self) -> Dict[str, Optional[str]]:
def column_formats(self) -> dict[str, str | None]:
return {m.metric_name: m.d3format for m in self.metrics if m.d3format}
def add_missing_metrics(self, metrics: List["BaseMetric"]) -> None:
def add_missing_metrics(self, metrics: list[BaseMetric]) -> None:
existing_metrics = {m.metric_name for m in self.metrics}
for metric in metrics:
if metric.metric_name not in existing_metrics:
@ -235,7 +226,7 @@ class BaseDatasource(
self.metrics.append(metric)
@property
def short_data(self) -> Dict[str, Any]:
def short_data(self) -> dict[str, Any]:
"""Data representation of the datasource sent to the frontend"""
return {
"edit_url": self.url,
@ -249,11 +240,11 @@ class BaseDatasource(
}
@property
def select_star(self) -> Optional[str]:
def select_star(self) -> str | None:
pass
@property
def order_by_choices(self) -> List[Tuple[str, str]]:
def order_by_choices(self) -> list[tuple[str, str]]:
choices = []
# self.column_names return sorted column_names
for column_name in self.column_names:
@ -267,7 +258,7 @@ class BaseDatasource(
return choices
@property
def verbose_map(self) -> Dict[str, str]:
def verbose_map(self) -> dict[str, str]:
verb_map = {"__timestamp": "Time"}
verb_map.update(
{o.metric_name: o.verbose_name or o.metric_name for o in self.metrics}
@ -278,7 +269,7 @@ class BaseDatasource(
return verb_map
@property
def data(self) -> Dict[str, Any]:
def data(self) -> dict[str, Any]:
"""Data representation of the datasource sent to the frontend"""
return {
# simple fields
@ -313,8 +304,8 @@ class BaseDatasource(
}
def data_for_slices( # pylint: disable=too-many-locals
self, slices: List[Slice]
) -> Dict[str, Any]:
self, slices: list[Slice]
) -> dict[str, Any]:
"""
The representation of the datasource containing only the required data
to render the provided slices.
@ -381,8 +372,8 @@ class BaseDatasource(
if metric["metric_name"] in metric_names
]
filtered_columns: List[Column] = []
column_types: Set[GenericDataType] = set()
filtered_columns: list[Column] = []
column_types: set[GenericDataType] = set()
for column in data["columns"]:
generic_type = column.get("type_generic")
if generic_type is not None:
@ -413,18 +404,18 @@ class BaseDatasource(
@staticmethod
def filter_values_handler( # pylint: disable=too-many-arguments
values: Optional[FilterValues],
values: FilterValues | None,
operator: str,
target_generic_type: GenericDataType,
target_native_type: Optional[str] = None,
target_native_type: str | None = None,
is_list_target: bool = False,
db_engine_spec: Optional[Type[BaseEngineSpec]] = None,
db_extra: Optional[Dict[str, Any]] = None,
) -> Optional[FilterValues]:
db_engine_spec: builtins.type[BaseEngineSpec] | None = None,
db_extra: dict[str, Any] | None = None,
) -> FilterValues | None:
if values is None:
return None
def handle_single_value(value: Optional[FilterValue]) -> Optional[FilterValue]:
def handle_single_value(value: FilterValue | None) -> FilterValue | None:
if operator == utils.FilterOperator.TEMPORAL_RANGE:
return value
if (
@ -464,7 +455,7 @@ class BaseDatasource(
values = values[0] if values else None
return values
def external_metadata(self) -> List[Dict[str, str]]:
def external_metadata(self) -> list[dict[str, str]]:
"""Returns column information from the external system"""
raise NotImplementedError()
@ -483,7 +474,7 @@ class BaseDatasource(
"""
raise NotImplementedError()
def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
def values_for_column(self, column_name: str, limit: int = 10000) -> list[Any]:
"""Given a column, returns an iterable of distinct values
This is used to populate the dropdown showing a list of
@ -494,7 +485,7 @@ class BaseDatasource(
def default_query(qry: Query) -> Query:
return qry
def get_column(self, column_name: Optional[str]) -> Optional["BaseColumn"]:
def get_column(self, column_name: str | None) -> BaseColumn | None:
if not column_name:
return None
for col in self.columns:
@ -504,11 +495,11 @@ class BaseDatasource(
@staticmethod
def get_fk_many_from_list(
object_list: List[Any],
fkmany: List[Column],
fkmany_class: Type[Union["BaseColumn", "BaseMetric"]],
object_list: list[Any],
fkmany: list[Column],
fkmany_class: builtins.type[BaseColumn | BaseMetric],
key_attr: str,
) -> List[Column]:
) -> list[Column]:
"""Update ORM one-to-many list from object list
Used for syncing metrics and columns using the same code"""
@ -541,7 +532,7 @@ class BaseDatasource(
fkmany += new_fks
return fkmany
def update_from_object(self, obj: Dict[str, Any]) -> None:
def update_from_object(self, obj: dict[str, Any]) -> None:
"""Update datasource from a data structure
The UI's table editor crafts a complex data structure that
@ -578,7 +569,7 @@ class BaseDatasource(
def get_extra_cache_keys( # pylint: disable=no-self-use
self, query_obj: QueryObjectDict # pylint: disable=unused-argument
) -> List[Hashable]:
) -> list[Hashable]:
"""If a datasource needs to provide additional keys for calculation of
cache keys, those can be provided via this method
@ -607,14 +598,14 @@ class BaseDatasource(
@classmethod
def get_datasource_by_name(
cls, session: Session, datasource_name: str, schema: str, database_name: str
) -> Optional["BaseDatasource"]:
) -> BaseDatasource | None:
raise NotImplementedError()
class BaseColumn(AuditMixinNullable, ImportExportMixin):
"""Interface for column"""
__tablename__: Optional[str] = None # {connector_name}_column
__tablename__: str | None = None # {connector_name}_column
id = Column(Integer, primary_key=True)
column_name = Column(String(255), nullable=False)
@ -628,7 +619,7 @@ class BaseColumn(AuditMixinNullable, ImportExportMixin):
is_dttm = None
# [optional] Set this to support import/export functionality
export_fields: List[Any] = []
export_fields: list[Any] = []
def __repr__(self) -> str:
return str(self.column_name)
@ -666,7 +657,7 @@ class BaseColumn(AuditMixinNullable, ImportExportMixin):
return self.type and any(map(lambda t: t in self.type.upper(), self.bool_types))
@property
def type_generic(self) -> Optional[utils.GenericDataType]:
def type_generic(self) -> utils.GenericDataType | None:
if self.is_string:
return utils.GenericDataType.STRING
if self.is_boolean:
@ -686,7 +677,7 @@ class BaseColumn(AuditMixinNullable, ImportExportMixin):
raise NotImplementedError()
@property
def data(self) -> Dict[str, Any]:
def data(self) -> dict[str, Any]:
attrs = (
"id",
"column_name",
@ -705,7 +696,7 @@ class BaseColumn(AuditMixinNullable, ImportExportMixin):
class BaseMetric(AuditMixinNullable, ImportExportMixin):
"""Interface for Metrics"""
__tablename__: Optional[str] = None # {connector_name}_metric
__tablename__: str | None = None # {connector_name}_metric
id = Column(Integer, primary_key=True)
metric_name = Column(String(255), nullable=False)
@ -730,7 +721,7 @@ class BaseMetric(AuditMixinNullable, ImportExportMixin):
"""
@property
def perm(self) -> Optional[str]:
def perm(self) -> str | None:
raise NotImplementedError()
@property
@ -738,7 +729,7 @@ class BaseMetric(AuditMixinNullable, ImportExportMixin):
raise NotImplementedError()
@property
def data(self) -> Dict[str, Any]:
def data(self) -> dict[str, Any]:
attrs = (
"id",
"metric_name",

View File

@ -22,21 +22,10 @@ import json
import logging
import re
from collections import defaultdict
from collections.abc import Hashable
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import (
Any,
Callable,
cast,
Dict,
Hashable,
List,
Optional,
Set,
Tuple,
Type,
Union,
)
from typing import Any, Callable, cast
import dateutil.parser
import numpy as np
@ -136,9 +125,9 @@ ADDITIVE_METRIC_TYPES_LOWER = {op.lower() for op in ADDITIVE_METRIC_TYPES}
@dataclass
class MetadataResult:
added: List[str] = field(default_factory=list)
removed: List[str] = field(default_factory=list)
modified: List[str] = field(default_factory=list)
added: list[str] = field(default_factory=list)
removed: list[str] = field(default_factory=list)
modified: list[str] = field(default_factory=list)
class AnnotationDatasource(BaseDatasource):
@ -190,7 +179,7 @@ class AnnotationDatasource(BaseDatasource):
def get_query_str(self, query_obj: QueryObjectDict) -> str:
raise NotImplementedError()
def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
def values_for_column(self, column_name: str, limit: int = 10000) -> list[Any]:
raise NotImplementedError()
@ -201,7 +190,7 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
__tablename__ = "table_columns"
__table_args__ = (UniqueConstraint("table_id", "column_name"),)
table_id = Column(Integer, ForeignKey("tables.id"))
table: Mapped["SqlaTable"] = relationship(
table: Mapped[SqlaTable] = relationship(
"SqlaTable",
back_populates="columns",
)
@ -263,15 +252,15 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
return self.type_generic == GenericDataType.TEMPORAL
@property
def db_engine_spec(self) -> Type[BaseEngineSpec]:
def db_engine_spec(self) -> type[BaseEngineSpec]:
return self.table.db_engine_spec
@property
def db_extra(self) -> Dict[str, Any]:
def db_extra(self) -> dict[str, Any]:
return self.table.database.get_extra()
@property
def type_generic(self) -> Optional[utils.GenericDataType]:
def type_generic(self) -> utils.GenericDataType | None:
if self.is_dttm:
return GenericDataType.TEMPORAL
@ -310,8 +299,8 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
def get_sqla_col(
self,
label: Optional[str] = None,
template_processor: Optional[BaseTemplateProcessor] = None,
label: str | None = None,
template_processor: BaseTemplateProcessor | None = None,
) -> Column:
label = label or self.column_name
db_engine_spec = self.db_engine_spec
@ -332,10 +321,10 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
def get_timestamp_expression(
self,
time_grain: Optional[str],
label: Optional[str] = None,
template_processor: Optional[BaseTemplateProcessor] = None,
) -> Union[TimestampExpression, Label]:
time_grain: str | None,
label: str | None = None,
template_processor: BaseTemplateProcessor | None = None,
) -> TimestampExpression | Label:
"""
Return a SQLAlchemy Core element representation of self to be used in a query.
@ -365,7 +354,7 @@ class TableColumn(Model, BaseColumn, CertificationMixin):
return self.table.make_sqla_column_compatible(time_expr, label)
@property
def data(self) -> Dict[str, Any]:
def data(self) -> dict[str, Any]:
attrs = (
"id",
"column_name",
@ -399,7 +388,7 @@ class SqlMetric(Model, BaseMetric, CertificationMixin):
__tablename__ = "sql_metrics"
__table_args__ = (UniqueConstraint("table_id", "metric_name"),)
table_id = Column(Integer, ForeignKey("tables.id"))
table: Mapped["SqlaTable"] = relationship(
table: Mapped[SqlaTable] = relationship(
"SqlaTable",
back_populates="metrics",
)
@ -425,8 +414,8 @@ class SqlMetric(Model, BaseMetric, CertificationMixin):
def get_sqla_col(
self,
label: Optional[str] = None,
template_processor: Optional[BaseTemplateProcessor] = None,
label: str | None = None,
template_processor: BaseTemplateProcessor | None = None,
) -> Column:
label = label or self.metric_name
expression = self.expression
@ -437,7 +426,7 @@ class SqlMetric(Model, BaseMetric, CertificationMixin):
return self.table.make_sqla_column_compatible(sqla_col, label)
@property
def perm(self) -> Optional[str]:
def perm(self) -> str | None:
return (
("{parent_name}.[{obj.metric_name}](id:{obj.id})").format(
obj=self, parent_name=self.table.full_name
@ -446,11 +435,11 @@ class SqlMetric(Model, BaseMetric, CertificationMixin):
else None
)
def get_perm(self) -> Optional[str]:
def get_perm(self) -> str | None:
return self.perm
@property
def data(self) -> Dict[str, Any]:
def data(self) -> dict[str, Any]:
attrs = (
"is_certified",
"certified_by",
@ -473,11 +462,11 @@ sqlatable_user = Table(
def _process_sql_expression(
expression: Optional[str],
expression: str | None,
database_id: int,
schema: str,
template_processor: Optional[BaseTemplateProcessor] = None,
) -> Optional[str]:
template_processor: BaseTemplateProcessor | None = None,
) -> str | None:
if template_processor and expression:
expression = template_processor.process_template(expression)
if expression:
@ -501,12 +490,12 @@ class SqlaTable(
type = "table"
query_language = "sql"
is_rls_supported = True
columns: Mapped[List[TableColumn]] = relationship(
columns: Mapped[list[TableColumn]] = relationship(
TableColumn,
back_populates="table",
cascade="all, delete-orphan",
)
metrics: Mapped[List[SqlMetric]] = relationship(
metrics: Mapped[list[SqlMetric]] = relationship(
SqlMetric,
back_populates="table",
cascade="all, delete-orphan",
@ -577,11 +566,11 @@ class SqlaTable(
return self.name
@property
def db_extra(self) -> Dict[str, Any]:
def db_extra(self) -> dict[str, Any]:
return self.database.get_extra()
@staticmethod
def _apply_cte(sql: str, cte: Optional[str]) -> str:
def _apply_cte(sql: str, cte: str | None) -> str:
"""
Append a CTE before the SELECT statement if defined
@ -594,7 +583,7 @@ class SqlaTable(
return sql
@property
def db_engine_spec(self) -> Type[BaseEngineSpec]:
def db_engine_spec(self) -> __builtins__.type[BaseEngineSpec]:
return self.database.db_engine_spec
@property
@ -637,9 +626,9 @@ class SqlaTable(
cls,
session: Session,
datasource_name: str,
schema: Optional[str],
schema: str | None,
database_name: str,
) -> Optional[SqlaTable]:
) -> SqlaTable | None:
schema = schema or None
query = (
session.query(cls)
@ -660,7 +649,7 @@ class SqlaTable(
anchor = f'<a target="_blank" href="{self.explore_url}">{name}</a>'
return Markup(anchor)
def get_schema_perm(self) -> Optional[str]:
def get_schema_perm(self) -> str | None:
"""Returns schema permission if present, database one otherwise."""
return security_manager.get_schema_perm(self.database, self.schema)
@ -685,18 +674,18 @@ class SqlaTable(
)
@property
def dttm_cols(self) -> List[str]:
def dttm_cols(self) -> list[str]:
l = [c.column_name for c in self.columns if c.is_dttm]
if self.main_dttm_col and self.main_dttm_col not in l:
l.append(self.main_dttm_col)
return l
@property
def num_cols(self) -> List[str]:
def num_cols(self) -> list[str]:
return [c.column_name for c in self.columns if c.is_numeric]
@property
def any_dttm_col(self) -> Optional[str]:
def any_dttm_col(self) -> str | None:
cols = self.dttm_cols
return cols[0] if cols else None
@ -713,7 +702,7 @@ class SqlaTable(
def sql_url(self) -> str:
return self.database.sql_url + "?table_name=" + str(self.table_name)
def external_metadata(self) -> List[Dict[str, str]]:
def external_metadata(self) -> list[dict[str, str]]:
# todo(yongjie): create a physical table column type in a separate PR
if self.sql:
return get_virtual_table_metadata(dataset=self) # type: ignore
@ -724,14 +713,14 @@ class SqlaTable(
)
@property
def time_column_grains(self) -> Dict[str, Any]:
def time_column_grains(self) -> dict[str, Any]:
return {
"time_columns": self.dttm_cols,
"time_grains": [grain.name for grain in self.database.grains()],
}
@property
def select_star(self) -> Optional[str]:
def select_star(self) -> str | None:
# show_cols and latest_partition set to false to avoid
# the expensive cost of inspecting the DB
return self.database.select_star(
@ -739,20 +728,20 @@ class SqlaTable(
)
@property
def health_check_message(self) -> Optional[str]:
def health_check_message(self) -> str | None:
check = config["DATASET_HEALTH_CHECK"]
return check(self) if check else None
@property
def granularity_sqla(self) -> List[Tuple[Any, Any]]:
def granularity_sqla(self) -> list[tuple[Any, Any]]:
return utils.choicify(self.dttm_cols)
@property
def time_grain_sqla(self) -> List[Tuple[Any, Any]]:
def time_grain_sqla(self) -> list[tuple[Any, Any]]:
return [(g.duration, g.name) for g in self.database.grains() or []]
@property
def data(self) -> Dict[str, Any]:
def data(self) -> dict[str, Any]:
data_ = super().data
if self.type == "table":
data_["granularity_sqla"] = self.granularity_sqla
@ -767,7 +756,7 @@ class SqlaTable(
return data_
@property
def extra_dict(self) -> Dict[str, Any]:
def extra_dict(self) -> dict[str, Any]:
try:
return json.loads(self.extra)
except (TypeError, json.JSONDecodeError):
@ -775,7 +764,7 @@ class SqlaTable(
def get_fetch_values_predicate(
self,
template_processor: Optional[BaseTemplateProcessor] = None,
template_processor: BaseTemplateProcessor | None = None,
) -> TextClause:
fetch_values_predicate = self.fetch_values_predicate
if template_processor:
@ -792,7 +781,7 @@ class SqlaTable(
)
) from ex
def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]:
def values_for_column(self, column_name: str, limit: int = 10000) -> list[Any]:
"""Runs query against sqla to retrieve some
sample values for the given column.
"""
@ -869,8 +858,8 @@ class SqlaTable(
return tbl
def get_from_clause(
self, template_processor: Optional[BaseTemplateProcessor] = None
) -> Tuple[Union[TableClause, Alias], Optional[str]]:
self, template_processor: BaseTemplateProcessor | None = None
) -> tuple[TableClause | Alias, str | None]:
"""
Return where to select the columns and metrics from. Either a physical table
or a virtual table with it's own subquery. If the FROM is referencing a
@ -899,7 +888,7 @@ class SqlaTable(
return from_clause, cte
def get_rendered_sql(
self, template_processor: Optional[BaseTemplateProcessor] = None
self, template_processor: BaseTemplateProcessor | None = None
) -> str:
"""
Render sql with template engine (Jinja).
@ -928,8 +917,8 @@ class SqlaTable(
def adhoc_metric_to_sqla(
self,
metric: AdhocMetric,
columns_by_name: Dict[str, TableColumn],
template_processor: Optional[BaseTemplateProcessor] = None,
columns_by_name: dict[str, TableColumn],
template_processor: BaseTemplateProcessor | None = None,
) -> ColumnElement:
"""
Turn an adhoc metric into a sqlalchemy column.
@ -946,7 +935,7 @@ class SqlaTable(
if expression_type == utils.AdhocMetricExpressionType.SIMPLE:
metric_column = metric.get("column") or {}
column_name = cast(str, metric_column.get("column_name"))
table_column: Optional[TableColumn] = columns_by_name.get(column_name)
table_column: TableColumn | None = columns_by_name.get(column_name)
if table_column:
sqla_column = table_column.get_sqla_col(
template_processor=template_processor
@ -971,7 +960,7 @@ class SqlaTable(
self,
col: AdhocColumn,
force_type_check: bool = False,
template_processor: Optional[BaseTemplateProcessor] = None,
template_processor: BaseTemplateProcessor | None = None,
) -> ColumnElement:
"""
Turn an adhoc column into a sqlalchemy column.
@ -1021,7 +1010,7 @@ class SqlaTable(
return self.make_sqla_column_compatible(sqla_column, label)
def make_sqla_column_compatible(
self, sqla_col: ColumnElement, label: Optional[str] = None
self, sqla_col: ColumnElement, label: str | None = None
) -> ColumnElement:
"""Takes a sqlalchemy column object and adds label info if supported by engine.
:param sqla_col: sqlalchemy column instance
@ -1038,7 +1027,7 @@ class SqlaTable(
return sqla_col
def make_orderby_compatible(
self, select_exprs: List[ColumnElement], orderby_exprs: List[ColumnElement]
self, select_exprs: list[ColumnElement], orderby_exprs: list[ColumnElement]
) -> None:
"""
If needed, make sure aliases for selected columns are not used in
@ -1069,7 +1058,7 @@ class SqlaTable(
def get_sqla_row_level_filters(
self,
template_processor: BaseTemplateProcessor,
) -> List[TextClause]:
) -> list[TextClause]:
"""
Return the appropriate row level security filters for this table and the
current user. A custom username can be passed when the user is not present in the
@ -1078,8 +1067,8 @@ class SqlaTable(
:param template_processor: The template processor to apply to the filters.
:returns: A list of SQL clauses to be ANDed together.
"""
all_filters: List[TextClause] = []
filter_groups: Dict[Union[int, str], List[TextClause]] = defaultdict(list)
all_filters: list[TextClause] = []
filter_groups: dict[int | str, list[TextClause]] = defaultdict(list)
try:
for filter_ in security_manager.get_rls_filters(self):
clause = self.text(
@ -1114,9 +1103,9 @@ class SqlaTable(
def _get_series_orderby(
self,
series_limit_metric: Metric,
metrics_by_name: Dict[str, SqlMetric],
columns_by_name: Dict[str, TableColumn],
template_processor: Optional[BaseTemplateProcessor] = None,
metrics_by_name: dict[str, SqlMetric],
columns_by_name: dict[str, TableColumn],
template_processor: BaseTemplateProcessor | None = None,
) -> Column:
if utils.is_adhoc_metric(series_limit_metric):
assert isinstance(series_limit_metric, dict)
@ -1138,8 +1127,8 @@ class SqlaTable(
self,
row: pd.Series,
dimension: str,
columns_by_name: Dict[str, TableColumn],
) -> Union[str, int, float, bool, Text]:
columns_by_name: dict[str, TableColumn],
) -> str | int | float | bool | Text:
"""
Convert a prequery result type to its equivalent Python type.
@ -1159,7 +1148,7 @@ class SqlaTable(
value = value.item()
column_ = columns_by_name[dimension]
db_extra: Dict[str, Any] = self.database.get_extra()
db_extra: dict[str, Any] = self.database.get_extra()
if column_.type and column_.is_temporal and isinstance(value, str):
sql = self.db_engine_spec.convert_dttm(
@ -1174,9 +1163,9 @@ class SqlaTable(
def _get_top_groups(
self,
df: pd.DataFrame,
dimensions: List[str],
groupby_exprs: Dict[str, Any],
columns_by_name: Dict[str, TableColumn],
dimensions: list[str],
groupby_exprs: dict[str, Any],
columns_by_name: dict[str, TableColumn],
) -> ColumnElement:
groups = []
for _unused, row in df.iterrows():
@ -1201,7 +1190,7 @@ class SqlaTable(
errors = None
error_message = None
def assign_column_label(df: pd.DataFrame) -> Optional[pd.DataFrame]:
def assign_column_label(df: pd.DataFrame) -> pd.DataFrame | None:
"""
Some engines change the case or generate bespoke column names, either by
default or due to lack of support for aliasing. This function ensures that
@ -1283,7 +1272,7 @@ class SqlaTable(
else self.columns
)
old_columns_by_name: Dict[str, TableColumn] = {
old_columns_by_name: dict[str, TableColumn] = {
col.column_name: col for col in old_columns
}
results = MetadataResult(
@ -1341,8 +1330,8 @@ class SqlaTable(
session: Session,
database: Database,
datasource_name: str,
schema: Optional[str] = None,
) -> List[SqlaTable]:
schema: str | None = None,
) -> list[SqlaTable]:
query = (
session.query(cls)
.filter_by(database_id=database.id)
@ -1357,9 +1346,9 @@ class SqlaTable(
cls,
session: Session,
database: Database,
permissions: Set[str],
schema_perms: Set[str],
) -> List[SqlaTable]:
permissions: set[str],
schema_perms: set[str],
) -> list[SqlaTable]:
# TODO(hughhhh): add unit test
return (
session.query(cls)
@ -1389,7 +1378,7 @@ class SqlaTable(
)
@classmethod
def get_all_datasources(cls, session: Session) -> List[SqlaTable]:
def get_all_datasources(cls, session: Session) -> list[SqlaTable]:
qry = session.query(cls)
qry = cls.default_query(qry)
return qry.all()
@ -1409,7 +1398,7 @@ class SqlaTable(
:param query_obj: query object to analyze
:return: True if there are call(s) to an `ExtraCache` method, False otherwise
"""
templatable_statements: List[str] = []
templatable_statements: list[str] = []
if self.sql:
templatable_statements.append(self.sql)
if self.fetch_values_predicate:
@ -1428,7 +1417,7 @@ class SqlaTable(
return True
return False
def get_extra_cache_keys(self, query_obj: QueryObjectDict) -> List[Hashable]:
def get_extra_cache_keys(self, query_obj: QueryObjectDict) -> list[Hashable]:
"""
The cache key of a SqlaTable needs to consider any keys added by the parent
class and any keys added via `ExtraCache`.
@ -1489,7 +1478,7 @@ class SqlaTable(
@staticmethod
def update_column( # pylint: disable=unused-argument
mapper: Mapper, connection: Connection, target: Union[SqlMetric, TableColumn]
mapper: Mapper, connection: Connection, target: SqlMetric | TableColumn
) -> None:
"""
:param mapper: Unused.

View File

@ -17,19 +17,9 @@
from __future__ import annotations
import logging
from collections.abc import Iterable, Iterator
from functools import lru_cache
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Type,
TYPE_CHECKING,
TypeVar,
)
from typing import Any, Callable, TYPE_CHECKING, TypeVar
from uuid import UUID
from flask_babel import lazy_gettext as _
@ -58,8 +48,8 @@ if TYPE_CHECKING:
def get_physical_table_metadata(
database: Database,
table_name: str,
schema_name: Optional[str] = None,
) -> List[Dict[str, Any]]:
schema_name: str | None = None,
) -> list[dict[str, Any]]:
"""Use SQLAlchemy inspector to get table metadata"""
db_engine_spec = database.db_engine_spec
db_dialect = database.get_dialect()
@ -103,7 +93,7 @@ def get_physical_table_metadata(
return cols
def get_virtual_table_metadata(dataset: SqlaTable) -> List[ResultSetColumnType]:
def get_virtual_table_metadata(dataset: SqlaTable) -> list[ResultSetColumnType]:
"""Use SQLparser to get virtual dataset metadata"""
if not dataset.sql:
raise SupersetGenericDBErrorException(
@ -150,7 +140,7 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> List[ResultSetColumnType]:
def get_columns_description(
database: Database,
query: str,
) -> List[ResultSetColumnType]:
) -> list[ResultSetColumnType]:
db_engine_spec = database.db_engine_spec
try:
with database.get_raw_connection() as conn:
@ -171,7 +161,7 @@ def get_dialect_name(drivername: str) -> str:
@lru_cache(maxsize=LRU_CACHE_MAX_SIZE)
def get_identifier_quoter(drivername: str) -> Dict[str, Callable[[str], str]]:
def get_identifier_quoter(drivername: str) -> dict[str, Callable[[str], str]]:
return SqlaURL.create(drivername).get_dialect()().identifier_preparer.quote
@ -181,9 +171,9 @@ logger = logging.getLogger(__name__)
def find_cached_objects_in_session(
session: Session,
cls: Type[DeclarativeModel],
ids: Optional[Iterable[int]] = None,
uuids: Optional[Iterable[UUID]] = None,
cls: type[DeclarativeModel],
ids: Iterable[int] | None = None,
uuids: Iterable[UUID] | None = None,
) -> Iterator[DeclarativeModel]:
"""Find known ORM instances in cached SQLA session states.

View File

@ -447,7 +447,7 @@ class TableModelView( # pylint: disable=too-many-ancestors
resp = super().edit(pk)
if isinstance(resp, str):
return resp
return redirect("/explore/?datasource_type=table&datasource_id={}".format(pk))
return redirect(f"/explore/?datasource_type=table&datasource_id={pk}")
@expose("/list/")
@has_access

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import List, Optional
from typing import Optional
from superset.commands.base import BaseCommand
from superset.css_templates.commands.exceptions import (
@ -30,9 +30,9 @@ logger = logging.getLogger(__name__)
class BulkDeleteCssTemplateCommand(BaseCommand):
def __init__(self, model_ids: List[int]):
def __init__(self, model_ids: list[int]):
self._model_ids = model_ids
self._models: Optional[List[CssTemplate]] = None
self._models: Optional[list[CssTemplate]] = None
def run(self) -> None:
self.validate()

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import List, Optional
from typing import Optional
from sqlalchemy.exc import SQLAlchemyError
@ -31,7 +31,7 @@ class CssTemplateDAO(BaseDAO):
model_cls = CssTemplate
@staticmethod
def bulk_delete(models: Optional[List[CssTemplate]], commit: bool = True) -> None:
def bulk_delete(models: Optional[list[CssTemplate]], commit: bool = True) -> None:
item_ids = [model.id for model in models] if models else []
try:
db.session.query(CssTemplate).filter(CssTemplate.id.in_(item_ids)).delete(

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=isinstance-second-argument-not-valid-type
from typing import Any, Dict, List, Optional, Type, Union
from typing import Any, Optional, Union
from flask_appbuilder.models.filters import BaseFilter
from flask_appbuilder.models.sqla import Model
@ -37,7 +37,7 @@ class BaseDAO:
Base DAO, implement base CRUD sqlalchemy operations
"""
model_cls: Optional[Type[Model]] = None
model_cls: Optional[type[Model]] = None
"""
Child classes need to state the Model class so they don't need to implement basic
create, update and delete methods
@ -75,10 +75,10 @@ class BaseDAO:
@classmethod
def find_by_ids(
cls,
model_ids: Union[List[str], List[int]],
model_ids: Union[list[str], list[int]],
session: Session = None,
skip_base_filter: bool = False,
) -> List[Model]:
) -> list[Model]:
"""
Find a List of models by a list of ids, if defined applies `base_filter`
"""
@ -95,7 +95,7 @@ class BaseDAO:
return query.all()
@classmethod
def find_all(cls) -> List[Model]:
def find_all(cls) -> list[Model]:
"""
Get all that fit the `base_filter`
"""
@ -121,7 +121,7 @@ class BaseDAO:
return query.filter_by(**filter_by).one_or_none()
@classmethod
def create(cls, properties: Dict[str, Any], commit: bool = True) -> Model:
def create(cls, properties: dict[str, Any], commit: bool = True) -> Model:
"""
Generic for creating models
:raises: DAOCreateFailedError
@ -163,7 +163,7 @@ class BaseDAO:
@classmethod
def update(
cls, model: Model, properties: Dict[str, Any], commit: bool = True
cls, model: Model, properties: dict[str, Any], commit: bool = True
) -> Model:
"""
Generic update a model
@ -196,7 +196,7 @@ class BaseDAO:
return model
@classmethod
def bulk_delete(cls, models: List[Model], commit: bool = True) -> None:
def bulk_delete(cls, models: list[Model], commit: bool = True) -> None:
try:
for model in models:
cls.delete(model, False)

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import List, Optional
from typing import Optional
from flask_babel import lazy_gettext as _
@ -37,9 +37,9 @@ logger = logging.getLogger(__name__)
class BulkDeleteDashboardCommand(BaseCommand):
def __init__(self, model_ids: List[int]):
def __init__(self, model_ids: list[int]):
self._model_ids = model_ids
self._models: Optional[List[Dashboard]] = None
self._models: Optional[list[Dashboard]] = None
def run(self) -> None:
self.validate()

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
class CreateDashboardCommand(CreateMixin, BaseCommand):
def __init__(self, data: Dict[str, Any]):
def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
def run(self) -> Model:
@ -48,9 +48,9 @@ class CreateDashboardCommand(CreateMixin, BaseCommand):
return dashboard
def validate(self) -> None:
exceptions: List[ValidationError] = []
owner_ids: Optional[List[int]] = self._properties.get("owners")
role_ids: Optional[List[int]] = self._properties.get("roles")
exceptions: list[ValidationError] = []
owner_ids: Optional[list[int]] = self._properties.get("owners")
role_ids: Optional[list[int]] = self._properties.get("roles")
slug: str = self._properties.get("slug", "")
# Validate slug uniqueness

View File

@ -20,7 +20,8 @@ import json
import logging
import random
import string
from typing import Any, Dict, Iterator, Optional, Set, Tuple
from typing import Any, Optional
from collections.abc import Iterator
import yaml
@ -52,7 +53,7 @@ def suffix(length: int = 8) -> str:
)
def get_default_position(title: str) -> Dict[str, Any]:
def get_default_position(title: str) -> dict[str, Any]:
return {
"DASHBOARD_VERSION_KEY": "v2",
"ROOT_ID": {"children": ["GRID_ID"], "id": "ROOT_ID", "type": "ROOT"},
@ -66,7 +67,7 @@ def get_default_position(title: str) -> Dict[str, Any]:
}
def append_charts(position: Dict[str, Any], charts: Set[Slice]) -> Dict[str, Any]:
def append_charts(position: dict[str, Any], charts: set[Slice]) -> dict[str, Any]:
chart_hashes = [f"CHART-{suffix()}" for _ in charts]
# if we have ROOT_ID/GRID_ID, append orphan charts to a new row inside the grid
@ -109,7 +110,7 @@ class ExportDashboardsCommand(ExportModelsCommand):
@staticmethod
def _export(
model: Dashboard, export_related: bool = True
) -> Iterator[Tuple[str, str]]:
) -> Iterator[tuple[str, str]]:
file_name = get_filename(model.dashboard_title, model.id)
file_path = f"dashboards/{file_name}.yaml"

View File

@ -16,7 +16,7 @@
# under the License.
import logging
from typing import Any, Dict
from typing import Any
from marshmallow.exceptions import ValidationError
@ -43,7 +43,7 @@ class ImportDashboardsCommand(BaseCommand):
until it finds one that matches.
"""
def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any):
self.contents = contents
self.args = args
self.kwargs = kwargs

View File

@ -19,7 +19,7 @@ import logging
import time
from copy import copy
from datetime import datetime
from typing import Any, Dict, Optional
from typing import Any, Optional
from flask_babel import lazy_gettext as _
from sqlalchemy.orm import make_transient, Session
@ -83,7 +83,7 @@ def import_chart(
def import_dashboard(
# pylint: disable=too-many-locals,too-many-statements
dashboard_to_import: Dashboard,
dataset_id_mapping: Optional[Dict[int, int]] = None,
dataset_id_mapping: Optional[dict[int, int]] = None,
import_time: Optional[int] = None,
) -> int:
"""Imports the dashboard from the object to the database.
@ -97,7 +97,7 @@ def import_dashboard(
"""
def alter_positions(
dashboard: Dashboard, old_to_new_slc_id_dict: Dict[int, int]
dashboard: Dashboard, old_to_new_slc_id_dict: dict[int, int]
) -> None:
"""Updates slice_ids in the position json.
@ -166,7 +166,7 @@ def import_dashboard(
dashboard_to_import.slug = None
old_json_metadata = json.loads(dashboard_to_import.json_metadata or "{}")
old_to_new_slc_id_dict: Dict[int, int] = {}
old_to_new_slc_id_dict: dict[int, int] = {}
new_timed_refresh_immune_slices = []
new_expanded_slices = {}
new_filter_scopes = {}
@ -268,7 +268,7 @@ def import_dashboard(
return dashboard_to_import.id # type: ignore
def decode_dashboards(o: Dict[str, Any]) -> Any:
def decode_dashboards(o: dict[str, Any]) -> Any:
"""
Function to be passed into json.loads obj_hook parameter
Recreates the dashboard object from a json representation.
@ -302,7 +302,7 @@ def import_dashboards(
data = json.loads(content, object_hook=decode_dashboards)
if not data:
raise DashboardImportException(_("No data in file"))
dataset_id_mapping: Dict[int, int] = {}
dataset_id_mapping: dict[int, int] = {}
for table in data["datasources"]:
new_dataset_id = import_dataset(table, database_id, import_time=import_time)
params = json.loads(table.params)
@ -324,7 +324,7 @@ class ImportDashboardsCommand(BaseCommand):
# pylint: disable=unused-argument
def __init__(
self, contents: Dict[str, str], database_id: Optional[int] = None, **kwargs: Any
self, contents: dict[str, str], database_id: Optional[int] = None, **kwargs: Any
):
self.contents = contents
self.database_id = database_id

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, List, Set, Tuple
from typing import Any
from marshmallow import Schema
from sqlalchemy.orm import Session
@ -47,7 +47,7 @@ class ImportDashboardsCommand(ImportModelsCommand):
dao = DashboardDAO
model_name = "dashboard"
prefix = "dashboards/"
schemas: Dict[str, Schema] = {
schemas: dict[str, Schema] = {
"charts/": ImportV1ChartSchema(),
"dashboards/": ImportV1DashboardSchema(),
"datasets/": ImportV1DatasetSchema(),
@ -59,11 +59,11 @@ class ImportDashboardsCommand(ImportModelsCommand):
# pylint: disable=too-many-branches, too-many-locals
@staticmethod
def _import(
session: Session, configs: Dict[str, Any], overwrite: bool = False
session: Session, configs: dict[str, Any], overwrite: bool = False
) -> None:
# discover charts and datasets associated with dashboards
chart_uuids: Set[str] = set()
dataset_uuids: Set[str] = set()
chart_uuids: set[str] = set()
dataset_uuids: set[str] = set()
for file_name, config in configs.items():
if file_name.startswith("dashboards/"):
chart_uuids.update(find_chart_uuids(config["position"]))
@ -77,20 +77,20 @@ class ImportDashboardsCommand(ImportModelsCommand):
dataset_uuids.add(config["dataset_uuid"])
# discover databases associated with datasets
database_uuids: Set[str] = set()
database_uuids: set[str] = set()
for file_name, config in configs.items():
if file_name.startswith("datasets/") and config["uuid"] in dataset_uuids:
database_uuids.add(config["database_uuid"])
# import related databases
database_ids: Dict[str, int] = {}
database_ids: dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/") and config["uuid"] in database_uuids:
database = import_database(session, config, overwrite=False)
database_ids[str(database.uuid)] = database.id
# import datasets with the correct parent ref
dataset_info: Dict[str, Dict[str, Any]] = {}
dataset_info: dict[str, dict[str, Any]] = {}
for file_name, config in configs.items():
if (
file_name.startswith("datasets/")
@ -105,7 +105,7 @@ class ImportDashboardsCommand(ImportModelsCommand):
}
# import charts with the correct parent ref
chart_ids: Dict[str, int] = {}
chart_ids: dict[str, int] = {}
for file_name, config in configs.items():
if (
file_name.startswith("charts/")
@ -129,7 +129,7 @@ class ImportDashboardsCommand(ImportModelsCommand):
).fetchall()
# import dashboards
dashboard_chart_ids: List[Tuple[int, int]] = []
dashboard_chart_ids: list[tuple[int, int]] = []
for file_name, config in configs.items():
if file_name.startswith("dashboards/"):
config = update_id_refs(config, chart_ids, dataset_info)

View File

@ -17,7 +17,7 @@
import json
import logging
from typing import Any, Dict, Set
from typing import Any
from flask import g
from sqlalchemy.orm import Session
@ -32,12 +32,12 @@ logger = logging.getLogger(__name__)
JSON_KEYS = {"position": "position_json", "metadata": "json_metadata"}
def find_chart_uuids(position: Dict[str, Any]) -> Set[str]:
def find_chart_uuids(position: dict[str, Any]) -> set[str]:
return set(build_uuid_to_id_map(position))
def find_native_filter_datasets(metadata: Dict[str, Any]) -> Set[str]:
uuids: Set[str] = set()
def find_native_filter_datasets(metadata: dict[str, Any]) -> set[str]:
uuids: set[str] = set()
for native_filter in metadata.get("native_filter_configuration", []):
targets = native_filter.get("targets", [])
for target in targets:
@ -47,7 +47,7 @@ def find_native_filter_datasets(metadata: Dict[str, Any]) -> Set[str]:
return uuids
def build_uuid_to_id_map(position: Dict[str, Any]) -> Dict[str, int]:
def build_uuid_to_id_map(position: dict[str, Any]) -> dict[str, int]:
return {
child["meta"]["uuid"]: child["meta"]["chartId"]
for child in position.values()
@ -60,10 +60,10 @@ def build_uuid_to_id_map(position: Dict[str, Any]) -> Dict[str, int]:
def update_id_refs( # pylint: disable=too-many-locals
config: Dict[str, Any],
chart_ids: Dict[str, int],
dataset_info: Dict[str, Dict[str, Any]],
) -> Dict[str, Any]:
config: dict[str, Any],
chart_ids: dict[str, int],
dataset_info: dict[str, dict[str, Any]],
) -> dict[str, Any]:
"""Update dashboard metadata to use new IDs"""
fixed = config.copy()
@ -147,7 +147,7 @@ def update_id_refs( # pylint: disable=too-many-locals
def import_dashboard(
session: Session,
config: Dict[str, Any],
config: dict[str, Any],
overwrite: bool = False,
ignore_permissions: bool = False,
) -> Dashboard:

View File

@ -16,7 +16,7 @@
# under the License.
import json
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
@ -41,7 +41,7 @@ logger = logging.getLogger(__name__)
class UpdateDashboardCommand(UpdateMixin, BaseCommand):
def __init__(self, model_id: int, data: Dict[str, Any]):
def __init__(self, model_id: int, data: dict[str, Any]):
self._model_id = model_id
self._properties = data.copy()
self._model: Optional[Dashboard] = None
@ -64,9 +64,9 @@ class UpdateDashboardCommand(UpdateMixin, BaseCommand):
return dashboard
def validate(self) -> None:
exceptions: List[ValidationError] = []
owners_ids: Optional[List[int]] = self._properties.get("owners")
roles_ids: Optional[List[int]] = self._properties.get("roles")
exceptions: list[ValidationError] = []
owners_ids: Optional[list[int]] = self._properties.get("owners")
roles_ids: Optional[list[int]] = self._properties.get("roles")
slug: Optional[str] = self._properties.get("slug")
# Validate/populate model exists

View File

@ -17,7 +17,7 @@
import json
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from typing import Any, Optional, Union
from flask import g
from flask_appbuilder.models.sqla.interface import SQLAInterface
@ -68,12 +68,12 @@ class DashboardDAO(BaseDAO):
return dashboard
@staticmethod
def get_datasets_for_dashboard(id_or_slug: str) -> List[Any]:
def get_datasets_for_dashboard(id_or_slug: str) -> list[Any]:
dashboard = DashboardDAO.get_by_id_or_slug(id_or_slug)
return dashboard.datasets_trimmed_for_slices()
@staticmethod
def get_charts_for_dashboard(id_or_slug: str) -> List[Slice]:
def get_charts_for_dashboard(id_or_slug: str) -> list[Slice]:
return DashboardDAO.get_by_id_or_slug(id_or_slug).slices
@staticmethod
@ -173,7 +173,7 @@ class DashboardDAO(BaseDAO):
return model
@staticmethod
def bulk_delete(models: Optional[List[Dashboard]], commit: bool = True) -> None:
def bulk_delete(models: Optional[list[Dashboard]], commit: bool = True) -> None:
item_ids = [model.id for model in models] if models else []
# bulk delete, first delete related data
if models:
@ -196,8 +196,8 @@ class DashboardDAO(BaseDAO):
@staticmethod
def set_dash_metadata( # pylint: disable=too-many-locals
dashboard: Dashboard,
data: Dict[Any, Any],
old_to_new_slice_ids: Optional[Dict[int, int]] = None,
data: dict[Any, Any],
old_to_new_slice_ids: Optional[dict[int, int]] = None,
commit: bool = False,
) -> Dashboard:
new_filter_scopes = {}
@ -235,7 +235,7 @@ class DashboardDAO(BaseDAO):
if "filter_scopes" in data:
# replace filter_id and immune ids from old slice id to new slice id:
# and remove slice ids that are not in dash anymore
slc_id_dict: Dict[int, int] = {}
slc_id_dict: dict[int, int] = {}
if old_to_new_slice_ids:
slc_id_dict = {
old: new
@ -288,7 +288,7 @@ class DashboardDAO(BaseDAO):
return dashboard
@staticmethod
def favorited_ids(dashboards: List[Dashboard]) -> List[FavStar]:
def favorited_ids(dashboards: list[Dashboard]) -> list[FavStar]:
ids = [dash.id for dash in dashboards]
return [
star.obj_id
@ -303,7 +303,7 @@ class DashboardDAO(BaseDAO):
@classmethod
def copy_dashboard(
cls, original_dash: Dashboard, data: Dict[str, Any]
cls, original_dash: Dashboard, data: dict[str, Any]
) -> Dashboard:
dash = Dashboard()
dash.owners = [g.user] if g.user else []
@ -311,7 +311,7 @@ class DashboardDAO(BaseDAO):
dash.css = data.get("css")
metadata = json.loads(data["json_metadata"])
old_to_new_slice_ids: Dict[int, int] = {}
old_to_new_slice_ids: dict[int, int] = {}
if data.get("duplicate_slices"):
# Duplicating slices as well, mapping old ids to new ones
for slc in original_dash.slices:

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict
from typing import Any
from flask_appbuilder.models.sqla import Model
@ -40,7 +40,7 @@ logger = logging.getLogger(__name__)
class CreateFilterSetCommand(BaseFilterSetCommand):
# pylint: disable=C0103
def __init__(self, dashboard_id: int, data: Dict[str, Any]):
def __init__(self, dashboard_id: int, data: dict[str, Any]):
super().__init__(dashboard_id)
self._properties = data.copy()

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict
from typing import Any
from flask_appbuilder.models.sqla import Model
@ -31,7 +31,7 @@ logger = logging.getLogger(__name__)
class UpdateFilterSetCommand(BaseFilterSetCommand):
def __init__(self, dashboard_id: int, filter_set_id: int, data: Dict[str, Any]):
def __init__(self, dashboard_id: int, filter_set_id: int, data: dict[str, Any]):
super().__init__(dashboard_id)
self._filter_set_id = filter_set_id
self._properties = data.copy()

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict
from typing import Any
from flask_appbuilder.models.sqla import Model
from sqlalchemy.exc import SQLAlchemyError
@ -40,7 +40,7 @@ class FilterSetDAO(BaseDAO):
model_cls = FilterSet
@classmethod
def create(cls, properties: Dict[str, Any], commit: bool = True) -> Model:
def create(cls, properties: dict[str, Any], commit: bool = True) -> Model:
if cls.model_cls is None:
raise DAOConfigError()
model = FilterSet()

View File

@ -14,7 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, cast, Dict, Mapping
from collections.abc import Mapping
from typing import Any, cast
from marshmallow import fields, post_load, Schema, ValidationError
from marshmallow.validate import Length, OneOf
@ -64,11 +65,11 @@ class FilterSetPostSchema(FilterSetSchema):
@post_load
def validate(
self, data: Mapping[Any, Any], *, many: Any, partial: Any
) -> Dict[str, Any]:
) -> dict[str, Any]:
self._validate_json_meta_data(data[JSON_METADATA_FIELD])
if data[OWNER_TYPE_FIELD] == USER_OWNER_TYPE and OWNER_ID_FIELD not in data:
raise ValidationError("owner_id is mandatory when owner_type is User")
return cast(Dict[str, Any], data)
return cast(dict[str, Any], data)
class FilterSetPutSchema(FilterSetSchema):
@ -84,14 +85,14 @@ class FilterSetPutSchema(FilterSetSchema):
@post_load
def validate( # pylint: disable=unused-argument
self, data: Mapping[Any, Any], *, many: Any, partial: Any
) -> Dict[str, Any]:
) -> dict[str, Any]:
if JSON_METADATA_FIELD in data:
self._validate_json_meta_data(data[JSON_METADATA_FIELD])
return cast(Dict[str, Any], data)
return cast(dict[str, Any], data)
def validate_pair(first_field: str, second_field: str, data: Dict[str, Any]) -> None:
def validate_pair(first_field: str, second_field: str, data: dict[str, Any]) -> None:
if first_field in data and second_field not in data:
raise ValidationError(
"{} must be included alongside {}".format(first_field, second_field)
f"{first_field} must be included alongside {second_field}"
)

View File

@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Type
from flask import Response
from flask_appbuilder.api import expose, protect, safe
@ -35,16 +34,16 @@ class DashboardFilterStateRestApi(TemporaryCacheRestApi):
resource_name = "dashboard"
openapi_spec_tag = "Dashboard Filter State"
def get_create_command(self) -> Type[CreateFilterStateCommand]:
def get_create_command(self) -> type[CreateFilterStateCommand]:
return CreateFilterStateCommand
def get_update_command(self) -> Type[UpdateFilterStateCommand]:
def get_update_command(self) -> type[UpdateFilterStateCommand]:
return UpdateFilterStateCommand
def get_get_command(self) -> Type[GetFilterStateCommand]:
def get_get_command(self) -> type[GetFilterStateCommand]:
return GetFilterStateCommand
def get_delete_command(self) -> Type[DeleteFilterStateCommand]:
def get_delete_command(self) -> type[DeleteFilterStateCommand]:
return DeleteFilterStateCommand
@expose("/<int:pk>/filter_state", methods=("POST",))

View File

@ -14,14 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict, List, Optional, Tuple, TypedDict
from typing import Any, Optional, TypedDict
class DashboardPermalinkState(TypedDict):
dataMask: Optional[Dict[str, Any]]
activeTabs: Optional[List[str]]
dataMask: Optional[dict[str, Any]]
activeTabs: Optional[list[str]]
anchor: Optional[str]
urlParams: Optional[List[Tuple[str, str]]]
urlParams: Optional[list[tuple[str, str]]]
class DashboardPermalinkValue(TypedDict):

View File

@ -16,7 +16,7 @@
# under the License.
import json
import re
from typing import Any, Dict, Union
from typing import Any, Union
from marshmallow import fields, post_load, pre_load, Schema
from marshmallow.validate import Length, ValidationError
@ -144,9 +144,9 @@ class DashboardJSONMetadataSchema(Schema):
@pre_load
def remove_show_native_filters( # pylint: disable=unused-argument, no-self-use
self,
data: Dict[str, Any],
data: dict[str, Any],
**kwargs: Any,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Remove ``show_native_filters`` from the JSON metadata.
@ -254,7 +254,7 @@ class DashboardDatasetSchema(Schema):
class BaseDashboardSchema(Schema):
# pylint: disable=no-self-use,unused-argument
@post_load
def post_load(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
def post_load(self, data: dict[str, Any], **kwargs: Any) -> dict[str, Any]:
if data.get("slug"):
data["slug"] = data["slug"].strip()
data["slug"] = data["slug"].replace(" ", "-")

View File

@ -19,7 +19,7 @@ import json
import logging
from datetime import datetime
from io import BytesIO
from typing import Any, cast, Dict, List, Optional
from typing import Any, cast, Optional
from zipfile import is_zipfile, ZipFile
from flask import request, Response, send_file
@ -1328,13 +1328,13 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
500:
$ref: '#/components/responses/500'
"""
preferred_databases: List[str] = app.config.get("PREFERRED_DATABASES", [])
preferred_databases: list[str] = app.config.get("PREFERRED_DATABASES", [])
available_databases = []
for engine_spec, drivers in get_available_engine_specs().items():
if not drivers:
continue
payload: Dict[str, Any] = {
payload: dict[str, Any] = {
"name": engine_spec.engine_name,
"engine": engine_spec.engine,
"available_drivers": sorted(drivers),

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from flask import current_app
from flask_appbuilder.models.sqla import Model
@ -47,7 +47,7 @@ stats_logger = current_app.config["STATS_LOGGER"]
class CreateDatabaseCommand(BaseCommand):
def __init__(self, data: Dict[str, Any]):
def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
def run(self) -> Model:
@ -128,7 +128,7 @@ class CreateDatabaseCommand(BaseCommand):
return database
def validate(self) -> None:
exceptions: List[ValidationError] = []
exceptions: list[ValidationError] = []
sqlalchemy_uri: Optional[str] = self._properties.get("sqlalchemy_uri")
database_name: Optional[str] = self._properties.get("database_name")
if not sqlalchemy_uri:

View File

@ -18,7 +18,8 @@
import json
import logging
from typing import Any, Dict, Iterator, Tuple
from typing import Any
from collections.abc import Iterator
import yaml
@ -33,7 +34,7 @@ from superset.utils.ssh_tunnel import mask_password_info
logger = logging.getLogger(__name__)
def parse_extra(extra_payload: str) -> Dict[str, Any]:
def parse_extra(extra_payload: str) -> dict[str, Any]:
try:
extra = json.loads(extra_payload)
except json.decoder.JSONDecodeError:
@ -57,7 +58,7 @@ class ExportDatabasesCommand(ExportModelsCommand):
@staticmethod
def _export(
model: Database, export_related: bool = True
) -> Iterator[Tuple[str, str]]:
) -> Iterator[tuple[str, str]]:
db_file_name = get_filename(model.database_name, model.id, skip_id=True)
file_path = f"databases/{db_file_name}.yaml"

View File

@ -16,7 +16,7 @@
# under the License.
import logging
from typing import Any, Dict
from typing import Any
from marshmallow.exceptions import ValidationError
@ -38,7 +38,7 @@ class ImportDatabasesCommand(BaseCommand):
until it finds one that matches.
"""
def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any):
def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any):
self.contents = contents
self.args = args
self.kwargs = kwargs

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict
from typing import Any
from marshmallow import Schema
from sqlalchemy.orm import Session
@ -36,7 +36,7 @@ class ImportDatabasesCommand(ImportModelsCommand):
dao = DatabaseDAO
model_name = "database"
prefix = "databases/"
schemas: Dict[str, Schema] = {
schemas: dict[str, Schema] = {
"databases/": ImportV1DatabaseSchema(),
"datasets/": ImportV1DatasetSchema(),
}
@ -44,10 +44,10 @@ class ImportDatabasesCommand(ImportModelsCommand):
@staticmethod
def _import(
session: Session, configs: Dict[str, Any], overwrite: bool = False
session: Session, configs: dict[str, Any], overwrite: bool = False
) -> None:
# first import databases
database_ids: Dict[str, int] = {}
database_ids: dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/"):
database = import_database(session, config, overwrite=overwrite)

View File

@ -16,7 +16,7 @@
# under the License.
import json
from typing import Any, Dict
from typing import Any
from sqlalchemy.orm import Session
@ -28,7 +28,7 @@ from superset.models.core import Database
def import_database(
session: Session,
config: Dict[str, Any],
config: dict[str, Any],
overwrite: bool = False,
ignore_permissions: bool = False,
) -> Database:

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, cast, Dict
from typing import Any, cast
from superset.commands.base import BaseCommand
from superset.connectors.sqla.models import SqlaTable
@ -40,7 +40,7 @@ class TablesDatabaseCommand(BaseCommand):
self._schema_name = schema_name
self._force = force
def run(self) -> Dict[str, Any]:
def run(self) -> dict[str, Any]:
self.validate()
try:
tables = security_manager.get_datasources_accessible_by_user(

View File

@ -17,7 +17,7 @@
import logging
import sqlite3
from contextlib import closing
from typing import Any, Dict, Optional
from typing import Any, Optional
from flask import current_app as app
from flask_babel import gettext as _
@ -64,7 +64,7 @@ def get_log_connection_action(
class TestConnectionDatabaseCommand(BaseCommand):
def __init__(self, data: Dict[str, Any]):
def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
self._model: Optional[Database] = None

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
@ -47,7 +47,7 @@ logger = logging.getLogger(__name__)
class UpdateDatabaseCommand(BaseCommand):
def __init__(self, model_id: int, data: Dict[str, Any]):
def __init__(self, model_id: int, data: dict[str, Any]):
self._properties = data.copy()
self._model_id = model_id
self._model: Optional[Database] = None
@ -78,7 +78,7 @@ class UpdateDatabaseCommand(BaseCommand):
raise DatabaseConnectionFailedError() from ex
# Update database schema permissions
new_schemas: List[str] = []
new_schemas: list[str] = []
for schema in schemas:
old_view_menu_name = security_manager.get_schema_perm(
@ -164,7 +164,7 @@ class UpdateDatabaseCommand(BaseCommand):
chart.schema_perm = new_view_menu_name
def validate(self) -> None:
exceptions: List[ValidationError] = []
exceptions: list[ValidationError] = []
# Validate/populate model exists
self._model = DatabaseDAO.find_by_id(self._model_id)
if not self._model:

View File

@ -16,7 +16,7 @@
# under the License.
import json
from contextlib import closing
from typing import Any, Dict, Optional
from typing import Any, Optional
from flask_babel import gettext as __
@ -38,7 +38,7 @@ BYPASS_VALIDATION_ENGINES = {"bigquery"}
class ValidateDatabaseParametersCommand(BaseCommand):
def __init__(self, properties: Dict[str, Any]):
def __init__(self, properties: dict[str, Any]):
self._properties = properties.copy()
self._model: Optional[Database] = None

View File

@ -16,7 +16,7 @@
# under the License.
import logging
import re
from typing import Any, Dict, List, Optional, Type
from typing import Any, Optional
from flask import current_app
from flask_babel import gettext as __
@ -41,13 +41,13 @@ logger = logging.getLogger(__name__)
class ValidateSQLCommand(BaseCommand):
def __init__(self, model_id: int, data: Dict[str, Any]):
def __init__(self, model_id: int, data: dict[str, Any]):
self._properties = data.copy()
self._model_id = model_id
self._model: Optional[Database] = None
self._validator: Optional[Type[BaseSQLValidator]] = None
self._validator: Optional[type[BaseSQLValidator]] = None
def run(self) -> List[Dict[str, Any]]:
def run(self) -> list[dict[str, Any]]:
"""
Validates a SQL statement
@ -97,9 +97,7 @@ class ValidateSQLCommand(BaseCommand):
if not validators_by_engine or spec.engine not in validators_by_engine:
raise NoValidatorConfigFoundError(
SupersetError(
message=__(
"no SQL validator is configured for {}".format(spec.engine)
),
message=__(f"no SQL validator is configured for {spec.engine}"),
error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR,
level=ErrorLevel.ERROR,
),

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict, Optional
from typing import Any, Optional
from superset.dao.base import BaseDAO
from superset.databases.filters import DatabaseFilter
@ -38,7 +38,7 @@ class DatabaseDAO(BaseDAO):
def update(
cls,
model: Database,
properties: Dict[str, Any],
properties: dict[str, Any],
commit: bool = True,
) -> Database:
"""
@ -93,7 +93,7 @@ class DatabaseDAO(BaseDAO):
)
@classmethod
def get_related_objects(cls, database_id: int) -> Dict[str, Any]:
def get_related_objects(cls, database_id: int) -> dict[str, Any]:
database: Any = cls.find_by_id(database_id)
datasets = database.tables
dataset_ids = [dataset.id for dataset in datasets]

View File

@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Set
from typing import Any
from flask import g
from flask_babel import lazy_gettext as _
@ -30,7 +30,7 @@ from superset.views.base import BaseFilter
def can_access_databases(
view_menu_name: str,
) -> Set[str]:
) -> set[str]:
return {
security_manager.unpack_database_and_schema(vm).database
for vm in security_manager.user_view_menu_names(view_menu_name)

View File

@ -19,7 +19,7 @@
import inspect
import json
from typing import Any, Dict, List
from typing import Any
from flask import current_app
from flask_babel import lazy_gettext as _
@ -263,8 +263,8 @@ class DatabaseParametersSchemaMixin: # pylint: disable=too-few-public-methods
@pre_load
def build_sqlalchemy_uri(
self, data: Dict[str, Any], **kwargs: Any
) -> Dict[str, Any]:
self, data: dict[str, Any], **kwargs: Any
) -> dict[str, Any]:
"""
Build SQLAlchemy URI from separate parameters.
@ -325,9 +325,9 @@ class DatabaseParametersSchemaMixin: # pylint: disable=too-few-public-methods
def rename_encrypted_extra(
self: Schema,
data: Dict[str, Any],
data: dict[str, Any],
**kwargs: Any,
) -> Dict[str, Any]:
) -> dict[str, Any]:
"""
Rename ``encrypted_extra`` to ``masked_encrypted_extra``.
@ -707,8 +707,8 @@ class DatabaseFunctionNamesResponse(Schema):
class ImportV1DatabaseExtraSchema(Schema):
@pre_load
def fix_schemas_allowed_for_csv_upload( # pylint: disable=invalid-name
self, data: Dict[str, Any], **kwargs: Any
) -> Dict[str, Any]:
self, data: dict[str, Any], **kwargs: Any
) -> dict[str, Any]:
"""
Fixes for ``schemas_allowed_for_csv_upload``.
"""
@ -744,8 +744,8 @@ class ImportV1DatabaseExtraSchema(Schema):
class ImportV1DatabaseSchema(Schema):
@pre_load
def fix_allow_csv_upload(
self, data: Dict[str, Any], **kwargs: Any
) -> Dict[str, Any]:
self, data: dict[str, Any], **kwargs: Any
) -> dict[str, Any]:
"""
Fix for ``allow_csv_upload`` .
"""
@ -775,7 +775,7 @@ class ImportV1DatabaseSchema(Schema):
ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True)
@validates_schema
def validate_password(self, data: Dict[str, Any], **kwargs: Any) -> None:
def validate_password(self, data: dict[str, Any], **kwargs: Any) -> None:
"""If sqlalchemy_uri has a masked password, password is required"""
uuid = data["uuid"]
existing = db.session.query(Database).filter_by(uuid=uuid).first()
@ -789,7 +789,7 @@ class ImportV1DatabaseSchema(Schema):
@validates_schema
def validate_ssh_tunnel_credentials(
self, data: Dict[str, Any], **kwargs: Any
self, data: dict[str, Any], **kwargs: Any
) -> None:
"""If ssh_tunnel has a masked credentials, credentials are required"""
uuid = data["uuid"]
@ -829,7 +829,7 @@ class ImportV1DatabaseSchema(Schema):
# or there're times where it's masked.
# If both are masked, we need to return a list of errors
# so the UI ask for both fields at the same time if needed
exception_messages: List[str] = []
exception_messages: list[str] = []
if private_key is None or private_key == PASSWORD_MASK:
# If we get here we need to ask for the private key
exception_messages.append(
@ -864,7 +864,7 @@ class EncryptedDict(EncryptedField, fields.Dict):
pass
def encrypted_field_properties(self, field: Any, **_) -> Dict[str, Any]: # type: ignore
def encrypted_field_properties(self, field: Any, **_) -> dict[str, Any]: # type: ignore
ret = {}
if isinstance(field, EncryptedField):
if self.openapi_version.major > 2:

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
class CreateSSHTunnelCommand(BaseCommand):
def __init__(self, database_id: int, data: Dict[str, Any]):
def __init__(self, database_id: int, data: dict[str, Any]):
self._properties = data.copy()
self._properties["database_id"] = database_id
@ -61,7 +61,7 @@ class CreateSSHTunnelCommand(BaseCommand):
def validate(self) -> None:
# TODO(hughhh): check to make sure the server port is not localhost
# using the config.SSH_TUNNEL_MANAGER
exceptions: List[ValidationError] = []
exceptions: list[ValidationError] = []
database_id: Optional[int] = self._properties.get("database_id")
server_address: Optional[str] = self._properties.get("server_address")
server_port: Optional[int] = self._properties.get("server_port")

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict, Optional
from typing import Any, Optional
from flask_appbuilder.models.sqla import Model
@ -34,7 +34,7 @@ logger = logging.getLogger(__name__)
class UpdateSSHTunnelCommand(BaseCommand):
def __init__(self, model_id: int, data: Dict[str, Any]):
def __init__(self, model_id: int, data: dict[str, Any]):
self._properties = data.copy()
self._model_id = model_id
self._model: Optional[SSHTunnel] = None

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict
from typing import Any
from superset.dao.base import BaseDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
@ -31,7 +31,7 @@ class SSHTunnelDAO(BaseDAO):
def update(
cls,
model: SSHTunnel,
properties: Dict[str, Any],
properties: dict[str, Any],
commit: bool = True,
) -> SSHTunnel:
"""

View File

@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
from typing import Any, Dict
from typing import Any
import sqlalchemy as sa
from flask import current_app
@ -82,7 +82,7 @@ class SSHTunnel(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
]
@property
def data(self) -> Dict[str, Any]:
def data(self) -> dict[str, Any]:
output = {
"id": self.id,
"server_address": self.server_address,

Some files were not shown because too many files have changed in this diff Show More